torchwrench.hub package

class torchwrench.hub.RegistryEntry[source]

Bases: TypedDict

architecture : NotRequired[str]
fname : str
hash_type : NotRequired[Literal['sha256', 'md5']]
hash_value : NotRequired[str]
state_dict_key : NotRequired[str]
url : str
class torchwrench.hub.RegistryHub(infos: Mapping[T_Hashable, RegistryEntry], register_root: str | Path = '~/.cache/torch/hub/checkpoints')[source]

Bases: Generic[T_Hashable]

download_file(name: T_Hashable, force: bool = False, check_hash: bool = True, verbose: int = 0) tuple[Path, bool][source]

Download checkpoint file.

classmethod from_file(path: str | Path) RegistryHub[source]

Load register info from JSON file.

get_path(name: T_Hashable) Path[source]

Returns the expected filepath of an element.

property infos : dict[T_Hashable, RegistryEntry]
is_valid_hash(name: T_Hashable) bool[source]

Returns True if target file hash is valid. If no hash is provided in infos, this function also returns True.

load_state_dict(name_or_path: __SPHINX_IMMATERIAL_TYPE_VAR__V_T_Hashable | str | ~pathlib.Path, *, device: ~torch.device | None | ~typing.Literal['default', 'cuda_if_available'] | str | int = None, offline: bool = False, load_fn: ~typing.Callable[[~typing.Any], __SPHINX_IMMATERIAL_TYPE_VAR__V_T] | ~typing.Literal['csv', 'json', 'jsonl', 'h5py', 'numpy', 'pickle', 'safetensors', 'torch', 'torchaudio', 'yaml'] = <function load_torch>, load_kwds: ~typing.Dict[str, ~typing.Any] | None = None, verbose: int = 0) dict[str, Tensor][source]

Load state_dict weights.

Args:

model_name_or_path: Model name (case sensitive) or path to checkpoint file. device: Device of checkpoint weights. (deprecated) offline: If False, the checkpoint from a model name will be automatically downloaded. load_fn: Load function backend. defaults to torch.load. load_kwds: Optional keywords arguments passed to load_fn. defaults to None. verbose: Verbose level. defaults to 0.

Returns:

Loaded file content.

property names : list[T_Hashable]
property paths : list[Path]
property register_root : Path
remove_file(name: T_Hashable) None[source]
save(path: str | Path) None[source]

Save info to JSON file.

torchwrench.hub.download_file(url: str, dst: str | Path | None = '.', *, hash_prefix: str | None = None, make_parents: bool = False, make_intermediate: bool | None = None, verbose: int = 0) Path[source]

Download file to target filepath or directory.

Args:

url: Target URL. dst: Target filepath or directory. None means current working directory. defaults to “.”. hash_prefix: Optional hash prefix present in destination filename. defaults to None. make_parents: If True, make intermediate directories to destination. defaults to False. make_intermediate: Deprecated: alias for ‘make_parents’. If not None, overwrite any value of ‘make_parents’. defaults to None. verbose: Verbose level. defaults to 0.

Returns:

Path to the downloaded file.

torchwrench.hub.get_cache_dir(mkdir: bool = False, make_parents: bool = True) Path[source]

Returns torchwrench cache directory for storing checkpoints, data and models.

Defaults is ~/.cache/torchwrench. Can be overriden with ‘TORCHWRENCH_CACHEDIR’ environment variable.

torchwrench.hub.get_tmp_dir(mkdir: bool = False, make_parents: bool = True) Path[source]

Returns torchwrench temporary directory.

Defaults is /tmp/torchwrench. Can be overriden with ‘TORCHWRENCH_TMPDIR’ environment variable.

Submodules