torchwrench.utils.data.collate module

class torchwrench.utils.data.collate.AdvancedCollateDict(pad_values: dict[str, Any] | None = None, include_keys: str | Pattern | Iterable[str | Pattern] | None = None, exclude_keys: str | Pattern | Iterable[str | Pattern] | None = None, key_mode: 'intersect' | 'same' | 'union' = 'same')[source]

Bases: object

Advanced collate object for DataLoader.

Merge lists in dicts into a single dict of lists. Audio will be padded if a fill pad_values is given in __init__.

Example
>>> collate = AdvancedCollateDict({"audio": 0.0})
>>> loader = DataLoader(..., collate_fn=collate)
>>> next(iter(loader))
... {"audio": tensor([[...]]), ...}
class torchwrench.utils.data.collate.CollateDict(key_mode: 'intersect' | 'same' | 'union' = 'same')[source]

Bases: object

Collate object for DataLoader.

Merge lists in dicts into a single dict of lists. No padding is applied.