Source code for torchwrench.extras.hdf.dataset

#!/usr/bin/env python
# -*- coding: utf-8 -*-

import copy
import json
import logging
import os
import os.path as osp
import pickle
from functools import cached_property
from json import JSONDecodeError
from pathlib import Path
from typing import (
    Any,
    Callable,
    Dict,
    Generic,
    Iterable,
    List,
    Literal,
    Optional,
    Tuple,
    TypeVar,
    Union,
    get_args,
    overload,
)

import h5py
import numpy as np
import pythonwrench as pw
from h5py import Dataset as HDFRawDataset
from pythonwrench.typing import is_iterable_bytes_or_list, is_iterable_str
from torch import Tensor
from typing_extensions import Self, TypeAlias, override

import torchwrench as tw
from torchwrench.extras.hdf.common import (
    _DEFAULTS_RAW_HDF_ATTRIBUTES,
    _DUMPED_JSON_KEYS,
    HDFDatasetAttributes,
    HDFItemType,
    _dict_to_tuple,
)
from torchwrench.extras.numpy.scan_info import numpy_dtype_to_torch_dtype
from torchwrench.nn.functional.indices import get_inverse_perm
from torchwrench.serialization.common import as_builtin
from torchwrench.types._typing import ScalarLike
from torchwrench.types.guards import is_scalar_like
from torchwrench.utils.data.dataset.slicer import DatasetSlicer

T = TypeVar("T", covariant=True)
U = TypeVar("U", covariant=False)

IndexLike: TypeAlias = Union[int, Iterable[int], Tensor, slice, None]
ColumnLike: TypeAlias = Union[str, Iterable[str], None]
CastMode = Literal[
    "to_torch_or_builtin",
    "to_torch_or_numpy",
    "as_builtin",
    "to_numpy_src",
    "to_torch_src",
    "none",
]

pylog = logging.getLogger(__name__)


[docs] class HDFDataset(Generic[T, U], DatasetSlicer[U]): def __init__( self, hdf_fpath: Union[str, Path], *, transform: Optional[Callable[[T], U]] = pw.identity, keep_padding: Iterable[str] = (), return_added_columns: bool = False, open_hdf: bool = True, cast: CastMode = "none", file_kwds: Optional[Dict[str, Any]] = None, ) -> None: """HDFDataset to read an packed hdf file. In HDF, all tensors are padded internally then cropped on-the-fly. Several options allows to extract non-padded tensors or return the internal shape of each column. Args: hdf_fpath: The path to the HDF file. transform: The transform to apply values. default to None. keep_padding: Keys to keep padding values. defaults to (). return_added_columns: If True, returns the columns added by pack_to_hdf(.) function. defaults to False. open_hdf: If True, open the HDF file at start. defaults to True. cast: Cast policy when loading data. defaults to None. file_kwds: Options given to h5py file object. defaults to None. """ hdf_fpath = Path(hdf_fpath).resolve().expanduser() if not hdf_fpath.is_file(): names = os.listdir(osp.dirname(hdf_fpath)) names = [name for name in names if name.endswith(".hdf")] names = list(sorted(names)) names_str = "\n - ".join(names) msg = f"Cannot find HDF file in path {hdf_fpath=}." if len(names) > 0: msg += f" Possible HDF files are:\n - {names_str}" raise FileNotFoundError(msg) keep_padding = list(keep_padding) if file_kwds is None: file_kwds = {} super().__init__( add_indices_support=False, add_mask_support=False, add_slice_support=False, add_none_support=True, ) self._hdf_fpath = hdf_fpath self._transform = transform self._keep_padding = keep_padding self._return_added_columns = return_added_columns self._cast: CastMode = cast self._file_kwds = file_kwds self._hdf_file: Any = None if open_hdf: self.open() # Public properties @property def added_columns(self) -> List[str]: """Return the list of columns added by pack_to_hdf function.""" return self.attrs["added_columns"] @property def all_columns(self) -> List[str]: """The name of all columns of the dataset.""" return list(self.get_hdf_keys()) @cached_property def attrs(self) -> HDFDatasetAttributes: attrs = copy.copy(_DEFAULTS_RAW_HDF_ATTRIBUTES) attrs.update(self._hdf_file.attrs) for name in _DUMPED_JSON_KEYS: try: attrs[name] = json.loads(attrs[name]) except JSONDecodeError: pylog.error(f"Cannot load JSON data {attrs[name]=} from {name=}.") attrs["added_columns"] = list(attrs["added_columns"]) attrs["src_np_dtypes"] = { k: np.dtype(v) for k, v in attrs["src_np_dtypes"].items() } return attrs # type: ignore @property def column_names(self) -> Tuple[str, ...]: """The name of each column of the dataset.""" column_names = self.all_columns column_names = [ name for name in column_names if self._return_added_columns or name not in self.added_columns ] return tuple(column_names) @property def info(self) -> Dict[str, Any]: """Return the global dataset info.""" return self.attrs["info"] @property def item_type(self) -> HDFItemType: """Return the global dataset info.""" return self._hdf_file.attrs.get("item_type", "dict") @property def keep_padding(self) -> List[str]: return self._keep_padding @keep_padding.setter def keep_padding(self, new_value: Iterable[str]) -> None: self._keep_padding = list(new_value) @property def num_columns(self) -> int: return len(self.column_names) @property def num_rows(self) -> int: return len(self) @property def shape(self) -> Tuple[int, ...]: """The shape of the Clotho dataset.""" return len(self), len(self.column_names) @property def transform(self) -> Optional[Callable[[T], U]]: return self._transform @property def user_attrs(self) -> Any: return self.attrs["user_attrs"] # Private properties to define internal behaviours @property def _encoding(self) -> str: return self.attrs["encoding"] @property def _shape_suffix(self) -> str: """Return the tensor shape suffix in column names.""" return self.attrs["shape_suffix"] @property def _src_np_dtypes(self) -> Dict[str, np.dtype]: return self.attrs["src_np_dtypes"] @cached_property def _src_is_unicode(self) -> Dict[str, bool]: is_unicode = {name: dt.kind == "U" for name, dt in self._src_np_dtypes.items()} return is_unicode @property def _load_as_complex(self) -> Dict[str, bool]: return self.attrs["load_as_complex"] # Public methods
[docs] @pw.deprecated_function() def at(self, *args, **kwargs) -> Any: """Deprecated: Use get_item method instead.""" return self.get_item(*args, **kwargs)
[docs] def close(self, ignore_if_closed: bool = False, remove_file: bool = False) -> None: if self.is_closed() and not ignore_if_closed: raise RuntimeError("Cannot close the HDF file twice.") if not ignore_if_closed: self._hdf_file.close() if remove_file: os.remove(self._hdf_fpath) self._hdf_file = None self._clear_caches()
[docs] def get_attrs(self) -> HDFDatasetAttributes: return self.attrs
[docs] def get_hdf_fpath(self) -> Path: return self._hdf_fpath
[docs] def get_hdf_keys(self) -> Tuple[str, ...]: if self.is_closed(): raise RuntimeError("Cannot get keys from a closed HDF file.") return tuple(self._hdf_file.keys())
[docs] def get_column_shape(self, column_name: str) -> Tuple[int, ...]: if self.is_closed(): msg = f"Cannot get_column_shape with a closed HDF file. ({self._hdf_file is None=} or {not bool(self._hdf_file)=})" raise RuntimeError(msg) return tuple(self._hdf_file[column_name].shape)
[docs] def get_columns_shapes(self) -> Dict[str, Tuple[int, ...]]: if self.is_closed(): msg = f"Cannot get_columns_shapes with a closed HDF file. ({self._hdf_file is None=} or {not bool(self._hdf_file)=})" raise RuntimeError(msg) return { column_name: tuple(self._hdf_file[column_name].shape) for column_name in self.column_names }
[docs] def get_column_dtype(self, column_name: str) -> np.dtype: if self.is_closed(): msg = f"Cannot get dtype with a closed HDF file. ({self._hdf_file is None=} or {not bool(self._hdf_file)=})" raise RuntimeError(msg) return self._hdf_file[column_name].dtype
@overload def get_item( self, index: int, column: None = None, ) -> U: ... @overload def get_item( self, index: Union[Iterable[int], slice, None], column: str, ) -> List: ... @overload def get_item( self, index: Union[Iterable[int], slice, None], column: Union[List[str], None] = None, ) -> Dict[str, List]: ... @overload def get_item( self, index: Any, column: Any, raw: bool = False, ) -> Any: ...
[docs] @override def get_item( self, index: IndexLike, column: ColumnLike = None, raw: bool = False, ) -> Any: if self.is_closed(): msg = f"Cannot get_raw value with closed HDF file. ({self._hdf_file is not None=} and {bool(self._hdf_file)=})" raise RuntimeError(msg) if index is None: index = slice(None) elif is_scalar_like(index): index = tw.to_item(index) # type: ignore elif ( pw.isinstance_generic(index, Iterable[int]) or pw.isinstance_generic(index, Iterable[bool]) or tw.is_tensor_or_array(index) ): index = tw.to_ndarray(index) elif isinstance(index, (int, slice)): pass else: raise TypeError(f"Invalid argument type {type(index)=}.") if column is None: column = self.column_names if is_iterable_str(column, accept_str=False): result_dict = { column_i: self.get_item(index, column_i) for column_i in column } if self.item_type == "tuple": result = _dict_to_tuple(result_dict) else: result = result_dict if ( isinstance(index, int) and self._transform is not None and set(column) == set(self.column_names) ): result = self._transform(result) # type: ignore return result # type: ignore assert isinstance(column, str) if column not in self.all_columns: closest = pw.find_closest_in_list(column, self.all_columns) msg = f"Invalid argument {column=}. (did you mean '{closest}'? Expected one of {tuple(self.all_columns)})" raise ValueError(msg) if isinstance(index, slice) or ( isinstance(index, np.ndarray) and index.ndim == 1 and index.dtype.kind in ("b", "i") ): is_mult = True elif isinstance(index, int): if not (-len(self) <= index < len(self)): msg = f"Invalid argument {index=}. (expected int in range [{-len(self)}, {len(self) - 1}])" raise IndexError(msg) is_mult = False else: raise TypeError(f"Invalid argument type {type(index)=}.") hdf_value = self._get_raw_item(index, column) if raw: return hdf_value if is_mult: hdf_values = hdf_value else: hdf_values = hdf_value[None] del hdf_value shape_column = f"{column}{self._shape_suffix}" must_remove_padding = ( shape_column in self._hdf_file.keys() and column not in self._keep_padding ) hdf_ds: HDFRawDataset = self._hdf_file[column] hdf_dtype: np.dtype = hdf_ds.dtype if must_remove_padding: shapes = self._get_raw_item(index, shape_column) if not is_mult: shapes = shapes[None] slices_lst = [ tuple(slice(shape_i) for shape_i in shape) for shape in shapes ] else: slices_lst = [None] * int(hdf_ds.shape[0]) if ( self._src_is_unicode.get(column, False) or h5py.check_vlen_dtype(hdf_dtype) is str ): hdf_values = _decode_bytes(hdf_values, encoding=self._encoding) if must_remove_padding: outputs = [] for hdf_value, slices in zip(hdf_values, slices_lst): # Remove the padding part if slices is not None: assert isinstance(hdf_value, np.ndarray) and not isinstance( hdf_value, str ) hdf_value = hdf_value[slices] hdf_value = self._cast_values(hdf_value, column, hdf_dtype) outputs.append(hdf_value) else: outputs = self._cast_values(hdf_values, column, hdf_dtype) del hdf_values if not is_mult: outputs = outputs[0] return outputs
[docs] def is_closed(self) -> bool: return not self.is_open()
[docs] def is_open(self) -> bool: return self._hdf_file is not None and bool(self._hdf_file)
[docs] def keys(self) -> Tuple[str, ...]: return self.column_names
[docs] def open(self, ignore_if_opened: bool = False) -> None: if self.is_open(): if ignore_if_opened: return None else: raise RuntimeError("Cannot open the HDF file twice.") self._clear_caches() self._hdf_file = h5py.File(self._hdf_fpath, "r", **self._file_kwds) self._sanity_check()
[docs] def to_dict(self, raw: bool = False) -> Dict[str, np.ndarray]: return { col: self.get_item(slice(None), col, raw=raw) for col in self.column_names }
# Magic methods def __eq__(self, __o: object) -> bool: return isinstance(__o, HDFDataset) and pickle.dumps(self) == pickle.dumps(__o) def __enter__(self) -> Self: return self def __exit__(self, *args, **kwargs) -> None: if self.is_open(): self.close() @overload def __getitem__(self, index: int, /) -> U: ... @overload def __getitem__(self, index: str, /) -> np.ndarray: ... @overload def __getitem__(self, index: Tuple[IndexLike, str], /) -> np.ndarray: ... @overload def __getitem__( self, index: Union[Iterable[int], List[str], Tuple[str, ...], slice, None], ) -> Dict[str, np.ndarray]: ... @overload def __getitem__(self, index: Any, /) -> Any: ... def __getitem__( # type: ignore self, index: Union[IndexLike, ColumnLike, Tuple[IndexLike, ColumnLike]], /, ) -> Any: if isinstance(index, str) or pw.isinstance_generic(index, Iterable[str]): index = slice(None), index return super().__getitem__(index) # type: ignore def __getstate__(self) -> Dict[str, Any]: return { "hdf_fpath": self._hdf_fpath, "transform": self._transform, "keep_padding": self._keep_padding, "return_added_columns": self._return_added_columns, "cast": self._cast, "file_kwds": self._file_kwds, "is_open": self.is_open(), } def __hash__(self) -> int: hash_value = 0 if self.is_open(): hash_value += self.attrs["global_hash_value"] if self._transform is not None: hash_value += hash(self._transform) hash_value += sum(map(hash, self._keep_padding)) return hash_value def __len__(self) -> int: if self.is_closed(): msg = f"Cannot length of a closed HDF file. ({self._hdf_file is not None=} and {bool(self._hdf_file)=})" raise RuntimeError(msg) if "length" in self._hdf_file.attrs: length = self._hdf_file.attrs["length"] elif len(self._hdf_file) > 0: hdf_dsets: List[HDFRawDataset] = list(self._hdf_file.values()) hdf_dsets_lens = [len(ds) for ds in hdf_dsets] if not pw.all_eq(hdf_dsets_lens): msg = f"Found an different number of lengths in hdf sub-datasets. (found {set(hdf_dsets_lens)})" raise ValueError(msg) length = hdf_dsets_lens[0] else: length = 0 return length def __repr__(self) -> str: repr_hparams: Dict[str, Any] = {"file": osp.basename(self._hdf_fpath)} if self.is_open(): repr_hparams["shape"] = self.shape repr_ = ", ".join(f"{k}={v}" for k, v in repr_hparams.items()) return f"{HDFDataset.__name__}({repr_})" def __setstate__(self, data: Dict[str, Any]) -> None: is_init = hasattr(self, "_hdf_fpath") and hasattr(self, "_hdf_file") files_are_different = is_init and self._hdf_fpath != data["hdf_fpath"] is_open = is_init and self.is_open() if is_init and files_are_different and is_open: self.close() self._hdf_fpath = data["hdf_fpath"] self._transform = data["transform"] self._keep_padding = data["keep_padding"] self._return_added_columns = data["return_added_columns"] self._cast = data["cast"] self._file_kwds = data["file_kwds"] self._hdf_file = None if not is_init or (files_are_different and is_open): self.open() # Private methods def _get_raw_item( self, index: Union[int, slice, np.ndarray], column: str, ) -> np.ndarray: if isinstance(index, (int, slice)) or ( isinstance(index, np.ndarray) and index.ndim == 1 and index.dtype.kind == "b" ): hdf_value: Any = self._hdf_file[column][index] if self._load_as_complex.get(column, False): hdf_value = tw.view_as_complex(hdf_value) hdf_value = np.array(hdf_value) elif ( isinstance(index, np.ndarray) and index.ndim == 1 and index.dtype.kind == "i" ): # Note: slicing with indices required strict sorted list of indices, so we sort + remove duplicates before loading local_idxs = np.argsort(index, axis=-1) sorted_idxs = index[local_idxs] uniq, counts = np.unique(sorted_idxs, return_counts=True) hdf_value: Any = self._hdf_file[column][uniq] hdf_value = np.repeat(hdf_value, counts, axis=0) inv_local_idxs = get_inverse_perm(tw.ndarray_to_tensor(local_idxs)).numpy() hdf_value = hdf_value[inv_local_idxs] if self._load_as_complex.get(column, False): hdf_value = [tw.view_as_complex(value) for value in hdf_value] else: raise TypeError(f"Invalid argument type {type(index)=}.") return hdf_value def _sanity_check(self) -> None: lens = [dset.shape[0] for dset in self._hdf_file.values()] if not pw.all_eq(lens) or (len(lens) > 0 and lens[0] != len(self)): msg = ( f"Incorrect length stored in HDF file. (found {lens=} and {len(self)=})" ) pylog.error(msg) if not hasattr(self, "__orig_class__"): return None t_type = self.__orig_class__.__args__[0] # type: ignore if t_type is not Any and ( (issubclass(t_type, dict) and self.item_type != "dict") or (issubclass(t_type, tuple) and self.item_type != "tuple") ): msg = f"Invalid HDFDataset typing. (found specified type '{t_type.__name__}' but the internal dataset contains type '{self.item_type}')" raise TypeError(msg) def _clear_caches(self) -> None: if hasattr(self, "attrs"): del self.attrs if hasattr(self, "_is_unicode"): del self._src_is_unicode def _cast_values( self, hdf_values: Union[ScalarLike, np.ndarray, List], column: str, hdf_dtype: np.dtype, ) -> Any: if self._cast == "none": result = hdf_values elif self._cast == "to_torch_or_builtin": valid = tw.get_shape(hdf_values, return_indicator=True).valid if valid and hdf_dtype.kind not in ("V", "S", "O"): result = tw.as_tensor(hdf_values) elif isinstance(hdf_values, np.ndarray): result = hdf_values.tolist() else: result = as_builtin(hdf_values) elif self._cast == "to_torch_or_numpy": valid = tw.get_shape(hdf_values, return_indicator=True).valid if valid and hdf_dtype.kind not in ("V", "S", "O"): result = tw.as_tensor(hdf_values) else: result = np.array(hdf_values) elif self._cast == "as_builtin": result = as_builtin(hdf_values) elif self._cast == "to_numpy_src": assert isinstance(hdf_values, np.ndarray), f"{type(hdf_values)=}" valid = tw.get_shape(hdf_values, return_indicator=True).valid src_np_dtypes = self.attrs["src_np_dtypes"] target_np_dtype = src_np_dtypes.get(column, hdf_values.dtype) if isinstance(hdf_values, np.ndarray): result = hdf_values.astype(target_np_dtype) elif valid: result = np.array(hdf_values, dtype=target_np_dtype) else: result = hdf_values elif self._cast == "to_torch_src": assert isinstance(hdf_values, np.ndarray), f"{type(hdf_values)=}" valid = tw.get_shape(hdf_values, return_indicator=True).valid src_np_dtypes = self.attrs["src_np_dtypes"] target_np_dtype = src_np_dtypes.get(column, hdf_values.dtype) target_pt_dtype = numpy_dtype_to_torch_dtype(target_np_dtype, invalid=None) if isinstance(hdf_values, np.ndarray): hdf_values_view = hdf_values.view(target_np_dtype) result = tw.ndarray_to_tensor(hdf_values_view) elif valid: result = tw.as_tensor(hdf_values, dtype=target_pt_dtype) else: result = hdf_values else: msg = f"Invalid argument {self._cast=}. (expected one of {get_args(CastMode)})" raise ValueError(msg) return result
def _decode_bytes( encoded: Union[bytes, bytearray, np.ndarray, list], encoding: str, ) -> Union[str, np.ndarray, list]: """Decode bytes to str with the specified encoding. Works recursively on list of bytes, list of list of bytes, etc.""" if isinstance(encoded, (bytes, bytearray)): return encoded.decode(encoding=encoding) elif isinstance(encoded, np.ndarray): if encoded.dtype.kind == "S": return np.char.decode(encoded, encoding=encoding) elif encoded.dtype.kind == "O" and encoded.ndim > 0: return [ _decode_bytes(encoded_i, encoding=encoding) for encoded_i in encoded ] else: return _decode_bytes(encoded.item(), encoding=encoding) elif is_iterable_bytes_or_list(encoded): return [_decode_bytes(elt, encoding) for elt in encoded] # type: ignore else: msg = f"Invalid argument type {type(encoded)} for {pw.get_current_fn_name()}. (expected bytes, bytes ndarray or Iterable)" raise TypeError(msg)