frarch.modules.checkpointer module#

class frarch.modules.checkpointer.Checkpointer(save_path: Union[str, pathlib.Path], modules: torch.nn.modules.container.ModuleDict, save_best_only: bool = False, reference_metric: Optional[str] = None, mode: str = 'min')[source]#

Bases: object

Class for managing checkpoints.

Parameters
  • save_path (Union[str, Path]) – folder to store the checkpoint and training data to.

  • modules (Mapping[str, torch.nn.Module]) – dict-like structure with modules.

  • save_best_only (bool, optional) – If true, save only the best model according to some metric. If True, reference metric should be specified. If False, save all end of epoch checkpoints. Defaults to False.

  • reference_metric (str, optional) – Metric to use to determine the best model when save_best_only is True. Must be a string in the keys of the modules dict-like structure. Defaults to None.

  • mode (str, optional) – min if lower is better, max if higher is better. Examples: min for error and max for accuracy. Defaults to “min”.

Raises
  • ValueError – modules are not a dict or torch.nn.ModuleDict instance

  • ValueError – modules in modules dict-like don’t have string keys or torch.nn.Module values ValueError: path must be a string or Path object

  • ValueError – metadata key is reserved for the metadata.json object.

  • ValueError – metric mode is not min or max

  • ValueError – save_best_only is True and no metric is defined.

property current_epoch: int#
exists_checkpoint() bool[source]#

Check if save_path contains a checkpoint folder.

Returns

True if it contains a checkpoint folder, False if not.

Return type

bool

load(mode='last', **load_kwargs) bool[source]#

Load checkpoint from folder.

Parameters

mode (str, optional) – last for loading the last checkpoint stored and best to load the model with the mest metric. Defaults to “last”.

Raises

ValueError – mode is not best or last

property next_epoch: int#
save(epoch: int, current_step: int, intra_epoch: bool = False, extra_data: Optional[Dict] = None, **metrics: Any) None[source]#

Save checkpoint.

Parameters
  • epoch (int) – current epoch index.

  • current_step (int) – current batch index.

  • intra_epoch (bool, optional) – boolean flag to indicate if the checkpoint is intra epoch if true and end of epoch if false. Defaults to False.

  • extra_data (Dict, optional) – extra metadata in json format to add to the metadata.json file. Defaults to None.

save_initial_weights() None[source]#

Save weights with which the model has been initialized.

property step: int#