Source code for tune.noniterative.objective

from typing import Any, Callable, Optional

from tune._utils import run_monitored_process
from tune.concepts.flow import Trial, TrialReport
from tune.concepts.logger import make_logger
from tune.constants import TUNE_STOPPER_DEFAULT_CHECK_INTERVAL


[docs]class NonIterativeObjectiveFunc:
[docs] def generate_sort_metric(self, value: float) -> float: # pragma: no cover return value
[docs] def run(self, trial: Trial) -> TrialReport: # pragma: no cover raise NotImplementedError
[docs] def safe_run(self, trial: Trial) -> TrialReport: report = self.run(trial) return report.with_sort_metric(self.generate_sort_metric(report.metric))
[docs]class NonIterativeObjectiveLocalOptimizer: @property def distributable(self) -> bool: return True
[docs] def run( self, func: NonIterativeObjectiveFunc, trial: Trial, logger: Any ) -> TrialReport: # TODO: how to utilize execution_engine? if logger is None: report = func.safe_run(trial) else: with make_logger(logger) as p_logger: with p_logger.create_child( name=trial.trial_id[:5] + "-" + p_logger.unique_id, description=repr(trial), ) as c_logger: report = func.safe_run(trial) c_logger.log_report( report, log_params=True, extract_metrics=True, log_metadata=True ) return report
[docs] def run_monitored_process( self, func: NonIterativeObjectiveFunc, trial: Trial, stop_checker: Callable[[], bool], logger: Any, interval: Any = TUNE_STOPPER_DEFAULT_CHECK_INTERVAL, ) -> TrialReport: return run_monitored_process( self.run, [func, trial], {"logger": logger}, stop_checker=stop_checker, interval=interval, )
[docs]def validate_noniterative_objective( func: NonIterativeObjectiveFunc, trial: Trial, validator: Callable[[TrialReport], None], optimizer: Optional[NonIterativeObjectiveLocalOptimizer] = None, logger: Any = None, ) -> None: _optimizer = optimizer or NonIterativeObjectiveLocalOptimizer() validator( _optimizer.run_monitored_process( func, trial, lambda: False, interval="1sec", logger=logger ) )