snorkel.classification.MultitaskClassifier¶
-
class
snorkel.classification.MultitaskClassifier(tasks, name=None, **kwargs)[source]¶ Bases:
torch.nn.modules.module.ModuleA classifier built from one or more tasks to support advanced workflows.
- Parameters
tasks (
List[Task]) – A list ofTasks to build a model fromname (
Optional[str]) – The name of the classifier
-
__init__(tasks, name=None, **kwargs)[source]¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
- Return type
None
Methods
__init__(tasks[, name])Initializes internal Module state, shared by both nn.Module and ScriptModule.
add_module(name, module)Adds a child module to the current module.
add_task(task)Add a single task to the network.
apply(fn)Applies
fnrecursively to every submodule (as returned by.children()) as well as self.bfloat16()Casts all floating point parameters and buffers to
bfloat16datatype.buffers([recurse])Returns an iterator over module buffers.
calculate_loss(X_dict, Y_dict)Calculate the loss for each task and the number of data points contributing.
children()Returns an iterator over immediate children modules.
cpu()Moves all model parameters and buffers to the CPU.
cuda([device])Moves all model parameters and buffers to the GPU.
double()Casts all floating point parameters and buffers to
doubledatatype.eval()Sets the module in evaluation mode.
extra_repr()Set the extra representation of the module
float()Casts all floating point parameters and buffers to float datatype.
forward(X_dict, task_names)Do a forward pass through the network for all specified tasks.
half()Casts all floating point parameters and buffers to
halfdatatype.load(model_path)Load a saved model from the provided file path and moves it to a device.
load_state_dict(state_dict[, strict])Copies parameters and buffers from
state_dictinto this module and its descendants.modules()Returns an iterator over all modules in the network.
named_buffers([prefix, recurse])Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
named_children()Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
named_modules([memo, prefix])Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
named_parameters([prefix, recurse])Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
parameters([recurse])Returns an iterator over module parameters.
predict(dataloader[, return_preds, remap_labels])Calculate probabilities, (optionally) predictions, and pull out gold labels.
register_backward_hook(hook)Registers a backward hook on the module.
register_buffer(name, tensor[, persistent])Adds a buffer to the module.
register_forward_hook(hook)Registers a forward hook on the module.
register_forward_pre_hook(hook)Registers a forward pre-hook on the module.
register_full_backward_hook(hook)Registers a backward hook on the module.
register_parameter(name, param)Adds a parameter to the module.
requires_grad_([requires_grad])Change if autograd should record operations on parameters in this module.
save(model_path)Save the model to the specified file path.
score(dataloaders[, remap_labels, as_dataframe])Calculate scores for the provided DictDataLoaders.
share_memory()- rtype
~T
state_dict([destination, prefix, keep_vars])Returns a dictionary containing a whole state of the module.
to(*args, **kwargs)Moves and/or casts the parameters and buffers.
train([mode])Sets the module in training mode.
type(dst_type)Casts all parameters and buffers to
dst_type.xpu([device])Moves all model parameters and buffers to the XPU.
zero_grad([set_to_none])Sets gradients of all model parameters to zero.
Attributes
T_destinationdump_patches-
add_task(task)[source]¶ Add a single task to the network.
- Parameters
task (
Task) – ATaskto add- Return type
None
-
calculate_loss(X_dict, Y_dict)[source]¶ Calculate the loss for each task and the number of data points contributing.
- Parameters
X_dict (
Dict[str,Any]) – A dict of data fieldsY_dict (
Dict[str,Tensor]) – A dict from task names to label sets
- Returns
A dict of losses by task name and seen examples by task name
- Return type
Dict[str, torch.Tensor], Dict[str, float]
-
forward(X_dict, task_names)[source]¶ Do a forward pass through the network for all specified tasks.
- Parameters
X_dict (
Dict[str,Any]) – A dict of data fieldstask_names (
Iterable[str]) – The names of the tasks to execute the forward pass for
- Returns
A dict mapping each operation name to its corresponding output
- Return type
OutputDict
- Raises
TypeError – If an Operation input has an invalid type
ValueError – If a specified Operation failed to execute
-
load(model_path)[source]¶ Load a saved model from the provided file path and moves it to a device.
- Parameters
model_path (
str) – The path to a saved model- Return type
None
-
predict(dataloader, return_preds=False, remap_labels={})[source]¶ Calculate probabilities, (optionally) predictions, and pull out gold labels.
- Parameters
dataloader (
DictDataLoader) – A DictDataLoader to make predictions forreturn_preds (
bool) – If True, include predictions in the return dict (not just probabilities)remap_labels (
Dict[str,Optional[str]]) – A dict specifying which labels in the dataset’s Y_dict (key) to remap to a new task (value)
- Returns
A dictionary mapping label type (‘golds’, ‘probs’, ‘preds’) to values
- Return type
Dict[str, Dict[str, torch.Tensor]]
-
save(model_path)[source]¶ Save the model to the specified file path.
- Parameters
model_path (
str) – The path where the model should be saved- Raises
BaseException – If the torch.save() method fails
- Return type
None
-
score(dataloaders, remap_labels={}, as_dataframe=False)[source]¶ Calculate scores for the provided DictDataLoaders.
- Parameters
dataloaders (
List[DictDataLoader]) – A list of DictDataLoaders to calculate scores forremap_labels (
Dict[str,Optional[str]]) – A dict specifying which labels in the dataset’s Y_dict (key) to remap to a new task (value)as_dataframe (
bool) – A boolean indicating whether to return results as pandas DataFrame (True) or dict (False)
- Returns
A dictionary mapping metric names to corresponding scores Metric names will be of the form “task/dataset/split/metric”
- Return type
Dict[str, float]