snorkel.classification.DictDataset

class snorkel.classification.DictDataset(name, split, X_dict, Y_dict)[source]

Bases: torch.utils.data.dataset.Dataset

A dataset where both the data fields and labels are stored in as dictionaries.

Parameters
  • name (str) – The name of the dataset (e.g., this will be used to report metrics on a per-dataset basis)

  • split (str) – The name of the split that the data in this object represents

  • X_dict (Dict[str, Any]) – A map from field name to values (e.g., {“tokens”: …, “uids”: …})

  • Y_dict (Dict[str, Tensor]) – A map from task name to its corresponding set of labels

Raises

ValueError – All values in the Y_dict must be of type torch.Tensor

name[source]

See above

split[source]

See above

X_dict[source]

See above

Y_dict[source]

See above

__init__(name, split, X_dict, Y_dict)[source]

Initialize self. See help(type(self)) for accurate signature.

Return type

None

Methods

__init__(name, split, X_dict, Y_dict)

Initialize self.

from_tensors(X_tensor, Y_tensor, split[, …])

Initialize a DictDataset from PyTorch Tensors.

classmethod from_tensors(X_tensor, Y_tensor, split, input_data_key='input_data', task_name='task', dataset_name='SnorkelDataset')[source]

Initialize a DictDataset from PyTorch Tensors.

Parameters
  • X_tensor (Tensor) – Input data of shape [num_examples, feature_dim]

  • Y_tensor (Tensor) – Labels of shape [num_samples, num_classes]

  • split (str) – Name of data split corresponding to this dataset.

  • input_data_key (str) – Name of data field to initialize in X_dict

  • task_name (str) – Name of task and corresponding label key in Y_dict

  • dataset_name (str) – Name of DictDataset to be initialized; See __init__ above.

Returns

Class initialized with single task and label corresponding to input data

Return type

DictDataset