frarch.train.base_trainer module#

class definition of a base trainer.

Description

base trainer class

Authors

victor badenas (victor.badenas@gmail.com)

Version

0.1.0

class frarch.train.base_trainer.BaseTrainer(modules: Mapping[str, torch.nn.modules.module.Module], opt_class: Type[torch.optim.optimizer.Optimizer], hparams: Mapping[str, Any], checkpointer: Optional[frarch.modules.checkpointer.Checkpointer] = None, freeze_layers: Optional[List[str]] = None)[source]#

Bases: object

Abstract class for trainer managers.

Parameters
  • modules (Mapping[str, torch.nn.Module]) – trainable modules in the training.

  • opt_class (Type[torch.optim.Optimizer]) – optimizer class for training.

  • hparams (Mapping[str, Any]) – hparams dict-like structure from hparams file.

  • checkpointer (Optional[Checkpointer], optional) – Checkpointer class for saving the model and the hyperparameters needed. If None, no checkpoints are saved. Defaults to None.

Raises
  • ValueError – ckpt_interval_minutes must be > 0 or None

  • SystemError – Python version not supported. Python version must be >= 3.7

ckpt_interval_minutes: Optional[int] = None#
debug: bool = False#
debug_batches: int = 2#
device: str = 'cpu'#
fit(train_set: Union[torch.utils.data.dataset.Dataset, torch.utils.data.dataloader.DataLoader], valid_set: Optional[Union[torch.utils.data.dataset.Dataset, torch.utils.data.dataloader.DataLoader]] = None, train_loader_kwargs: Optional[dict] = None, valid_loader_kwargs: Optional[dict] = None) None[source]#

Fit the modules to the dataset. Main function of the Trainer class.

Parameters
  • train_set (Union[Dataset, DataLoader]) – dataset for training.

  • valid_set (Optional[Union[Dataset, DataLoader]], optional) – dataset for validation. If not provided, validation will not be performed. Defaults to None.

  • train_loader_kwargs (dict, optional) – optional kwargs for train dataloader. Defaults to None.

  • valid_loader_kwargs (dict, optional) – optional kwargs for valid dataloader. Defaults to None.

modules: torch.nn.modules.container.ModuleDict#
nonfinite_patience: int = 3#
noprogressbar: bool = False#
train_interval: int = 10#