snorkel.slicing.SliceCombinerModule

class snorkel.slicing.SliceCombinerModule(slice_ind_key='_ind_head', slice_pred_key='_pred_head', slice_pred_feat_key='_pred_transform', temperature=1.0)[source]

Bases: torch.nn.modules.module.Module

A module for combining the weighted representations learned by slices.

Intended for use with the MultitaskClassifier including:
  • Indicator operations

  • Prediction operations

  • Prediction transform features

NOTE: This module currently only handles binary labels.

Parameters
  • slice_ind_key (str) – Suffix of operation corresponding to the slice indicator heads

  • slice_pred_key (str) – Suffix of operation corresponding to the slice predictor heads

  • slice_pred_feat_key (str) – Suffix of operation corresponding to the slice predictor features heads

  • temperature (float) – Temperature constant for scaling the weighting between indicator prediction and predictor confidences: SoftMax(indicator_pred * predictor_confidence / tau)

slice_ind_key[source]

See above

slice_pred_key[source]

See above

slice_pred_feat_key[source]

See above

__init__(slice_ind_key='_ind_head', slice_pred_key='_pred_head', slice_pred_feat_key='_pred_transform', temperature=1.0)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Return type

None

Methods

__init__([slice_ind_key, slice_pred_key, …])

Initializes internal Module state, shared by both nn.Module and ScriptModule.

add_module(name, module)

Adds a child module to the current module.

apply(fn)

Applies fn recursively to every submodule (as returned by .children()) as well as self.

buffers([recurse])

Returns an iterator over module buffers.

children()

Returns an iterator over immediate children modules.

cpu()

Moves all model parameters and buffers to the CPU.

cuda([device])

Moves all model parameters and buffers to the GPU.

double()

Casts all floating point parameters and buffers to double datatype.

eval()

Sets the module in evaluation mode.

extra_repr()

Set the extra representation of the module

float()

Casts all floating point parameters and buffers to float datatype.

forward(output_dict)

Reweight and combine predictor representations given output dict.

half()

Casts all floating point parameters and buffers to half datatype.

load_state_dict(state_dict[, strict])

Copies parameters and buffers from state_dict into this module and its descendants.

modules()

Returns an iterator over all modules in the network.

named_buffers([prefix, recurse])

Returns an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.

named_children()

Returns an iterator over immediate children modules, yielding both the name of the module as well as the module itself.

named_modules([memo, prefix])

Returns an iterator over all modules in the network, yielding both the name of the module as well as the module itself.

named_parameters([prefix, recurse])

Returns an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.

parameters([recurse])

Returns an iterator over module parameters.

register_backward_hook(hook)

Registers a backward hook on the module.

register_buffer(name, tensor)

Adds a persistent buffer to the module.

register_forward_hook(hook)

Registers a forward hook on the module.

register_forward_pre_hook(hook)

Registers a forward pre-hook on the module.

register_parameter(name, param)

Adds a parameter to the module.

requires_grad_([requires_grad])

Change if autograd should record operations on parameters in this module.

share_memory()

state_dict([destination, prefix, keep_vars])

Returns a dictionary containing a whole state of the module.

to(*args, **kwargs)

Moves and/or casts the parameters and buffers.

train([mode])

Sets the module in training mode.

type(dst_type)

Casts all parameters and buffers to dst_type.

zero_grad()

Sets gradients of all model parameters to zero.

Attributes

dump_patches