Source code for torchwrench.serialization.torch

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import inspect
import io
import os
import pickle
from io import BufferedWriter
from pathlib import Path
from typing import IO, Any, BinaryIO, Callable, Dict, Optional, Union

import torch
from pythonwrench._core import _setup_output_fpath
from pythonwrench.semver import Version
from torch import (  # noqa: F401
    load,
    save,
)
from torch.serialization import DEFAULT_PROTOCOL
from torch.types import Storage
from typing_extensions import TypeAlias

FileLike: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
MapLocationLike: TypeAlias = Optional[
    Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]
]


[docs] def dump_torch( obj: object, f: Optional[FileLike] = None, pickle_module: Any = pickle, pickle_protocol: int = DEFAULT_PROTOCOL, _use_new_zipfile_serialization: bool = True, _disable_byteorder_record: bool = False, *, overwrite: bool = True, make_parents: bool = True, ) -> bytes: if "_disable_byteorder_record" in inspect.getargs(torch.save.__code__).args: kwds = dict(_disable_byteorder_record=_disable_byteorder_record) else: kwds = {} buffer = io.BytesIO() torch.save( obj, buffer, pickle_module, pickle_protocol, _use_new_zipfile_serialization, **kwds, ) content = buffer.getvalue() buffer.close() if isinstance(f, (str, Path, os.PathLike)) or f is None: f = _setup_output_fpath(f, overwrite, make_parents) if isinstance(f, Path): f.write_bytes(content) elif isinstance(f, (BinaryIO, BufferedWriter)): f.write(content) f.flush() return content
[docs] def load_torch( f: FileLike, map_location: MapLocationLike = None, pickle_module: Any = None, *, weights_only: bool = ..., mmap: Optional[bool] = None, **pickle_load_args: Any, ) -> Any: kwds = {} if Version(torch.__version__) < Version("2.1.0"): pickle_module = pickle else: if weights_only is ...: weights_only = Version(torch.__version__) >= "2.6.0" kwds.update( weights_only=weights_only, mmap=mmap, ) return torch.load( f, map_location, pickle_module, **kwds, **pickle_load_args, )