torchwrench.nn.functional.predicate module

torchwrench.nn.functional.predicate.all_eq(x: Tensor | ndarray | bool | int | float | complex | None | str | bytes | generic | Tensor0D | Iterable, dim: None = None) bool[source]
torchwrench.nn.functional.predicate.all_eq(x: T_TensorOrArray, dim: int) T_TensorOrArray

Check if all elements are equal in a tensor, ndarray, iterable or scalar object.

torchwrench.nn.functional.predicate.all_ne(x: Tensor | ndarray | bool | int | float | complex | None | str | bytes | generic | Tensor0D | Iterable) bool[source]

Check if all elements are NOT equal in a tensor, ndarray, iterable or scalar object.

torchwrench.nn.functional.predicate.is_complex(x: Tensor) TypeGuard[ComplexFloatingTensor][source]
torchwrench.nn.functional.predicate.is_complex(x: ndarray) TypeGuard[ndarray]
torchwrench.nn.functional.predicate.is_complex(x: complex) TypeGuard[complex]
torchwrench.nn.functional.predicate.is_complex(x: Any) TypeGuard[ComplexFloatingTensor | ndarray | complex]

Returns True if object is a/contains complex-valued object(s).

torchwrench.nn.functional.predicate.is_convertible_to_tensor(x: Any) bool[source]

Returns True if inputs can be passed to torch.as_tensor function.

This function returns False for heterogeneous inputs like [[], 1], but this kind of value can be accepted by torch.as_tensor.

torchwrench.nn.functional.predicate.is_floating_point(x: Tensor) TypeGuard[FloatingTensor][source]
torchwrench.nn.functional.predicate.is_floating_point(x: ndarray) TypeGuard[ndarray]
torchwrench.nn.functional.predicate.is_floating_point(x: float) TypeGuard[float]
torchwrench.nn.functional.predicate.is_floating_point(x: Any) TypeGuard[FloatingTensor | ndarray | float]

Returns True if object is a/contains floating-point object(s).

torchwrench.nn.functional.predicate.is_full(x: Tensor | ndarray, target: Any = Ellipsis) bool[source]

Check if all element are equal to target in a tensor or array. Accept an optional value ‘target’ to specified the expected value.

torchwrench.nn.functional.predicate.is_sorted(x: Tensor | ndarray | Iterable, *, reverse: bool = False, strict: bool = False) bool[source]

Returns True if the sequence is sorted.

torchwrench.nn.functional.predicate.is_stackable(tensors: list[Any] | tuple[Any, ...]) TypeGuard[list[Tensor] | tuple[Tensor, ...]][source]

Returns True if inputs can be passed to torch.stack function, i.e. contains a non-empty list or tuple of tensors with the same shape.

torchwrench.nn.functional.predicate.is_unique(x: Tensor | ndarray | bool | int | float | complex | None | str | bytes | generic | Tensor0D | Iterable) bool[source]

Check if all elements are NOT equal in a tensor, ndarray, iterable or scalar object.