torchwrench.nn.functional.powerset module

torchwrench.nn.functional.powerset.build_powerset_mapping(num_classes: int, max_set_size: int, dtype: dtype | None | 'default' | str | DTypeEnum = None, device: device | None | 'default' | 'cuda_if_available' | str | int = None) Tensor2D[source]

Build powerset mapping matrix of shape (num_powerset_classes, num_classes).

torchwrench.nn.functional.powerset.get_num_powerset_classes(num_classes: int, max_set_size: int) int[source]
torchwrench.nn.functional.powerset.multilabel_to_powerset(multilabel: Tensor, *, num_classes: int, max_set_size: int) Tensor3D[source]
torchwrench.nn.functional.powerset.multilabel_to_powerset(multilabel: Tensor, *, mapping: Tensor) Tensor3D
Args:

multilabel: (batch_size, num_frames, num_classes) Tensor

Returns:

powerset: (batch_size, num_frames, num_powerset_classes) Tensor

torchwrench.nn.functional.powerset.powerset_to_multilabel(powerset: Tensor, soft: bool = False, *, num_classes: int, max_set_size: int) Tensor3D[source]
torchwrench.nn.functional.powerset.powerset_to_multilabel(powerset: Tensor, soft: bool = False, *, mapping: Tensor) Tensor3D
Args:

powerset: Powerset logits, probabilities or onehot tensor of shape (batch_size, num_frames, num_powerset_classes).

Returns:

multilabel: (batch_size, num_frames, num_classes) Tensor