#!/usr/bin/env python
# -*- coding: utf-8 -*-
import json
import logging
import os
from dataclasses import asdict
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import (
Any,
Callable,
Dict,
Generic,
Hashable,
List,
Literal,
Mapping,
Optional,
Set,
Tuple,
TypeVar,
Union,
get_args,
)
import h5py
import numpy as np
import torch
from h5py import Dataset as HDFRawDataset
from torch.utils.data.dataloader import DataLoader
from typing_extensions import TypeAlias
try:
from tqdm import tqdm # type: ignore
except ImportError:
def tqdm(x, *args, **kwargs):
return x
import pythonwrench as pw
from pythonwrench.functools import Compose
from pythonwrench.typing import (
BuiltinScalar,
SupportsGetitemLen,
SupportsIterLen,
is_dataclass_instance,
isinstance_generic,
)
import torchwrench as tw
from torchwrench import nn
from torchwrench.extras.hdf.common import (
_DUMPED_JSON_KEYS,
HDF_ENCODING,
HDF_STRING_DTYPE,
HDF_VOID_DTYPE,
SHAPE_SUFFIX,
ExistsMode,
HDFItemType,
_tuple_to_dict,
)
from torchwrench.extras.hdf.dataset import HDFDataset
from torchwrench.extras.numpy import (
merge_numpy_dtypes,
numpy_is_complex_dtype,
scan_shape_dtypes,
)
K = TypeVar("K", covariant=True, bound=Hashable)
V = TypeVar("V", covariant=True)
T = TypeVar("T", covariant=True)
T_DictOrTuple = TypeVar("T_DictOrTuple", tuple, dict, covariant=True)
HDFDType: TypeAlias = Union[np.dtype, Literal["b", "i", "u", "f", "c"], type]
pylog = logging.getLogger(__name__)
[docs]
@torch.inference_mode()
def pack_to_hdf(
dataset: Union[
SupportsGetitemLen[T_DictOrTuple],
SupportsIterLen[T_DictOrTuple],
Mapping[str, SupportsGetitemLen],
],
hdf_fpath: Union[str, Path],
pre_transform: Optional[Callable[[T_DictOrTuple], T_DictOrTuple]] = pw.identity,
*,
# Loader args
batch_size: int = 32,
num_workers: Union[int, Literal["auto"]] = "auto",
skip_scan: bool = False,
# Packing args
encoding: str = HDF_ENCODING,
file_kwds: Optional[Dict[str, Any]] = None,
col_kwds: Optional[Dict[str, Any]] = None,
shape_suffix: str = SHAPE_SUFFIX,
store_str_as_vlen: bool = False,
user_attrs: Any = None,
# Others args
exists: ExistsMode = "error",
ds_kwds: Optional[Dict[str, Any]] = None,
verbose: int = 0,
) -> HDFDataset[T_DictOrTuple, T_DictOrTuple]:
"""Pack a dataset to HDF file.
Args:
dataset: The sized dataset to pack. Must be sized and all items must be of dict type.
The key of each dictionaries are strings and values can be int, float, str, Tensor, non-empty List[int], non-empty List[float], non-empty List[str].
If values are tensors or lists, the number of dimensions must be the same for all items in the dataset.
hdf_fpath: The path to the HDF file.
pre_transform: The optional transform to apply to audio returned by the dataset BEFORE storing it in HDF file.
Can be used for deterministic transforms like Resample, LogMelSpectrogram, etc. defaults to None.
batch_size: The batch size of the dataloader. defaults to 32.
num_workers: The number of workers of the dataloader.
If "auto", it will be set to `len(os.sched_getaffinity(0))`. defaults to "auto".
skip_scan: If True, the input dataset will be considered as fully homogeneous, which means that all columns values contains the same shape and dtype, which will be inferred from the first batch.
It is meant to skip the first step which scans each dataset item once and speed up packing to HDF file.
defaults to False.
encoding: String encoding used in file. defaults to "utf-8".
file_kwds: Options given to h5py.File object. defaults to None.
col_kwds: Options given to all dataset columns, i.e. h5py.File().create_dataset(.) method. defaults to None.
shape_suffix: Shape column suffix in HDF file. defaults to "_shape".
store_str_as_vlen: If True, store strings as variable length string dtype. defaults to False.
user_attrs: Additional metadata to add to the hdf file. It must be convertible to JSON with `json.dumps`. defaults to None.
exists: Determine which action should be performed if the target HDF file already exists.
"overwrite": Replace the target file then pack dataset.
"skip": Skip this function and returns the packed dataset.
"error": Raises a ValueError.
ds_kwds: Keywords arguments passed to the returned HDFDataset instance if the target file already exists and if exists == "skip".
verbose: Verbose level. defaults to 0.
Returns:
hdf_dataset: The target HDF dataset object.
"""
if len(dataset) == 0:
msg = f"Cannot pack to hdf an empty dataset. (found {len(dataset)=})"
raise ValueError(msg)
if ds_kwds is None:
ds_kwds = {}
hdf_fpath = Path(hdf_fpath).resolve().expanduser()
if hdf_fpath.exists() and not hdf_fpath.is_file():
msg = f"Item {hdf_fpath=} exists but it is not a file."
raise RuntimeError(msg)
if not hdf_fpath.is_file() or exists == "overwrite":
pass
elif exists == "skip":
return HDFDataset(hdf_fpath, **ds_kwds)
elif exists == "error":
msg = f"Cannot overwrite file {hdf_fpath}. Please remove it or use exists='overwrite' or exists='skip' option."
raise ValueError(msg)
else:
msg = f"Invalid argument {exists=}. (expected one of {get_args(ExistsMode)})"
raise ValueError(msg)
if file_kwds is None:
file_kwds = {}
if num_workers == "auto":
num_workers = pw.get_num_cpus_available()
if verbose >= 2:
pylog.debug(f"Start packing data into HDF file '{hdf_fpath}'...")
# Step 1: First pass to the dataset to build static HDF dataset shapes (much faster for read the resulting file)
pre_transform_name = pw.get_fullname(pre_transform)
(
dataset,
dict_pre_transform,
item_type,
max_shapes,
hdf_dtypes,
all_eq_shapes,
src_np_dtypes,
) = _scan_dataset(
dataset=dataset,
pre_transform=pre_transform,
batch_size=batch_size,
num_workers=num_workers,
verbose=verbose,
store_str_as_vlen=store_str_as_vlen,
encoding=encoding,
skip_scan=skip_scan,
)
total = sum(pw.prod(shape) for shape in max_shapes.values())
max_shapes_ratios = {
attr_name: pw.prod(shape) / total for attr_name, shape in max_shapes.items()
}
# For debugging purposes
data = {
"item_type": item_type,
"max_shapes": max_shapes,
"hdf_dtypes": hdf_dtypes,
"all_eq_shapes": all_eq_shapes,
"src_np_dtypes": src_np_dtypes,
}
data = pw.as_builtin(data)
with NamedTemporaryFile(
"w",
prefix="HDF_scan_results_",
suffix=".json",
delete=False,
) as file:
json.dump(data, file)
scan_results_fpath = Path(file.name)
creation_date = pw.get_now()
if col_kwds is None:
col_kwds = {}
with h5py.File(hdf_fpath, "w", **file_kwds) as hdf_file:
# Step 2: Build hdf datasets in file
hdf_dsets: Dict[str, HDFRawDataset] = {}
# Create sub-datasets for main data
for attr_name, shape in max_shapes.items():
hdf_dtype = hdf_dtypes.get(attr_name)
kwargs: Dict[str, Any] = {}
fill_value = hdf_dtype_to_fill_value(hdf_dtype)
if fill_value is not None:
kwargs["fillvalue"] = fill_value
kwargs.update(col_kwds)
hdf_ds_shape = (len(dataset),) + shape
try:
hdf_dsets[attr_name] = hdf_file.create_dataset(
name=attr_name,
shape=hdf_ds_shape,
dtype=hdf_dtype,
**kwargs,
)
except ValueError as err:
msg = f"Cannot create hdf dataset {attr_name=} of shape '{hdf_ds_shape}' with dtype '{hdf_dtype}' and {kwargs=}."
pylog.error(msg)
raise err
if verbose >= 2:
num_scalars = sum(len(hdf_ds.shape) == 1 for hdf_ds in hdf_dsets.values())
ratio = num_scalars / total
msg = f"{num_scalars}/{len(hdf_dsets)} column dsets contains a single dim. ({ratio * 100:.3f}%)"
pylog.debug(msg)
if num_scalars < len(hdf_dsets):
msg = "Others multidims column dsets are:"
pylog.debug(msg)
for attr_name, hdf_ds in hdf_dsets.items():
if len(hdf_ds.shape) == 1:
continue
ratio = max_shapes_ratios[attr_name]
msg = f"HDF column dset multidim '{attr_name}' has been built. (with shape={hdf_ds.shape}, nelement_per_item={pw.prod(hdf_ds.shape[1:])} ({ratio * 100:.3f}%), dtype={hdf_ds.dtype})"
pylog.debug(msg)
added_columns: List[str] = []
# Create sub-datasets for shape data
for attr_name, shape in max_shapes.items():
if len(shape) == 0 or all_eq_shapes[attr_name]:
continue
shape_name = f"{attr_name}{shape_suffix}"
raw_dset_shape = len(dataset), len(shape)
if shape_name not in hdf_dsets:
pass
elif hdf_dsets[shape_name].shape == raw_dset_shape:
continue
else:
msg = f"Column {shape_name} already exists in source dataset with a different shape. (found shape={hdf_dsets[shape_name].shape} but expected shape is {raw_dset_shape})"
raise RuntimeError(msg)
hdf_dsets[shape_name] = hdf_file.create_dataset(
shape_name,
raw_dset_shape,
np.int32,
fillvalue=-1,
)
added_columns.append(shape_name)
# Fill sub-datasets with a second pass through the whole dataset
i = 0
global_hash_value = 0
loader = DataLoader(
dataset, # type: ignore
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=nn.Identity(),
drop_last=False,
pin_memory=False,
)
for _batch_idx, batch in enumerate(
tqdm(
loader,
desc="Pack data into HDF...",
disable=verbose <= 0,
)
):
batch = [dict_pre_transform(item) for item in batch]
for item in batch:
for attr_name, value in item.items():
hdf_dset = hdf_dsets[attr_name]
shape = tw.get_shape(value)
# Check every shape
if len(shape) != hdf_dset.ndim - 1:
msg = f"Invalid number of dimension in audio (expected {len(shape)}, found {hdf_dset.ndim - 1=})."
raise ValueError(msg)
# Check dataset size
if any(
shape_i > dset_shape_i
for shape_i, dset_shape_i in zip(shape, hdf_dset.shape[1:])
):
msg = f"Resize hdf_dset {attr_name} of shape {tuple(hdf_dset.shape[1:])} with new {shape=}."
pylog.error(msg)
msg = "INTERNAL ERROR: Cannot resize dataset when pre-computing shapes."
raise RuntimeError(msg)
# Note: "hdf_dset[slices]" is a generic version of "hdf_dset[i, :shape_0, :shape_1]"
slices = (i,) + tuple(slice(shape_i) for shape_i in shape)
try:
hdf_dset[slices] = value
except (TypeError, ValueError, OSError) as err:
msg = f"Cannot set data {value} of shape {shape} into {hdf_dset.shape=} ({attr_name=}, {i=}, {slices=}, {value.dtype=} {hdf_dset.dtype=})"
pylog.error(msg)
raise err
# Store original shape if needed
shape_name = f"{attr_name}{shape_suffix}"
if shape_name in hdf_dsets.keys():
hdf_shapes_dset = hdf_dsets[shape_name]
hdf_shapes_dset[i] = shape
global_hash_value += tw.checksum_any(value)
i += 1
# note: HDF cannot save too large int values with too many bits
global_hash_value = global_hash_value % (2**31)
if not hasattr(dataset, "info"):
info = {}
else:
info = dataset.info # type: ignore
if is_dataclass_instance(info):
info = asdict(info)
elif isinstance(info, Mapping):
info = dict(info.items()) # type: ignore
else:
info = {}
src_np_dtypes_dumped = {
name: str(merge_numpy_dtypes(np_dtypes, empty=HDF_VOID_DTYPE))
for name, np_dtypes in src_np_dtypes.items()
}
attributes = {
"added_columns": added_columns,
"creation_date": creation_date,
"encoding": encoding,
"file_kwds": file_kwds,
"global_hash_value": global_hash_value,
"info": pw.as_builtin(info),
"item_type": item_type,
"length": len(dataset),
"load_as_complex": {}, # for backward compatibility only
"pre_transform": pre_transform_name,
"shape_suffix": shape_suffix,
"source_dataset": dataset.__class__.__name__,
"src_np_dtypes": src_np_dtypes_dumped,
"store_complex_as_real": False, # for backward compatibility only
"store_str_as_vlen": store_str_as_vlen,
"user_attrs": pw.as_builtin(user_attrs),
"torchwrench_version": str(tw.__version__),
}
for name in _DUMPED_JSON_KEYS:
attributes[name] = json.dumps(attributes[name])
if verbose >= 2:
dumped_attributes = json.dumps(attributes, indent="\t")
pylog.debug(f"Saving attributes in HDF file:\n{dumped_attributes}")
attrs_errors: List[TypeError] = []
for attr_name, attr_val in attributes.items():
try:
hdf_file.attrs[attr_name] = attr_val
except TypeError as err:
msg = f"Cannot store attribute {attr_name=} with value {attr_val=} in HDF."
pylog.error(msg)
attrs_errors.append(err)
# Raises attributes errors after closing HDF file
for err in attrs_errors:
raise err
if verbose >= 2:
pylog.debug(f"Data has been packed into HDF file '{hdf_fpath}'.")
if scan_results_fpath.is_file():
os.remove(scan_results_fpath)
hdf_dataset = HDFDataset(hdf_fpath, **ds_kwds)
return hdf_dataset
[docs]
def hdf_dtype_to_fill_value(hdf_dtype: Optional[HDFDType]) -> BuiltinScalar:
if isinstance(hdf_dtype, np.dtype):
hdf_dtype = hdf_dtype.type
if hdf_dtype == "b" or hdf_dtype == np.bool_:
return False
elif hdf_dtype in ("i", "u") or (
isinstance(hdf_dtype, type) and issubclass(hdf_dtype, np.integer)
):
return 0
elif hdf_dtype == "f" or (
isinstance(hdf_dtype, type) and issubclass(hdf_dtype, np.floating)
):
return 0.0
elif (
hdf_dtype == "c"
or (
isinstance(hdf_dtype, type)
and (
hdf_dtype in (np.void, np.object_, np.bytes_, np.str_)
or issubclass(hdf_dtype, np.complexfloating)
)
)
or (isinstance(hdf_dtype, np.dtype) and numpy_is_complex_dtype(hdf_dtype))
):
return None
else:
msg = f"Unsupported type {hdf_dtype=}."
raise ValueError(msg)
[docs]
def numpy_dtype_to_hdf_dtype(
dtype: Optional[np.dtype],
*,
encoding: str = HDF_ENCODING,
) -> np.dtype:
if dtype is None:
return HDF_VOID_DTYPE
elif isinstance(dtype, np.dtype) and dtype.kind == "U":
return h5py.string_dtype(encoding, None)
else:
return dtype
[docs]
def hdf_dtype_to_numpy_dtype(hdf_dtype: HDFDType) -> np.dtype:
if isinstance(hdf_dtype, np.dtype):
return hdf_dtype
if hdf_dtype == HDF_VOID_DTYPE:
return np.dtype("V")
if hdf_dtype == HDF_STRING_DTYPE:
return np.dtype("<U")
if hdf_dtype == "f":
return np.dtype("float32")
if hdf_dtype == "i":
return np.dtype("int32")
if hdf_dtype == "b":
return np.dtype("int8")
if hdf_dtype == "c":
return np.dtype("|S1")
raise ValueError(f"Unsupported dtype {hdf_dtype=} for numpy dtype.")
def _scan_dataset(
dataset: Union[
SupportsGetitemLen[T_DictOrTuple],
SupportsIterLen[T_DictOrTuple],
Mapping[str, SupportsGetitemLen],
],
pre_transform: Optional[Callable[[T], T_DictOrTuple]],
batch_size: int,
num_workers: int,
store_str_as_vlen: bool,
verbose: int,
encoding: str,
skip_scan: bool,
) -> Tuple[
Union[SupportsGetitemLen[T], SupportsIterLen[T_DictOrTuple]],
Callable[[T], Dict[str, Any]],
HDFItemType,
Dict[str, Tuple[int, ...]],
Dict[str, HDFDType],
Dict[str, bool],
Dict[str, Set[np.dtype]],
]:
if pre_transform is None:
pre_transform = nn.Identity()
if isinstance(dataset, Mapping):
item_0 = {k: next(iter(v)) for k, v in dataset.items()}
elif isinstance(dataset, SupportsGetitemLen):
item_0 = dataset[0]
elif isinstance(dataset, SupportsIterLen):
item_0 = next(iter(dataset))
else:
raise TypeError(f"Invalid argument type {type(dataset)}.")
def encode_array(x: np.ndarray) -> Any:
if x.dtype.kind == "U":
x = np.char.encode(x, encoding=encoding)
if x.dtype.kind == "S":
x = x.tolist()
return x
def encode_dict_array(x: Dict[str, np.ndarray]) -> Dict[str, Any]:
return {k: encode_array(tw.to_ndarray(v)) for k, v in x.items()} # type: ignore
to_dict_fn: Callable[[T], Dict[str, Any]]
if isinstance_generic(item_0, Dict[str, Any]):
item_type = "dict"
to_dict_fn = tw.identity # type: ignore
elif isinstance(item_0, tuple):
item_type = "tuple"
to_dict_fn = _tuple_to_dict # type: ignore
else:
msg = f"Invalid item type for {pw.get_fullname(dataset)}. (expected Dict[str, Any] or tuple but found {type(item_0)})"
raise ValueError(msg)
del item_0
encode_dict_fn = tw.identity if store_str_as_vlen else encode_dict_array
dict_pre_transform: Callable[[T], Dict[str, Any]] = Compose(
pre_transform,
to_dict_fn,
encode_dict_fn,
)
if isinstance(dataset, Mapping):
wrapped_dataset = _DictWrapper(dataset) # type: ignore
elif isinstance(dataset, SupportsGetitemLen):
wrapped_dataset = dataset
elif isinstance(dataset, SupportsIterLen):
wrapped_dataset = iter(dataset) # type: ignore
else:
raise TypeError(f"Invalid argument type {type(dataset)}.")
del dataset
loader = DataLoader(
wrapped_dataset, # type: ignore
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
collate_fn=nn.Identity(),
drop_last=False,
pin_memory=False,
)
infos_dict: Dict[str, Set[Tuple[Tuple[int, ...], np.dtype]]] = {}
src_np_dtypes: Dict[str, Set[np.dtype]] = {}
for batch in tqdm(
loader,
desc="Pre compute shapes...",
disable=verbose <= 0 or skip_scan,
):
batch = [pre_transform(item) for item in batch]
batch = [to_dict_fn(item) for item in batch] # type: ignore
for item in batch:
for attr_name, value in item.items():
info = scan_shape_dtypes(value, empty_np=HDF_VOID_DTYPE)
shape = info.shape
np_dtype = info.numpy_dtype
kind = np_dtype.kind
if attr_name in src_np_dtypes:
src_np_dtypes[attr_name].add(np_dtype) # type: ignore
else:
src_np_dtypes[attr_name] = {np_dtype} # type: ignore
value = tw.to_ndarray(value)
if kind == "U" and not store_str_as_vlen:
value = encode_array(value) # type: ignore
# update shape and np_dtype after encoding
info = scan_shape_dtypes(value, empty_np=HDF_VOID_DTYPE)
shape = info.shape
np_dtype = info.numpy_dtype
if attr_name in infos_dict:
infos_dict[attr_name].add((shape, np_dtype)) # type: ignore
else:
infos_dict[attr_name] = {(shape, np_dtype)} # type: ignore
if skip_scan:
break
max_shapes: Dict[str, Tuple[int, ...]] = {}
hdf_dtypes: Dict[str, HDFDType] = {}
all_eq_shapes: Dict[str, bool] = {}
for attr_name, info in infos_dict.items():
shapes = [shape for shape, _ in info]
ndims = list(map(len, shapes))
if not pw.all_eq(ndims):
ndims_set = tuple(set(ndims))
indices = [ndims.index(ndim) for ndim in ndims_set]
msg = f"Invalid ndim for attribute {attr_name}. (found multiple ndims: {ndims_set} at {indices=})"
raise ValueError(msg)
np_dtypes = [np_dtype for _, np_dtype in info]
np_dtype = merge_numpy_dtypes(np_dtypes, empty=HDF_VOID_DTYPE)
hdf_dtype = numpy_dtype_to_hdf_dtype(np_dtype, encoding=encoding)
all_eq_shapes[attr_name] = pw.all_eq(shapes)
max_shapes[attr_name] = tuple(map(max, zip(*shapes)))
hdf_dtypes[attr_name] = hdf_dtype
del infos_dict
if verbose >= 2:
pylog.debug(f"Found max_shapes:\n{max_shapes}")
pylog.debug(f"Found hdf_dtypes:\n{hdf_dtypes}")
pylog.debug(f"Found all_eq_shapes:\n{all_eq_shapes}")
pylog.debug(f"Found src_np_dtypes:\n{src_np_dtypes}")
return (
wrapped_dataset, # type: ignore
dict_pre_transform,
item_type,
max_shapes,
hdf_dtypes,
all_eq_shapes,
src_np_dtypes,
)
class _DictWrapper(Generic[K, V]):
def __init__(self, mapping: Mapping[K, SupportsGetitemLen[V]]) -> None:
super().__init__()
self.mapping = mapping
def __getitem__(self, index: int) -> Dict[K, V]:
return {k: v[index] for k, v in self.mapping.items()}
def __len__(self) -> int:
return len(next(iter(self.mapping.values())))