#!/usr/bin/env python
# -*- coding: utf-8 -*-
import logging
import os
from pathlib import Path
from typing import (
Any,
BinaryIO,
Callable,
Dict,
Optional,
TextIO,
TypeVar,
Union,
overload,
)
from pythonwrench.functools import function_alias
from pythonwrench.jsonl import load_jsonl
from pythonwrench.warnings import deprecated_alias
from typing_extensions import TypeAlias
from torchwrench.core.packaging import (
_H5PY_AVAILABLE,
_NUMPY_AVAILABLE,
_SAFETENSORS_AVAILABLE,
_TORCHAUDIO_AVAILABLE,
_YAML_AVAILABLE,
)
from .common import SavingBackend, _fpath_to_saving_backend
from .csv import load_csv
from .json import load_json
from .pickle import load_pickle
from .torch import load_torch
T = TypeVar("T", covariant=True)
pylog = logging.getLogger(__name__)
LoadFn: TypeAlias = Callable[[Any], T]
LoadFnLike: TypeAlias = Union[LoadFn[T], SavingBackend]
LOAD_FNS: Dict[SavingBackend, LoadFn[Any]] = {
"csv": load_csv,
"json": load_json,
"jsonl": load_jsonl,
"pickle": load_pickle,
"torch": load_torch,
}
if _H5PY_AVAILABLE:
from .hdf import load_hdf
LOAD_FNS["h5py"] = load_hdf
if _NUMPY_AVAILABLE:
from .numpy import load_ndarray
LOAD_FNS["numpy"] = load_ndarray
if _SAFETENSORS_AVAILABLE:
from torchwrench.extras.safetensors import load_safetensors
LOAD_FNS["safetensors"] = load_safetensors
if _TORCHAUDIO_AVAILABLE:
from .torchaudio import load_with_torchaudio
LOAD_FNS["torchaudio"] = load_with_torchaudio
if _YAML_AVAILABLE:
from .yaml import load_yaml
LOAD_FNS["yaml"] = load_yaml
@overload
def load_from(
fpath: Union[TextIO, BinaryIO],
*args,
saving_backend: SavingBackend = "torch",
**kwargs,
) -> Any: ...
@overload
def load_from(
fpath: Union[str, Path, os.PathLike],
*args,
saving_backend: Optional[SavingBackend] = "torch",
**kwargs,
) -> Any: ...
[docs]
def load_from(
fpath: Union[str, Path, os.PathLike, TextIO, BinaryIO],
*args,
saving_backend: Optional[SavingBackend] = "torch",
**kwargs,
) -> Any:
"""Load from file using the correct backend."""
if isinstance(fpath, (str, os.PathLike)):
fpath = Path(fpath)
if not fpath.is_file():
msg = f"Invalid argument {fpath=}. (path is not a file)"
raise FileNotFoundError(msg)
if saving_backend is None:
saving_backend = _fpath_to_saving_backend(fpath)
if saving_backend not in LOAD_FNS:
msg = f"Invalid argument {saving_backend=}. (expected one of {tuple(LOAD_FNS.keys())})"
raise ValueError(msg)
load_fn = LOAD_FNS[saving_backend]
result = load_fn(fpath, *args, **kwargs)
return result
[docs]
@function_alias(load_from)
def read_from(*args, **kwargs): ...
[docs]
@deprecated_alias(load_from)
def load(*args, **kwargs): ...
[docs]
@deprecated_alias(load_from)
def read(*args, **kwargs): ...