Source code for torchwrench.serialization.dump_fn

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
from pathlib import Path
from typing import Any, BinaryIO, Callable, Dict, Optional, Union

from pythonwrench.functools import function_alias
from pythonwrench.jsonl import dump_jsonl
from pythonwrench.warnings import deprecated_alias
from typing_extensions import TypeAlias

from torchwrench.core.packaging import (
    _H5PY_AVAILABLE,
    _NUMPY_AVAILABLE,
    _SAFETENSORS_AVAILABLE,
    _TORCHAUDIO_AVAILABLE,
    _YAML_AVAILABLE,
)

from .common import SavingBackend, _fpath_to_saving_backend
from .csv import dump_csv
from .json import dump_json
from .pickle import dump_pickle
from .torch import dump_torch

DumpFn: TypeAlias = Callable[..., Any]
DumpFnLike: TypeAlias = Union[DumpFn, SavingBackend]


DUMP_FNS: Dict[SavingBackend, DumpFn] = {
    "csv": dump_csv,
    "json": dump_json,
    "jsonl": dump_jsonl,
    "pickle": dump_pickle,
    "torch": dump_torch,
}

if _H5PY_AVAILABLE:
    from .hdf import dump_hdf

    DUMP_FNS["h5py"] = dump_hdf

if _NUMPY_AVAILABLE:
    from .numpy import dump_ndarray

    DUMP_FNS["numpy"] = dump_ndarray

if _SAFETENSORS_AVAILABLE:
    from torchwrench.extras.safetensors import dump_safetensors

    DUMP_FNS["safetensors"] = dump_safetensors

if _TORCHAUDIO_AVAILABLE:
    from .torchaudio import dump_with_torchaudio

    DUMP_FNS["torchaudio"] = dump_with_torchaudio

if _YAML_AVAILABLE:
    from .yaml import dump_yaml

    DUMP_FNS["yaml"] = dump_yaml


[docs] def dump_to( obj: Any, fpath: Union[str, Path, os.PathLike, None, BinaryIO] = None, *args, saving_backend: Optional[SavingBackend] = "torch", **kwargs, ) -> Union[str, bytes]: """Save to file using the correct backend.""" if isinstance(fpath, (str, os.PathLike)): fpath = Path(fpath) if saving_backend is None: if not isinstance(fpath, (str, Path, os.PathLike)): msg = f"Invalid combinaison of arguments {fpath=} and {saving_backend=}." raise ValueError(msg) saving_backend = _fpath_to_saving_backend(fpath) elif saving_backend not in DUMP_FNS: msg = f"Invalid argument {saving_backend=}. (expected one of {tuple(DUMP_FNS.keys())})" raise ValueError(msg) dump_fn = DUMP_FNS[saving_backend] result = dump_fn(obj, fpath, *args, **kwargs) return result
[docs] @function_alias(dump_to) def save_to(*args, **kwargs): ...
[docs] @deprecated_alias(dump_to) def dump(*args, **kwargs): ...
[docs] @deprecated_alias(dump_to) def save(*args, **kwargs): ...