Source code for torchwrench.serialization.dump_fn
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from pathlib import Path
from typing import Any, BinaryIO, Callable, Dict, Optional, Union
from pythonwrench.functools import function_alias
from pythonwrench.jsonl import dump_jsonl
from pythonwrench.warnings import deprecated_alias
from typing_extensions import TypeAlias
from torchwrench.core.packaging import (
_H5PY_AVAILABLE,
_NUMPY_AVAILABLE,
_SAFETENSORS_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
_YAML_AVAILABLE,
)
from .common import SavingBackend, _fpath_to_saving_backend
from .csv import dump_csv
from .json import dump_json
from .pickle import dump_pickle
from .torch import dump_torch
DumpFn: TypeAlias = Callable[..., Any]
DumpFnLike: TypeAlias = Union[DumpFn, SavingBackend]
DUMP_FNS: Dict[SavingBackend, DumpFn] = {
"csv": dump_csv,
"json": dump_json,
"jsonl": dump_jsonl,
"pickle": dump_pickle,
"torch": dump_torch,
}
if _H5PY_AVAILABLE:
from .hdf import dump_hdf
DUMP_FNS["h5py"] = dump_hdf
if _NUMPY_AVAILABLE:
from .numpy import dump_ndarray
DUMP_FNS["numpy"] = dump_ndarray
if _SAFETENSORS_AVAILABLE:
from torchwrench.extras.safetensors import dump_safetensors
DUMP_FNS["safetensors"] = dump_safetensors
if _TORCHAUDIO_AVAILABLE:
from .torchaudio import dump_with_torchaudio
DUMP_FNS["torchaudio"] = dump_with_torchaudio
if _YAML_AVAILABLE:
from .yaml import dump_yaml
DUMP_FNS["yaml"] = dump_yaml
[docs]
def dump_to(
obj: Any,
fpath: Union[str, Path, os.PathLike, None, BinaryIO] = None,
*args,
saving_backend: Optional[SavingBackend] = "torch",
**kwargs,
) -> Union[str, bytes]:
"""Save to file using the correct backend."""
if isinstance(fpath, (str, os.PathLike)):
fpath = Path(fpath)
if saving_backend is None:
if not isinstance(fpath, (str, Path, os.PathLike)):
msg = f"Invalid combinaison of arguments {fpath=} and {saving_backend=}."
raise ValueError(msg)
saving_backend = _fpath_to_saving_backend(fpath)
elif saving_backend not in DUMP_FNS:
msg = f"Invalid argument {saving_backend=}. (expected one of {tuple(DUMP_FNS.keys())})"
raise ValueError(msg)
dump_fn = DUMP_FNS[saving_backend]
result = dump_fn(obj, fpath, *args, **kwargs)
return result
[docs]
@function_alias(dump_to)
def save_to(*args, **kwargs): ...
[docs]
@deprecated_alias(dump_to)
def dump(*args, **kwargs): ...
[docs]
@deprecated_alias(dump_to)
def save(*args, **kwargs): ...