import tempfile
from typing import Any, Dict, List, Tuple
from fs.base import FS as FSBase
from tensorflow import keras
from triad import FileSystem
from tune.concepts.space import to_template, TuningParametersTemplate
[docs]class KerasTrainingSpec:
def __init__(self, params: Any, dfs: Dict[str, Any]):
self._params = to_template(params)
self._dfs = dfs
@property
def params(self) -> TuningParametersTemplate:
return self._params
@property
def dfs(self) -> Dict[str, Any]:
return self._dfs
[docs] def finalize(self) -> None:
pass
[docs] def generate_sort_metric(self, metric: float) -> float:
raise NotImplementedError # pragma: no cover
[docs] def get_fit_metric(self, history: keras.callbacks.History) -> float:
raise NotImplementedError # pragma: no cover
[docs] def get_fit_params(self) -> Tuple[List[Any], Dict[str, Any]]:
raise NotImplementedError # pragma: no cover
[docs] def get_compile_params(self) -> Dict[str, Any]:
raise NotImplementedError # pragma: no cover
[docs] def get_model(self) -> keras.models.Model:
raise NotImplementedError # pragma: no cover
[docs] def save_checkpoint(self, fs: FSBase, model: keras.models.Model) -> None:
with tempfile.NamedTemporaryFile(suffix=".h5") as tf:
model.save_weights(tf.name)
with open(tf.name, "rb") as fin:
fs.writefile("model.h5", fin)
[docs] def load_checkpoint(self, fs: FSBase, model: keras.models.Model) -> None:
with tempfile.NamedTemporaryFile(suffix=".h5") as tf:
local_fs = FileSystem()
with fs.open("model.h5", "rb") as fin:
local_fs.writefile(tf.name, fin)
model.load_weights(tf.name)
[docs] def compile_model(self, **add_kwargs: Any) -> keras.models.Model:
params = dict(self.get_compile_params())
params.update(add_kwargs)
model = self.get_model()
model.compile(**params)
return model
[docs] def fit(self, **add_kwargs: Any) -> keras.callbacks.History:
args, kwargs = self.get_fit_params()
kwargs = dict(kwargs)
kwargs.update(add_kwargs)
model = self.compile_model()
metric = model.fit(*args, **kwargs)
self.finalize()
return metric
[docs] def compute_sort_metric(self, **add_kwargs: Any) -> float:
metric = self.get_fit_metric(self.fit(**add_kwargs))
return self.generate_sort_metric(metric)