frarch.train.base_trainer module
frarch.train.base_trainer module#
class definition of a base trainer.
- Description
base trainer class
- Authors
victor badenas (victor.badenas@gmail.com)
- Version
0.1.0
- class frarch.train.base_trainer.BaseTrainer(modules: Mapping[str, torch.nn.modules.module.Module], opt_class: Type[torch.optim.optimizer.Optimizer], hparams: Mapping[str, Any], checkpointer: Optional[frarch.modules.checkpointer.Checkpointer] = None, freeze_layers: Optional[List[str]] = None)[source]#
Bases:
object
Abstract class for trainer managers.
- Parameters
modules (Mapping[str, torch.nn.Module]) – trainable modules in the training.
opt_class (Type[torch.optim.Optimizer]) – optimizer class for training.
hparams (Mapping[str, Any]) – hparams dict-like structure from hparams file.
checkpointer (Optional[Checkpointer], optional) – Checkpointer class for saving the model and the hyperparameters needed. If None, no checkpoints are saved. Defaults to None.
- Raises
ValueError – ckpt_interval_minutes must be > 0 or None
SystemError – Python version not supported. Python version must be >= 3.7
- fit(train_set: Union[torch.utils.data.dataset.Dataset, torch.utils.data.dataloader.DataLoader], valid_set: Optional[Union[torch.utils.data.dataset.Dataset, torch.utils.data.dataloader.DataLoader]] = None, train_loader_kwargs: Optional[dict] = None, valid_loader_kwargs: Optional[dict] = None) None [source]#
Fit the modules to the dataset. Main function of the Trainer class.
- Parameters
train_set (Union[Dataset, DataLoader]) – dataset for training.
valid_set (Optional[Union[Dataset, DataLoader]], optional) – dataset for validation. If not provided, validation will not be performed. Defaults to None.
train_loader_kwargs (dict, optional) – optional kwargs for train dataloader. Defaults to None.
valid_loader_kwargs (dict, optional) – optional kwargs for valid dataloader. Defaults to None.
- modules: torch.nn.modules.container.ModuleDict#