Source code for torchwrench.nn.modules.activation
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from typing import Iterable, Union
from pythonwrench.collections import dump_dict
from torch import Tensor
from torchwrench.nn.functional.activation import log_softmax_multidim, softmax_multidim
from .module import Module
[docs]
class SoftmaxMultidim(Module):
"""
For more information, see :func:`~torchwrench.nn.functional.activation.softmax_multidim`.
"""
def __init__(
self,
dims: Union[Iterable[int], None] = (-1,),
) -> None:
super().__init__()
self.dims = dims
[docs]
def forward(
self,
input: Tensor,
) -> Tensor:
return softmax_multidim(input, dims=self.dims)
[docs]
class LogSoftmaxMultidim(Module):
"""
For more information, see :func:`~torchwrench.nn.functional.activation.softmax_multidim`.
"""
def __init__(
self,
dims: Union[Iterable[int], None] = (-1,),
) -> None:
super().__init__()
self.dims = dims
[docs]
def forward(
self,
input: Tensor,
) -> Tensor:
return log_softmax_multidim(input, dims=self.dims)