torchwrench.hub package¶
- class torchwrench.hub.RegistryEntry[source]¶
Bases:
TypedDict- architecture : NotRequired[str]¶
- hash_type : NotRequired[Literal['sha256', 'md5']]¶
- hash_value : NotRequired[str]¶
- state_dict_key : NotRequired[str]¶
-
class torchwrench.hub.RegistryHub(infos: Mapping[T_Hashable, RegistryEntry], register_root: str | Path =
'~/.cache/torch/hub/checkpoints')[source]¶ Bases:
Generic[T_Hashable]-
download_file(name: T_Hashable, force: bool =
False, check_hash: bool =True, verbose: int =0) tuple[Path, bool][source]¶ Download checkpoint file.
- property infos : dict[T_Hashable, RegistryEntry]¶
- is_valid_hash(name: T_Hashable) bool[source]¶
Returns True if target file hash is valid. If no hash is provided in infos, this function also returns True.
- load_state_dict(name_or_path: __SPHINX_IMMATERIAL_TYPE_VAR__V_T_Hashable | str | ~pathlib.Path, *, device: ~torch.device | None | ~typing.Literal['default', 'cuda_if_available'] | str | int = None, offline: bool = False, load_fn: ~typing.Callable[[~typing.Any], __SPHINX_IMMATERIAL_TYPE_VAR__V_T] | ~typing.Literal['csv', 'json', 'jsonl', 'h5py', 'numpy', 'pickle', 'safetensors', 'torch', 'torchaudio', 'yaml'] = <function load_torch>, load_kwds: ~typing.Dict[str, ~typing.Any] | None = None, verbose: int = 0) dict[str, Tensor][source]¶
Load state_dict weights.
- Args:
model_name_or_path: Model name (case sensitive) or path to checkpoint file. device: Device of checkpoint weights. (deprecated) offline: If False, the checkpoint from a model name will be automatically downloaded. load_fn: Load function backend. defaults to torch.load. load_kwds: Optional keywords arguments passed to load_fn. defaults to None. verbose: Verbose level. defaults to 0.
- Returns:
Loaded file content.
-
download_file(name: T_Hashable, force: bool =
-
torchwrench.hub.download_file(url: str, dst: str | Path | None =
'.', *, hash_prefix: str | None =None, make_parents: bool =False, make_intermediate: bool | None =None, verbose: int =0) Path[source]¶ Download file to target filepath or directory.
- Args:
url: Target URL. dst: Target filepath or directory. None means current working directory. defaults to “.”. hash_prefix: Optional hash prefix present in destination filename. defaults to None. make_parents: If True, make intermediate directories to destination. defaults to False. make_intermediate: Deprecated: alias for ‘make_parents’. If not None, overwrite any value of ‘make_parents’. defaults to None. verbose: Verbose level. defaults to 0.
- Returns:
Path to the downloaded file.
-
torchwrench.hub.get_cache_dir(mkdir: bool =
False, make_parents: bool =True) Path[source]¶ Returns torchwrench cache directory for storing checkpoints, data and models.
Defaults is ~/.cache/torchwrench. Can be overriden with ‘TORCHWRENCH_CACHEDIR’ environment variable.
-
torchwrench.hub.get_tmp_dir(mkdir: bool =
False, make_parents: bool =True) Path[source]¶ Returns torchwrench temporary directory.
Defaults is /tmp/torchwrench. Can be overriden with ‘TORCHWRENCH_TMPDIR’ environment variable.