snorkel.augmentation.MeanFieldPolicy¶
-
class
snorkel.augmentation.
MeanFieldPolicy
(n_tfs, sequence_length=1, p=None, n_per_original=1, keep_original=True)[source]¶ Bases:
snorkel.augmentation.policy.core.Policy
Sample sequences of TFs according to a distribution.
Samples sequences of indices of a specified length from a user-provided distribution. A distribution over TFs can be learned by a TANDA mean-field model, for example. See https://hazyresearch.github.io/snorkel/blog/tanda.html
- Parameters
n_tfs (
int
) – Total number of TFssequence_length (
int
) – Number of TFs to run on each data pointp (
Optional
[Sequence
[float
]]) – Probability distribution from which to sample TF indices. Must have lengthn_tfs
and be a valid distribution.n_per_original (
int
) – Number of transformed data points per originalkeep_original (
bool
) – Keep untransformed data point in augmented data set? Note that even if in-place modifications are made to the original data point by the TFs being applied, the original data point will remain unchanged.
-
__init__
(n_tfs, sequence_length=1, p=None, n_per_original=1, keep_original=True)[source]¶ Initialize self. See help(type(self)) for accurate signature.
- Return type
None
Methods
__init__
(n_tfs[, sequence_length, p, …])Initialize self.
generate
()Generate a sequence of TF indices by sampling from distribution.
Generate all sequences of TF indices for a single example.