torchwrench.hub.registry module¶
- class torchwrench.hub.registry.RegistryEntry[source]¶
Bases:
TypedDict- architecture : NotRequired[str]¶
- hash_type : NotRequired[Literal['sha256', 'md5']]¶
- hash_value : NotRequired[str]¶
- state_dict_key : NotRequired[str]¶
-
class torchwrench.hub.registry.RegistryHub(infos: Mapping[T_Hashable, RegistryEntry], register_root: str | Path =
'~/.cache/torch/hub/checkpoints')[source]¶ Bases:
Generic[T_Hashable]-
download_file(name: T_Hashable, force: bool =
False, check_hash: bool =True, verbose: int =0) tuple[Path, bool][source]¶ Download checkpoint file.
- property infos : dict[T_Hashable, RegistryEntry]¶
- is_valid_hash(name: T_Hashable) bool[source]¶
Returns True if target file hash is valid. If no hash is provided in infos, this function also returns True.
- load_state_dict(name_or_path: __SPHINX_IMMATERIAL_TYPE_VAR__V_T_Hashable | str | ~pathlib.Path, *, device: ~torch.device | None | ~typing.Literal['default', 'cuda_if_available'] | str | int = None, offline: bool = False, load_fn: ~typing.Callable[[~typing.Any], __SPHINX_IMMATERIAL_TYPE_VAR__V_T] | ~typing.Literal['csv', 'json', 'jsonl', 'h5py', 'numpy', 'pickle', 'safetensors', 'torch', 'torchaudio', 'yaml'] = <function load_torch>, load_kwds: ~typing.Dict[str, ~typing.Any] | None = None, verbose: int = 0) dict[str, Tensor][source]¶
Load state_dict weights.
- Args:
model_name_or_path: Model name (case sensitive) or path to checkpoint file. device: Device of checkpoint weights. (deprecated) offline: If False, the checkpoint from a model name will be automatically downloaded. load_fn: Load function backend. defaults to torch.load. load_kwds: Optional keywords arguments passed to load_fn. defaults to None. verbose: Verbose level. defaults to 0.
- Returns:
Loaded file content.
-
download_file(name: T_Hashable, force: bool =