Source code for torchwrench.utils.data.sampler
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import itertools
from typing import Iterable, Iterator, List, Literal, Sequence, Union
import torch
from pythonwrench.semver import Version
from torch import Tensor
from torch.utils.data.sampler import Sampler
from torchwrench.nn.functional.transform import (
GeneratorLike,
as_generator,
as_tensor,
shuffled,
)
[docs]
class SubsetSampler(Sampler[int]):
def __init__(self, indices: Union[List[int], Tensor]) -> None:
"""
A sampler to load a subset of a dataset from indices.
Args:
indices: List of indices to return.
"""
if Version(str(torch.__version__)) < "2.10.0":
args = (None,)
else:
args = ()
indices = as_tensor(indices, dtype="long")
super().__init__(*args)
self._indices = indices
def __iter__(self) -> Iterator[int]:
return iter(self._indices.tolist())
def __len__(self) -> int:
return len(self._indices)
[docs]
class SubsetCycleSampler(Sampler[int]):
def __init__(
self,
indices: Union[Tensor, Iterable[int]],
n_max_iterations: Union[int, Literal["inf"]] = "inf",
shuffle: bool = True,
seed: GeneratorLike = None,
) -> None:
"""SubsetCycleSampler that cycle on indices indifinitely or until a number max of iterations is reached.
Args:
indices: The list of indices of the items.
n_max_iterations: The maximal number of iterations.
If "inf", any call to __len__ will raises a NotImplementedError exception.
defaults to "inf".
shuffle: If True, shuffle the indices at every len(indices).
defaults to True.
seed: Optional seed or generator used to shuffle indices.
defaults to None.
"""
if Version(str(torch.__version__)) < "2.10.0":
args = (None,)
else:
args = ()
indices = as_tensor(indices, dtype="long")
generator = as_generator(seed)
super().__init__(*args)
self._indices = indices
self._n_max_iterations = n_max_iterations
self._shuffle = shuffle
self._generator = generator
self._shuffle_indices()
def __iter__(self) -> Iterator[int]:
for i, idx in enumerate(itertools.cycle(self._indices)):
if i % len(self._indices) == len(self._indices) - 1:
self._shuffle_indices()
if isinstance(self._n_max_iterations, int) and i >= self._n_max_iterations:
break
yield idx.item() # type: ignore
def __len__(self) -> int:
if isinstance(self._n_max_iterations, int):
return self._n_max_iterations
elif self._n_max_iterations == "inf":
msg = "Infinite sampler does not have __len__() method."
raise NotImplementedError(msg)
else:
msg = f"Invalid argument {self._n_max_iterations=}."
raise ValueError(msg)
def _shuffle_indices(self) -> None:
if not self._shuffle:
return None
self._indices = shuffled(self._indices, generator=self._generator)
[docs]
class BalancedSampler(Sampler):
def __init__(
self,
indices_per_class: Sequence[Sequence[int]],
n_max_iterations: int,
shuffle: bool = True,
seed: GeneratorLike = None,
) -> None:
"""BalancedSampler class.
Args:
indices_per_class: List of indices per class index.
n_max_iterations: The maximal number of iterations.
If "inf", any call to __len__ will raises a NotImplementedError exception.
defaults to "inf".
shuffle: If True, shuffle the indices at every len(indices).
defaults to True.
seed: Optional seed or generator used to shuffle indices.
defaults to None.
"""
if Version(str(torch.__version__)) < "2.10.0":
args = (None,)
else:
args = ()
indices_per_class = [list(indices) for indices in indices_per_class]
for cls_idx, indices in enumerate(indices_per_class):
if len(indices) == 0:
msg = f"Found a class index {cls_idx} without any indices."
raise RuntimeError(msg)
max_idx = max(len(indices) for indices in indices_per_class)
pointers_per_class = [
list(range(len(indices))) for indices in indices_per_class
]
local_idx_per_class = [0 for _ in range(len(indices_per_class))]
generator = as_generator(seed)
super().__init__(*args)
self._cls_to_sample_indices = indices_per_class
self._n_max_iterations = n_max_iterations
self._shuffle = shuffle
self._generator = generator
self._max_idx = max_idx
self._pointers_per_class = pointers_per_class
self._local_idx_per_class = local_idx_per_class
self._shuffle_indices()
def __iter__(self) -> Iterator[int]:
global_idx = 0
n_classes = len(self._cls_to_sample_indices)
cls_indices = torch.randperm(n_classes, generator=self._generator)
for i, cls_idx in enumerate(itertools.cycle(range(n_classes))):
cls_idx = int(cls_indices[cls_idx].item())
if global_idx >= self._n_max_iterations:
break
if global_idx % self._max_idx == self._max_idx - 1:
self._shuffle_indices()
sample_indices = self._cls_to_sample_indices[cls_idx]
pointers = self._pointers_per_class[cls_idx]
pointer_idx = self._local_idx_per_class[cls_idx]
pointer = pointers[pointer_idx]
sample_idx = sample_indices[pointer]
yield sample_idx
self._local_idx_per_class[cls_idx] = (pointer_idx + 1) % len(pointers)
global_idx += 1
if self._shuffle and i == n_classes - 1:
cls_indices = shuffled(cls_indices, generator=self._generator)
def __len__(self) -> int:
return self._n_max_iterations
def _shuffle_indices(self) -> None:
if not self._shuffle:
return None
for i, pointers in enumerate(self._pointers_per_class):
pointers_pt = as_tensor(pointers)
pointers_pt = shuffled(pointers_pt, generator=self._generator)
self._pointers_per_class[i] = pointers_pt.tolist()