2017-06-15 3 views
1

Je travaille sur un framework de calcul conditionnel utilisant MxNet. Supposons que nous avons N échantillons dans notre minibatch. Je dois d'exécuter ce genre d'opérations dans mon graphe de calcul, en utilisant pseudocode:L'informatique conditionnelle est-elle possible avec Tensorflow?

x = graph.Variable("x") 
y = graph.DoSomeTranformations(x) 
# The following operation generates a Nxk sized matrix, k responses for each sample. 
z = graph.DoDecision(y) 
for i in range(k): 
    argmax_sample_indices_for_i = graph.ArgMaxIndices(z, i) 
    y_selected_samples = graph.TakeSelectedSample(y, argmax_sample_indices_for_i) 
    result = graph.DoSomeTransformations(y_selected_samples) 

Ce que je veux atteindre est la suivante: après l'obtention de y, j'applique une fonction de décision (cela peut être un D à k Couche entièrement connectée, où D est la dimension de données) et obtenir k activations pour chaque échantillon dans ma minibatch de taille N. Ensuite, je veux diviser dynamiquement mon minibatch en k parties différentes (k peut être 2, 3, un petit nombre), basé sur l'index de la colonne de l'activation maximale pour chaque échantillon. Ma hypothétique fonction "graph.ArgMaxIndices" fait que, étant donné z, une matrice de taille Nxk, et i, la fonction recherche les indices d'échantillon qui donnent les activations maximales le long de la colonne i et retourne leurs indices. (Notez que je cherche n'importe quelle série ou combinaison de fonctions qui donne le résultat équivalent à "graph.ArgMaxIndices", pas une seule fonction, spécifiquement). Puis enfin, pour chaque i, je sélectionne les échantillons avec des activations maximales et leur applique des transformations spécifiques. Actuellement, à ma connaissance, MxNet ne supporte pas ce genre de calculs conditionnels dans leurs réseaux symboliques. Par conséquent, je construis des graphiques symboliques séparés après chaque décision et ai dû coder ma comptabilité séparée - structures de graphe conditionnel pour chaque division de minibatch, qui produit 1) Code très complexe et encombrant pour maintenir et développer 2) Performance de course dégradée pendant la formation et l'évaluation.

Ma question est, puis-je faire ce qui précède à l'aide des opérateurs symboliques de tensorflow? Permet-il de sélectionner des sous-ensembles de la minibatch, en fonction d'un critère? Y at-il une fonction ou une série de fonctions qui est équivalente à la "graphique.ArgMaxIndices" dans le pseudo-code ci-dessus? (Étant donné une matrice Nxk et un index de colonne i, renvoie les indices de lignes, qui ont l'activation maximale à la colonne k).

Répondre

2

Vous pouvez le faire en tensorflow.

La meilleure façon que je vois est d'utiliser un masque et tf.boolean_maskk fois, avec le masque i -ème étant donné par tf.equal(i, tf.argmax(z, axis=-1))

x = graph.Variable("x") 
y = graph.DoSomeTranformations(x) 
# The following operation generates a Nxk sized matrix, k responses for each sample. 
z = graph.DoDecision(y) 
max_indices = tf.argmax(z, axis=-1) 
for i in range(k): 
    argmax_sample_indices_for_i = tf.equal(i, max_indices) 
    y_selected_samples = tf.boolean_mask(y, mask=argmax_sample_indices_for_i) 
    result = graph.DoSomeTransformations(y_selected_samples) 
+0

Lorsque nous utilisons boolean_mask, ne tensorflow sauter le calcul de ces échantillons dans la prochaine couches? Je veux dire, ils ne devraient pas être inclus dans des opérations par lots comme la normalisation par lots après cette division k-aire. –

+1

Oui, le résultat de tf.boolean_mask est fondamentalement la liste des éléments qui ont passé le test, le reste n'est pas mis à 0 mais entièrement retiré du tenseur. Si votre entrée y est de forme '(batch_size, d1, d2, ..., dn)', 'y_selected_samples' sera de forme' (number_of_samples_selected, d1, ..., dn) '. Donc, pour tous les calculs qui arrivent à 'y_selected_samples', c'est vraiment comme si les échantillons non-sélectionnés n'existaient pas. – gdelab

+1

C'était une réponse très utile. Il semble que j'ai besoin d'une transition vers Tensorflow dès que possible ... –