#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Module versions of tensor functions that do not already exists in PyTorch."""
from typing import List, Optional, Sequence, Tuple, Union, overload
import torch
from pythonwrench.collections import dump_dict
from pythonwrench.semver import Version
from torch import Tensor
from torch.nn import functional as F
from torch.types import Number
from torchwrench.nn.functional.make import DTypeLike, as_dtype
from torchwrench.utils import return_types
from .module import Module
[docs]
class Abs(Module):
"""
Module version of :func:`~torch.abs`.
"""
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.abs()
[docs]
class Angle(Module):
"""
Module version of :func:`~torch.angle`.
"""
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.angle()
[docs]
class Exp(Module):
"""
Module version of :func:`~torch.exp`.
"""
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.exp()
[docs]
class Exp2(Module):
"""
Module version of :func:`~torch.exp2`.
"""
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.exp2()
[docs]
class FFT(Module):
"""
Module version of :func:`~torch.fft.fft`.
"""
[docs]
def forward(self, x: Tensor) -> Tensor:
return torch.fft.fft(x)
[docs]
class IFFT(Module):
"""
Module version of :func:`~torch.fft.ifft`.
"""
[docs]
def forward(self, x: Tensor) -> Tensor:
return torch.fft.ifft(x)
[docs]
class Imag(Module):
"""
Module version of :func:`~torch.Tensor.imag`.
"""
def __init__(self, *, return_zeros: bool = False) -> None:
"""Return the imaginary part of a complex tensor.
Args:
return_zeros:
If True and the input is not a complex tensor, the module will return a tensor of same shape containing zeros.
If False and the input is not a complex tensor, raises the default RuntimError of PyTorch.
"""
super().__init__()
self.return_zeros = return_zeros
[docs]
def forward(self, x: Tensor) -> Tensor:
if self.return_zeros and not x.is_complex():
return torch.zeros_like(x)
else:
return x.imag
[docs]
class Interpolate(Module):
"""
Module version of :func:`~torch.nn.functional.interpolate`.
"""
def __init__(
self,
size: Union[int, Tuple[int, ...], None] = None,
scale_factor: Union[float, Tuple[float, ...], None] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
) -> None:
super().__init__()
self.size = size
self.scale_factor = scale_factor
self.mode = mode
self.align_corners = align_corners
self.recompute_scale_factor = recompute_scale_factor
self.antialias = antialias
[docs]
def forward(self, x: Tensor) -> Tensor:
kwds = {}
if Version(torch.__version__) >= Version("2.0.0"):
kwds.update(
recompute_scale_factor=self.recompute_scale_factor,
antialias=self.antialias,
)
return F.interpolate(
x,
size=self.size,
scale_factor=self.scale_factor,
mode=self.mode,
align_corners=self.align_corners,
**kwds,
)
[docs]
class Log(Module):
"""
Module version of :func:`~torch.log`.
"""
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.log()
[docs]
class Log10(Module):
"""
Module version of :func:`~torch.log10`.
"""
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.log10()
[docs]
class Log2(Module):
"""
Module version of :func:`~torch.log2`.
"""
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.log2()
[docs]
class Max(Module):
"""
Module version of :func:`~torch.max`.
"""
def __init__(
self,
dim: Optional[int] = None,
keepdim: bool = False,
*,
return_values: bool = True,
return_indices: Optional[bool] = None,
) -> None:
if return_indices is None:
return_indices = dim is not None
if not return_values and not return_indices:
msg = f"Invalid combinaison of arguments {return_values=} and {return_indices=}. (at least one of them must be True)"
raise ValueError(msg)
if dim is None and keepdim:
msg = f"Invalid combinaison of arguments {dim=} and {keepdim=}. (expected dim is not None or keepdim=False)"
raise ValueError(msg)
super().__init__()
self.dim = dim
self.return_values = return_values
self.return_indices = return_indices
self.keepdim = keepdim
[docs]
def forward(self, x: Tensor) -> Union[Tensor, return_types.max]:
if self.dim is None:
index = x.argmax()
values_indices = return_types.max([x.flatten()[index], index])
else:
values_indices = x.max(dim=self.dim, keepdim=self.keepdim)
if self.return_values and self.return_indices:
return values_indices # type: ignore
elif self.return_values:
return values_indices.values
elif self.return_indices:
return values_indices.indices
else:
msg = f"Invalid combinaison of arguments {self.return_values=} and {self.return_indices=}. (at least one of them must be True)"
raise ValueError(msg)
[docs]
class Mean(Module):
"""
Module version of :func:`~torch.mean`.
"""
def __init__(
self,
dim: Optional[int] = None,
keepdim: bool = False,
dtype: DTypeLike = None,
) -> None:
super().__init__()
self.dim = dim
self.keepdim = keepdim
self.dtype = dtype
[docs]
def forward(self, x: Tensor) -> Tensor:
dtype = as_dtype(self.dtype)
if (Version(torch.__version__) >= Version("2.0.0")) or (self.dim is not None):
return x.mean(dim=self.dim, keepdim=self.keepdim, dtype=dtype) # type: ignore
# support for older torch versions
result = x.mean(dtype=dtype)
if self.keepdim:
return torch.full(x.shape, result.item(), dtype=dtype, device=x.device)
else:
return result
[docs]
class Min(Module):
"""
Module version of :func:`~torch.min`.
"""
def __init__(
self,
dim: Optional[int] = None,
keepdim: bool = False,
*,
return_values: bool = True,
return_indices: Optional[bool] = None,
) -> None:
if return_indices is None:
return_indices = dim is not None
if not return_values and not return_indices:
msg = f"Invalid combinaison of arguments {return_values=} and {return_indices=}. (at least one of them must be True)"
raise ValueError(msg)
if dim is None and keepdim:
msg = f"Invalid combinaison of arguments {dim=} and {keepdim=}. (expected dim is not None or keepdim=False)"
raise ValueError(msg)
super().__init__()
self.dim = dim
self.return_values = return_values
self.return_indices = return_indices
self.keepdim = keepdim
[docs]
def forward(self, x: Tensor) -> Union[Tensor, return_types.min]:
if self.dim is None:
index = x.argmin()
values_indices = return_types.min([x.flatten()[index], index])
else:
values_indices = x.min(dim=self.dim, keepdim=self.keepdim)
if self.return_values and self.return_indices:
return values_indices # type: ignore
elif self.return_values:
return values_indices.values
elif self.return_indices:
return values_indices.indices
else:
msg = f"Invalid combinaison of arguments {self.return_values=} and {self.return_indices=}. (at least one of them must be True)"
raise ValueError(msg)
[docs]
class Normalize(Module):
"""
Module version of :func:`~torch.nn.functional.normalize`.
"""
def __init__(
self,
p: float = 2.0,
dim: int = 1,
eps: float = 1e-12,
) -> None:
super().__init__()
self.p = p
self.dim = dim
self.eps = eps
[docs]
def forward(self, x: Tensor) -> Tensor:
return F.normalize(x, self.p, self.dim, self.eps)
[docs]
class Permute(Module):
"""
Module version of :func:`~torch.permute`.
"""
def __init__(self, *args: int) -> None:
super().__init__()
self.dims = tuple(args)
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.permute(self.dims)
[docs]
class Pow(Module):
"""
Module version of :func:`~torch.Tensor.pow`.
"""
def __init__(self, exponent: Union[Number, Tensor]) -> None:
super().__init__()
self.exponent = exponent
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.pow(self.exponent)
[docs]
class Real(Module):
"""
Module version of :func:`~torch.Tensor.real`.
"""
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.real
[docs]
class Repeat(Module):
"""
Module version of :func:`~torch.repeat`.
"""
def __init__(self, *repeats: int) -> None:
super().__init__()
self.repeats = repeats
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.repeat(self.repeats)
[docs]
class RepeatInterleave(Module):
"""
Module version of :func:`~torch.repeat_interleave`.
"""
def __init__(
self,
repeats: Union[int, Tensor],
dim: int,
output_size: Optional[int] = None,
) -> None:
super().__init__()
self.repeats = repeats
self.dim = dim
self.output_size = output_size
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.repeat_interleave(self.repeats, self.dim, output_size=self.output_size)
[docs]
class Reshape(Module):
"""
Module version of :func:`~torch.reshape`.
"""
def __init__(self, *shape: int) -> None:
super().__init__()
self.shape = shape
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.reshape(self.shape)
[docs]
class Sort(Module):
"""
Module version of :func:`~torch.Tensor.sort`.
"""
def __init__(
self,
dim: int = -1,
descending: bool = False,
*,
return_values: bool = True,
return_indices: bool = True,
) -> None:
if not return_values and not return_indices:
msg = f"Invalid combinaison of arguments {return_values=} and {return_indices=}. (at least one of them must be True)"
raise ValueError(msg)
super().__init__()
self.dim = dim
self.descending = descending
self.return_values = return_values
self.return_indices = return_indices
[docs]
def forward(self, x: Tensor) -> Union[return_types.sort, Tensor]:
result = x.sort(dim=self.dim, descending=self.descending)
result = return_types.sort(result)
if self.return_values and self.return_indices:
return result
elif self.return_values:
return result.values
elif self.return_indices:
return result.indices
else:
msg = f"Invalid combinaison of arguments {self.return_values=} and {self.return_indices=}. (at least one of them must be True)"
raise ValueError(msg)
[docs]
class TensorTo(Module):
"""
Module version of :func:`~torch.Tensor.to`.
"""
def __init__(self, **kwargs) -> None:
super().__init__()
self.kwargs = kwargs
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.to(**self.kwargs)
[docs]
class ToList(Module):
"""
Module version of :func:`~torch.Tensor.tolist`.
"""
[docs]
def forward(self, x: Tensor) -> List:
return x.tolist()
[docs]
class Transpose(Module):
"""
Module version of :func:`~torch.transpose`.
"""
def __init__(self, dim0: int, dim1: int, copy: bool = False) -> None:
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
self.copy = copy
[docs]
def forward(self, x: Tensor) -> Tensor:
if self.copy and not hasattr(torch, "transpose_copy"):
msg = f"Invalid argument {self.copy=} in torch {torch.__version__}."
raise ValueError(msg)
if self.copy:
return torch.transpose_copy(x, self.dim0, self.dim1) # type: ignore
else:
return torch.transpose(x, self.dim0, self.dim1)
[docs]
class View(Module):
@overload
def __init__(self, dtype: torch.dtype, /) -> None: ...
@overload
def __init__(self, size: Sequence[int], /) -> None: ...
@overload
def __init__(self, *size: int) -> None: ...
def __init__(self, *args) -> None:
super().__init__()
self.args = args
[docs]
def forward(self, x: Tensor) -> Tensor:
return x.view(*self.args)