snorkel.analysis.get_label_buckets

snorkel.analysis.get_label_buckets(*y)[source]

Return data point indices bucketed by label combinations.

Parameters

*y – A list of np.ndarray of (int) labels

Returns

A mapping of each label bucket to a NumPy array of its corresponding indices

Return type

Dict[Tuple[int, ..], np.ndarray]

Example

A common use case is calling buckets = label_buckets(Y_gold, Y_pred) where Y_gold is a set of gold (i.e. ground truth) labels and Y_pred is a corresponding set of predicted labels.

>>> Y_gold = np.array([1, 1, 1, 0])
>>> Y_pred = np.array([1, 1, -1, -1])
>>> buckets = get_label_buckets(Y_gold, Y_pred)

The returned buckets[(i, j)] is a NumPy array of data point indices with true label i and predicted label j.

More generally, the returned indices within each bucket refer to the order of the labels that were passed in as function arguments.

>>> buckets[(1, 1)]  # true positives
array([0, 1])
>>> (1, 0) in buckets  # false positives
False
>>> (0, 1) in buckets  # false negatives
False
>>> (0, 0) in buckets  # true negatives
False
>>> buckets[(1, -1)]  # abstained positives
array([2])
>>> buckets[(0, -1)]  # abstained negatives
array([3])