torchwrench.utils.data.split module¶
- torchwrench.utils.data.split.balanced_monolabel_split(targets_indices: ~torch.Tensor | ~typing.List[int], num_classes: int, lengths: ~typing.Iterable[float], generator: ~torch._C.Generator | None | ~typing.Literal['default'] | int = None, round_fn: ~typing.Callable[[float], int] = <built-in function floor>) list[list[int]][source]¶
Generate indices for a random dataset split while keeping the same multiclass distribution.
- Args:
targets: List of class indices of size (N,). num_classes: Number of classes. lengths: Ratios of the target splits. Values should be in range [0, 1]. generator: Torch Generator or seed to make this function deterministic. defaults to None. round_fn: Function to round ratios to integer sizes. defaults to math.floor.
- torchwrench.utils.data.split.random_split(num_samples_or_indices: int | ~typing.List[int] | ~torch.Tensor, lengths: ~typing.Iterable[float], generator: ~torch._C.Generator | None | ~typing.Literal['default'] | int = None, round_fn: ~typing.Callable[[float], int] = <built-in function floor>) list[list[int]][source]¶
Generate indices for a random dataset split.
- Args:
num_samples_or_indices: Number of total samples or list of indices to split. lengths: Ratios of the target splits. Values should be in range [0, 1]. generator: Torch Generator or seed to make this function deterministic. defaults to None. round_fn: Function to round ratios to integer sizes. defaults to math.floor.