torchwrench.utils.data package¶
-
class torchwrench.utils.data.AdvancedCollateDict(pad_values: dict[str, Any] | None =
None, include_keys: str | Pattern | Iterable[str | Pattern] | None =None, exclude_keys: str | Pattern | Iterable[str | Pattern] | None =None, key_mode: 'intersect' | 'same' | 'union' ='same')[source]¶ Bases:
objectAdvanced collate object for
DataLoader.Merge lists in dicts into a single dict of lists. Audio will be padded if a fill pad_values is given in __init__.
Example¶>>> collate = AdvancedCollateDict({"audio": 0.0}) >>> loader = DataLoader(..., collate_fn=collate) >>> next(iter(loader)) ... {"audio": tensor([[...]]), ...}
-
class torchwrench.utils.data.BalancedSampler(indices_per_class: Sequence[Sequence[int]], n_max_iterations: int, shuffle: bool =
True, seed: Generator | None | 'default' | int =None)[source]¶ Bases:
Sampler
-
class torchwrench.utils.data.CollateDict(key_mode: 'intersect' | 'same' | 'union' =
'same')[source]¶ Bases:
objectCollate object for
DataLoader.Merge lists in dicts into a single dict of lists. No padding is applied.
-
class torchwrench.utils.data.SubsetCycleSampler(indices: Tensor | Iterable[int], n_max_iterations: int | 'inf' =
'inf', shuffle: bool =True, seed: Generator | None | 'default' | int =None)[source]¶
- torchwrench.utils.data.balanced_monolabel_split(targets_indices: ~torch.Tensor | ~typing.List[int], num_classes: int, lengths: ~typing.Iterable[float], generator: ~torch._C.Generator | None | ~typing.Literal['default'] | int = None, round_fn: ~typing.Callable[[float], int] = <built-in function floor>) list[list[int]][source]¶
Generate indices for a random dataset split while keeping the same multiclass distribution.
- Args:
targets: List of class indices of size (N,). num_classes: Number of classes. lengths: Ratios of the target splits. Values should be in range [0, 1]. generator: Torch Generator or seed to make this function deterministic. defaults to None. round_fn: Function to round ratios to integer sizes. defaults to math.floor.
- torchwrench.utils.data.get_auto_num_cpus() int¶
Returns the number of CPUs available for the current process on Linux-based platforms.
On Windows and MAC OS, this will just return the number of logical CPUs on this machine. If the number of CPUs cannot be detected, returns 0.
- torchwrench.utils.data.get_auto_num_gpus() int¶
Return the number of GPUs available.
Note
This API will NOT poison fork if NVML discovery succeeds. See Poison fork in multiprocessing for more details.
- torchwrench.utils.data.random_split(num_samples_or_indices: int | ~typing.List[int] | ~torch.Tensor, lengths: ~typing.Iterable[float], generator: ~torch._C.Generator | None | ~typing.Literal['default'] | int = None, round_fn: ~typing.Callable[[float], int] = <built-in function floor>) list[list[int]][source]¶
Generate indices for a random dataset split.
- Args:
num_samples_or_indices: Number of total samples or list of indices to split. lengths: Ratios of the target splits. Values should be in range [0, 1]. generator: Torch Generator or seed to make this function deterministic. defaults to None. round_fn: Function to round ratios to integer sizes. defaults to math.floor.
Subpackages¶
- torchwrench.utils.data.dataset package
- torchwrench.utils.data.dataset.DatasetSlicer
- torchwrench.utils.data.dataset.DatasetSlicerWrapper
- torchwrench.utils.data.dataset.EmptyDataset
- torchwrench.utils.data.dataset.IterableSubset
- torchwrench.utils.data.dataset.IterableTransformWrapper
- torchwrench.utils.data.dataset.IterableWrapper
- torchwrench.utils.data.dataset.Subset
- torchwrench.utils.data.dataset.TabularDataset
- torchwrench.utils.data.dataset.TransformWrapper
- torchwrench.utils.data.dataset.Wrapper
- Submodules
- torchwrench.utils.data.dataset.slicer module
- torchwrench.utils.data.dataset.tabular module
- torchwrench.utils.data.dataset.wrapper module
- torchwrench.utils.data.dataset.wrapper.EmptyDataset
- torchwrench.utils.data.dataset.wrapper.IterableSubset
- torchwrench.utils.data.dataset.wrapper.IterableTransformWrapper
- torchwrench.utils.data.dataset.wrapper.IterableWrapper
- torchwrench.utils.data.dataset.wrapper.Subset
- torchwrench.utils.data.dataset.wrapper.TransformWrapper
- torchwrench.utils.data.dataset.wrapper.Wrapper