caldera.utils.reindex_tensor

caldera.utils.reindex_tensor(a: torch.Tensor)torch.Tensor[source]

Reindex a tensor to lowest index. Handles multiple tensors and tensors with many dimensions.

a = torch.tensor([1, 1, 1, 4, 0, 5, 0, 0, 0])
b = reindex(a)
print(b)
# >> tensor([0, 0, 0, 1, 2, 3, 2, 2, 2])
# multiple tensors with multiple dimensions
a = torch.tensor([1, 1, 1, 1, 0, 2, 0, 5, 6])
b = torch.tensor([[6, 5, 1, 70], [0, 80, 5, 6]])
expected1 = torch.tensor([0, 0, 0, 0, 1, 2, 1, 3, 4])
expected2 = torch.tensor([[4, 3, 0, 5], [1, 6, 3, 4]])
c, d = reindex_tensor(a, b)
assert torch.all(c == expected1)
assert torch.all(d == expected2)
Parameters
  • a (Tensor) – tensor to reindex

  • tensors (Tuple[Tensor, …]) –

Return type

Tuple[Tensor, …]

Returns

new reindexed tensor