snorkel.classification.cross_entropy_with_probs¶
-
snorkel.classification.
cross_entropy_with_probs
(input, target, weight=None, reduction='mean')[source]¶ Calculate cross-entropy loss when targets are probabilities (floats), not ints.
PyTorch’s F.cross_entropy() method requires integer labels; it does accept probabilistic labels. We can, however, simulate such functionality with a for loop, calculating the loss contributed by each class and accumulating the results. Libraries such as keras do not require this workaround, as methods like “categorical_crossentropy” accept float labels natively.
Note that the method signature is intentionally very similar to F.cross_entropy() so that it can be used as a drop-in replacement when target labels are changed from from a 1D tensor of ints to a 2D tensor of probabilities.
- Parameters
input (
Tensor
) – A [num_points, num_classes] tensor of logitstarget (
Tensor
) – A [num_points, num_classes] tensor of probabilistic target labelsweight (
Optional
[Tensor
]) – An optional [num_classes] array of weights to multiply the loss by per classreduction (
str
) – One of “none”, “mean”, “sum”, indicating whether to return one loss per data point, the mean loss, or the sum of losses
- Returns
The calculated loss
- Return type
torch.Tensor
- Raises
ValueError – If an invalid reduction keyword is submitted