Source code for torchwrench.serialization.load_fn

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import os
from pathlib import Path
from typing import (
    Any,
    BinaryIO,
    Callable,
    Dict,
    Optional,
    TextIO,
    TypeVar,
    Union,
    overload,
)

from pythonwrench.functools import function_alias
from pythonwrench.jsonl import load_jsonl
from pythonwrench.warnings import deprecated_alias
from typing_extensions import TypeAlias

from torchwrench.core.packaging import (
    _H5PY_AVAILABLE,
    _NUMPY_AVAILABLE,
    _SAFETENSORS_AVAILABLE,
    _TORCHAUDIO_AVAILABLE,
    _YAML_AVAILABLE,
)

from .common import SavingBackend, _fpath_to_saving_backend
from .csv import load_csv
from .json import load_json
from .pickle import load_pickle
from .torch import load_torch

T = TypeVar("T", covariant=True)
pylog = logging.getLogger(__name__)

LoadFn: TypeAlias = Callable[[Any], T]
LoadFnLike: TypeAlias = Union[LoadFn[T], SavingBackend]


LOAD_FNS: Dict[SavingBackend, LoadFn[Any]] = {
    "csv": load_csv,
    "json": load_json,
    "jsonl": load_jsonl,
    "pickle": load_pickle,
    "torch": load_torch,
}

if _H5PY_AVAILABLE:
    from .hdf import load_hdf

    LOAD_FNS["h5py"] = load_hdf


if _NUMPY_AVAILABLE:
    from .numpy import load_ndarray

    LOAD_FNS["numpy"] = load_ndarray


if _SAFETENSORS_AVAILABLE:
    from torchwrench.extras.safetensors import load_safetensors

    LOAD_FNS["safetensors"] = load_safetensors


if _TORCHAUDIO_AVAILABLE:
    from .torchaudio import load_with_torchaudio

    LOAD_FNS["torchaudio"] = load_with_torchaudio


if _YAML_AVAILABLE:
    from .yaml import load_yaml

    LOAD_FNS["yaml"] = load_yaml


@overload
def load_from(
    fpath: Union[TextIO, BinaryIO],
    *args,
    saving_backend: SavingBackend = "torch",
    **kwargs,
) -> Any: ...


@overload
def load_from(
    fpath: Union[str, Path, os.PathLike],
    *args,
    saving_backend: Optional[SavingBackend] = "torch",
    **kwargs,
) -> Any: ...


[docs] def load_from( fpath: Union[str, Path, os.PathLike, TextIO, BinaryIO], *args, saving_backend: Optional[SavingBackend] = "torch", **kwargs, ) -> Any: """Load from file using the correct backend.""" if isinstance(fpath, (str, os.PathLike)): fpath = Path(fpath) if not fpath.is_file(): msg = f"Invalid argument {fpath=}. (path is not a file)" raise FileNotFoundError(msg) if saving_backend is None: saving_backend = _fpath_to_saving_backend(fpath) if saving_backend not in LOAD_FNS: msg = f"Invalid argument {saving_backend=}. (expected one of {tuple(LOAD_FNS.keys())})" raise ValueError(msg) load_fn = LOAD_FNS[saving_backend] result = load_fn(fpath, *args, **kwargs) return result
[docs] @function_alias(load_from) def read_from(*args, **kwargs): ...
[docs] @deprecated_alias(load_from) def load(*args, **kwargs): ...
[docs] @deprecated_alias(load_from) def read(*args, **kwargs): ...