caldera.utils.scatter_indices

caldera.utils.scatter_indices(indices, shape)[source]

Unroll the coo indices using the provided shape.

indices = torch.tensor([
    [0, 1, 2],
    [2, 3, 4],
    [4, 5, 4]
])
shape = (3, 2)
print(scatter_indices(indices, shape))

# tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2,
#  0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2],
# [2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4,
#  2, 3, 4, 2, 3, 4, 2, 3, 4, 2, 3, 4],
# [4, 5, 4, 4, 5, 4, 4, 5, 4, 4, 5, 4, 4, 5, 4, 4, 5, 4, 4, 5, 4, 4, 5, 4,
#  4, 5, 4, 4, 5, 4, 4, 5, 4, 4, 5, 4],
# [0, 0, 1, 1, 2, 2, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 2, 2, 0, 1, 0, 1, 0, 1,
#  0, 0, 1, 1, 2, 2, 0, 1, 0, 1, 0, 1]])
Parameters
  • indices (LongTensor) –

  • shape (Union[Size, Tuple[int, …]]) –

Returns