torchwrench.utils.data.collate module¶
-
class torchwrench.utils.data.collate.AdvancedCollateDict(pad_values: dict[str, Any] | None =
None, include_keys: str | Pattern | Iterable[str | Pattern] | None =None, exclude_keys: str | Pattern | Iterable[str | Pattern] | None =None, key_mode: 'intersect' | 'same' | 'union' ='same')[source]¶ Bases:
objectAdvanced collate object for
DataLoader.Merge lists in dicts into a single dict of lists. Audio will be padded if a fill pad_values is given in __init__.
Example¶>>> collate = AdvancedCollateDict({"audio": 0.0}) >>> loader = DataLoader(..., collate_fn=collate) >>> next(iter(loader)) ... {"audio": tensor([[...]]), ...}
-
class torchwrench.utils.data.collate.CollateDict(key_mode: 'intersect' | 'same' | 'union' =
'same')[source]¶ Bases:
objectCollate object for
DataLoader.Merge lists in dicts into a single dict of lists. No padding is applied.