TransWikia.com

Pytorch: all-but-one summation?

Stack Overflow Asked by sterne on December 16, 2021

I am in the process of moving some message passing code from Numpy to Pytorch. I am unsure of how to do this single step of a much larger algorithm. Below is the simplest explanation of the step.

Given the following:

index = [[2,0,1], [2,2,0]]
value = [[0.1, 1.2, 2.3], [3.4, 4.5, 5.6]]

I would like to compute the "all-but-one" sum of messages to each index. Here is a graphical representation:
messages sent to nodes

The answer I am looking for is:

ans = [[7.9, 5.6, 0], [4.6, 3.5, 1.2]]

The explanation is that, for example, index[0][0] points at node 2. The sum of all messages at node 2 is 0.1+3.4+4.5=8. However we want to exclude the message we are considering (value[0][0]=0.1), so we obtain ans[0][0]=7.9. If only a single index points at a node then the answer is 0 (eg. node 1 with ans[0][2]).

I would be happy with computing the sums for each node, and then subtracting out the individual messages. I am aware that this can lead to loss of significance, but I believe that my use case is very well-behaved (eg. no floating point infinities).

I can also provide the minimal numpy code, but even the minimal example is a bit long. I have looked at pytorch’s scatter and gather commands, but I don’t think that they’re appropriate here.

3 Answers

After working with pytorch for a while longer, and writing some code for other situations, I realized that there is a much more efficient solution that I hadn't considered. So I am pasting it here for anyone else who comes after me:

import torch

index = [[2, 0, 1], [2, 2, 0]]
value = [[0.1, 1.2, 2.3], [3.4, 4.5, 5.6]]

# convert to tensor
index_tensor = torch.tensor(index)
value_tensor = torch.tensor(value)

num_nodes = 3
totals = torch.zeros(num_nodes)
totals = totals.index_add_(0, index_tensor.flatten(), value_tensor.flatten())
result = totals[index_tensor] - value_tensor
print(result)

It uses much less memory than the scatter_add solution given by Jodag. It avoids all the for loops given in the other solutions. Hooray for much faster code!

Answered by sterne on December 16, 2021

Here is an approach which only requires a loop to iterate over the number of nodes, i.e. the maximum value in index plus 1 (in this case 3). It is not clear by the question if this value is always at most the number of columns in your input tensors so we just compute it explicitly.

This approach also allows you to compute gradients w.r.t. value_tensor and should be easily translatable to numpy as well.

import torch

index = [[2, 0, 1], [2, 2, 0]]
value = [[0.1, 1.2, 2.3], [3.4, 4.5, 5.6]]

# convert to tensor
index_tensor = torch.tensor(index)
value_tensor = torch.tensor(value)

# optionally require gradients for value_tensor
# value_tensor.requires_grad_(True)

# perhaps this is always index_tensor.shape[1]? not clear from question
num_nodes = index_tensor.max() + 1

# compute total sum for each node
total_sum = torch.empty(num_nodes, device=value_tensor.device)
for n in range(num_nodes):
    total_sum[n] = value_tensor[index_tensor == n].sum()

# compute all-but-one
result = total_sum[index_tensor] - value_tensor

print(result)

which results in

tensor([[7.9000, 5.6000, 0.0000],
        [4.6000, 3.5000, 1.2000]])

Alternative using Tensor.scatter_add_

Here's an interesting alternative to the above which avoids the loop entirely by using the built in scatter-add operation. It requires more memory and is probably more difficult to follow than the previous solution but will likely be faster in some cases.

import torch

index = [[2, 0, 1], [2, 2, 0]]
value = [[0.1, 1.2, 2.3], [3.4, 4.5, 5.6]]

# convert to tensor
index_tensor = torch.tensor(index)
value_tensor = torch.tensor(value)

# optionally require gradients for value_tensor
# value_tensor.requires_grad_(True)

num_rows = index_tensor.shape[0]
# perhaps this is always index_tensor.shape[1]? but doesn't need to be
num_nodes = index_tensor.max() + 1

# scatter_add will be applied to this tensor of zeros
scattered = torch.zeros((num_rows, num_nodes), device=value_tensor.device)

# apply scatter_add_
scattered.scatter_add_(1, index_tensor, value_tensor)

# which is equivalent to:
# num_cols = index_tensor.shape[1]
# for r in range(num_rows):
#     for c in range(num_cols):
#         scattered[r][index_tensor[r, c]] += value_tensor[r, c]

# sum the rows to get the total sum
total_sum = scattered.sum(dim=0)

# compute all-but-one
result = total_sum[index_tensor] - value_tensor

print(result)

Answered by jodag on December 16, 2021

Not sure if this is an improvement over your current solution, but you can do something like this:

import torch

index = [[2,0,1], [2,2,0]]
value = [[0.1, 1.2, 2.3], [3.4, 4.5, 5.6]]

# convert to tensor
index_tensor = torch.tensor(index)
value_tensor = torch.tensor(value)

# initialize a tensor to store the result
ans = torch.empty_like(value_tensor)

# sum
for i, v_row in enumerate(value):
    for j, v in enumerate(v_row):
        ans[i, j] = value_tensor[index_tensor == index_tensor[i, j]].sum() - v

print(ans)
# tensor([[7.9000, 5.6000, 0.0000],
#         [4.6000, 3.5000, 1.2000]])

# if you need a list, just use ans.tolist()

Pretty sure there is a way to remove at least one of these for loops. I'll update the answer if I can figure it out.

Answered by Berriel on December 16, 2021

Add your own answers!

Ask a Question

Get help from others!

© 2024 TransWikia.com. All rights reserved. Sites we Love: PCI Database, UKBizDB, Menu Kuliner, Sharing RPP