2017-10-18 6 views
0

J'essaie d'utiliser la fonction gather dans pytorch mais je ne comprends pas le rôle du paramètre dim.Impact de la dimension du paramètre dans la fonction de regroupement

code:

t = torch.Tensor([[1,2],[3,4]]) 
print(torch.gather(t, 0, torch.LongTensor([[0,0],[1,0]]))) 

Sortie:

1 2 
3 2 
[torch.FloatTensor of size 2x2] 

Dimension réglé sur 1:

print(torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]]))) 

sortie devient:

1 1 
4 3 
[torch.FloatTensor of size 2x2] 

Comment fonctionne la fonction gather?

Répondre

2

Je me suis rendu compte comment fonctionne la fonction de collecte.

t = torch.Tensor([[1,2],[3,4]]) 
index = torch.LongTensor([[0,0],[1,0]]) 
torch.gather(t, 0, index) 

Depuis le dimension est égal à zéro, de sorte que la sortie sera:

| t[index[0, 0] 0] t[index[0, 1] 1] | 
| t[index[1, 0] 0] t[index[1, 1] 1] | 

Si le dimension est réglé sur un, la sortie devient:

| t[0 index[0, 0]] t[0 index[0, 1]] | 
| t[1 index[1, 0]] t[1 index[1, 1]] | 

Ainsi, la formule est :

For a 3-D tensor the output is specified by: 

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0 
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1 
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2 

Référence: http://pytorch.org/docs/master/torch.html?highlight=gather#torch.gather