Source code for torchwrench.nn.functional.cropping
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Iterable, List, Literal, Union, get_args
import torch
from torch import Tensor
from torchwrench.core.make import GeneratorLike, as_generator
CropAlign = Literal["left", "right", "center", "random"]
[docs]
def crop_dim(
x: Tensor,
target_length: int,
*,
dim: int = -1,
align: CropAlign = "left",
generator: GeneratorLike = None,
) -> Tensor:
"""Generic function to crop a tensor along a single dimension.
Args:
x: Tensor to crop of with N dims of shape (..., D, ...), where D is the size of the dim-th dimension.
target_length: Target length for dim.
dims: Axis/dim to crop. defaults to -1.
align: Alignement for crop.
generator: Random generator when align is "random".
Returns:
Cropped tensor of N dims of (..., target_length, ...).
"""
return crop_dims(
x,
[target_length],
dims=[dim],
aligns=[align],
generator=generator,
)
[docs]
def crop_dims(
x: Tensor,
target_lengths: Iterable[int],
*,
dims: Union[Iterable[int], Literal["auto"], None] = "auto",
aligns: Union[CropAlign, Iterable[CropAlign]] = "left",
generator: GeneratorLike = None,
) -> Tensor:
"""Generic function to crop a tensor along multiple dimensions.
Args:
x: Tensor to crop of with N dims.
target_lengths: List of target lengths for each dimension. The list has size M <= N.
dims: Dimensions for each length. Must be of size M. If "auto", creates a list of the M last dimensions.
aligns: Alignement or list of alignements for each dimension of size M.
generator: Random generator when aligns is "random".
Returns:
Cropped tensor of N dims.
"""
target_lengths = list(target_lengths)
aligns_lst: List[CropAlign]
if isinstance(aligns, str):
aligns_lst = [aligns] * len(target_lengths)
else:
aligns_lst = list(aligns)
del aligns
if dims == "auto" or dims is None:
dims = list(range(-len(target_lengths), 0))
else:
dims = list(dims)
generator = as_generator(generator)
if len(target_lengths) != len(dims):
msg = f"Invalid number of targets lengths ({len(target_lengths)}) with the number of dimensions ({len(dims)})."
raise ValueError(msg)
if len(aligns_lst) != len(dims):
msg = f"Invalid number of aligns ({len(aligns_lst)}) with the number of dimensions ({len(dims)})."
raise ValueError(msg)
slices = [slice(None)] * len(x.shape)
for target_length, dim, align in zip(target_lengths, dims, aligns_lst):
if x.shape[dim] <= target_length:
continue
if align == "left":
start = 0
end = target_length
elif align == "right":
start = x.shape[dim] - target_length
end = None
elif align == "center":
diff = x.shape[dim] - target_length
start = diff // 2 + diff % 2
end = start + target_length
elif align == "random":
diff = x.shape[dim] - target_length
start = torch.randint(low=0, high=diff, size=(), generator=generator).item()
end = start + target_length
else:
msg = f"Invalid argument {align=}. (expected one of {get_args(CropAlign)})"
raise ValueError(msg)
slices[dim] = slice(start, end)
slices = tuple(slices)
x = x[slices]
return x