Source code for tune.concepts.checkpoint
import json
from typing import Any, List
from uuid import uuid4
from fs.base import FS as FSBase
from triad import assert_or_throw
_CHECKPOINT_STATE_FILE = "STATE"
[docs]class Checkpoint:
"""An abstraction for tuning checkpoint
:param fs: the file system
.. attention::
Normally you don't need to create a checkpoint by yourself,
please read :ref:`Checkpoint Tutorial </notebooks/checkpoint.ipynb>`
if you want to understand how it works.
"""
def __init__(self, fs: FSBase):
self._fs = fs
try:
self._iterations: List[str] = json.loads(
fs.readtext(_CHECKPOINT_STATE_FILE)
)
except Exception:
self._iterations = []
def __len__(self) -> int:
"""Count of the current iterations"""
return len(self._iterations)
@property
def latest(self) -> FSBase:
"""latest checkpoint folder
:raises AssertionError: if there was no checkpoint
"""
assert_or_throw(len(self) > 0, "checkpoint history is empty")
return self._fs.opendir(self._iterations[-1])
[docs] def create(self) -> "NewCheckpoint":
"""Create a new checkpoint"""
return NewCheckpoint(self)
[docs]class NewCheckpoint:
"""A helper class for adding new checkpoints
:param checkpoint: the parent checkpoint
.. attention::
Do not construct this class directly, please read
:ref:`Checkpoint Tutorial </notebooks/checkpoint.ipynb>`
for details
"""
def __init__(self, checkpoint: Checkpoint):
self._parent = checkpoint
self._name = str(uuid4())
def __enter__(self) -> FSBase:
return self._parent._fs.makedir(self._name)
def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None:
if exc_type is not None:
try:
self._parent._fs.removetree(self._name)
except Exception: # pragma: no cover
pass
else:
new_iterations = self._parent._iterations + [self._name]
self._parent._fs.writetext(
_CHECKPOINT_STATE_FILE, json.dumps(new_iterations)
)
self._parent._iterations = new_iterations