snorkel.classification.cross_entropy_with_probs

snorkel.classification.cross_entropy_with_probs(input, target, weight=None, reduction='mean')[source]

Calculate cross-entropy loss when targets are probabilities (floats), not ints.

PyTorch’s F.cross_entropy() method requires integer labels; it does accept probabilistic labels. We can, however, simulate such functionality with a for loop, calculating the loss contributed by each class and accumulating the results. Libraries such as keras do not require this workaround, as methods like “categorical_crossentropy” accept float labels natively.

Note that the method signature is intentionally very similar to F.cross_entropy() so that it can be used as a drop-in replacement when target labels are changed from from a 1D tensor of ints to a 2D tensor of probabilities.

Parameters
  • input (Tensor) – A [num_points, num_classes] tensor of logits

  • target (Tensor) – A [num_points, num_classes] tensor of probabilistic target labels

  • weight (Optional[Tensor]) – An optional [num_classes] array of weights to multiply the loss by per class

  • reduction (str) – One of “none”, “mean”, “sum”, indicating whether to return one loss per data point, the mean loss, or the sum of losses

Returns

The calculated loss

Return type

torch.Tensor

Raises

ValueError – If an invalid reduction keyword is submitted