Source code for torchwrench.extras.safetensors

#!/usr/bin/env python
# -*- coding: utf-8 -*-

from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, Union, overload

from pythonwrench._core import _setup_output_fpath
from pythonwrench.inspect import get_fullname
from pythonwrench.typing.checks import isinstance_generic
from safetensors import safe_open
from safetensors.torch import save
from torch import Tensor

from torchwrench.nn import functional as F


@overload
def load_safetensors(
    fpath: Union[str, Path],
    *,
    device: str = "cpu",
    return_metadata: Literal[False] = False,
) -> Dict[str, Tensor]: ...


@overload
def load_safetensors(
    fpath: Union[str, Path],
    *,
    device: str = "cpu",
    return_metadata: Literal[True],
) -> Tuple[Dict[str, Tensor], Dict[str, str]]: ...


[docs] def load_safetensors( fpath: Union[str, Path], *, device: str = "cpu", return_metadata: bool = False, ) -> Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], Dict[str, str]]]: tensors = {} with safe_open(fpath, framework="pt", device=device) as f: for k in f.keys(): tensors[k] = f.get_tensor(k) if return_metadata: metadata = f.metadata() result = tensors, metadata else: result = tensors return result
@overload def dump_safetensors( tensors: Dict[str, Tensor], fpath: Union[str, Path, None] = None, metadata: Optional[Dict[str, str]] = None, *, overwrite: bool = True, make_parents: bool = True, convert_to_tensor: Literal[False] = False, ) -> bytes: ... @overload def dump_safetensors( tensors: Dict[str, Any], fpath: Union[str, Path, None] = None, metadata: Optional[Dict[str, str]] = None, *, overwrite: bool = True, make_parents: bool = True, convert_to_tensor: Literal[True], ) -> bytes: ...
[docs] def dump_safetensors( tensors: Dict[str, Any], fpath: Union[str, Path, None] = None, metadata: Optional[Dict[str, str]] = None, *, overwrite: bool = True, make_parents: bool = True, convert_to_tensor: bool = False, ) -> bytes: """Dump tensors to safetensors format. Requires safetensors package installed.""" if convert_to_tensor: tensors = {k: F.as_tensor(v) for k, v in tensors.items()} elif not isinstance_generic(tensors, Dict[str, Tensor]): msg = f"Invalid argument type {type(tensors)}. (expected dict[str, Tensor] but found {get_fullname(type(tensors))})" raise TypeError(msg) fpath = _setup_output_fpath(fpath, overwrite, make_parents) content = save(tensors, metadata) if fpath is not None: fpath.write_bytes(content) return content