snorkel.utils.filter_labels

snorkel.utils.filter_labels(label_dict, filter_dict)[source]

Filter out examples from arrays based on specified labels to filter.

The most common use of this method is to remove examples whose gold label is unknown (marked with a -1) or examples whose predictions were abstains (also -1) before calculating metrics.

NB: If an example matches the filter criteria for any label set, it will be removed from all label sets (so that the returned arrays are of the same size and still aligned).

Parameters
  • label_dict (Dict[str, ndarray]) – A mapping from label set name to the array of labels The arrays in a label_dict.values() are assumed to be aligned

  • filter_dict (Dict[str, List[int]]) – A mapping from label set name to the labels that should be filtered out for that label set

Returns

A mapping with the same keys as label_dict but with filtered arrays as values

Return type

Dict[str, np.ndarray]

Example

>>> golds = np.array([-1, 0, 0, 1, 0])
>>> preds = np.array([0, 0, 0, 1, -1])
>>> filtered = filter_labels(
...     label_dict={"golds": golds, "preds": preds},
...     filter_dict={"golds": [-1], "preds": [-1]}
... )
>>> filtered["golds"]
array([0, 0, 1])
>>> filtered["preds"]
array([0, 0, 1])