caldera.utils.torch_scatter_group¶
-
caldera.utils.
torch_scatter_group
(x, idx)[source]¶ Group a tensor by indices. This is equivalent to successive applications of x[torch.where(x == index)] for all provided sorted indices.
Example:
idx = torch.tensor([2, 2, 0, 1, 1, 1, 2]) x = torch.tensor([0, 1, 2, 3, 4, 5, 6]) uniq_sorted_idx, out = scatter_group(x, idx) # node the idx is sorted assert torch.all(torch.eq(out[0], torch.tensor([0, 1, 2]))) # where idx == 0 assert torch.all(torch.eq(out[1][0], torch.tensor([2]))) # where idx == 1 assert torch.all(torch.eq(out[1][1], torch.tensor([3, 4, 5]))) # where idx == 2 assert torch.all(torch.eq(out[1][2], torch.tensor([0, 1, 6])))