from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
from triad import SerializableRLock, to_uuid
from tune.concepts.flow import (
Monitor,
Trial,
TrialDecision,
TrialJudge,
TrialReport,
TrialReportHeap,
)
[docs]class RungHeap:
def __init__(self, n: int):
self._lock = SerializableRLock()
self._n = n
self._heap = TrialReportHeap(min_heap=False)
self._bests: List[float] = []
def __len__(self) -> int:
with self._lock:
return len(self._heap)
@property
def capacity(self) -> int:
return self._n
@property
def best(self) -> float:
with self._lock:
return self._bests[-1] if len(self._bests) > 0 else float("nan")
@property
def bests(self) -> List[float]:
with self._lock:
return self._bests
@property
def full(self) -> bool:
with self._lock:
return self.capacity <= len(self)
def __contains__(self, tid: str) -> bool:
with self._lock:
return tid in self._heap
[docs] def values(self) -> Iterable[TrialReport]:
return self._heap.values()
[docs] def push(self, report: TrialReport) -> bool:
with self._lock:
if len(self) == 0:
best = report.sort_metric
else:
best = min(self.best, report.sort_metric)
self._heap.push(report)
self._bests.append(best)
return (
len(self._heap) <= self._n
or self._heap.pop().trial_id != report.trial_id
)
class _PerTrial:
def __init__(self, parent: "_PerPartition") -> None:
self._history: List[TrialReport] = []
self._parent = parent
self._active = True
def can_promote(self, report: TrialReport) -> Tuple[bool, str]:
reasons: List[str] = []
if self._active:
can_accept = self._parent.can_accept(report.trial)
early_stop = self._parent._parent._trial_early_stop(
report, self._history, self._parent._rungs
)
self._active = can_accept and not early_stop
if not can_accept:
reasons.append("can't accept new")
if early_stop:
reasons.append("trial early stop")
if self._active:
self._history.append(report)
can_push = self._parent._rungs[report.rung].push(report)
if not can_push:
# data = sorted(
# (x for x in self._parent._rungs[report.rung].values()),
# key=lambda x: x["sort_metric"],
# )
# reasons.append("not best: " + json.dumps(data))
reasons.append("not best")
return can_push, ", ".join(reasons)
return False, ", ".join(reasons)
def judge(self, report: TrialReport) -> TrialDecision:
if report.rung >= len(self._parent._parent.schedule) - 1:
self._history.append(report)
self._parent._rungs[report.rung].push(report)
return TrialDecision(
report, budget=0, should_checkpoint=True, reason="last"
)
promote, reason = self.can_promote(report)
if not promote:
return TrialDecision(
report, budget=0, should_checkpoint=True, reason=reason
)
next_budget = self._parent.get_budget(report.trial, report.rung + 1)
return TrialDecision(
report,
budget=next_budget,
should_checkpoint=next_budget <= 0
or self._parent._parent.always_checkpoint,
reason="" if next_budget > 0 else "budget==0",
)
class _PerPartition:
def __init__(self, parent: "ASHAJudge", keys: List[Any]):
self._keys = keys
self._data: Dict[str, _PerTrial] = {}
self._lock = SerializableRLock()
self._parent = parent
self._rungs: List[RungHeap] = [RungHeap(x[1]) for x in self._parent.schedule]
self._active = True
self._accepted_ids: Set[str] = set()
def can_accept(self, trial: Trial) -> bool:
with self._lock:
if self._active:
self._active = not self._parent._study_early_stop(
self._keys, self._rungs
)
if self._active:
self._accepted_ids.add(trial.trial_id)
return True
# if not active, can only accept existing trials
return trial.trial_id in self._accepted_ids
def get_budget(self, trial: Trial, rung: int) -> float:
if rung >= len(self._parent.schedule) or not self.can_accept(trial):
return 0.0 # pragma: no cover
return self._parent.schedule[rung][0]
def judge(self, report: TrialReport) -> TrialDecision:
return self._get_judge(report.trial).judge(report)
def _get_judge(self, trial: Trial) -> _PerTrial:
key = trial.trial_id
with self._lock:
if key not in self._data:
self._data[key] = _PerTrial(self)
return self._data[key]
[docs]class ASHAJudge(TrialJudge):
def __init__(
self,
schedule: List[Tuple[float, int]],
always_checkpoint: bool = False,
study_early_stop: Optional[Callable[[List[Any], List[RungHeap]], bool]] = None,
trial_early_stop: Optional[
Callable[[TrialReport, List[TrialReport], List[RungHeap]], bool]
] = None,
monitor: Optional[Monitor] = None,
):
super().__init__(monitor=monitor)
self._lock = SerializableRLock()
self._data: Dict[str, _PerPartition] = {}
self._schedule = schedule
self._always_checkpoint = always_checkpoint
self._study_early_stop = study_early_stop or _default_study_early_stop
self._trial_early_stop = trial_early_stop or _default_trial_early_stop
@property
def schedule(self) -> List[Tuple[float, int]]:
return self._schedule
@property
def always_checkpoint(self) -> bool:
return self._always_checkpoint
[docs] def can_accept(self, trial: Trial) -> bool:
return self._get_judge(trial).can_accept(trial)
[docs] def get_budget(self, trial: Trial, rung: int) -> float:
budget = self._get_judge(trial).get_budget(trial, rung)
self.monitor.on_get_budget(trial, rung, budget)
return budget
[docs] def judge(self, report: TrialReport) -> TrialDecision:
self.monitor.on_report(report)
decision = self._get_judge(report.trial).judge(report)
self.monitor.on_judge(decision)
return decision
def _get_judge(self, trial: Trial) -> _PerPartition:
key = to_uuid(trial.keys)
with self._lock:
if key not in self._data:
self._data[key] = _PerPartition(self, trial.keys)
return self._data[key]
def _default_study_early_stop(keys: List[Any], rungs: List["RungHeap"]) -> bool:
return all(r.full for r in rungs)
def _default_trial_early_stop(
report: TrialReport, reports: List[TrialReport], rungs: List["RungHeap"]
) -> bool:
return False