from datetime import datetime
from typing import Any, Callable, Dict, List, Set
from triad import SerializableRLock
from triad.utils.convert import to_timedelta
from tune.concepts.flow import (
Trial,
TrialDecision,
TrialJudge,
TrialReport,
TrialReportLogger,
)
[docs]class TrialReportCollection(TrialReportLogger):
def __init__(self, new_best_only: bool = False):
super().__init__(new_best_only=new_best_only)
self._reports: List[TrialReport] = []
[docs] def log(self, report: TrialReport) -> None:
self._reports.append(report.reset_log_time())
@property
def reports(self) -> List[TrialReport]:
with self._lock:
return list(self._reports)
[docs]class NonIterativeStopper(TrialJudge):
def __init__(self, log_best_only: bool = False):
super().__init__()
self._stopper_updated = False
self._lock = SerializableRLock()
self._log_best_only = log_best_only
self._logs: Dict[str, TrialReportCollection] = {}
@property
def updated(self) -> bool:
return self._stopper_updated
[docs] def should_stop(self, trial: Trial) -> bool: # pragma: no cover
return False
[docs] def on_report(self, report: TrialReport) -> bool:
self._stopper_updated = True
self.monitor.on_report(report)
with self._lock:
key = str(report.trial.keys)
if key not in self._logs:
self._logs[key] = TrialReportCollection(self._log_best_only)
return self._logs[key].on_report(report)
[docs] def can_accept(self, trial: Trial) -> bool:
return not self.should_stop(trial)
[docs] def judge(self, report: TrialReport) -> TrialDecision:
self.on_report(report)
return TrialDecision(report, 0.0, False)
[docs] def get_reports(self, trial: Trial) -> List[TrialReport]:
with self._lock:
key = str(trial.keys)
if key not in self._logs:
return []
v = self._logs[key]
return v.reports
def __and__(self, other: "NonIterativeStopper") -> "NonIterativeStopperCombiner":
return NonIterativeStopperCombiner(self, other, is_and=True)
def __or__(self, other: "NonIterativeStopper") -> "NonIterativeStopperCombiner":
return NonIterativeStopperCombiner(self, other, is_and=False)
[docs]class NonIterativeStopperCombiner(NonIterativeStopper):
def __init__(
self, left: NonIterativeStopper, right: NonIterativeStopper, is_and: bool
):
super().__init__()
assert not left.updated, "can't reuse updated stopper"
assert not right.updated, "can't reuse updated stopper"
self._left = left
self._right = right
self._is_and = is_and
[docs] def should_stop(self, trial: Trial) -> bool: # pragma: no cover
if self._is_and:
return self._left.should_stop(trial) and self._right.should_stop(trial)
else:
return self._left.should_stop(trial) or self._right.should_stop(trial)
[docs] def on_report(self, report: TrialReport) -> bool:
self.monitor.on_report(report)
left = self._left.on_report(report)
right = self._right.on_report(report)
return left or right
[docs] def get_reports(self, trial: Trial) -> List[TrialReport]: # pragma: no cover
raise NotImplementedError
[docs]class SimpleNonIterativeStopper(NonIterativeStopper):
def __init__(
self,
partition_should_stop: Callable[[TrialReport, bool, List[TrialReport]], bool],
log_best_only: bool = False,
):
super().__init__(log_best_only=log_best_only)
self._partition_should_stop = partition_should_stop
self._stopped: Set[str] = set()
[docs] def should_stop(self, trial: Trial) -> bool:
key = str(trial.keys)
with self._lock:
return key in self._stopped
[docs] def on_report(self, report: TrialReport) -> bool:
updated = super().on_report(report)
key = str(report.trial.keys)
with self._lock:
if key not in self._stopped:
if self._partition_should_stop(
report, updated, self.get_reports(report.trial)
):
self._stopped.add(key)
return updated
[docs]def n_samples(n: int) -> SimpleNonIterativeStopper:
def func(current: TrialReport, updated: bool, reports: List[TrialReport]):
return len(reports) >= n
return SimpleNonIterativeStopper(func, log_best_only=False)
[docs]def n_updates(n: int) -> SimpleNonIterativeStopper:
def func(current: TrialReport, updated: bool, reports: List[TrialReport]):
return len(reports) >= n
return SimpleNonIterativeStopper(func, log_best_only=True)
[docs]def no_update_period(period: Any) -> SimpleNonIterativeStopper:
_interval = to_timedelta(period)
def func(current: TrialReport, updated: bool, reports: List[TrialReport]):
if updated or len(reports) == 0:
return False
return datetime.now() - reports[-1].log_time > _interval
return SimpleNonIterativeStopper(func, log_best_only=True)
[docs]def small_improvement(threshold: float, updates: int) -> SimpleNonIterativeStopper:
assert updates > 0
def func(current: TrialReport, updated: bool, reports: List[TrialReport]):
if not updated:
return False
if len(reports) <= updates:
return False
diff = reports[-updates - 1].sort_metric - current.sort_metric
return diff < threshold
return SimpleNonIterativeStopper(func, log_best_only=True)