Source code for tune_tensorflow.utils
from typing import Any, Type, Dict
from triad.utils.convert import get_full_type_path, to_type
from tune.concepts.space.parameters import TuningParametersTemplate
from tune_tensorflow.spec import KerasTrainingSpec
from tune import Space
from tune.constants import SPACE_MODEL_NAME
_TYPE_DICT: Dict[str, Type[KerasTrainingSpec]] = {}
[docs]def to_keras_spec(obj: Any) -> Type[KerasTrainingSpec]:
if isinstance(obj, str) and obj in _TYPE_DICT:
return _TYPE_DICT[obj]
return to_type(obj, KerasTrainingSpec)
[docs]def to_keras_spec_expr(spec: Any) -> str:
if isinstance(spec, str):
spec = to_keras_spec(spec)
return get_full_type_path(spec)
[docs]def keras_space(model: Any, **params: Any) -> Space:
expr = to_keras_spec_expr(model)
_TYPE_DICT[expr] = to_keras_spec(model)
data = {SPACE_MODEL_NAME: expr}
data.update(params)
return Space(**data)
[docs]def extract_keras_spec(
params: TuningParametersTemplate, type_dict: Dict[str, Any]
) -> Type[KerasTrainingSpec]:
obj = params.simple_value[SPACE_MODEL_NAME]
if isinstance(obj, str) and obj in type_dict:
return type_dict[obj]
return to_keras_spec(obj)