#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Any, Iterable, List, Sized, Tuple, TypeVar, Union, overload
import torch
from pythonwrench.collections import all_eq as builtin_all_eq
from pythonwrench.collections import all_ne as builtin_all_ne
from pythonwrench.collections import is_sorted as builtin_is_sorted
from pythonwrench.functools import function_alias
from pythonwrench.typing import is_builtin_number, isinstance_generic
from torch import Tensor, is_tensor # noqa: F401
from typing_extensions import TypeGuard
from torchwrench.extras.numpy import (
ACCEPTED_NUMPY_DTYPES,
np,
numpy_all_eq,
numpy_all_ne,
numpy_is_complex,
numpy_is_floating_point,
)
from torchwrench.nn.functional.others import nelement
from torchwrench.types._typing import (
ComplexFloatingTensor,
FloatingTensor,
ScalarLike,
T_TensorOrArray,
Tensor0D,
TensorOrArray,
)
from torchwrench.types.guards import is_scalar_like
T = TypeVar("T")
U = TypeVar("U")
[docs]
def is_stackable(
tensors: Union[List[Any], Tuple[Any, ...]],
) -> TypeGuard[Union[List[Tensor], Tuple[Tensor, ...]]]:
"""Returns True if inputs can be passed to `torch.stack` function, i.e. contains a non-empty list or tuple of tensors with the same shape."""
if not isinstance_generic(tensors, List[Tensor]) and not isinstance_generic(
tensors, Tuple[Tensor, ...]
):
return False
if len(tensors) == 0:
return False
shape0 = tensors[0].shape
result = all(tensor.shape == shape0 for tensor in tensors[1:])
return result
[docs]
def is_convertible_to_tensor(x: Any) -> bool:
"""Returns True if inputs can be passed to `torch.as_tensor` function.
This function returns False for heterogeneous inputs like `[[], 1]`, but this kind of value can be accepted by `torch.as_tensor`.
"""
if isinstance(x, Tensor):
return True
else:
return __can_be_converted_to_tensor_nested(x)
def __can_be_converted_to_tensor_list_tuple(x: Union[List, Tuple]) -> bool:
if len(x) == 0:
return True
valid_items = all(__can_be_converted_to_tensor_nested(xi) for xi in x)
if not valid_items:
return False
# If all values are scalars-like items
if all((not isinstance(xi, Sized) or isinstance(xi, Tensor0D)) for xi in x):
return True
# If all values are sized items with same size
elif all(isinstance(xi, Sized) for xi in x):
len0 = len(x[0])
return all(len(xi) == len0 for xi in x[1:])
else:
return False
def __can_be_converted_to_tensor_nested(
x: Any,
) -> bool:
if is_builtin_number(x):
return True
elif isinstance(x, Tensor0D):
return True
elif isinstance(x, (np.ndarray, np.generic)) and x.dtype in ACCEPTED_NUMPY_DTYPES:
return True
elif isinstance(x, (List, Tuple)):
return __can_be_converted_to_tensor_list_tuple(x)
else:
return False
@overload
def is_floating_point(x: Tensor) -> TypeGuard[FloatingTensor]: ...
@overload
def is_floating_point(x: np.ndarray) -> TypeGuard[np.ndarray]: ...
@overload
def is_floating_point(x: float) -> TypeGuard[float]: ...
@overload
def is_floating_point(
x: Any,
) -> TypeGuard[Union[FloatingTensor, np.ndarray, float]]: ...
[docs]
def is_floating_point(x: Any) -> TypeGuard[Union[FloatingTensor, np.ndarray, float]]:
"""Returns True if object is a/contains floating-point object(s)."""
if isinstance(x, Tensor):
return x.is_floating_point()
elif isinstance(x, (np.ndarray, np.generic)):
return numpy_is_floating_point(x)
else:
return isinstance(x, float)
@overload
def is_complex(x: Tensor) -> TypeGuard[ComplexFloatingTensor]: ...
@overload
def is_complex(x: np.ndarray) -> TypeGuard[np.ndarray]: ...
@overload
def is_complex(x: complex) -> TypeGuard[complex]: ...
@overload
def is_complex(
x: Any,
) -> TypeGuard[Union[ComplexFloatingTensor, np.ndarray, complex]]: ...
[docs]
def is_complex(x: Any) -> TypeGuard[Union[ComplexFloatingTensor, np.ndarray, complex]]:
"""Returns True if object is a/contains complex-valued object(s)."""
if isinstance(x, Tensor):
return x.is_complex()
elif isinstance(x, (np.ndarray, np.generic)):
return numpy_is_complex(x)
else:
return isinstance(x, complex)
[docs]
def is_sorted(
x: Union[Tensor, np.ndarray, Iterable],
*,
reverse: bool = False,
strict: bool = False,
) -> bool:
"""Returns True if the sequence is sorted."""
if isinstance(x, (Tensor, np.ndarray)):
if x.ndim != 1:
msg = f"Invalid number of dims in argument {x.ndim=}. (expected 1)"
raise ValueError(msg)
prev = slice(None, -1)
next_ = slice(1, None)
if not reverse and not strict:
result = x[prev] <= x[next_]
elif not reverse and strict:
result = x[prev] < x[next_]
elif reverse and not strict:
result = x[prev] >= x[next_]
else: # reverse and strict
result = x[prev] > x[next_]
result = result.all().item()
return result # type: ignore
elif isinstance(x, Iterable):
return builtin_is_sorted(x, reverse=reverse, strict=strict)
else:
raise TypeError(f"Invalid argument type {type(x)=}.")
@overload
def all_eq(
x: Union[Tensor, np.ndarray, ScalarLike, Iterable],
dim: None = None,
) -> bool: ...
@overload
def all_eq(
x: T_TensorOrArray,
dim: int,
) -> T_TensorOrArray: ...
[docs]
def all_eq(
x: Union[T_TensorOrArray, ScalarLike, Iterable],
dim: Union[int, None] = None,
) -> Union[bool, T_TensorOrArray]:
"""Check if all elements are equal in a tensor, ndarray, iterable or scalar object."""
if isinstance(x, Tensor):
if dim is None:
if x.ndim == 0 or x.nelement() == 0:
return True
x = x.reshape(-1)
return (x[0] == x[1:]).all().item() # type: ignore
else:
slices: List[Union[slice, int, None]] = [slice(None) for _ in range(x.ndim)]
slices[dim] = 0
slices.insert(dim + 1, None)
return (x == x[tuple(slices)]).all(dim) # type: ignore
elif isinstance(x, (np.ndarray, np.generic)):
return numpy_all_eq(x, dim=dim) # type: ignore
elif dim is not None:
raise ValueError(f"Invalid argument {dim=} with {type(x)=}.")
elif is_scalar_like(x):
return True
elif isinstance(x, Iterable):
return builtin_all_eq(x)
else:
raise TypeError(f"Invalid argument type {type(x)=}.")
[docs]
def all_ne(x: Union[Tensor, np.ndarray, ScalarLike, Iterable]) -> bool:
"""Check if all elements are NOT equal in a tensor, ndarray, iterable or scalar object."""
if isinstance(x, Tensor):
return len(torch.unique(x)) == x.nelement()
elif isinstance(x, (np.ndarray, np.generic)):
return numpy_all_ne(x)
elif is_scalar_like(x):
return True
elif isinstance(x, Iterable):
return builtin_all_ne(x)
else:
raise TypeError(f"Invalid argument type {type(x)=}.")
[docs]
def is_full(x: TensorOrArray, target: Any = ...) -> bool:
"""Check if all element are equal to target in a tensor or array. Accept an optional value 'target' to specified the expected value."""
if nelement(x) == 0 and target is not ...:
return False
if nelement(x) == 0 and target is ...:
return True
indices = tuple([0] * x.ndim)
first_elem = x[indices]
return (target is ... or first_elem == target) and all_eq(x)
[docs]
@function_alias(all_ne)
def is_unique(*args, **kwargs): ...