Source code for torchwrench.types.guards
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Any
import torch
from pythonwrench.typing.checks import ( # noqa: F401
is_builtin_collection,
is_builtin_number,
is_builtin_obj,
is_builtin_scalar,
is_dataclass_instance,
is_namedtuple_instance,
is_typed_dict,
isinstance_generic,
)
from torch import Tensor
from typing_extensions import TypeGuard, TypeIs
from torchwrench.core.make import DTypeLike, as_dtype
from torchwrench.extras.numpy import is_numpy_number_like, is_numpy_scalar_like, np
from ._typing import (
IntegralTensor,
NumberLike,
ScalarLike,
Tensor0D,
TensorOrArray,
)
[docs]
def is_number_like(x: Any) -> TypeGuard[NumberLike]:
"""Returns True if input is a scalar number.
Accepted numbers-like objects are:
- Python numbers (int, float, bool, complex)
- Numpy zero-dimensional arrays
- Numpy numbers
- PyTorch zero-dimensional tensors
"""
return is_builtin_number(x) or is_numpy_number_like(x) or isinstance(x, Tensor0D)
[docs]
def is_scalar_like(x: Any) -> TypeGuard[ScalarLike]:
"""Returns True if input is a scalar number.
Accepted scalar-like objects are:
- Python scalars like (int, float, bool, complex, None, str, bytes)
- Numpy zero-dimensional arrays
- Numpy generic
- PyTorch zero-dimensional tensors
"""
return is_builtin_scalar(x) or is_numpy_scalar_like(x) or isinstance(x, Tensor0D)
[docs]
def is_tensor_or_array(x: Any) -> TypeIs[TensorOrArray]:
return isinstance(x, (Tensor, np.ndarray))
[docs]
def is_integral_dtype(dtype: DTypeLike) -> bool:
dtype = as_dtype(dtype)
return isinstance(torch.empty((0,), dtype=dtype), IntegralTensor)