Source code for torchwrench.nn.modules.padding

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from typing import Iterable, Literal, Union

from pythonwrench.collections import dump_dict
from torch import Tensor
from torch.types import Number

from torchwrench.core.make import DeviceLike, DTypeLike, GeneratorLike
from torchwrench.nn.functional.padding import (
    PadAlign,
    PadMode,
    PadValue,
    pad_and_stack_rec,
    pad_dim,
    pad_dims,
)

from .module import Module


[docs] class PadDim(Module): """ For more information, see :func:`~torchwrench.nn.functional.pad.pad_dim`. """ def __init__( self, target_length: int, *, dim: int = -1, align: PadAlign = "left", pad_value: PadValue = 0.0, mode: PadMode = "constant", generator: GeneratorLike = None, ) -> None: super().__init__() self.target_length = target_length self.dim = dim self.align: PadAlign = align self.pad_value = pad_value self.mode: PadMode = mode self.generator: GeneratorLike = generator
[docs] def forward( self, x: Tensor, ) -> Tensor: return pad_dim( x, target_length=self.target_length, dim=self.dim, align=self.align, pad_value=self.pad_value, mode=self.mode, generator=self.generator, )
[docs] def extra_repr(self) -> str: return dump_dict( dict( target_length=self.target_length, dim=self.dim, align=self.align, pad_value=self.pad_value, mode=self.mode, ) )
[docs] class PadDims(Module): """ For more information, see :func:`~torchwrench.nn.functional.pad.pad_dims`. """ def __init__( self, target_lengths: Iterable[int], *, dims: Union[Iterable[int], None, Literal["auto"]] = None, aligns: Union[PadAlign, Iterable[PadAlign]] = "left", pad_value: PadValue = 0.0, mode: PadMode = "constant", generator: GeneratorLike = None, ) -> None: super().__init__() self.target_lengths = target_lengths self.dims: Union[Iterable[int], None, Literal["auto"]] = dims self.aligns: Union[PadAlign, Iterable[PadAlign]] = aligns self.pad_value = pad_value self.mode: PadMode = mode self.generator: GeneratorLike = generator
[docs] def forward( self, x: Tensor, ) -> Tensor: return pad_dims( x, target_lengths=self.target_lengths, dims=self.dims, aligns=self.aligns, pad_value=self.pad_value, mode=self.mode, generator=self.generator, )
[docs] def extra_repr(self) -> str: return dump_dict( dict( target_lengths=self.target_lengths, dims=self.dims, aligns=self.aligns, pad_value=self.pad_value, mode=self.mode, ) )
[docs] class PadAndStackRec(Module): """ For more information, see :func:`~torchwrench.nn.functional.pad.pad_and_stack_rec`. """ def __init__( self, pad_value: Number = 0, *, align: PadAlign = "left", device: DeviceLike = None, dtype: DTypeLike = None, ) -> None: super().__init__() self.pad_value = pad_value self.align: PadAlign = align self.device = device self.dtype = dtype
[docs] def forward( self, sequence: Union[Tensor, int, float, tuple, list], ) -> Tensor: return pad_and_stack_rec( sequence, self.pad_value, align=self.align, device=self.device, dtype=self.dtype, )
[docs] def extra_repr(self) -> str: return dump_dict( dict( pad_value=self.pad_value, align=self.align, device=self.device, dtype=self.dtype, ) )