Source code for torchwrench.nn.modules._mixins

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import copy
import logging
from typing import (
    Any,
    Callable,
    ClassVar,
    Dict,
    Generic,
    Iterator,
    List,
    Literal,
    Mapping,
    MutableMapping,
    Optional,
    OrderedDict,
    Protocol,
    Tuple,
    TypeVar,
    Union,
    overload,
    runtime_checkable,
)

import torch
from pythonwrench.collections import dump_dict
from pythonwrench.re import match_patterns
from pythonwrench.typing import NoneType, isinstance_generic
from torch import Tensor, nn
from torch.nn.parameter import Parameter
from typing_extensions import TypeAlias

from torchwrench.nn.functional.checksum import checksum_module
from torchwrench.nn.functional.others import count_parameters

T = TypeVar("T", covariant=True)
InType = TypeVar("InType", covariant=False, contravariant=True)
OutType = TypeVar("OutType", covariant=True, contravariant=False)
OutType2 = TypeVar("OutType2", covariant=True, contravariant=False)
OutType3 = TypeVar("OutType3", covariant=False, contravariant=False)
T_MutableMappingStr = TypeVar("T_MutableMappingStr", bound=MutableMapping[str, Any])

DeviceDetectMode = Literal["proxy", "first_param", "none"]
DEVICE_DETECT_MODES = ("proxy", "first_param", "none")
_DEFAULT_DEVICE_DETECT_MODE = "first_param"


pylog = logging.getLogger(__name__)


@runtime_checkable
class SupportsTypedForward(Protocol[InType, OutType]):
    def __call__(self, *args, **kwargs): ...

    def forward(self, x: InType, /) -> OutType: ...


TypedModuleLike: TypeAlias = Union[
    SupportsTypedForward[InType, OutType],
    "TypedModule[InType, OutType]",
]


class ProxyDeviceModule(nn.Module):
    def __init__(
        self,
        *,
        device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE,
    ) -> None:
        if device_detect_mode not in DEVICE_DETECT_MODES:
            msg = f"Invalid argument {device_detect_mode=}. (expected one of {DEVICE_DETECT_MODES})"
            raise ValueError(msg)

        super().__init__()
        self.__device_detect_mode: DeviceDetectMode = device_detect_mode
        if device_detect_mode == "proxy":
            self.register_buffer("__proxy", torch.empty((0,)), persistent=False)

    @property
    def device_detect_mode(self) -> DeviceDetectMode:
        return self.__device_detect_mode

    def get_device(self) -> Optional[torch.device]:
        """Returns the Module device according to device_detect_mode property."""
        if self.__device_detect_mode == "proxy":
            return self._buffers["__proxy"].device  # type: ignore
        elif self.__device_detect_mode == "first_param":
            try:
                device0 = next(self._get_devices_iterator(params=True, buffers=True))
                return device0
            except StopIteration:
                return None
        else:
            return None

    def get_devices(
        self,
        *,
        params: bool = True,
        buffers: bool = True,
        recurse: bool = True,
        output_type: Callable[[Iterator[torch.device]], T] = list,
    ) -> T:
        return output_type(
            self._get_devices_iterator(
                params=params,
                buffers=buffers,
                recurse=recurse,
            )
        )

    def _get_devices_iterator(
        self,
        *,
        params: bool = True,
        buffers: bool = True,
        recurse: bool = True,
    ) -> Iterator[torch.device]:
        """Returns an iterator over all unique devices in module."""
        its: List[Iterator[Union[Tensor, Parameter]]] = []
        if params:
            its.append(self.parameters(recurse=recurse))
        if buffers:
            its.append(self.buffers(recurse=recurse))

        devices = {}
        for it in its:
            for param_or_buffer in it:
                device = param_or_buffer.device
                if device not in devices:
                    yield device
                devices[param_or_buffer.device] = None


class ConfigModule(Generic[T_MutableMappingStr], nn.Module):
    _CONFIG_TYPES: ClassVar[Tuple[type, ...]] = (int, str, bool, float, NoneType)
    _CONFIG_EXCLUDE: ClassVar[Tuple[str, ...]] = ("^_.*",) + tuple(
        f".*{k}$" for k in nn.Module().__dict__.keys()
    )
    _DEFAULT_EXTRA_REPR: ClassVar[bool] = False

    def __init__(
        self,
        *,
        strict_load: bool = False,
        config_to_extra_repr: Optional[bool] = None,
        config: Optional[T_MutableMappingStr] = None,
    ) -> None:
        if config is None:
            config = {}  # type: ignore
        if config_to_extra_repr is None:
            config_to_extra_repr = self._DEFAULT_EXTRA_REPR

        attrs = {
            "config": config,
            "strict_load": strict_load,
            "config_to_extra_repr": config_to_extra_repr,
        }
        for name, value in attrs.items():
            object.__setattr__(self, f"_{ConfigModule.__name__}__{name}", value)

        super().__init__()
        self.__config: T_MutableMappingStr
        self.__strict_load: bool
        self.__config_to_extra_repr: bool

    @property
    def config(self) -> T_MutableMappingStr:
        return self.__config

    def __setattr__(self, name: str, value: Any) -> None:
        self.__update_config(name, value)
        return super().__setattr__(name, value)

    def __delattr__(self, name) -> None:
        self.__config.pop(name, None)
        return super().__delattr__(name)

    def extra_repr(self) -> str:
        if not self.__config_to_extra_repr:
            return super().extra_repr()
        else:
            return dump_dict(self.config)

    def add_module(self, name: str, module: Union[nn.Module, None]) -> None:
        self.__update_config(name, module)
        return super().add_module(name, module)

    def get_extra_state(self) -> Any:
        state = {"config": self.__config}
        return state

    def set_extra_state(self, state: Any) -> Any:
        # return type is Any because parent class typed NoReturn in some versions, which is incompatible with None and the usage of returns code.
        if not self.__strict_load:
            return None

        in_config = state["config"]
        if self.config == in_config:
            return None

        if isinstance_generic(in_config, Dict[str, Any]) and isinstance_generic(
            self.config, Dict[str, Any]
        ):
            MISSING = "<missing>"
            union = set(in_config.keys()).union(self.config.keys())
            msgs = []
            for key in union:
                v1 = in_config.get(key, MISSING)
                v2 = in_config.get(key, MISSING)
                if v1 != v2:
                    msgs.append(f"{v1} != {v2}")
            msg = (
                "Invalid loaded config with current one. Invalid keys are:\n"
                + "\n\t".join(msgs)
            )
        else:
            msg = f"Invalid loaded config {in_config} with current one {self.config}."

        raise ValueError(msg)

    def __update_config(self, name: str, value: Any) -> None:
        subconfig = self.__class__._detect_subconfig(name, value)
        self.__config.update(subconfig)

    @classmethod
    def _detect_subconfig(cls, name: str, value: Any) -> Dict[str, Any]:
        prefix = f"{name}."
        if cls._is_config_name_value(name, value):
            subconfig = {name: value}
            prefix = ""
        elif isinstance(value, ConfigModule):
            subconfig = value.config
        elif hasattr(value, "_hparams") and isinstance_generic(
            value._hparams, Mapping[str, Any]
        ):
            subconfig = dict(value._hparams.items())  # type: ignore
        elif isinstance(value, torch.nn.Module):
            subconfig = cls._detect_torch_module_subconfig(value)
        elif hasattr(value, "__dict__"):
            subconfig = value.__dict__
        else:
            subconfig = {}

        subconfig = {f"{prefix}{k}": v for k, v in subconfig.items()}
        subconfig = {
            k: v for k, v in subconfig.items() if cls._is_config_name_value(k, v)
        }
        return subconfig

    @classmethod
    def _detect_torch_module_subconfig(cls, value: torch.nn.Module) -> Dict[str, Any]:
        subconfig = {
            k: v
            for k, v in value.__dict__.items()
            if k != "_modules" and match_patterns(k, exclude=cls._CONFIG_EXCLUDE)
        }
        subconfig.update(
            {
                kv: vv
                for k, v in value.__dict__["_modules"].items()
                for kv, vv in cls._detect_subconfig(k, v).items()
            }
        )
        return subconfig

    @classmethod
    def _is_config_name_value(cls, name: str, value: Any) -> bool:
        if not match_patterns(name, exclude=cls._CONFIG_EXCLUDE):
            return False
        else:
            return cls._is_config_value(value)

    @classmethod
    def _is_config_value(cls, value) -> bool:
        if isinstance(value, cls._CONFIG_TYPES):
            return True
        elif isinstance(value, (list, tuple, set, frozenset)):
            return all(cls._is_config_value(vi) for vi in value)
        elif isinstance(value, dict):
            return all(
                cls._is_config_value(k) and cls._is_config_value(v)
                for k, v in value.items()
            )
        else:
            return False


class TypedModule(Generic[InType, OutType], nn.Module):
    """Typed version of torch.nn.Module. Can specify an input and output type."""

    def __call__(self, *args: InType, **kwargs: InType) -> OutType:
        return super().__call__(*args, **kwargs)


class TypedSequential(
    Generic[InType, OutType],
    TypedModule[InType, OutType],
    nn.Sequential,
):
    def __init__(
        self,
        *args,
        unpack_tuple: bool = False,
        unpack_dict: bool = False,
    ) -> None:
        TypedModule.__init__(self)
        nn.Sequential.__init__(self, *args)

        self.__unpack_tuple = unpack_tuple
        self.__unpack_dict = unpack_dict

    @property
    def unpack_tuple(self) -> bool:
        return self.__unpack_tuple

    @property
    def unpack_dict(self) -> bool:
        return self.__unpack_dict

    def __call__(self, x: InType) -> OutType:  # type: ignore
        return nn.Sequential.__call__(self, x)

    def forward(self, x: InType) -> OutType:  # type: ignore
        for module in self:
            if self.__unpack_tuple and isinstance(x, tuple):
                x = module(*x)
            elif self.__unpack_dict and isinstance_generic(x, Dict[str, Any]):
                x = module(**x)
            else:
                x = module(x)
        return x  # type: ignore

    def tolist(self) -> List[nn.Module]:
        return list(self._modules.values())

    def todict(self) -> Dict[str, nn.Module]:
        return copy.copy(self._modules)


[docs] class EModule( Generic[InType, OutType], ConfigModule, TypedModule[InType, OutType], ProxyDeviceModule, ): """Enriched torch.nn.Module with proxy device, forward typing and automatic configuration detection from attributes. The default behaviour is the same than PyTorch Module class. """ def __init__( self, *, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: """ Args: strict_load: If True, Module config will be compared during load_state_dict(...) method call and raises a ValueError. defaults to False. config_to_extra_repr: If True, add config to extra repr. defaults to False. device_detect_mode: Enable automatic detection of the module device. defaults to "first_param". """ # ConfigModule must be first ConfigModule.__init__( self, strict_load=strict_load, config_to_extra_repr=config_to_extra_repr, ) TypedModule.__init__(self) ProxyDeviceModule.__init__( self, device_detect_mode=device_detect_mode, )
[docs] def count_parameters( self, *, recurse: bool = True, only_trainable: bool = False, buffers: bool = False, ) -> int: """Returns the number of parameters in this module.""" return count_parameters( self, recurse=recurse, only_trainable=only_trainable, buffers=buffers, )
[docs] def checksum( self, *, only_trainable: bool = False, with_names: bool = False, buffers: bool = False, training: bool = False, ) -> int: return checksum_module( self, only_trainable=only_trainable, with_names=with_names, buffers=buffers, training=training, )
@overload def chain( self, *others: TypedModuleLike[Any, OutType], ) -> "ESequential[InType, OutType]": ... @overload def chain(self, *others: nn.Module) -> "ESequential[InType, Any]": ...
[docs] def chain(self, *others): return ESequential(self, *others)
def __or__( self, other: TypedModuleLike[Any, OutType], ) -> "ESequential[InType, OutType]": return self.chain(other) def __ror__( self, other: TypedModuleLike[InType, Any], ) -> "ESequential[InType, OutType]": return ESequential(other, self)
[docs] class ESequential( Generic[InType, OutType], EModule[InType, OutType], TypedSequential[InType, OutType], ): """Enriched torch.nn.Sequential with proxy device, forward typing and automatic configuration detection from attributes. Designed to work with `torchwrench.nn.EModule` instances. The default behaviour is the same than PyTorch Sequential class. """ @overload def __init__( self, /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg0: TypedModuleLike[InType, OutType], /, *, unpack_tuple: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, unpack_dict: bool = False, ) -> None: ... @overload def __init__( self, arg0: TypedModuleLike[InType, Any], arg1: TypedModuleLike[Any, OutType], /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg0: TypedModuleLike[InType, Any], arg1: TypedModuleLike[Any, Any], arg2: TypedModuleLike[Any, OutType], /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg0: TypedModuleLike[InType, Any], arg1: TypedModuleLike[Any, Any], arg2: TypedModuleLike[Any, Any], arg3: TypedModuleLike[Any, OutType], /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg0: TypedModuleLike[InType, Any], arg1: TypedModuleLike[Any, Any], arg2: TypedModuleLike[Any, Any], arg3: TypedModuleLike[Any, Any], arg4: TypedModuleLike[Any, OutType], /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg0: TypedModuleLike[InType, Any], arg1: TypedModuleLike[Any, Any], arg2: TypedModuleLike[Any, Any], arg3: TypedModuleLike[Any, Any], arg4: TypedModuleLike[Any, Any], arg5: TypedModuleLike[Any, OutType], /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg0: TypedModuleLike[InType, Any], arg1: TypedModuleLike[Any, Any], arg2: TypedModuleLike[Any, Any], arg3: TypedModuleLike[Any, Any], arg4: TypedModuleLike[Any, Any], arg5: TypedModuleLike[Any, Any], arg6: TypedModuleLike[Any, OutType], /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg0: TypedModuleLike[InType, Any], arg1: TypedModuleLike[Any, Any], arg2: TypedModuleLike[Any, Any], arg3: TypedModuleLike[Any, Any], arg4: TypedModuleLike[Any, Any], arg5: TypedModuleLike[Any, Any], arg6: TypedModuleLike[Any, Any], arg7: TypedModuleLike[Any, OutType], /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg0: TypedModuleLike[InType, Any], arg1: TypedModuleLike[Any, Any], arg2: TypedModuleLike[Any, Any], arg3: TypedModuleLike[Any, Any], arg4: TypedModuleLike[Any, Any], arg5: TypedModuleLike[Any, Any], arg6: TypedModuleLike[Any, Any], arg7: TypedModuleLike[Any, Any], arg8: TypedModuleLike[Any, OutType], /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg0: TypedModuleLike[InType, Any], arg1: TypedModuleLike[Any, Any], arg2: TypedModuleLike[Any, Any], arg3: TypedModuleLike[Any, Any], arg4: TypedModuleLike[Any, Any], arg5: TypedModuleLike[Any, Any], arg6: TypedModuleLike[Any, Any], arg7: TypedModuleLike[Any, Any], arg8: TypedModuleLike[Any, Any], arg9: TypedModuleLike[Any, OutType], /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg: "OrderedDict[str, TypedModuleLike[InType, OutType]]", /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, arg: "OrderedDict[str, nn.Module]", /, *, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... @overload def __init__( self, *args: nn.Module, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: ... def __init__( self, *args, unpack_tuple: bool = False, unpack_dict: bool = False, strict_load: bool = False, config_to_extra_repr: bool = False, device_detect_mode: DeviceDetectMode = _DEFAULT_DEVICE_DETECT_MODE, ) -> None: """ Args: unpack_tuple: If True, the outputs of a module that returns a tuple at position i will be unpacked for positional arguments for the next module at position i+1. defaults to False. unpack_tuple: If True, the outputs of a module that returns a dict at position i will be unpacked for keywords arguments for the next module at position i+1. defaults to False. strict_load: If True, Module config will be compared during load_state_dict(...) method call and raises a ValueError. defaults to False. config_to_extra_repr: If True, add config to extra repr. defaults to False. device_detect_mode: Enable automatic detection of the module device. defaults to "first_param". """ EModule.__init__( self, strict_load=strict_load, config_to_extra_repr=config_to_extra_repr, device_detect_mode=device_detect_mode, ) TypedSequential.__init__( self, *args, unpack_tuple=unpack_tuple, unpack_dict=unpack_dict, )