import os
import tempfile
from typing import Callable, List, Optional
from uuid import uuid4
import cloudpickle
from fs.base import FS as FSBase
from triad import FileSystem
from tune.concepts.checkpoint import Checkpoint
from tune.concepts.flow import Monitor, Trial, TrialDecision, TrialJudge, TrialReport
[docs]class IterativeObjectiveFunc:
def __init__(self):
self._rung = 0
self._current_trial: Optional[Trial] = None
[docs] def copy(self) -> "IterativeObjectiveFunc": # pragma: no cover
raise NotImplementedError
@property
def current_trial(self) -> Trial:
assert self._current_trial is not None
return self._current_trial
@property
def rung(self) -> int:
return self._rung
[docs] def generate_sort_metric(self, value: float) -> float:
return value
[docs] def load_checkpoint(self, fs: FSBase) -> None: # pragma: no cover
return
[docs] def save_checkpoint(self, fs: FSBase) -> None: # pragma: no cover
return
[docs] def initialize(self) -> None: # pragma: no cover
return
[docs] def finalize(self) -> None: # pragma: no cover
return
[docs] def run_single_iteration(self) -> TrialReport: # pragma: no cover
raise NotImplementedError
[docs] def run_single_rung(self, budget: float) -> TrialReport:
used = 0.0
while True:
current_report = self.run_single_iteration()
used += current_report.cost
if used >= budget:
return current_report.with_cost(used)
[docs] def run(
self,
trial: Trial,
judge: TrialJudge,
checkpoint_basedir_fs: FSBase,
) -> None:
checkpoint = Checkpoint(
checkpoint_basedir_fs.makedir(trial.trial_id, recreate=True)
)
if not judge.can_accept(trial):
return
self._current_trial = trial
self.initialize()
try:
if len(checkpoint) > 0:
self._rung = int(checkpoint.latest.readtext("__RUNG__")) + 1
self.load_checkpoint(checkpoint.latest)
budget = judge.get_budget(trial, self.rung)
while budget > 0:
report = self.run_single_rung(budget)
report = report.with_rung(self.rung).with_sort_metric(
self.generate_sort_metric(report.metric)
)
decision = judge.judge(report)
if decision.should_checkpoint:
with checkpoint.create() as fs:
fs.writetext("__RUNG__", str(self.rung))
self.save_checkpoint(fs)
budget = decision.budget
self._rung += 1
finally:
self.finalize()
[docs]def validate_iterative_objective(
func: IterativeObjectiveFunc,
trial: Trial,
budgets: List[float],
validator: Callable[[List[TrialReport]], None],
continuous: bool = False,
checkpoint_path: str = "",
monitor: Optional[Monitor] = None,
) -> None:
path = checkpoint_path if checkpoint_path != "" else tempfile.gettempdir()
_basefs = FileSystem()
basefs = _basefs.makedirs(os.path.join(path, str(uuid4())), recreate=True)
j = _Validator(monitor, budgets, continuous=continuous)
if continuous:
f = cloudpickle.loads(cloudpickle.dumps(func)).copy()
f.run(trial, j, checkpoint_basedir_fs=basefs)
else:
for _ in budgets:
f = cloudpickle.loads(cloudpickle.dumps(func)).copy()
f.run(trial, j, checkpoint_basedir_fs=basefs)
validator(j.reports)
class _Validator(TrialJudge):
def __init__(
self, monitor: Optional[Monitor], budgets: List[float], continuous: bool
):
super().__init__(monitor)
self._budgets = budgets
self._continuous = continuous
self._reports: List[TrialReport] = []
@property
def reports(self) -> List[TrialReport]:
return self._reports
def can_accept(self, trial: Trial) -> bool:
return True
def get_budget(self, trial: Trial, rung: int) -> float:
budget = self._budgets[rung] if rung < len(self._budgets) else 0.0
self.monitor.on_get_budget(trial, rung, budget)
return budget
def judge(self, report: TrialReport) -> TrialDecision:
self.monitor.on_report(report)
self._reports.append(report)
decision = TrialDecision(
report,
budget=self.get_budget(report.trial, report.rung + 1)
if self._continuous
else 0.0,
should_checkpoint=report.rung >= len(self._budgets)
if self._continuous
else True,
)
self.monitor.on_judge(decision)
return decision