Source code for torchwrench.extras.numpy.scan_info

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from dataclasses import dataclass
from typing import Any, Generic, Iterable, Tuple, TypeVar, Union

import pythonwrench as pw
import torch
from pythonwrench import BuiltinScalar, get_current_fn_name
from torch import Tensor

import torchwrench as tw
from torchwrench.extras.numpy.definitions import ACCEPTED_NUMPY_DTYPES, np

T_Invalid = TypeVar("T_Invalid", covariant=True)
T_EmptyNp = TypeVar("T_EmptyNp", covariant=True)
T_EmptyTorch = TypeVar("T_EmptyTorch", covariant=True)


[docs] class InvalidTorchDType(metaclass=pw.Singleton): """Default return type for torch_dtype when an invalid data is passed as argument of scan_torch_dtype function. (like str for example)""" def __repr__(self) -> str: return f"{self.__class__.__name__}()"
[docs] @dataclass(frozen=True) class ShapeDTypeInfo(Generic[T_Invalid, T_EmptyTorch, T_EmptyNp]): shape: Tuple[int, ...] torch_dtype: Union[torch.dtype, T_Invalid, T_EmptyTorch] numpy_dtype: Union[np.dtype, T_EmptyNp] valid_shape: bool @property def fill_value(self) -> BuiltinScalar: return numpy_dtype_to_fill_value(self.numpy_dtype) @property def get_ndim(self) -> int: return len(self.shape) @property def kind(self) -> str: if isinstance(self.numpy_dtype, np.dtype): return self.numpy_dtype.kind else: return "V"
[docs] def scan_shape_dtypes( x: Any, *, accept_heterogeneous_shape: bool = False, empty_torch: T_EmptyTorch = None, empty_np: T_EmptyNp = np.dtype("V"), ) -> ShapeDTypeInfo[InvalidTorchDType, T_EmptyTorch, T_EmptyNp]: """Returns the shape and the hdf_dtype for an input.""" valid_shape, shape = tw.get_shape(x, return_indicator=True) if not accept_heterogeneous_shape and not valid_shape: msg = f"Invalid argument {x} for {get_current_fn_name()}. (cannot compute shape for heterogeneous data)" raise ValueError(msg) torch_dtype = scan_torch_dtype(x, empty=empty_torch) numpy_dtype = scan_numpy_dtype(x, empty=empty_np) info = ShapeDTypeInfo[InvalidTorchDType, T_EmptyTorch, T_EmptyNp]( shape, torch_dtype, numpy_dtype, valid_shape, ) return info
[docs] def scan_torch_dtype( x: Any, *, invalid: T_Invalid = InvalidTorchDType(), empty: T_EmptyTorch = None, ) -> Union[torch.dtype, T_Invalid, T_EmptyTorch]: """Returns torch dtype of an arbitrary object. Works recursively on tuples and lists. An instance of InvalidTorchDType can be returned if a str is passed.""" if isinstance(x, (int, float, bool, complex)): torch_dtype = torch.as_tensor(x).dtype return torch_dtype if isinstance(x, Tensor): torch_dtype = x.dtype return torch_dtype if isinstance(x, (np.ndarray, np.generic)): torch_dtype = numpy_dtype_to_torch_dtype(x.dtype, invalid=invalid) return torch_dtype if isinstance(x, (str, bytes, bytearray)): return invalid if isinstance(x, (list, tuple)): if len(x) == 0: return empty torch_dtypes = [scan_torch_dtype(xi, invalid=invalid, empty=empty) for xi in x] torch_dtype = merge_torch_dtypes(torch_dtypes, invalid=invalid, empty=empty) return torch_dtype msg = f"Unsupported type {x.__class__.__name__} in function {pw.get_current_fn_name()}." raise TypeError(msg)
[docs] def scan_numpy_dtype( x: Any, *, empty: T_EmptyNp = np.dtype("V"), ) -> Union[np.dtype, T_EmptyNp]: if isinstance(x, (int, float, bool, complex)): numpy_dtype = np.array(x).dtype return numpy_dtype if isinstance(x, Tensor): numpy_dtype = torch_dtype_to_numpy_dtype(x.dtype) return numpy_dtype if isinstance(x, (np.ndarray, np.generic)): numpy_dtype = x.dtype return numpy_dtype if isinstance(x, (str, bytes, bytearray)): numpy_dtype = np.array(x).dtype return numpy_dtype if isinstance(x, (list, tuple)): if len(x) == 0: return empty numpy_dtypes = [scan_numpy_dtype(xi, empty=empty) for xi in x] numpy_dtype = merge_numpy_dtypes(numpy_dtypes, empty=empty) return numpy_dtype msg = f"Unsupported type {x.__class__.__name__} in function {pw.get_current_fn_name()}." raise TypeError(msg)
[docs] def torch_dtype_to_numpy_dtype(dtype: torch.dtype) -> np.dtype: x = torch.empty((0,), dtype=dtype) x = tw.tensor_to_ndarray(x) return x.dtype
[docs] def numpy_dtype_to_torch_dtype( dtype: np.dtype, *, invalid: T_Invalid = InvalidTorchDType(), ) -> Union[torch.dtype, T_Invalid]: if dtype in ACCEPTED_NUMPY_DTYPES: x = np.empty((0,), dtype=dtype) x = tw.ndarray_to_tensor(x) return x.dtype else: return invalid
[docs] def numpy_dtype_to_fill_value(dtype: Any) -> BuiltinScalar: if not isinstance(dtype, np.dtype): return None kind = dtype.kind if kind in ("b",): return False elif kind in ("u", "i"): return 0 elif kind in ("f",): return 0.0 elif kind in ("c",): return 0j elif kind in ("U", "S"): return "" else: KINDS = ("b", "u", "i", "f", "c", "U", "S") msg = f"Invalid argument {dtype=}. (expected dtype.kind in {KINDS})" raise ValueError(msg)
[docs] def merge_numpy_dtypes( dtypes: Iterable[Union[np.dtype, T_EmptyNp]], *, empty: T_EmptyNp = np.dtype("V"), ) -> Union[np.dtype, T_EmptyNp]: dtypes = list(dict.fromkeys(dtypes)) dtypes = [dtype for dtype in dtypes if dtype != empty] if len(dtypes) == 0: return empty dummy_arrays = [np.empty((0,), dtype=dtype) for dtype in dtypes] # type: ignore dtype = np.stack(dummy_arrays).dtype return dtype
[docs] def merge_torch_dtypes( dtypes: Iterable[Union[torch.dtype, T_Invalid, T_EmptyNp]], *, invalid: T_Invalid = InvalidTorchDType(), empty: T_EmptyNp = None, ) -> Union[torch.dtype, T_Invalid, T_EmptyNp]: dtypes = list(dict.fromkeys(dtypes)) dtypes = [dtype for dtype in dtypes if dtype != empty] if len(dtypes) == 0: return empty if any(dtype == invalid for dtype in dtypes): return invalid dummy_tensors = [torch.empty((0,), dtype=dtype) for dtype in dtypes] # type: ignore dtype = torch.stack(dummy_tensors).dtype return dtype
[docs] def get_default_numpy_dtype() -> np.dtype: return np.empty((0,)).dtype