2017-05-17 1 views
7

J'ai le code suivantThéano commutateur sage ligne efficace

output = T.switch(cond, a, b) 

cond est un (N,1) bool Tensor, tandis que a et b sont (N, M) tenseurs numériques avec M étant assez grand. La condition fonctionne de manière rangée. Je peux facilement faire fonctionner le commutateur en exécutant T.repeat() sur cond, mais c'est assez lent. Y at-il un moyen que je peux faire efficacement les bools dans cond décider si a ou b doit être retourné?

Répondre

3

Y a-t-il un moyen de faire en sorte que les bools dans cond déterminent si a ou b devrait être retourné?

Oui, vous pouvez le faire

cond * a + (1-cond) * b 

cond sera diffusée à (N, M) forme.

Cela devrait être proche de la limite théorique, qui est la bande passante mémoire: cette opération doit lire environ N*M éléments et écrire N*M.

Au lieu de cela, nous lisons 2*N*M, mais supprimez la logique conditionnelle.

(je n'ai pas Théano devant moi, donc je ne suis pas sûr que ce soit plus rapide que T.switch, mais il devrait être à peu près aussi bon qu'il obtient. De plus, je vais essayer casting cond au même dtype comme a et b)


Si vous souhaitez mettre à jour a en place, vous pouvez le faire en utilisant T.set_subtensor:

a = np.random.uniform(size=(N, M)).astype(np.float32) 
b = np.random.uniform(size=(N, M)).astype(np.float32) 

a = theano.shared(a) 
b = theano.shared(b) 

c = T.vector() # mostly 0, presumably (1-cond) 

nz = T.nonzero(c) 

s = T.set_subtensor(a[nz], b[nz]) 
fn = theano.function([c], [], updates=[(a, s)]) 

... 

fn(1-cond) 

il peut ou ne peut pas être plus rapide t han la première approche, en fonction de N, M et d'autres facteurs.

+0

Merci pour la réponse, je vais essayer! Pensées intéressantes sur la limite théorique. Je suppose que je pourrais éviter les grandes lectures et écritures en exploitant le fait que le plus souvent 'a' serait la bonne valeur à retourner et c'est bien pour la méthode de modifier' a' directement. Supposons que seulement 5% du temps 'b' devrait être retourné pour une ligne donnée, ne pourrait-on pas obtenir de meilleures performances en modifiant' a' directement seulement sur les lignes nécessitant une modification? – pir

+0

@pir Optimisez-vous pour le processeur ou le GPU? Quels sont les N, N et dtype typiques? – MaxB

+0

@pir aussi, est-ce que cette partie d'un NN ou quelque chose qui a besoin du gradient? – MaxB