#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Helper functions for conversion between classes indices, onehot, names and probabilities for multiclass classification."""
from typing import (
Any,
Callable,
Hashable,
Iterable,
List,
Mapping,
Optional,
Sequence,
TypeVar,
Union,
overload,
)
import torch
from pythonwrench.collections import prod
from pythonwrench.functools import function_alias
from pythonwrench.typing import SupportsIterLen
from pythonwrench.warnings import warn_once
from torch import Tensor
from torch.nn import functional as F
from torchwrench.core.make import DeviceLike, DTypeLike, as_device, as_dtype
from torchwrench.extras.numpy import np
from torchwrench.nn.functional.others import get_ndim, get_shape
from torchwrench.nn.functional.transform import to_item
from torchwrench.types import (
BoolTensor2D,
BoolTensor3D,
LongTensor,
is_number_like,
)
from torchwrench.types._typing import T_TensorOrArray, TensorOrArray
T_Name = TypeVar("T_Name", bound=Hashable, covariant=True)
@overload
def index_to_onehot(
index: Iterable[int],
num_classes: int,
*,
padding_idx: Optional[int] = None,
device: DeviceLike = None,
) -> BoolTensor2D: ...
@overload
def index_to_onehot(
index: Iterable[Iterable[int]],
num_classes: int,
*,
padding_idx: Optional[int] = None,
device: DeviceLike = None,
) -> BoolTensor3D: ...
@overload
def index_to_onehot(
index: Iterable,
num_classes: int,
*,
padding_idx: Optional[int] = None,
device: DeviceLike = None,
dtype: DTypeLike = torch.bool,
) -> Tensor: ...
[docs]
def index_to_onehot(
index: Iterable,
num_classes: int,
*,
padding_idx: Optional[int] = None,
device: DeviceLike = None,
dtype: DTypeLike = torch.bool,
) -> Tensor:
"""Convert indices of labels to onehot boolean encoding for **multiclass** classification.
Args:
indices: List label indices.
Can be a nested list of indices, but it should be convertible to Tensor.
num_classes: Number maximal of unique classes.
padding_idx: Class index to ignore. Output will contains only zeroes for this value. defaults to None.
device: PyTorch device of the output tensor.
dtype: PyTorch DType of the output tensor.
"""
device = as_device(device)
dtype = as_dtype(dtype)
index = torch.as_tensor(index, device=device, dtype=torch.long)
if padding_idx is not None:
mask = index == padding_idx
index = torch.where(mask, num_classes, index)
num_classes += 1
if index.nelement() > 0 and not (0 <= index.min() <= index.max() < num_classes):
msg = f"Invalid argument {index=}. (expected 0 <= min={index.min()} <= max={index.max()} < {num_classes=})"
raise ValueError(msg)
onehot: Tensor = F.one_hot(index, num_classes)
onehot = onehot.to(dtype=dtype)
if padding_idx is not None:
onehot = onehot[..., :-1]
return onehot
[docs]
@function_alias(index_to_onehot)
def one_hot(
tensor: Union[Sequence[int], TensorOrArray, Sequence],
num_classes: int,
*,
padding_idx: Optional[int] = None,
device: DeviceLike = None,
dtype: DTypeLike = torch.bool,
) -> Tensor: ...
[docs]
def index_to_name(
index: Union[Sequence[int], TensorOrArray, Sequence],
idx_to_name: Union[Mapping[int, T_Name], Iterable[T_Name]],
*,
is_number_fn: Callable[[Any], bool] = is_number_like,
) -> List[T_Name]:
"""Convert indices of labels to names using a mapping for **multiclass** classification.
Args:
indices: List of list of label indices.
idx_to_name: Mapping to convert a class index to its name.
is_number_fn: Type guard to check if a value is a scalar number. defaults to `is_number_like`.
"""
index_ndim = get_ndim(index)
if index_ndim > 1 and prod(get_shape(index)[:-1]) == 0:
msg = f"Found 0 elements in {index=} but {index_ndim=} > 1, which means that we will lose information about shape when converting to names."
warn_once(msg)
if not isinstance(idx_to_name, Mapping):
idx_to_name = list(idx_to_name)
def _impl(x) -> Union[T_Name, list]:
if is_number_fn(x):
return idx_to_name[to_item(x)] # type: ignore
elif isinstance(x, Iterable):
return [_impl(xi) for xi in x]
else:
msg = f"Invalid argument {x=}. (not present in idx_to_name and not an iterable type)"
raise ValueError(msg)
name = _impl(index)
return name # type: ignore
[docs]
def onehot_to_index(
onehot: T_TensorOrArray,
*,
padding_idx: Optional[int] = None,
dim: int = -1,
) -> T_TensorOrArray:
"""Convert onehot boolean encoding to indices of labels for **multiclass** classification.
Args:
onehot: Onehot labels encoded as 2D matrix.
padding_idx: Class index placeholder when input contains only zeroes. defaults to None.
dim: Dimension of classes. defaults to -1.
"""
if isinstance(onehot, Tensor):
onehot = onehot.int() # type: ignore
elif isinstance(onehot, np.ndarray):
onehot = onehot.astype(int) # type: ignore
else:
msg = f"Invalid argument type {onehot}. (expected Tensor or ndarray)"
raise TypeError(msg)
index = onehot.argmax(dim)
if padding_idx is not None:
empty = (onehot == False).all(dim) # noqa: E712
if isinstance(empty, Tensor):
index = torch.where(empty, padding_idx, index)
elif isinstance(empty, np.ndarray):
index = np.where(empty, padding_idx, index)
else:
msg = f"Invalid argument type {onehot}. (expected Tensor or ndarray)"
raise TypeError(msg)
return index # type: ignore
[docs]
def onehot_to_name(
onehot: Tensor,
idx_to_name: Union[Mapping[int, T_Name], Iterable[T_Name]],
*,
dim: int = -1,
) -> List[T_Name]:
"""Convert onehot boolean encoding to names using a mapping for **multiclass** classification.
Args:
onehot: Onehot labels encoded as 2D matrix.
idx_to_name: Mapping to convert a class index to its name.
dim: Dimension of classes. defaults to -1.
"""
indices = onehot_to_index(onehot, dim=dim)
names = index_to_name(indices, idx_to_name)
return names
[docs]
def name_to_index(
name: List[T_Name],
idx_to_name: Union[Mapping[int, T_Name], Iterable[T_Name]],
) -> Tensor:
"""Convert names to indices of labels for **multiclass** classification.
Args:
names: List of list of label names.
idx_to_name: Mapping to convert a class index to its name.
"""
if isinstance(idx_to_name, Mapping):
name_to_idx = {name: idx for idx, name in idx_to_name.items()}
else:
name_to_idx = {name: idx for idx, name in enumerate(idx_to_name)}
del idx_to_name
def _impl(x) -> Union[T_Name, list]:
if isinstance(x, Hashable) and x in name_to_idx:
return name_to_idx[x] # type: ignore
elif isinstance(x, Iterable):
return [_impl(xi) for xi in x]
else:
msg = f"Invalid argument {x=}. (not present in name_to_idx and not an iterable type)"
raise ValueError(msg)
index = _impl(name)
index = torch.as_tensor(index, dtype=torch.long)
return index # type: ignore
[docs]
def name_to_onehot(
name: List[T_Name],
idx_to_name: Union[Mapping[int, T_Name], SupportsIterLen[T_Name]],
*,
device: DeviceLike = None,
dtype: DTypeLike = torch.bool,
) -> Tensor:
"""Convert names to onehot boolean encoding for **multiclass** classification.
Args:
names: List of list of label names.
idx_to_name: Mapping to convert a class index to its name.
device: PyTorch device of the output tensor.
dtype: PyTorch DType of the output tensor.
"""
index = name_to_index(name, idx_to_name)
onehot = index_to_onehot(index, len(idx_to_name), device=device, dtype=dtype)
return onehot
[docs]
def probs_to_index(
probs: Tensor,
*,
dim: int = -1,
) -> LongTensor:
"""Convert matrix of probabilities to indices of labels for **multiclass** classification.
Args:
probs: Output probabilities for each classes.
dim: Dimension of classes. defaults to -1.
"""
index = probs.argmax(dim)
return index # type: ignore
[docs]
def probs_to_onehot(
probs: Tensor,
*,
dim: int = -1,
device: DeviceLike = None,
dtype: DTypeLike = torch.bool,
) -> Tensor:
"""Convert matrix of probabilities to onehot boolean encoding for **multiclass** classification.
Args:
probs: Output probabilities for each classes.
dim: Dimension of classes. defaults to -1.
device: PyTorch device of the output tensor.
dtype: PyTorch DType of the output tensor.
"""
if device is None:
device = probs.device
indices = probs_to_index(probs, dim=dim)
onehot = index_to_onehot(indices, probs.shape[-1], device=device, dtype=dtype)
return onehot
[docs]
def probs_to_name(
probs: Tensor,
idx_to_name: Union[Mapping[int, T_Name], Iterable[T_Name]],
*,
dim: int = -1,
) -> List[T_Name]:
"""Convert matrix of probabilities to labels names for **multiclass** classification.
Args:
probs: Output probabilities for each classes.
idx_to_name: Mapping to convert a class index to its name.
dim: Dimension of classes. defaults to -1.
"""
indices = probs_to_index(probs, dim=dim)
names = index_to_name(indices, idx_to_name)
return names