Source code for tune.noniterative.study

from copy import copy
from typing import Any, Callable, Dict, Iterable, Optional

from fugue import ArrayDataFrame, DataFrame, ExecutionEngine
from triad import assert_or_throw
from tune._utils import run_monitored_process
from tune.concepts.dataset import StudyResult, TuneDataset, _get_trials_from_row
from tune.concepts.flow import RemoteTrialJudge, TrialCallback, TrialJudge, TrialReport
from tune.concepts.flow.judge import Monitor, NoOpTrailJudge
from tune.constants import TUNE_REPORT_ADD_SCHEMA, TUNE_STOPPER_DEFAULT_CHECK_INTERVAL
from tune.exceptions import TuneCompileError, TuneInterrupted
from tune.noniterative.objective import (
    NonIterativeObjectiveFunc,
    NonIterativeObjectiveLocalOptimizer,
)
from tune.noniterative.stopper import NonIterativeStopper


def _make_judge(
    monitor: Optional[Monitor] = None, stopper: Optional[NonIterativeStopper] = None
) -> Optional[TrialJudge]:
    if monitor is None and stopper is None:
        return None
    if stopper is None and monitor is not None:
        return NoOpTrailJudge(monitor)
    if stopper is not None and monitor is None:
        return stopper
    if stopper is not None and monitor is not None:
        stopper.reset_monitor(monitor)
        return stopper
    raise NotImplementedError  # pragma: no cover


[docs]class NonIterativeStudy: def __init__( self, objective: NonIterativeObjectiveFunc, optimizer: NonIterativeObjectiveLocalOptimizer, ): self._objective = objective self._optimizer = optimizer
[docs] def optimize( # noqa: C901 self, dataset: TuneDataset, distributed: Optional[bool] = None, monitor: Optional[Monitor] = None, stopper: Optional[NonIterativeStopper] = None, stop_check_interval: Any = None, logger: Any = None, ) -> StudyResult: _dist = self._get_distributed(distributed) entrypoint: Any = None judge = _make_judge(monitor, stopper) if judge is not None: cb = TrialCallback(judge) entrypoint = cb.entrypoint if stopper is None: _interval: Any = None else: _interval = stop_check_interval or TUNE_STOPPER_DEFAULT_CHECK_INTERVAL def compute_processor(engine: ExecutionEngine, df: DataFrame) -> DataFrame: out_schema = df.schema + TUNE_REPORT_ADD_SCHEMA def get_rows() -> Iterable[Any]: for row in self._compute_transformer( df.as_local().as_dict_iterable(), entrypoint=entrypoint, stop_check_interval=_interval, logger=logger, ): yield [row[k] for k in out_schema.names] # TODO: need to add back execution_engine for engine aware optimizers # t._execution_engine = engine # type:ignore return ArrayDataFrame(get_rows(), out_schema) def preprocess(df: DataFrame) -> DataFrame: if judge is not None: judge.monitor.initialize() return df def postprocess(df: DataFrame) -> None: if judge is not None: judge.monitor.finalize() if not _dist: res = dataset.data.process(preprocess).process(compute_processor) else: res = ( dataset.data.process(preprocess) .per_row() .transform( self._compute_transformer, schema=f"*,{TUNE_REPORT_ADD_SCHEMA}", callback=entrypoint, params=dict(stop_check_interval=_interval, logger=logger), ) ) res.persist().output(postprocess) return StudyResult(dataset=dataset, result=res)
def _get_distributed(self, distributed: Optional[bool]) -> bool: if distributed is None: return self._optimizer.distributable if distributed: assert_or_throw( self._optimizer.distributable, TuneCompileError( f"can't distribute non-distributable optimizer {self._optimizer}" ), ) return True return False def _compute_transformer( self, df: Iterable[Dict[str, Any]], entrypoint: Optional[Callable[[str, Dict[str, Any]], Any]] = None, stop_check_interval: Any = None, logger: Any = None, ) -> Iterable[Dict[str, Any]]: j: Optional[RemoteTrialJudge] = ( None if entrypoint is None else RemoteTrialJudge(entrypoint) ) for row in df: for n, trial in enumerate(_get_trials_from_row(row, with_dfs=False)): if j is not None: if stop_check_interval is None: # monitor only report = self._local_process_trial(row, n, logger) j.judge(report) yield report.fill_dict(dict(row)) elif j.can_accept(trial): try: report = run_monitored_process( self._local_process_trial, [row, n], {"logger": logger}, lambda: not j.can_accept(trial), # type: ignore # noqa stop_check_interval, ) except TuneInterrupted: continue j.judge(report) yield report.fill_dict(dict(row)) else: report = self._local_process_trial(row, n, logger) yield report.fill_dict(dict(row)) def _local_process_trial( self, row: Dict[str, Any], idx: int, logger: Any ) -> TrialReport: trial = list(_get_trials_from_row(row))[idx] # This copy is to prevent statful objective to mess up state # between runs or between threads objective = copy(self._objective) return self._optimizer.run(objective, trial, logger=logger)