caldera.utils.reindex_tensor¶
-
caldera.utils.
reindex_tensor
(a: torch.Tensor) → torch.Tensor[source]¶ Reindex a tensor to lowest index. Handles multiple tensors and tensors with many dimensions.
a = torch.tensor([1, 1, 1, 4, 0, 5, 0, 0, 0]) b = reindex(a) print(b) # >> tensor([0, 0, 0, 1, 2, 3, 2, 2, 2])
# multiple tensors with multiple dimensions a = torch.tensor([1, 1, 1, 1, 0, 2, 0, 5, 6]) b = torch.tensor([[6, 5, 1, 70], [0, 80, 5, 6]]) expected1 = torch.tensor([0, 0, 0, 0, 1, 2, 1, 3, 4]) expected2 = torch.tensor([[4, 3, 0, 5], [1, 6, 3, 4]]) c, d = reindex_tensor(a, b) assert torch.all(c == expected1) assert torch.all(d == expected2)