Add slice labels to dataloader and creates new slice tasks (including base slice).
Each slice will get two slice-specific heads: - an indicator head that learns to identify when DataPoints are in that slice - a predictor head that is trained on only members of that slice
The base task’s head is replaced by a master head that makes predictions based on a combination of the predictor heads’ predictions that are weighted by the indicator heads’ prediction confidences.
WARNING: The current implementation pollutes the module_pool—the indicator task’s module_pool includes predictor modules and vice versa since both are modified in place. This does not affect the result because the op sequences dictate which modules get used, and those do not include the extra modules. An alternative would be to make separate copies of the module pool for each, but that wastes time and memory extra copies of (potentially very large) modules that will be merged in a moment away in the model since they have the same name. We leave resolution of this issue for a future release.
Task) – Task for which we are adding slice tasks. As noted in the WARNING, this task’s module_pool will currently be modified in place for efficiency purposes.
str]) – List of slice names corresponding to the columns of the slice matrix.
Containins original base_task, pred/ind tasks for the base slice, and pred/ind tasks for each of the specified slice_names
- Return type