snorkel.augmentation.MeanFieldPolicy

class snorkel.augmentation.MeanFieldPolicy(n_tfs, sequence_length=1, p=None, n_per_original=1, keep_original=True)[source]

Bases: snorkel.augmentation.policy.core.Policy

Sample sequences of TFs according to a distribution.

Samples sequences of indices of a specified length from a user-provided distribution. A distribution over TFs can be learned by a TANDA mean-field model, for example. See https://hazyresearch.github.io/snorkel/blog/tanda.html

Parameters
  • n_tfs (int) – Total number of TFs

  • sequence_length (int) – Number of TFs to run on each data point

  • p (Optional[Sequence[float]]) – Probability distribution from which to sample TF indices. Must have length n_tfs and be a valid distribution.

  • n_per_original (int) – Number of transformed data points per original

  • keep_original (bool) – Keep untransformed data point in augmented data set? Note that even if in-place modifications are made to the original data point by the TFs being applied, the original data point will remain unchanged.

n[source]

Total number of TFs

n_per_original[source]

See above

keep_original[source]

See above

sequence_length[source]

See above

__init__(n_tfs, sequence_length=1, p=None, n_per_original=1, keep_original=True)[source]

Initialize self. See help(type(self)) for accurate signature.

Return type

None

Methods

__init__(n_tfs[, sequence_length, p, …])

Initialize self.

generate()

Generate a sequence of TF indices by sampling from distribution.

generate_for_example()

Generate all sequences of TF indices for a single example.

generate()[source]

Generate a sequence of TF indices by sampling from distribution.

Returns

Indices of TFs to run on data point in order.

Return type

List[int]

generate_for_example()[source]

Generate all sequences of TF indices for a single example.

Generates n_per_original sequences, and adds an empty sequence if keep_original is True.

Returns

Sequences of indices of TFs to run on data point in order.

Return type

List[List[int]]