Source code for torchwrench.extras.numpy.functional

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Any, Iterable, List, Literal, Tuple, Union, overload

import torch
from pythonwrench import (
    BuiltinScalar,
    function_alias,
    get_current_fn_name,
    is_builtin_scalar,
    prod,
    reduce_and,
    reduce_or,
)
from pythonwrench.semver import Version
from torch import Tensor
from typing_extensions import TypeGuard

from torchwrench.core.make import DeviceLike, DTypeLike, as_device, as_dtype
from torchwrench.extras.numpy.definitions import NumpyNumberLike, NumpyScalarLike, np


[docs] def to_ndarray( x: Union[Tensor, np.ndarray, Iterable, BuiltinScalar], *, dtype: Union[str, np.dtype, None] = None, force: bool = False, ) -> np.ndarray: """Convert input to numpy array. Works with any arbitrary object.""" if isinstance(x, Tensor): return tensor_to_ndarray(x, dtype=dtype, force=force) else: return np.array(x, dtype=dtype) # type: ignore
[docs] def tensor_to_ndarray( x: Tensor, *, dtype: Union[str, np.dtype, None] = None, force: bool = False, ) -> np.ndarray: """Convert PyTorch tensor to numpy array.""" if Version(str(torch.__version__)) >= Version("1.13.0"): kwargs = dict(force=force) elif not force: kwargs = dict() else: msg = f"Invalid argument {force=} for {get_current_fn_name()}. (expected True because torch version is below 1.13)" raise ValueError(msg) x_arr: np.ndarray = x.detach().cpu().numpy(**kwargs) if dtype is not None: # supports older numpy version x_arr = x_arr.astype(dtype=dtype) # type: ignore return x_arr
[docs] def ndarray_to_tensor( x: Union[np.ndarray, np.number], *, device: DeviceLike = None, dtype: DTypeLike = None, ) -> Tensor: """Convert numpy array to PyTorch tensor.""" device = as_device(device) dtype = as_dtype(dtype) return torch.from_numpy(x).to(dtype=dtype, device=device)
[docs] def numpy_view_as_real(x: np.ndarray) -> np.ndarray: """Convert complex array to float array. Args: x: The input complex array of any shape (...,) Returns: x_real: The same data in a float array of shape (..., 2) """ assert numpy_is_complex(x) float_dtype = numpy_complex_dtype_to_float_dtype(x.dtype) if x.ndim > 0: return x.view(float_dtype).reshape(*x.shape, 2) else: # note: rebuild array here because view does not work on 0d arrays return np.array([x.real, x.imag], dtype=float_dtype) # type: ignore
[docs] def numpy_complex_dtype_to_float_dtype(dtype: np.dtype) -> np.dtype: """Returns the associated float dtype from complex dtype. If input dtype is not complex, it just returns the same dtype.""" return np.empty((0,), dtype=dtype).real.dtype # type: ignore
[docs] def numpy_view_as_complex(x: np.ndarray) -> np.ndarray: """Convert complex array to float array. Args: x: The input float array of any shape (..., 2) Returns: x_real: The same data in a complex array of shape (...,) """ assert not numpy_is_complex(x) return x[..., 0] + x[..., 1] * 1j
[docs] def numpy_is_floating_point(x: Union[np.ndarray, np.generic]) -> bool: return x.dtype.kind == "f"
[docs] def numpy_is_complex(x: Union[np.ndarray, np.generic]) -> bool: return np.iscomplexobj(x)
[docs] def numpy_is_complex_dtype(dtype: np.dtype) -> bool: return np.iscomplexobj(np.empty((0,), dtype=dtype))
[docs] def is_numpy_bool_array(x: Any) -> TypeGuard[Union[np.bool_, np.ndarray]]: return isinstance(x, (np.generic, np.ndarray)) and x.dtype.kind == "b"
[docs] def is_numpy_str_array(x: Any) -> TypeGuard[Union[np.str_, np.ndarray]]: return isinstance(x, (np.generic, np.ndarray)) and x.dtype.kind in ("U", "S")
[docs] def is_numpy_integral_array(x: Any) -> TypeGuard[Union[np.ndarray, np.generic]]: return isinstance(x, (np.generic, np.ndarray)) and issubclass(x.dtype, np.integer)
[docs] def is_numpy_number_like(x: Any) -> TypeGuard[NumpyNumberLike]: """Returns True if x is an instance of a numpy number type, a np.bool_ or a zero-dimensional numpy array. If numpy is not installed, this function always returns False. """ return isinstance(x, (np.number, np.bool_)) or ( isinstance(x, np.ndarray) and x.ndim == 0 and np.issubdtype(x.dtype, np.number) )
[docs] def is_numpy_scalar_like(x: Any) -> TypeGuard[NumpyScalarLike]: """Returns True if x is an instance of a numpy number type or a zero-dimensional numpy array. If numpy is not installed, this function always returns False. """ return isinstance(x, np.generic) or (isinstance(x, np.ndarray) and x.ndim == 0)
[docs] def numpy_topk( x: np.ndarray, k: int, dim: int = -1, largest: bool = True, sorted: bool = True, ) -> Tuple[np.ndarray, np.ndarray]: x_pt = ndarray_to_tensor(x) values, indices = x_pt.topk(k=k, dim=dim, largest=largest, sorted=sorted) values = tensor_to_ndarray(values) indices = tensor_to_ndarray(indices) return values, indices
[docs] def numpy_item(x: Union[np.ndarray, np.generic, BuiltinScalar]) -> np.generic: if isinstance(x, np.generic): return x if is_builtin_scalar(x, strict=True): return np.array(x)[()] # type: ignore if prod(x.shape) != 1: msg = f"Invalid argument shape {x.shape=}. (expected nd-array with 1 element)" raise ValueError(msg) indices = tuple([0] * x.ndim) return x[indices] # type: ignore
@overload def numpy_all_eq( x: Union[np.generic, np.ndarray], dim: Literal[None] = None, ) -> bool: ... @overload def numpy_all_eq( x: Union[np.generic, np.ndarray], dim: int, ) -> np.ndarray: ...
[docs] def numpy_all_eq( x: Union[np.generic, np.ndarray], dim: Union[int, None] = None, ) -> Union[bool, np.ndarray]: if isinstance(x, np.generic): return True elif dim is None: if x.ndim == 0 or prod(x.shape) == 0: return True else: return (x.flat[0] == x.flat[1:]).all() else: indexer: List[Union[int, slice, None]] = [slice(None) for _ in range(x.ndim)] indexer[dim] = 0 indexer.insert(dim + 1, None) indexer_tuple = tuple(indexer) return (x == x[indexer_tuple]).all(dim)
[docs] def numpy_all_ne(x: Union[np.generic, np.ndarray]) -> bool: return len(np.unique(x)) == x.size
@function_alias(reduce_and) def logical_and_lst(*args, **kwargs): ... @function_alias(reduce_or) def logical_or_lst(*args, **kwargs): ...
[docs] @function_alias(to_ndarray) def to_numpy(*args, **kwargs): ...
[docs] @function_alias(tensor_to_ndarray) def tensor_to_numpy(*args, **kwargs): ...
[docs] @function_alias(ndarray_to_tensor) def numpy_to_tensor(*args, **kwargs): ...