torchwrench.extras.safetensors module¶
-
torchwrench.extras.safetensors.dump_safetensors(tensors: dict[str, Tensor], fpath: str | Path | None =
None, metadata: dict[str, str] | None =None, *, overwrite: bool =True, make_parents: bool =True, convert_to_tensor: False =False) bytes[source]¶ -
torchwrench.extras.safetensors.dump_safetensors(tensors: dict[str, Any], fpath: str | Path | None =
None, metadata: dict[str, str] | None =None, *, overwrite: bool =True, make_parents: bool =True, convert_to_tensor: True) bytes Dump tensors to safetensors format. Requires safetensors package installed.
-
torchwrench.extras.safetensors.load_safetensors(fpath: str | Path, *, device: str =
'cpu', return_metadata: False =False) dict[str, Tensor][source]¶ -
torchwrench.extras.safetensors.load_safetensors(fpath: str | Path, *, device: str =
'cpu', return_metadata: True) tuple[dict[str, Tensor], dict[str, str]]