Source code for tune_notebook.monitors

from datetime import datetime
from typing import Any, Dict, List

import pandas as pd
from triad import SerializableRLock
from triad.utils.convert import to_timedelta
from tune import Monitor, TrialReport, TrialReportLogger, parse_monitor


@parse_monitor.candidate(
    lambda obj: isinstance(obj, str) and obj == "hist", priority=0.0
)
def _hist(obj: str) -> Monitor:  # pragma: no cover
    return NotebookSimpleHist()


@parse_monitor.candidate(
    lambda obj: isinstance(obj, str) and obj == "rungs", priority=0.0
)
def _rungs(obj: str) -> Monitor:  # pragma: no cover
    return NotebookSimpleRungs()


@parse_monitor.candidate(lambda obj: isinstance(obj, str) and obj == "ts", priority=0.0)
def _ts(obj: str) -> Monitor:  # pragma: no cover
    return NotebookSimpleTimeSeries()


@parse_monitor.candidate(
    lambda obj: isinstance(obj, str) and obj == "text", priority=0.0
)
def _text(obj: str) -> Monitor:  # pragma: no cover
    return PrintBest()


[docs]class PrintBest(Monitor): def __init__(self): super().__init__() self._lock = SerializableRLock() self._bins: Dict[str, "_ReportBin"] = {}
[docs] def on_report(self, report: TrialReport) -> None: with self._lock: key = str(report.trial.keys) if key not in self._bins: self._bins[key] = _ReportBin(new_best_only=True) rbin = self._bins[key] if rbin.on_report(report): print(report.trial.keys, report.metric, report)
[docs]class NotebookSimpleChart(Monitor): def __init__( self, interval: Any = "1sec", best_only: bool = True, always_update: bool = False, ): super().__init__() self._lock = SerializableRLock() self._last: Any = None self._bins: Dict[str, "_ReportBin"] = {} self._interval = to_timedelta(interval) self._best_only = best_only self._always_update = always_update
[docs] def on_report(self, report: TrialReport) -> None: now = datetime.now() with self._lock: key = str(report.trial.keys) if key not in self._bins: self._bins[key] = _ReportBin(new_best_only=self._best_only) rbin = self._bins[key] updated = rbin.on_report(report) if not updated and not self._always_update: return with self._lock: if self._last is None or now - self._last > self._interval: self._redraw() self._last = datetime.now()
[docs] def plot(self, df: pd.DataFrame) -> None: return # pragma: no cover
[docs] def finalize(self) -> None: self._redraw()
def _redraw(self) -> None: import matplotlib.pyplot as plt from IPython.display import clear_output df = pd.concat( [ pd.DataFrame( x.records, columns=[ "partition", "rung", "time", "id", "metric", "best_metric", ], ) for x in self._bins.values() ] ) clear_output() self.plot(df) plt.show() for best in [x.best for x in self._bins.values() if x.best is not None]: if best is not None: print(best.trial.keys, best.metric, best)
[docs]class NotebookSimpleRungs(NotebookSimpleChart): def __init__(self, interval: Any = "1sec"): super().__init__(interval, best_only=False, always_update=True)
[docs] def plot(self, df: pd.DataFrame) -> None: import seaborn as sns sns.lineplot(data=df, x="rung", y="metric", hue="id", marker="o", legend=False)
[docs]class NotebookSimpleTimeSeries(NotebookSimpleChart): def __init__(self, interval: Any = "1sec"): super().__init__(interval)
[docs] def plot(self, df: pd.DataFrame) -> None: import seaborn as sns sns.lineplot(data=df, x="time", y="best_metric", hue="partition", marker="o")
[docs]class NotebookSimpleHist(NotebookSimpleChart): def __init__(self, interval: Any = "1sec"): super().__init__(interval, best_only=False, always_update=True)
[docs] def plot(self, df: pd.DataFrame) -> None: import seaborn as sns sns.histplot(data=df, x="metric", hue="partition")
class _ReportBin(TrialReportLogger): def __init__(self, new_best_only: bool = False): super().__init__(new_best_only=new_best_only) self._values: List[List[Any]] = [] def log(self, report: TrialReport) -> None: self._values.append( [ str(report.trial.keys), report.rung, datetime.now(), report.trial_id, report.metric, self.best.metric, # type: ignore ] ) @property def records(self) -> List[List[Any]]: with self._lock: return list(self._values)