2017-10-16 3 views
2

J'ai essayé d'utiliser while_loop dans tensorflow, mais lorsque je tente de retourner la cible sortie de appelable en boucle while, il me donne une erreur parce que la forme est augmentée à chaque fois.while_loop erreur dans tensorflow

La sortie doit contenir des valeurs (0 ou 1) basées sur data valeur (tableau d'entrée). Si données la valeur est supérieure à 5 retour sinon retour . La valeur retournée doit être ajoutée en sortie

Voici le code ::

import numpy as np 
import tensorflow as tf 

data = np.random.randint(10, size=(30)) 
data = tf.constant(data, dtype= tf.float32) 

global output 
output= tf.constant([], dtype= tf.float32) 
i = tf.constant(0) 
c = lambda i: tf.less(i, 30) 


def b(i): 
    i= tf.add(i,1) 
    cond= tf.cond(tf.greater(data[i-1], tf.constant(5.)), lambda: tf.constant(1.0), lambda: tf.constant([0.0])) 
    output =tf.expand_dims(cond, axis = i-1) 
    return i, output 

r,out = tf.while_loop(c, b, [i]) 
print(out) 
sess= tf.Session() 
sess.run(out) 

L'erreur ::

r, out = tf.while_loop(c, b, [i])

ValueError: The two structures don't have the same number of elements.

First structure (1 elements): [tf.Tensor 'while/Identity:0' shape=() dtype=int32]

Second structure (2 elements): [tf.Tensor 'while/Add:0' shape=() dtype=int32, tf.Tensor 'while/ExpandDims:0' shape=unknown dtype=float32>]

J'utilise tensorflow-1.1.3 et python -3,5

Comment puis-je changer mon code t o me donne le résultat cible?

EDIT ::

-je modifier le code basé sur @mrry réponse, mais je reste un problème que la sortie est réponse incorrecte la sortie est le nombre sommation

a = tf.ones([10,4]) 
print(a) 
a = tf.reduce_sum(a, axis = 1) 
i =tf.constant(0) 
c = lambda i, _:tf.less(i,10) 

def Smooth(x): 
    return tf.add(x,2) 

summ = tf.constant(0.) 
def b(i,_): 
    global summ 
    summ = tf.add(summ, tf.cast(Smooth(a[i]), tf.float32)) 
    i= tf.add(i,1) 
    return i, summ 

r, smooth_l1 = tf.while_loop(c, b, [i, smooth_l1]) 

print(smooth_l1) 

sess = tf.Session() 
print(sess.run(smooth_l1)) 

l'out mettre est 6.0 (faux).

Répondre

3

La fonction tf.while_loop() exige que les quatre listes ont la même longueur et le même type pour chaque élément suivant:

  • La liste des arguments de la fonction cond (c dans ce cas).
  • La liste des arguments de la fonction body (b dans ce cas). La liste des valeurs renvoyées par la fonction body. La liste des loop_vars représentant les variables de boucle.

Par conséquent, si votre corps de la boucle a deux sorties, vous devez ajouter un argument correspondant à b et c, et un élément correspondant à loop_vars:

c = lambda i, _: tf.less(i, 30) 

def b(i, _): 
    i = tf.add(i, 1) 
    cond = tf.cond(tf.greater(data[i-1], tf.constant(5.)), 
       lambda: tf.constant(1.0), 
       lambda: tf.constant([0.0])) 

    # NOTE: This line fails with a shape error, because the output of `cond` has 
    # a rank of either 0 or 1, but axis may be as large as 28. 
    output = tf.expand_dims(cond, axis=i-1) 
    return i, output 

# NOTE: Use a shapeless `tf.placeholder_with_default()` because the shape 
# of the output will vary from one iteration to the next. 
r, out = tf.while_loop(c, b, [i, tf.placeholder_with_default(0., None)]) 

Comme il est indiqué dans les commentaires, le corps de la boucle (en particulier l'appel à tf.expand_dims()) semble être incorrecte et ce programme ne fonctionnera pas tel quel, mais j'espère que c'est suffisant pour vous aider à démarrer.

+0

merci pour votre réponse, je édite le code pour que la sortie soit le résultat de la sommation des nombres. et je n'obtiens aucune erreur de syntaxe, mais la sortie n'est pas la bonne réponse. Je ne sais pas pourquoi c'est arrivé.Je vais modifier ma question en fonction de votre réponse ici – CCCC

+0

Quelle est la réponse attendue pour votre code? Le 'global summ' et en ignorant l'argument du second corps est suspect: vous voulez probablement passer' 0.' comme valeur initiale pour la deuxième variable de boucle, et utiliser le second argument de corps au lieu de 'global summ' comme accumulateur. – mrry