Source code for torchwrench.hub.paths

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import tempfile
from pathlib import Path

from pythonwrench.functools import function_alias
from torch.hub import get_dir


[docs] def get_tmp_dir(mkdir: bool = False, make_parents: bool = True) -> Path: """Returns torchwrench temporary directory. Defaults is `/tmp/torchwrench`. Can be overriden with 'TORCHWRENCH_TMPDIR' environment variable. """ default = tempfile.gettempdir() result = os.getenv("TORCHWRENCH_TMPDIR", default) result = Path(result).joinpath("torchwrench").resolve().expanduser() if mkdir: result.mkdir(parents=make_parents, exist_ok=True) return result
[docs] def get_cache_dir(mkdir: bool = False, make_parents: bool = True) -> Path: """Returns torchwrench cache directory for storing checkpoints, data and models. Defaults is `~/.cache/torchwrench`. Can be overriden with 'TORCHWRENCH_CACHEDIR' environment variable. """ default = Path.home().joinpath(".cache", "torchwrench") result = os.getenv("TORCHWRENCH_CACHEDIR", default) result = Path(result).resolve().expanduser() if mkdir: result.mkdir(parents=make_parents, exist_ok=True) return result
@function_alias(get_dir) def get_torch_cache_dir(*args, **kwargs): ...