torchwrench.extras.safetensors module

torchwrench.extras.safetensors.dump_safetensors(tensors: dict[str, Tensor], fpath: str | Path | None = None, metadata: dict[str, str] | None = None, *, overwrite: bool = True, make_parents: bool = True, convert_to_tensor: False = False) bytes[source]
torchwrench.extras.safetensors.dump_safetensors(tensors: dict[str, Any], fpath: str | Path | None = None, metadata: dict[str, str] | None = None, *, overwrite: bool = True, make_parents: bool = True, convert_to_tensor: True) bytes

Dump tensors to safetensors format. Requires safetensors package installed.

torchwrench.extras.safetensors.load_safetensors(fpath: str | Path, *, device: str = 'cpu', return_metadata: False = False) dict[str, Tensor][source]
torchwrench.extras.safetensors.load_safetensors(fpath: str | Path, *, device: str = 'cpu', return_metadata: True) tuple[dict[str, Tensor], dict[str, str]]