caldera.utils.torch_scatter_group

caldera.utils.torch_scatter_group(x, idx)[source]

Group a tensor by indices. This is equivalent to successive applications of x[torch.where(x == index)] for all provided sorted indices.

Example:

idx = torch.tensor([2, 2, 0, 1, 1, 1, 2])
x = torch.tensor([0, 1, 2, 3, 4, 5, 6])

uniq_sorted_idx, out = scatter_group(x, idx)

# node the idx is sorted
assert torch.all(torch.eq(out[0], torch.tensor([0, 1, 2])))

# where idx == 0
assert torch.all(torch.eq(out[1][0], torch.tensor([2])))

# where idx == 1
assert torch.all(torch.eq(out[1][1], torch.tensor([3, 4, 5])))

# where idx == 2
assert torch.all(torch.eq(out[1][2], torch.tensor([0, 1, 6])))
Parameters
Return type

Tuple[Tensor, List[Tensor]]

Returns

tuple of unique, sorted indices and a list of tensors corresponding to the groups