snorkel.classification.Trainer

class snorkel.classification.Trainer(name=None, **kwargs)[source]

Bases: object

A class for training a MultitaskClassifier.

Parameters
  • name (Optional[str]) – An optional name for this trainer object

  • kwargs (Any) – Settings to be merged into the default Trainer config dict

name[source]

See above

config[source]

The config dict with the settings for the Trainer

checkpointer[source]

Saves the best model seen during training

log_manager[source]

Identifies when its time to log or evaluate on the valid set

log_writer[source]

Writes training statistics to file or TensorBoard

optimizer[source]

Updates model weights based on the loss

lr_scheduler[source]

Adjusts the learning rate over the course of training

batch_scheduler[source]

Returns batches from the DataLoaders in a particular order for training

__init__(name=None, **kwargs)[source]

Initialize self. See help(type(self)) for accurate signature.

Return type

None

Methods

__init__([name])

Initialize self.

fit(model, dataloaders)

Train a MultitaskClassifier.

fit(model, dataloaders)[source]

Train a MultitaskClassifier.

Parameters
  • model (MultitaskClassifier) – The model to train

  • dataloaders (List[DictDataLoader]) – A list of DataLoaders. These will split into train, valid, and test splits based on the split attribute of the DataLoaders.

Return type

None