Source code for torchwrench.nn.functional.others

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import itertools
import math
from typing import (
    Any,
    Callable,
    Iterable,
    List,
    Literal,
    Mapping,
    Optional,
    Tuple,
    TypeVar,
    Union,
    overload,
)

import torch
from pythonwrench.collections import all_eq as builtin_all_eq
from pythonwrench.collections import prod as builtin_prod
from pythonwrench.collections import unzip
from pythonwrench.functools import function_alias, identity
from pythonwrench.semver import Version
from pythonwrench.typing import BuiltinNumber, SupportsIterLen, T_BuiltinNumber
from pythonwrench.warnings import warn_once
from torch import (  # noqa: F401
    Tensor,
    equal,
    initial_seed,
    manual_seed,
    matmul,
    nn,
    no_grad,
    seed,
    split,
    where,
)

from torchwrench.extras.numpy import np
from torchwrench.extras.pandas import pd
from torchwrench.nn import functional as F
from torchwrench.types._typing import (
    LongTensor,
    ScalarLike,
    T_TensorOrArray,
    TensorOrArray,
)
from torchwrench.types.guards import is_scalar_like
from torchwrench.types.tensor_subclasses import Tensor0D, Tensor1D, Tensor2D, Tensor3D
from torchwrench.utils import return_types

T = TypeVar("T", covariant=True)
U = TypeVar("U", covariant=True)


[docs] def count_parameters( model: nn.Module, *, recurse: bool = True, only_trainable: bool = False, buffers: bool = False, ) -> int: """Returns the number of parameters in a module.""" params = ( param for param in model.parameters(recurse) if not only_trainable or param.requires_grad ) if buffers: params = itertools.chain(params, (buffer for buffer in model.buffers(recurse))) num_params = sum(param.numel() for param in params) return num_params
[docs] def find( value: Any, x: Tensor, *, default: Union[None, Tensor, BuiltinNumber] = None, dim: int = -1, ) -> LongTensor: """Return the index of the first occurrence of value in a tensor.""" if x.ndim == 0: msg = f"Function 'find' does not supports 0-d tensors. (found {x.ndim=})" raise ValueError(msg) mask = x.eq(value) contains = mask.any(dim=dim) indices = mask.long().argmax(dim=dim) if default is None: if not contains.all(): raise RuntimeError(f"Cannot find {value=} in tensor.") return indices # type: ignore else: output = torch.where(contains, indices, default) return output # type: ignore
@overload def get_ndim( x: Union[ScalarLike, Tensor, np.ndarray, Iterable], *, use_first_for_list_tuple: bool = False, return_indicator: Literal[False] = False, return_default_on_invalid: Literal[False] = False, default: Any = -1, return_valid: Optional[bool] = None, ) -> int: ... @overload def get_ndim( x: Union[ScalarLike, Tensor, np.ndarray, Iterable], *, use_first_for_list_tuple: bool = False, return_indicator: Literal[False] = False, return_default_on_invalid: bool, default: U = -1, return_valid: Optional[bool] = None, ) -> Union[int, U]: ... @overload def get_ndim( x: Union[ScalarLike, Tensor, np.ndarray, Iterable], *, use_first_for_list_tuple: bool = False, return_indicator: Literal[True], return_default_on_invalid: Literal[False] = False, default: Any = -1, return_valid: Optional[bool] = None, ) -> return_types.ndim[int]: ... @overload def get_ndim( x: Union[ScalarLike, Tensor, np.ndarray, Iterable], *, use_first_for_list_tuple: bool = False, return_indicator: Literal[True], return_default_on_invalid: bool, default: U = -1, return_valid: Optional[bool] = None, ) -> return_types.ndim[Union[int, U]]: ...
[docs] def get_ndim( x: Union[ScalarLike, Tensor, np.ndarray, Iterable], *, use_first_for_list_tuple: bool = False, return_indicator: bool = False, return_default_on_invalid: bool = False, default: U = -1, return_valid: Optional[bool] = None, ) -> Union[Union[int, U], return_types.ndim[Union[int, U]]]: """Scan first argument to return its number of dimension(s). Works recursively with Tensors, numpy arrays and builtins types instances. Note: Sets and dicts are considered as scalars with a ndim equal to 0. Args: x: Input value to scan. use_first_for_list_tuple: If True, use first value to determine ndim for list and tuple argument. Otherwise it will scan each value in argument to determine its shape. defaults to False. return_indicator: If True, returns a tuple containing a boolean indicator if the data has an homogeneous ndim instead of raising a ValueError. defaults to False. return_default_on_invalid: If True and return_indicator=False, returns the default value instead of raising a ValueError. defaults to False. default: Value to return if input is a heterogeneous list/tuple. defaults to (). return_valid: Deprecated. Use return_indicator instead. Raises: ValueError if input has an heterogeneous number of dimensions and return_valid=False. TypeError if input has an unsupported type. """ if return_valid is not None: msg = f"Deprecated argument {return_valid=}. Use return_indicator instead." warn_once(msg) return_indicator = return_valid del return_valid def _impl( x: Union[ScalarLike, Tensor, np.ndarray, Iterable], ) -> Tuple[bool, Union[int, U]]: if is_scalar_like(x): return True, 0 elif isinstance(x, (Tensor, np.ndarray, np.generic, pd.DataFrame)): return True, x.ndim elif isinstance(x, (set, frozenset, dict)): return True, 0 elif isinstance(x, (list, tuple)): valids_and_ndims = unzip(_impl(xi) for xi in x) if len(valids_and_ndims) == 0: return True, 1 valids, ndims = valids_and_ndims if (use_first_for_list_tuple and valids[0]) or ( all(valids) and builtin_all_eq(ndims) ): return True, ndims[0] + 1 # type: ignore else: return False, default else: raise TypeError(f"Invalid argument type {type(x)}.") valid, ndim = _impl(x) if not valid and not return_indicator and not return_default_on_invalid: msg = f"Invalid argument {x}. (cannot compute ndim for heterogeneous data)" raise ValueError(msg) if return_indicator: ndim = return_types.shape(valid, ndim) return ndim # type: ignore
[docs] @function_alias(get_ndim) def ndim(*args, **kwargs): ...
@overload def get_shape( x: Union[ ScalarLike, Tensor, np.ndarray, pd.DataFrame, list, tuple, set, frozenset, dict ], *, output_type: Callable[[Tuple[int, ...]], T] = identity, use_first_for_list_tuple: bool = False, return_indicator: Literal[False] = False, return_default_on_invalid: Literal[False] = False, default: Any = (), return_valid: Optional[bool] = None, ) -> T: ... @overload def get_shape( x: Union[ ScalarLike, Tensor, np.ndarray, pd.DataFrame, list, tuple, set, frozenset, dict ], *, output_type: Callable[[Tuple[int, ...]], T] = identity, use_first_for_list_tuple: bool = False, return_indicator: Literal[False] = False, return_default_on_invalid: bool, default: U = (), return_valid: Optional[bool] = None, ) -> Union[T, U]: ... @overload def get_shape( x: Union[ ScalarLike, Tensor, np.ndarray, pd.DataFrame, list, tuple, set, frozenset, dict ], *, output_type: Callable[[Tuple[int, ...]], T] = identity, use_first_for_list_tuple: bool = False, return_indicator: Literal[True], return_default_on_invalid: Literal[False] = False, default: Any = (), return_valid: Optional[bool] = None, ) -> return_types.shape[T]: ... @overload def get_shape( x: Union[ ScalarLike, Tensor, np.ndarray, pd.DataFrame, list, tuple, set, frozenset, dict ], *, output_type: Callable[[Tuple[int, ...]], T] = identity, use_first_for_list_tuple: bool = False, return_indicator: Literal[True], return_default_on_invalid: bool, default: U = (), return_valid: Optional[bool] = None, ) -> return_types.shape[Union[T, U]]: ...
[docs] def get_shape( x: Union[ ScalarLike, Tensor, np.ndarray, pd.DataFrame, list, tuple, set, frozenset, dict ], *, output_type: Callable[[Tuple[int, ...]], T] = identity, use_first_for_list_tuple: bool = False, return_indicator: bool = False, return_default_on_invalid: bool = False, default: U = (), return_valid: Optional[bool] = None, ) -> Union[T, U, return_types.shape[Union[T, U]]]: """Scan first argument to return its shape. Works recursively with Tensors, numpy arrays and builtins types instances. Note: Sets and dicts are considered as scalars with a shape equal to (). Args: x: Input value to scan. output_type: Output shape type. defaults to identity, which returns a tuple of ints. use_first_for_list_tuple: If True, use first value in sequences to determine shape for list and tuple argument. Otherwise it will scan each value in argument to determine its shape. defaults to False. return_indicator: If True, returns a tuple containing a boolean indicator if the data has an homogeneous shape instead of raising a ValueError. defaults to False. return_default_on_invalid: If True and return_indicator=False, returns the default value instead of raising a ValueError. defaults to False. default: Value to return if input is a heterogeneous list/tuple. This default value is NOT passed to the output_type() callable argument. defaults to (). return_valid: Deprecated. Use return_indicator instead. Raises: ValueError: if input has an heterogeneous shape and return_valid=False. TypeError: if input has an unsupported type. """ if return_valid is not None: msg = f"Deprecated argument {return_valid=}. Use return_indicator instead." warn_once(msg) return_indicator = return_valid del return_valid def _impl( x: Union[ScalarLike, Tensor, np.ndarray, Iterable], ) -> Tuple[bool, Union[Tuple[int, ...], U]]: if is_scalar_like(x): return True, () elif isinstance(x, (Tensor, np.ndarray, np.generic, pd.DataFrame)): return True, tuple(x.shape) elif isinstance(x, (set, frozenset, dict)): return True, () elif isinstance(x, (list, tuple)): valids_and_shapes = unzip(_impl(xi) for xi in x) # type: ignore if len(valids_and_shapes) == 0: return True, (0,) valids, shapes = valids_and_shapes if (use_first_for_list_tuple and valids[0]) or ( all(valids) and builtin_all_eq(shapes) ): return True, (len(shapes),) + shapes[0] # type: ignore else: return False, default else: raise TypeError(f"Invalid argument type {type(x)}.") valid, shape = _impl(x) if not valid and not return_indicator and not return_default_on_invalid: msg = f"Invalid argument {x}. (cannot compute shape for heterogeneous data)" raise ValueError(msg) if valid: shape = output_type(shape) # type: ignore if return_indicator: shape = return_types.shape(valid, shape) return shape # type: ignore
[docs] @function_alias(get_shape) def shape(*args, **kwargs): ...
[docs] def ranks(x: Tensor, dim: int = -1, descending: bool = False) -> LongTensor: """Get the ranks of each value in range [0, x.shape[dim][.""" return x.argsort(dim, descending).argsort(dim) # type: ignore
[docs] def nelement(x: Union[ScalarLike, Tensor, np.ndarray, Iterable]) -> int: """Returns the number of elements in Tensor-like object.""" if isinstance(x, Tensor): return x.nelement() elif isinstance(x, (np.ndarray, np.generic)): return x.size else: return builtin_prod(get_shape(x))
@overload def prod( x: T_TensorOrArray, *, dim: Optional[int] = None, start: Any = 1, ) -> T_TensorOrArray: ... @overload def prod( x: Iterable[T_BuiltinNumber], *, dim: Any = None, start: T_BuiltinNumber = 1, ) -> T_BuiltinNumber: ...
[docs] def prod( x: Union[TensorOrArray, Iterable[T_BuiltinNumber]], *, dim: Optional[int] = None, start: T_BuiltinNumber = 1, ) -> Union[Tensor, T_BuiltinNumber]: """Returns the product of all elements in input.""" if isinstance(x, Tensor): if dim is not None: return torch.prod(x, dim=dim) * start else: return torch.prod(x) * start elif isinstance(x, np.ndarray): return np.prod(x, axis=dim, initial=start) elif isinstance(x, Iterable): if dim is not None: msg = f"Invalid argument {dim=}. (expected None with {type(x)=})" raise ValueError(msg) return builtin_prod(x, start=start) # type: ignore else: msg = ( f"Invalid argument type {type(x)=}. (expected Tensor, ndarray or Iterable)" ) raise TypeError(msg)
[docs] def average_power( x: T_TensorOrArray, dim: Union[int, Tuple[int, ...], None] = -1, ) -> T_TensorOrArray: """Compute average power of a signal along a specified dim/axis.""" return (abs(x) ** 2).mean(dim) # type: ignore
[docs] def mse( x1: Tensor, x2: Tensor, *, dim: Union[int, Tuple[int, ...], None] = None, ) -> Tensor: """Mean squared error function.""" if dim is not None or Version(torch.__version__) >= "2.0.0": return ((x1 - x2) ** 2).mean(dim).sqrt() # type: ignore else: return ((x1 - x2) ** 2).mean().sqrt()
[docs] def rmse( x1: Tensor, x2: Tensor, *, dim: Union[int, Tuple[int, ...], None] = None, ) -> Tensor: """Root mean squared error function.""" return mse(x1, x2, dim=dim).sqrt()
[docs] def deep_equal(x: T, y: T, *args: T) -> bool: """Recursive comparison between objects. Supports Scalar-like, NDArrays, Tensors, DataFrames, Mapping and List-like objects. Unlike default equal, NaNs values are considered equal. Tensors and NDArrays of different shapes are supported and returns False. """ others = (y,) + args for other in others: result = _deep_equal_binary(x, other) if not result: return False return True
def _deep_equal_binary(x: T, y: T) -> bool: if is_scalar_like(x) and is_scalar_like(y): x_isnan = math.isnan(x) if F.is_floating_point(x) else False y_isnan = math.isnan(y) if F.is_floating_point(y) else False return (x_isnan and y_isnan) or F.to_item(x == y) # type: ignore if isinstance(x, Tensor) and isinstance(y, Tensor): x_isnan = x.isnan() y_isnan = y.isnan() return ( (x.shape == y.shape) and bool((x_isnan == y_isnan).all().item()) and torch.equal(x[~x_isnan], y[~y_isnan]) ) if isinstance(x, np.ndarray) and isinstance(y, np.ndarray): x_isnan = ( np.isnan(x) if F.is_floating_point(x) else np.full(x.shape, False, dtype=bool) ) y_isnan = ( np.isnan(y) if F.is_floating_point(y) else np.full(y.shape, False, dtype=bool) ) return ( (x.shape == y.shape) and (x_isnan == y_isnan).all().item() and np.equal(x[~x_isnan], y[~y_isnan]).all().item() ) if isinstance(x, pd.DataFrame) and isinstance(y, pd.DataFrame): if not (deep_equal(x.index, y.index) and deep_equal(x.columns, y.columns)): return False x_isna = x.isna() y_isna = y.isna() return (x_isna == y_isna).all(axis=None).item() and ( (x == y) | x_isna | y_isna ).all(axis=None).item() # type: ignore if isinstance(x, Mapping) and isinstance(y, Mapping): return deep_equal(list(x.items()), list(y.items())) if isinstance(x, SupportsIterLen) and isinstance(y, SupportsIterLen): return len(x) == len(y) and all(deep_equal(xi, yi) for xi, yi in zip(x, y)) return (type(x) is type(y)) and (x == y) # type: ignore @overload def stack( tensors: Union[List[Tensor0D], Tuple[Tensor0D, ...]], dim: int = 0, *, out: Optional[Tensor1D] = None, ) -> Tensor1D: ... @overload def stack( tensors: Union[List[Tensor1D], Tuple[Tensor1D, ...]], dim: int = 0, *, out: Optional[Tensor2D] = None, ) -> Tensor2D: ... @overload def stack( tensors: Union[List[Tensor2D], Tuple[Tensor2D, ...]], dim: int = 0, *, out: Optional[Tensor3D] = None, ) -> Tensor3D: ... @overload def stack( tensors: Union[List[Tensor], Tuple[Tensor, ...]], dim: int = 0, *, out: Optional[Tensor] = None, ) -> Tensor: ...
[docs] def stack( tensors, dim: int = 0, *, out: Optional[Tensor] = None, ) -> Tensor: # type: ignore return torch.stack(tensors, dim=dim, out=out)
[docs] def cat( tensors: Union[List[Tensor], Tuple[Tensor, ...]], dim: int = 0, *, out: Optional[Tensor] = None, ) -> Tensor: return torch.cat(tensors, dim=dim, out=out)
[docs] @function_alias(cat) def concat(*args, **kwargs): ...