torchwrench.nn.functional.indices module

torchwrench.nn.functional.indices.get_inverse_perm(indices: Tensor, dim: int = -1) Tensor[source]

Return inverse permutation indices. The output will be a tensor of shape (…, N).

Args:

indices: Original permutation indices as tensor of shape (…, N). dim: Dimension of indices. defaults to -1.

Example 1

>>> x = torch.as_tensor([2, 4, 8, 10])
>>> indices = torch.randperm(len(x))
>>> x = x[indices]
>>> # x is now shuffled, to get back the original order we need the indices
>>> inv_indices = get_inverse_perm(indices)
>>> x_reordered = x[inv_indices]
>>> x_reordered
... tensor([2, 4, 8, 10])
torchwrench.nn.functional.indices.get_perm_indices(x1: Tensor, x2: Tensor) LongTensor[source]

Find permutation between two vectors t1 and t2 which contains values from 0 to N-1.

Example 1::

>>> x1 = torch.as_tensor([0, 1, 2, 4, 3, 6, 5, 7])
>>> x2 = torch.as_tensor([0, 2, 1, 4, 3, 5, 6, 7])
>>> indices = get_perm_indices(x1, x2)
>>> torch.equal(x1, x2[indices])
True
torchwrench.nn.functional.indices.insert_at_indices(x: Tensor, indices: Tensor | list | bool | int | float | complex, values: bool | int | float | complex | Tensor) Tensor1D[source]

Insert value(s) in vector at specified indices.

Example 1::

>>> x = torch.as_tensor([1, 1, 2, 2, 2, 3])
>>> indices = torch.as_tensor([2, 5])
>>> values = 4
>>> insert_values(x, indices, values)
tensor([1, 1, 4, 2, 2, 2, 4, 3])
torchwrench.nn.functional.indices.randperm_diff(size: int, generator: Generator | None | 'default' | int = None, device: device | None | 'default' | 'cuda_if_available' | str | int = None, *, dtype: dtype | None | 'default' | str | DTypeEnum = torch.int64) LongTensor1D[source]

This function ensure that every value i cannot be the element at index i. The output will be a tensor of shape (size,).

Args:

size: The number of indices. Cannot be < 2. generator: The seed or torch.Generator used to generate permutation. defaults to None. device: The PyTorch device of the output indices tensor. defaults to None. dtype: The PyTorch datatype of the output tensor. defaults to torch.long.

Example 1

>>> torch.randperm(5)
tensor([1, 4, 2, 5, 0])  # 2 is the element of index 2 !
>>> randperm_diff(5)
tensor([2, 0, 4, 1, 3])
torchwrench.nn.functional.indices.remove_at_indices(x: Tensor, indices: Tensor | list | bool | int | float | complex) Tensor1D[source]

Remove value(s) in vector at specified indices.