frarch.modules.checkpointer module
frarch.modules.checkpointer module#
- class frarch.modules.checkpointer.Checkpointer(save_path: Union[str, pathlib.Path], modules: torch.nn.modules.container.ModuleDict, save_best_only: bool = False, reference_metric: Optional[str] = None, mode: str = 'min')[source]#
Bases:
object
Class for managing checkpoints.
- Parameters
save_path (Union[str, Path]) – folder to store the checkpoint and training data to.
modules (Mapping[str, torch.nn.Module]) – dict-like structure with modules.
save_best_only (bool, optional) – If true, save only the best model according to some metric. If True, reference metric should be specified. If False, save all end of epoch checkpoints. Defaults to False.
reference_metric (str, optional) – Metric to use to determine the best model when save_best_only is True. Must be a string in the keys of the modules dict-like structure. Defaults to None.
mode (str, optional) – min if lower is better, max if higher is better. Examples: min for error and max for accuracy. Defaults to “min”.
- Raises
ValueError – modules are not a dict or torch.nn.ModuleDict instance
ValueError – modules in modules dict-like don’t have string keys or torch.nn.Module values ValueError: path must be a string or Path object
ValueError – metadata key is reserved for the metadata.json object.
ValueError – metric mode is not min or max
ValueError – save_best_only is True and no metric is defined.
- exists_checkpoint() bool [source]#
Check if save_path contains a checkpoint folder.
- Returns
True if it contains a checkpoint folder, False if not.
- Return type
- load(mode='last', **load_kwargs) bool [source]#
Load checkpoint from folder.
- Parameters
mode (str, optional) – last for loading the last checkpoint stored and best to load the model with the mest metric. Defaults to “last”.
- Raises
ValueError – mode is not best or last
- save(epoch: int, current_step: int, intra_epoch: bool = False, extra_data: Optional[Dict] = None, **metrics: Any) None [source]#
Save checkpoint.
- Parameters
epoch (int) – current epoch index.
current_step (int) – current batch index.
intra_epoch (bool, optional) – boolean flag to indicate if the checkpoint is intra epoch if true and end of epoch if false. Defaults to False.
extra_data (Dict, optional) – extra metadata in json format to add to the metadata.json file. Defaults to None.