Source code for torchwrench.nn.modules.numpy

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Union

from torch import Tensor

from torchwrench.core.make import DeviceLike, DTypeLike
from torchwrench.core.packaging import _NUMPY_AVAILABLE
from torchwrench.extras.numpy.definitions import np
from torchwrench.extras.numpy.functional import (
    ndarray_to_tensor,
    tensor_to_ndarray,
    to_ndarray,
)
from torchwrench.nn.modules.module import Module


[docs] class ToNDArray(Module): """ For more information, see :func:`~torchwrench.nn.functional.numpy.to_ndarray`. """ def __init__( self, *, dtype: Union[str, np.dtype, None] = None, force: bool = False, ) -> None: if not _NUMPY_AVAILABLE: msg = f"Cannot use {self.__class__.__name__} because numpy dependancy is not installed." raise RuntimeError(msg) super().__init__() self.dtype = dtype self.force = force
[docs] def forward(self, x: Union[Tensor, np.ndarray, list]) -> np.ndarray: return to_ndarray(x, dtype=self.dtype, force=self.force)
[docs] class TensorToNDArray(Module): """ For more information, see :func:`~torchwrench.nn.functional.numpy.tensor_to_ndarray`. """ def __init__( self, *, dtype: Union[str, np.dtype, None] = None, force: bool = False, ) -> None: if not _NUMPY_AVAILABLE: msg = f"Cannot use {self.__class__.__name__} because numpy dependancy is not installed." raise RuntimeError(msg) super().__init__() self.dtype = dtype self.force = force
[docs] def forward(self, x: Tensor) -> np.ndarray: return tensor_to_ndarray(x, dtype=self.dtype, force=self.force)
[docs] class NDArrayToTensor(Module): """ For more information, see :func:`~torchwrench.nn.functional.numpy.ndarray_to_tensor`. """ def __init__( self, *, device: DeviceLike = None, dtype: DTypeLike = None, ) -> None: if not _NUMPY_AVAILABLE: msg = f"Cannot use {self.__class__.__name__} because numpy dependancy is not installed." raise RuntimeError(msg) super().__init__() self.device = device self.dtype = dtype
[docs] def forward(self, x: np.ndarray) -> Tensor: return ndarray_to_tensor(x, dtype=self.dtype, device=self.device)