#!/usr/bin/env python
# -*- coding: utf-8 -*-
import inspect
import io
import os
import pickle
from io import BufferedWriter
from pathlib import Path
from typing import IO, Any, BinaryIO, Callable, Dict, Optional, Union
import torch
from pythonwrench._core import _setup_output_fpath
from pythonwrench.semver import Version
from torch import ( # noqa: F401
load,
save,
)
from torch.serialization import DEFAULT_PROTOCOL
from torch.types import Storage
from typing_extensions import TypeAlias
FileLike: TypeAlias = Union[str, os.PathLike, BinaryIO, IO[bytes]]
MapLocationLike: TypeAlias = Optional[
Union[Callable[[Storage, str], Storage], torch.device, str, Dict[str, str]]
]
[docs]
def dump_torch(
obj: object,
f: Optional[FileLike] = None,
pickle_module: Any = pickle,
pickle_protocol: int = DEFAULT_PROTOCOL,
_use_new_zipfile_serialization: bool = True,
_disable_byteorder_record: bool = False,
*,
overwrite: bool = True,
make_parents: bool = True,
) -> bytes:
if "_disable_byteorder_record" in inspect.getargs(torch.save.__code__).args:
kwds = dict(_disable_byteorder_record=_disable_byteorder_record)
else:
kwds = {}
buffer = io.BytesIO()
torch.save(
obj,
buffer,
pickle_module,
pickle_protocol,
_use_new_zipfile_serialization,
**kwds,
)
content = buffer.getvalue()
buffer.close()
if isinstance(f, (str, Path, os.PathLike)) or f is None:
f = _setup_output_fpath(f, overwrite, make_parents)
if isinstance(f, Path):
f.write_bytes(content)
elif isinstance(f, (BinaryIO, BufferedWriter)):
f.write(content)
f.flush()
return content
[docs]
def load_torch(
f: FileLike,
map_location: MapLocationLike = None,
pickle_module: Any = None,
*,
weights_only: bool = ...,
mmap: Optional[bool] = None,
**pickle_load_args: Any,
) -> Any:
kwds = {}
if Version(torch.__version__) < Version("2.1.0"):
pickle_module = pickle
else:
if weights_only is ...:
weights_only = Version(torch.__version__) >= "2.6.0"
kwds.update(
weights_only=weights_only,
mmap=mmap,
)
return torch.load(
f,
map_location,
pickle_module,
**kwds,
**pickle_load_args,
)