#!/usr/bin/env python
# -*- coding: utf-8 -*-
import itertools
import math
from typing import (
Any,
Callable,
Iterable,
List,
Literal,
Mapping,
Optional,
Tuple,
TypeVar,
Union,
overload,
)
import torch
from pythonwrench.collections import all_eq as builtin_all_eq
from pythonwrench.collections import prod as builtin_prod
from pythonwrench.collections import unzip
from pythonwrench.functools import function_alias, identity
from pythonwrench.semver import Version
from pythonwrench.typing import BuiltinNumber, SupportsIterLen, T_BuiltinNumber
from pythonwrench.warnings import warn_once
from torch import ( # noqa: F401
Tensor,
equal,
initial_seed,
manual_seed,
matmul,
nn,
no_grad,
seed,
split,
where,
)
from torchwrench.extras.numpy import np
from torchwrench.extras.pandas import pd
from torchwrench.nn import functional as F
from torchwrench.types._typing import (
LongTensor,
ScalarLike,
T_TensorOrArray,
TensorOrArray,
)
from torchwrench.types.guards import is_scalar_like
from torchwrench.types.tensor_subclasses import Tensor0D, Tensor1D, Tensor2D, Tensor3D
from torchwrench.utils import return_types
T = TypeVar("T", covariant=True)
U = TypeVar("U", covariant=True)
[docs]
def count_parameters(
model: nn.Module,
*,
recurse: bool = True,
only_trainable: bool = False,
buffers: bool = False,
) -> int:
"""Returns the number of parameters in a module."""
params = (
param
for param in model.parameters(recurse)
if not only_trainable or param.requires_grad
)
if buffers:
params = itertools.chain(params, (buffer for buffer in model.buffers(recurse)))
num_params = sum(param.numel() for param in params)
return num_params
[docs]
def find(
value: Any,
x: Tensor,
*,
default: Union[None, Tensor, BuiltinNumber] = None,
dim: int = -1,
) -> LongTensor:
"""Return the index of the first occurrence of value in a tensor."""
if x.ndim == 0:
msg = f"Function 'find' does not supports 0-d tensors. (found {x.ndim=})"
raise ValueError(msg)
mask = x.eq(value)
contains = mask.any(dim=dim)
indices = mask.long().argmax(dim=dim)
if default is None:
if not contains.all():
raise RuntimeError(f"Cannot find {value=} in tensor.")
return indices # type: ignore
else:
output = torch.where(contains, indices, default)
return output # type: ignore
@overload
def get_ndim(
x: Union[ScalarLike, Tensor, np.ndarray, Iterable],
*,
use_first_for_list_tuple: bool = False,
return_indicator: Literal[False] = False,
return_default_on_invalid: Literal[False] = False,
default: Any = -1,
return_valid: Optional[bool] = None,
) -> int: ...
@overload
def get_ndim(
x: Union[ScalarLike, Tensor, np.ndarray, Iterable],
*,
use_first_for_list_tuple: bool = False,
return_indicator: Literal[False] = False,
return_default_on_invalid: bool,
default: U = -1,
return_valid: Optional[bool] = None,
) -> Union[int, U]: ...
@overload
def get_ndim(
x: Union[ScalarLike, Tensor, np.ndarray, Iterable],
*,
use_first_for_list_tuple: bool = False,
return_indicator: Literal[True],
return_default_on_invalid: Literal[False] = False,
default: Any = -1,
return_valid: Optional[bool] = None,
) -> return_types.ndim[int]: ...
@overload
def get_ndim(
x: Union[ScalarLike, Tensor, np.ndarray, Iterable],
*,
use_first_for_list_tuple: bool = False,
return_indicator: Literal[True],
return_default_on_invalid: bool,
default: U = -1,
return_valid: Optional[bool] = None,
) -> return_types.ndim[Union[int, U]]: ...
[docs]
def get_ndim(
x: Union[ScalarLike, Tensor, np.ndarray, Iterable],
*,
use_first_for_list_tuple: bool = False,
return_indicator: bool = False,
return_default_on_invalid: bool = False,
default: U = -1,
return_valid: Optional[bool] = None,
) -> Union[Union[int, U], return_types.ndim[Union[int, U]]]:
"""Scan first argument to return its number of dimension(s). Works recursively with Tensors, numpy arrays and builtins types instances.
Note: Sets and dicts are considered as scalars with a ndim equal to 0.
Args:
x: Input value to scan.
use_first_for_list_tuple: If True, use first value to determine ndim for list and tuple argument. Otherwise it will scan each value in argument to determine its shape. defaults to False.
return_indicator: If True, returns a tuple containing a boolean indicator if the data has an homogeneous ndim instead of raising a ValueError. defaults to False.
return_default_on_invalid: If True and return_indicator=False, returns the default value instead of raising a ValueError. defaults to False.
default: Value to return if input is a heterogeneous list/tuple. defaults to ().
return_valid: Deprecated. Use return_indicator instead.
Raises:
ValueError if input has an heterogeneous number of dimensions and return_valid=False.
TypeError if input has an unsupported type.
"""
if return_valid is not None:
msg = f"Deprecated argument {return_valid=}. Use return_indicator instead."
warn_once(msg)
return_indicator = return_valid
del return_valid
def _impl(
x: Union[ScalarLike, Tensor, np.ndarray, Iterable],
) -> Tuple[bool, Union[int, U]]:
if is_scalar_like(x):
return True, 0
elif isinstance(x, (Tensor, np.ndarray, np.generic, pd.DataFrame)):
return True, x.ndim
elif isinstance(x, (set, frozenset, dict)):
return True, 0
elif isinstance(x, (list, tuple)):
valids_and_ndims = unzip(_impl(xi) for xi in x)
if len(valids_and_ndims) == 0:
return True, 1
valids, ndims = valids_and_ndims
if (use_first_for_list_tuple and valids[0]) or (
all(valids) and builtin_all_eq(ndims)
):
return True, ndims[0] + 1 # type: ignore
else:
return False, default
else:
raise TypeError(f"Invalid argument type {type(x)}.")
valid, ndim = _impl(x)
if not valid and not return_indicator and not return_default_on_invalid:
msg = f"Invalid argument {x}. (cannot compute ndim for heterogeneous data)"
raise ValueError(msg)
if return_indicator:
ndim = return_types.shape(valid, ndim)
return ndim # type: ignore
[docs]
@function_alias(get_ndim)
def ndim(*args, **kwargs): ...
@overload
def get_shape(
x: Union[
ScalarLike, Tensor, np.ndarray, pd.DataFrame, list, tuple, set, frozenset, dict
],
*,
output_type: Callable[[Tuple[int, ...]], T] = identity,
use_first_for_list_tuple: bool = False,
return_indicator: Literal[False] = False,
return_default_on_invalid: Literal[False] = False,
default: Any = (),
return_valid: Optional[bool] = None,
) -> T: ...
@overload
def get_shape(
x: Union[
ScalarLike, Tensor, np.ndarray, pd.DataFrame, list, tuple, set, frozenset, dict
],
*,
output_type: Callable[[Tuple[int, ...]], T] = identity,
use_first_for_list_tuple: bool = False,
return_indicator: Literal[False] = False,
return_default_on_invalid: bool,
default: U = (),
return_valid: Optional[bool] = None,
) -> Union[T, U]: ...
@overload
def get_shape(
x: Union[
ScalarLike, Tensor, np.ndarray, pd.DataFrame, list, tuple, set, frozenset, dict
],
*,
output_type: Callable[[Tuple[int, ...]], T] = identity,
use_first_for_list_tuple: bool = False,
return_indicator: Literal[True],
return_default_on_invalid: Literal[False] = False,
default: Any = (),
return_valid: Optional[bool] = None,
) -> return_types.shape[T]: ...
@overload
def get_shape(
x: Union[
ScalarLike, Tensor, np.ndarray, pd.DataFrame, list, tuple, set, frozenset, dict
],
*,
output_type: Callable[[Tuple[int, ...]], T] = identity,
use_first_for_list_tuple: bool = False,
return_indicator: Literal[True],
return_default_on_invalid: bool,
default: U = (),
return_valid: Optional[bool] = None,
) -> return_types.shape[Union[T, U]]: ...
[docs]
def get_shape(
x: Union[
ScalarLike, Tensor, np.ndarray, pd.DataFrame, list, tuple, set, frozenset, dict
],
*,
output_type: Callable[[Tuple[int, ...]], T] = identity,
use_first_for_list_tuple: bool = False,
return_indicator: bool = False,
return_default_on_invalid: bool = False,
default: U = (),
return_valid: Optional[bool] = None,
) -> Union[T, U, return_types.shape[Union[T, U]]]:
"""Scan first argument to return its shape. Works recursively with Tensors, numpy arrays and builtins types instances.
Note: Sets and dicts are considered as scalars with a shape equal to ().
Args:
x: Input value to scan.
output_type: Output shape type. defaults to identity, which returns a tuple of ints.
use_first_for_list_tuple: If True, use first value in sequences to determine shape for list and tuple argument. Otherwise it will scan each value in argument to determine its shape. defaults to False.
return_indicator: If True, returns a tuple containing a boolean indicator if the data has an homogeneous shape instead of raising a ValueError. defaults to False.
return_default_on_invalid: If True and return_indicator=False, returns the default value instead of raising a ValueError. defaults to False.
default: Value to return if input is a heterogeneous list/tuple. This default value is NOT passed to the output_type() callable argument. defaults to ().
return_valid: Deprecated. Use return_indicator instead.
Raises:
ValueError: if input has an heterogeneous shape and return_valid=False.
TypeError: if input has an unsupported type.
"""
if return_valid is not None:
msg = f"Deprecated argument {return_valid=}. Use return_indicator instead."
warn_once(msg)
return_indicator = return_valid
del return_valid
def _impl(
x: Union[ScalarLike, Tensor, np.ndarray, Iterable],
) -> Tuple[bool, Union[Tuple[int, ...], U]]:
if is_scalar_like(x):
return True, ()
elif isinstance(x, (Tensor, np.ndarray, np.generic, pd.DataFrame)):
return True, tuple(x.shape)
elif isinstance(x, (set, frozenset, dict)):
return True, ()
elif isinstance(x, (list, tuple)):
valids_and_shapes = unzip(_impl(xi) for xi in x) # type: ignore
if len(valids_and_shapes) == 0:
return True, (0,)
valids, shapes = valids_and_shapes
if (use_first_for_list_tuple and valids[0]) or (
all(valids) and builtin_all_eq(shapes)
):
return True, (len(shapes),) + shapes[0] # type: ignore
else:
return False, default
else:
raise TypeError(f"Invalid argument type {type(x)}.")
valid, shape = _impl(x)
if not valid and not return_indicator and not return_default_on_invalid:
msg = f"Invalid argument {x}. (cannot compute shape for heterogeneous data)"
raise ValueError(msg)
if valid:
shape = output_type(shape) # type: ignore
if return_indicator:
shape = return_types.shape(valid, shape)
return shape # type: ignore
[docs]
@function_alias(get_shape)
def shape(*args, **kwargs): ...
[docs]
def ranks(x: Tensor, dim: int = -1, descending: bool = False) -> LongTensor:
"""Get the ranks of each value in range [0, x.shape[dim][."""
return x.argsort(dim, descending).argsort(dim) # type: ignore
[docs]
def nelement(x: Union[ScalarLike, Tensor, np.ndarray, Iterable]) -> int:
"""Returns the number of elements in Tensor-like object."""
if isinstance(x, Tensor):
return x.nelement()
elif isinstance(x, (np.ndarray, np.generic)):
return x.size
else:
return builtin_prod(get_shape(x))
@overload
def prod(
x: T_TensorOrArray,
*,
dim: Optional[int] = None,
start: Any = 1,
) -> T_TensorOrArray: ...
@overload
def prod(
x: Iterable[T_BuiltinNumber],
*,
dim: Any = None,
start: T_BuiltinNumber = 1,
) -> T_BuiltinNumber: ...
[docs]
def prod(
x: Union[TensorOrArray, Iterable[T_BuiltinNumber]],
*,
dim: Optional[int] = None,
start: T_BuiltinNumber = 1,
) -> Union[Tensor, T_BuiltinNumber]:
"""Returns the product of all elements in input."""
if isinstance(x, Tensor):
if dim is not None:
return torch.prod(x, dim=dim) * start
else:
return torch.prod(x) * start
elif isinstance(x, np.ndarray):
return np.prod(x, axis=dim, initial=start)
elif isinstance(x, Iterable):
if dim is not None:
msg = f"Invalid argument {dim=}. (expected None with {type(x)=})"
raise ValueError(msg)
return builtin_prod(x, start=start) # type: ignore
else:
msg = (
f"Invalid argument type {type(x)=}. (expected Tensor, ndarray or Iterable)"
)
raise TypeError(msg)
[docs]
def average_power(
x: T_TensorOrArray,
dim: Union[int, Tuple[int, ...], None] = -1,
) -> T_TensorOrArray:
"""Compute average power of a signal along a specified dim/axis."""
return (abs(x) ** 2).mean(dim) # type: ignore
[docs]
def mse(
x1: Tensor,
x2: Tensor,
*,
dim: Union[int, Tuple[int, ...], None] = None,
) -> Tensor:
"""Mean squared error function."""
if dim is not None or Version(torch.__version__) >= "2.0.0":
return ((x1 - x2) ** 2).mean(dim).sqrt() # type: ignore
else:
return ((x1 - x2) ** 2).mean().sqrt()
[docs]
def rmse(
x1: Tensor,
x2: Tensor,
*,
dim: Union[int, Tuple[int, ...], None] = None,
) -> Tensor:
"""Root mean squared error function."""
return mse(x1, x2, dim=dim).sqrt()
[docs]
def deep_equal(x: T, y: T, *args: T) -> bool:
"""Recursive comparison between objects.
Supports Scalar-like, NDArrays, Tensors, DataFrames, Mapping and List-like objects.
Unlike default equal, NaNs values are considered equal.
Tensors and NDArrays of different shapes are supported and returns False.
"""
others = (y,) + args
for other in others:
result = _deep_equal_binary(x, other)
if not result:
return False
return True
def _deep_equal_binary(x: T, y: T) -> bool:
if is_scalar_like(x) and is_scalar_like(y):
x_isnan = math.isnan(x) if F.is_floating_point(x) else False
y_isnan = math.isnan(y) if F.is_floating_point(y) else False
return (x_isnan and y_isnan) or F.to_item(x == y) # type: ignore
if isinstance(x, Tensor) and isinstance(y, Tensor):
x_isnan = x.isnan()
y_isnan = y.isnan()
return (
(x.shape == y.shape)
and bool((x_isnan == y_isnan).all().item())
and torch.equal(x[~x_isnan], y[~y_isnan])
)
if isinstance(x, np.ndarray) and isinstance(y, np.ndarray):
x_isnan = (
np.isnan(x)
if F.is_floating_point(x)
else np.full(x.shape, False, dtype=bool)
)
y_isnan = (
np.isnan(y)
if F.is_floating_point(y)
else np.full(y.shape, False, dtype=bool)
)
return (
(x.shape == y.shape)
and (x_isnan == y_isnan).all().item()
and np.equal(x[~x_isnan], y[~y_isnan]).all().item()
)
if isinstance(x, pd.DataFrame) and isinstance(y, pd.DataFrame):
if not (deep_equal(x.index, y.index) and deep_equal(x.columns, y.columns)):
return False
x_isna = x.isna()
y_isna = y.isna()
return (x_isna == y_isna).all(axis=None).item() and (
(x == y) | x_isna | y_isna
).all(axis=None).item() # type: ignore
if isinstance(x, Mapping) and isinstance(y, Mapping):
return deep_equal(list(x.items()), list(y.items()))
if isinstance(x, SupportsIterLen) and isinstance(y, SupportsIterLen):
return len(x) == len(y) and all(deep_equal(xi, yi) for xi, yi in zip(x, y))
return (type(x) is type(y)) and (x == y) # type: ignore
@overload
def stack(
tensors: Union[List[Tensor0D], Tuple[Tensor0D, ...]],
dim: int = 0,
*,
out: Optional[Tensor1D] = None,
) -> Tensor1D: ...
@overload
def stack(
tensors: Union[List[Tensor1D], Tuple[Tensor1D, ...]],
dim: int = 0,
*,
out: Optional[Tensor2D] = None,
) -> Tensor2D: ...
@overload
def stack(
tensors: Union[List[Tensor2D], Tuple[Tensor2D, ...]],
dim: int = 0,
*,
out: Optional[Tensor3D] = None,
) -> Tensor3D: ...
@overload
def stack(
tensors: Union[List[Tensor], Tuple[Tensor, ...]],
dim: int = 0,
*,
out: Optional[Tensor] = None,
) -> Tensor: ...
[docs]
def stack(
tensors,
dim: int = 0,
*,
out: Optional[Tensor] = None,
) -> Tensor: # type: ignore
return torch.stack(tensors, dim=dim, out=out)
[docs]
def cat(
tensors: Union[List[Tensor], Tuple[Tensor, ...]],
dim: int = 0,
*,
out: Optional[Tensor] = None,
) -> Tensor:
return torch.cat(tensors, dim=dim, out=out)
[docs]
@function_alias(cat)
def concat(*args, **kwargs): ...