Source code for torchwrench.nn.modules.tensor

#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""Module versions of tensor functions that do not already exists in PyTorch."""

from typing import List, Optional, Sequence, Tuple, Union, overload

import torch
from pythonwrench.collections import dump_dict
from pythonwrench.semver import Version
from torch import Tensor
from torch.nn import functional as F
from torch.types import Number

from torchwrench.nn.functional.make import DTypeLike, as_dtype
from torchwrench.utils import return_types

from .module import Module


[docs] class Abs(Module): """ Module version of :func:`~torch.abs`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.abs()
[docs] class Angle(Module): """ Module version of :func:`~torch.angle`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.angle()
[docs] class Exp(Module): """ Module version of :func:`~torch.exp`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.exp()
[docs] class Exp2(Module): """ Module version of :func:`~torch.exp2`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.exp2()
[docs] class FFT(Module): """ Module version of :func:`~torch.fft.fft`. """
[docs] def forward(self, x: Tensor) -> Tensor: return torch.fft.fft(x)
[docs] class IFFT(Module): """ Module version of :func:`~torch.fft.ifft`. """
[docs] def forward(self, x: Tensor) -> Tensor: return torch.fft.ifft(x)
[docs] class Imag(Module): """ Module version of :func:`~torch.Tensor.imag`. """ def __init__(self, *, return_zeros: bool = False) -> None: """Return the imaginary part of a complex tensor. Args: return_zeros: If True and the input is not a complex tensor, the module will return a tensor of same shape containing zeros. If False and the input is not a complex tensor, raises the default RuntimError of PyTorch. """ super().__init__() self.return_zeros = return_zeros
[docs] def forward(self, x: Tensor) -> Tensor: if self.return_zeros and not x.is_complex(): return torch.zeros_like(x) else: return x.imag
[docs] class Interpolate(Module): """ Module version of :func:`~torch.nn.functional.interpolate`. """ def __init__( self, size: Union[int, Tuple[int, ...], None] = None, scale_factor: Union[float, Tuple[float, ...], None] = None, mode: str = "nearest", align_corners: Optional[bool] = None, recompute_scale_factor: Optional[bool] = None, antialias: bool = False, ) -> None: super().__init__() self.size = size self.scale_factor = scale_factor self.mode = mode self.align_corners = align_corners self.recompute_scale_factor = recompute_scale_factor self.antialias = antialias
[docs] def forward(self, x: Tensor) -> Tensor: kwds = {} if Version(torch.__version__) >= Version("2.0.0"): kwds.update( recompute_scale_factor=self.recompute_scale_factor, antialias=self.antialias, ) return F.interpolate( x, size=self.size, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners, **kwds, )
[docs] class Log(Module): """ Module version of :func:`~torch.log`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.log()
[docs] class Log10(Module): """ Module version of :func:`~torch.log10`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.log10()
[docs] class Log2(Module): """ Module version of :func:`~torch.log2`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.log2()
[docs] class Max(Module): """ Module version of :func:`~torch.max`. """ def __init__( self, dim: Optional[int] = None, keepdim: bool = False, *, return_values: bool = True, return_indices: Optional[bool] = None, ) -> None: if return_indices is None: return_indices = dim is not None if not return_values and not return_indices: msg = f"Invalid combinaison of arguments {return_values=} and {return_indices=}. (at least one of them must be True)" raise ValueError(msg) if dim is None and keepdim: msg = f"Invalid combinaison of arguments {dim=} and {keepdim=}. (expected dim is not None or keepdim=False)" raise ValueError(msg) super().__init__() self.dim = dim self.return_values = return_values self.return_indices = return_indices self.keepdim = keepdim
[docs] def forward(self, x: Tensor) -> Union[Tensor, return_types.max]: if self.dim is None: index = x.argmax() values_indices = return_types.max([x.flatten()[index], index]) else: values_indices = x.max(dim=self.dim, keepdim=self.keepdim) if self.return_values and self.return_indices: return values_indices # type: ignore elif self.return_values: return values_indices.values elif self.return_indices: return values_indices.indices else: msg = f"Invalid combinaison of arguments {self.return_values=} and {self.return_indices=}. (at least one of them must be True)" raise ValueError(msg)
[docs] def extra_repr(self) -> str: return dump_dict( dict( dim=self.dim, return_values=self.return_values, return_indices=self.return_indices, keepdim=self.keepdim, ), )
[docs] class Mean(Module): """ Module version of :func:`~torch.mean`. """ def __init__( self, dim: Optional[int] = None, keepdim: bool = False, dtype: DTypeLike = None, ) -> None: super().__init__() self.dim = dim self.keepdim = keepdim self.dtype = dtype
[docs] def forward(self, x: Tensor) -> Tensor: dtype = as_dtype(self.dtype) if (Version(torch.__version__) >= Version("2.0.0")) or (self.dim is not None): return x.mean(dim=self.dim, keepdim=self.keepdim, dtype=dtype) # type: ignore # support for older torch versions result = x.mean(dtype=dtype) if self.keepdim: return torch.full(x.shape, result.item(), dtype=dtype, device=x.device) else: return result
[docs] def extra_repr(self) -> str: return dump_dict( dict( dim=self.dim, keepdim=self.keepdim, dtype=self.dtype, ), ignore_lst=(None,), )
[docs] class Min(Module): """ Module version of :func:`~torch.min`. """ def __init__( self, dim: Optional[int] = None, keepdim: bool = False, *, return_values: bool = True, return_indices: Optional[bool] = None, ) -> None: if return_indices is None: return_indices = dim is not None if not return_values and not return_indices: msg = f"Invalid combinaison of arguments {return_values=} and {return_indices=}. (at least one of them must be True)" raise ValueError(msg) if dim is None and keepdim: msg = f"Invalid combinaison of arguments {dim=} and {keepdim=}. (expected dim is not None or keepdim=False)" raise ValueError(msg) super().__init__() self.dim = dim self.return_values = return_values self.return_indices = return_indices self.keepdim = keepdim
[docs] def forward(self, x: Tensor) -> Union[Tensor, return_types.min]: if self.dim is None: index = x.argmin() values_indices = return_types.min([x.flatten()[index], index]) else: values_indices = x.min(dim=self.dim, keepdim=self.keepdim) if self.return_values and self.return_indices: return values_indices # type: ignore elif self.return_values: return values_indices.values elif self.return_indices: return values_indices.indices else: msg = f"Invalid combinaison of arguments {self.return_values=} and {self.return_indices=}. (at least one of them must be True)" raise ValueError(msg)
[docs] def extra_repr(self) -> str: return dump_dict( dict( dim=self.dim, return_values=self.return_values, return_indices=self.return_indices, keepdim=self.keepdim, ), )
[docs] class Normalize(Module): """ Module version of :func:`~torch.nn.functional.normalize`. """ def __init__( self, p: float = 2.0, dim: int = 1, eps: float = 1e-12, ) -> None: super().__init__() self.p = p self.dim = dim self.eps = eps
[docs] def forward(self, x: Tensor) -> Tensor: return F.normalize(x, self.p, self.dim, self.eps)
[docs] def extra_repr(self) -> str: return dump_dict( dict( p=self.p, dim=self.dim, eps=self.eps, ) )
[docs] class Permute(Module): """ Module version of :func:`~torch.permute`. """ def __init__(self, *args: int) -> None: super().__init__() self.dims = tuple(args)
[docs] def forward(self, x: Tensor) -> Tensor: return x.permute(self.dims)
[docs] def extra_repr(self) -> str: return dump_dict( dict( dims=self.dims, ), fmt="{value}", )
[docs] class Pow(Module): """ Module version of :func:`~torch.Tensor.pow`. """ def __init__(self, exponent: Union[Number, Tensor]) -> None: super().__init__() self.exponent = exponent
[docs] def forward(self, x: Tensor) -> Tensor: return x.pow(self.exponent)
[docs] def extra_repr(self) -> str: return dump_dict(exponent=self.exponent)
[docs] class Real(Module): """ Module version of :func:`~torch.Tensor.real`. """
[docs] def forward(self, x: Tensor) -> Tensor: return x.real
[docs] class Repeat(Module): """ Module version of :func:`~torch.repeat`. """ def __init__(self, *repeats: int) -> None: super().__init__() self.repeats = repeats
[docs] def forward(self, x: Tensor) -> Tensor: return x.repeat(self.repeats)
[docs] def extra_repr(self) -> str: return dump_dict(repeats=self.repeats)
[docs] class RepeatInterleave(Module): """ Module version of :func:`~torch.repeat_interleave`. """ def __init__( self, repeats: Union[int, Tensor], dim: int, output_size: Optional[int] = None, ) -> None: super().__init__() self.repeats = repeats self.dim = dim self.output_size = output_size
[docs] def forward(self, x: Tensor) -> Tensor: return x.repeat_interleave(self.repeats, self.dim, output_size=self.output_size)
[docs] def extra_repr(self) -> str: return dump_dict( dict( repeats=self.repeats, dim=self.dim, output_size=self.output_size, ), ignore_lst=(None,), )
[docs] class Reshape(Module): """ Module version of :func:`~torch.reshape`. """ def __init__(self, *shape: int) -> None: super().__init__() self.shape = shape
[docs] def forward(self, x: Tensor) -> Tensor: return x.reshape(self.shape)
[docs] def extra_repr(self) -> str: return dump_dict( dict( shape=self.shape, ), )
[docs] class Sort(Module): """ Module version of :func:`~torch.Tensor.sort`. """ def __init__( self, dim: int = -1, descending: bool = False, *, return_values: bool = True, return_indices: bool = True, ) -> None: if not return_values and not return_indices: msg = f"Invalid combinaison of arguments {return_values=} and {return_indices=}. (at least one of them must be True)" raise ValueError(msg) super().__init__() self.dim = dim self.descending = descending self.return_values = return_values self.return_indices = return_indices
[docs] def forward(self, x: Tensor) -> Union[return_types.sort, Tensor]: result = x.sort(dim=self.dim, descending=self.descending) result = return_types.sort(result) if self.return_values and self.return_indices: return result elif self.return_values: return result.values elif self.return_indices: return result.indices else: msg = f"Invalid combinaison of arguments {self.return_values=} and {self.return_indices=}. (at least one of them must be True)" raise ValueError(msg)
[docs] class TensorTo(Module): """ Module version of :func:`~torch.Tensor.to`. """ def __init__(self, **kwargs) -> None: super().__init__() self.kwargs = kwargs
[docs] def forward(self, x: Tensor) -> Tensor: return x.to(**self.kwargs)
[docs] def extra_repr(self) -> str: return dump_dict(self.kwargs)
[docs] class ToList(Module): """ Module version of :func:`~torch.Tensor.tolist`. """
[docs] def forward(self, x: Tensor) -> List: return x.tolist()
[docs] class Transpose(Module): """ Module version of :func:`~torch.transpose`. """ def __init__(self, dim0: int, dim1: int, copy: bool = False) -> None: super().__init__() self.dim0 = dim0 self.dim1 = dim1 self.copy = copy
[docs] def forward(self, x: Tensor) -> Tensor: if self.copy and not hasattr(torch, "transpose_copy"): msg = f"Invalid argument {self.copy=} in torch {torch.__version__}." raise ValueError(msg) if self.copy: return torch.transpose_copy(x, self.dim0, self.dim1) # type: ignore else: return torch.transpose(x, self.dim0, self.dim1)
[docs] def extra_repr(self) -> str: return dump_dict( dict( dim0=self.dim0, dim1=self.dim1, ), fmt="{value}", )
[docs] class View(Module): @overload def __init__(self, dtype: torch.dtype, /) -> None: ... @overload def __init__(self, size: Sequence[int], /) -> None: ... @overload def __init__(self, *size: int) -> None: ... def __init__(self, *args) -> None: super().__init__() self.args = args
[docs] def forward(self, x: Tensor) -> Tensor: return x.view(*self.args)