#!/usr/bin/env python
# -*- coding: utf-8 -*-
import itertools
import math
import struct
from typing import Callable, Union
import torch
from pythonwrench.checksum import (
_cached_checksum_str,
_checksum_iterable,
checksum_any, # noqa: F401
checksum_bytes,
checksum_dict,
checksum_float,
checksum_list_tuple,
checksum_str,
register_checksum_fn,
)
from pythonwrench.inspect import get_fullname
from torch import Tensor, nn
from torchwrench.core.packaging import _NUMPY_AVAILABLE, _PANDAS_AVAILABLE
from torchwrench.extras.numpy import np
from torchwrench.extras.pandas import pd
from torchwrench.nn.functional.predicate import is_complex, is_floating_point
[docs]
@register_checksum_fn(pd.DataFrame)
def checksum_dataframe(x: pd.DataFrame, **kwargs) -> int:
if not _PANDAS_AVAILABLE:
msg = "Cannot call function 'checksum_dataframe' because optional dependency 'pandas' is not installed. Please install it using 'pip install torchwrench[extras]'"
raise NotImplementedError(msg)
kwargs["accumulator"] = kwargs.get("accumulator", 0) + _cached_checksum_str(
get_fullname(x)
)
xdict = x.to_dict()
return checksum_dict(xdict, **kwargs) # type: ignore
[docs]
@register_checksum_fn(pd.Series)
def checksum_series(x: pd.Series, **kwargs) -> int:
if not _PANDAS_AVAILABLE:
msg = "Cannot call function 'checksum_series' because optional dependency 'pandas' is not installed. Please install it using 'pip install torchwrench[extras]'"
raise NotImplementedError(msg)
kwargs["accumulator"] = kwargs.get("accumulator", 0) + _cached_checksum_str(
get_fullname(x)
)
xlist = x.tolist()
return checksum_list_tuple(xlist, **kwargs) # type: ignore
[docs]
@register_checksum_fn((torch.dtype, np.dtype))
def checksum_dtype(x: Union[torch.dtype, np.dtype], **kwargs) -> int:
kwargs["accumulator"] = kwargs.get("accumulator", 0) + _cached_checksum_str(
get_fullname(x)
)
xstr = str(x)
return checksum_str(xstr, **kwargs)
[docs]
@register_checksum_fn(nn.Module)
def checksum_module(
x: nn.Module,
*,
only_trainable: bool = False,
with_names: bool = False,
buffers: bool = False,
training: bool = False,
**kwargs,
) -> int:
"""Compute a simple checksum over module parameters."""
training = x.training
x.train(training)
if with_names:
params_it = (
(n, p)
for n, p in x.named_parameters()
if not only_trainable or p.requires_grad # type: ignore
)
else:
params_it = (
param
for param in x.parameters()
if not only_trainable or param.requires_grad
)
if not buffers:
iterator = params_it
elif with_names:
buffers_it = (name_buffer for name_buffer in x.named_buffers())
iterator = itertools.chain(params_it, buffers_it)
else:
buffers_it = (buffer for buffer in x.buffers())
iterator = itertools.chain(params_it, buffers_it)
csum = _checksum_iterable(iterator, **kwargs)
x.train(training)
return csum
# Intermediate functions
[docs]
@torch.inference_mode()
@register_checksum_fn(Tensor)
def checksum_tensor(x: Tensor, **kwargs) -> int:
"""Compute a simple checksum of a tensor. Order of values matter for the checksum."""
return _checksum_tensor_array_like(
x,
nan_to_num_fn=torch.nan_to_num,
**kwargs,
)
[docs]
@torch.inference_mode()
@register_checksum_fn((np.ndarray, np.generic))
def checksum_numpy(x: Union[np.ndarray, np.generic], **kwargs) -> int:
"""Compute a simple checksum of a tensor. Order of values matter for the checksum."""
return _checksum_tensor_array_like(
x,
nan_to_num_fn=np.nan_to_num,
**kwargs,
)
# Private functions
def _checksum_tensor_array_like(
x: Union[Tensor, np.ndarray, np.generic],
*,
nan_to_num_fn: Callable,
**kwargs,
) -> int:
if is_floating_point(x) or is_complex(x):
nan_csum = checksum_float(math.nan, **kwargs)
neginf_csum = checksum_float(-math.inf, **kwargs)
posinf_csum = checksum_float(math.inf, **kwargs)
x = nan_to_num_fn(
x,
nan=nan_csum,
neginf=neginf_csum,
posinf=posinf_csum,
)
# Ensure that accumulator exists
kwargs["accumulator"] = kwargs.get("accumulator", 0)
kwargs["accumulator"] += checksum_dtype(x.dtype, **kwargs)
kwargs["accumulator"] += _checksum_iterable(x.shape, **kwargs)
kwargs["accumulator"] += _cached_checksum_str(get_fullname(x))
if isinstance(x, (np.ndarray, np.generic)):
xbytes = x.tobytes()
csum = checksum_bytes(xbytes, **kwargs)
elif isinstance(x, Tensor):
if _NUMPY_AVAILABLE:
xbytes = x.detach().cpu().numpy().tobytes()
else:
xbytes = _serialize_tensor_to_bytes(x)
csum = checksum_bytes(xbytes, **kwargs)
else:
msg = f"Invalid argument type {type(x)}. (expected ndarray or Tensor)"
raise TypeError(msg)
return csum
def _serialize_tensor_to_bytes(x: Tensor) -> bytes:
"""Convert tensor data to bytes, but very slow compare to numpy' tobytes() method."""
x = x.view(torch.int8).view(-1)
xbytes = struct.pack(f"{len(x)}b", *x)
return xbytes