torchwrench.nn.functional.indices module¶
-
torchwrench.nn.functional.indices.get_inverse_perm(indices: Tensor, dim: int =
-1) Tensor[source]¶ Return inverse permutation indices. The output will be a tensor of shape (…, N).
- Args:
indices: Original permutation indices as tensor of shape (…, N). dim: Dimension of indices. defaults to -1.
Example 1¶
>>> x = torch.as_tensor([2, 4, 8, 10]) >>> indices = torch.randperm(len(x)) >>> x = x[indices] >>> # x is now shuffled, to get back the original order we need the indices >>> inv_indices = get_inverse_perm(indices) >>> x_reordered = x[inv_indices] >>> x_reordered ... tensor([2, 4, 8, 10])
- torchwrench.nn.functional.indices.get_perm_indices(x1: Tensor, x2: Tensor) LongTensor[source]¶
Find permutation between two vectors t1 and t2 which contains values from 0 to N-1.
Example 1::¶
>>> x1 = torch.as_tensor([0, 1, 2, 4, 3, 6, 5, 7]) >>> x2 = torch.as_tensor([0, 2, 1, 4, 3, 5, 6, 7]) >>> indices = get_perm_indices(x1, x2) >>> torch.equal(x1, x2[indices]) True
- torchwrench.nn.functional.indices.insert_at_indices(x: Tensor, indices: Tensor | list | bool | int | float | complex, values: bool | int | float | complex | Tensor) Tensor1D[source]¶
Insert value(s) in vector at specified indices.
Example 1::¶
>>> x = torch.as_tensor([1, 1, 2, 2, 2, 3]) >>> indices = torch.as_tensor([2, 5]) >>> values = 4 >>> insert_values(x, indices, values) tensor([1, 1, 4, 2, 2, 2, 4, 3])
-
torchwrench.nn.functional.indices.randperm_diff(size: int, generator: Generator | None | 'default' | int =
None, device: device | None | 'default' | 'cuda_if_available' | str | int =None, *, dtype: dtype | None | 'default' | str | DTypeEnum =torch.int64) LongTensor1D[source]¶ This function ensure that every value i cannot be the element at index i. The output will be a tensor of shape (size,).
- Args:
size: The number of indices. Cannot be < 2. generator: The seed or torch.Generator used to generate permutation. defaults to None. device: The PyTorch device of the output indices tensor. defaults to None. dtype: The PyTorch datatype of the output tensor. defaults to torch.long.
Example 1¶
>>> torch.randperm(5) tensor([1, 4, 2, 5, 0]) # 2 is the element of index 2 ! >>> randperm_diff(5) tensor([2, 0, 4, 1, 3])