torchwrench.nn.functional.cropping module

torchwrench.nn.functional.cropping.crop_dim(x: Tensor, target_length: int, *, dim: int = -1, align: 'left' | 'right' | 'center' | 'random' = 'left', generator: Generator | None | 'default' | int = None) Tensor[source]

Generic function to crop a tensor along a single dimension.

Args:

x: Tensor to crop of with N dims of shape (…, D, …), where D is the size of the dim-th dimension. target_length: Target length for dim. dims: Axis/dim to crop. defaults to -1. align: Alignement for crop. generator: Random generator when align is “random”.

Returns:

Cropped tensor of N dims of (…, target_length, …).

torchwrench.nn.functional.cropping.crop_dims(x: Tensor, target_lengths: Iterable[int], *, dims: Iterable[int] | 'auto' | None = 'auto', aligns: 'left' | 'right' | 'center' | 'random' | Iterable['left' | 'right' | 'center' | 'random'] = 'left', generator: Generator | None | 'default' | int = None) Tensor[source]

Generic function to crop a tensor along multiple dimensions.

Args:

x: Tensor to crop of with N dims. target_lengths: List of target lengths for each dimension. The list has size M <= N. dims: Dimensions for each length. Must be of size M. If “auto”, creates a list of the M last dimensions. aligns: Alignement or list of alignements for each dimension of size M. generator: Random generator when aligns is “random”.

Returns:

Cropped tensor of N dims.