2017-07-29 1 views
0

J'ai le code suivant writtein dans lua.Supprimer l'article de torch.Tensor

Je voudrais obtenir des indices pour N scores maximum de scores et leurs scores correspondants.

Il semble que je vais devoir itérativement supprimer la valeur maximale actuelle de scores et récupérer le maximum à nouveau, mais ne trouve pas un moyen approprié de le faire.

nqs=dataset['question']:size(1); 
scores=torch.Tensor(nqs,noutput); 
qids=torch.LongTensor(nqs); 
for i=1,nqs,batch_size do 
    xlua.progress(i, nqs) 
    r=math.min(i+batch_size-1,nqs); 
    scores[{{i,r},{}}],qids[{{i,r}}]=forward(i,r); 
-- print(scores) 
end 

tmp,pred=torch.max(scores,2); 

Répondre

1

J'espère que je ne l'ai pas mal compris, puisque le code que vous montrer (en particulier la boucle Foor) ne semble pas vraiment pertinent de vouloir que vous voulez faire. Quoi qu'il en soit, voici comment je le ferais.

sr=scores:view(-1,scores:size(1)*scores:size(2)) 
val,id=sr:sort() 
--val is a row vector with the values stored in increasing order 
--id will be the corresponding index in sr 
--now you can slice val and id from the end to find the N values you want, then you can recover the original index in the scores matrix simply with 
col=(index-1)%scores:size(2)+1 
row=math.ceil(index/scores:size(2)) 

espérons que cette aide.

+0

pourriez-vous élaborer sur la partie "tranche val et id de la fin pour trouver les valeurs N"? – ytrewq

+0

Je veux dire tout comme 'val [{{1}, {val: taille (2) -N + 1, val: taille (2)}}]' et la même chose avec 'id', puisque les plus grands éléments' N' sont à la fin du tenseur trié. – Ash

+0

Notez que cela ne résout pas le problème des doublons (je veux dire si "scores" contient, * par exemple *, deux fois la valeur maximale de, disons, "100"), mais je suppose que ce n'est pas un problème car mentionné dans votre question. – Ash