2017-09-29 9 views
0

Je viens de parcourir Stack Overflow et d'autres forums, mais je n'ai rien trouvé d'utile pour mon problème. Mais il semble lié à this question.Tensorflow ne peut pas alimenter la valeur de forme (1,) pour Tensor 'x: 0' qui a la forme '(?, 128)'

J'ai actuellement un modèle de Tensorflow (128 entrées et 11 sorties) que j'ai enregistré, suivant le tutoriel MNIST de Tensorflow.

Il semblait avoir du succès et j'ai maintenant un modèle dans ce dossier avec les 3 fichiers (.meta, .ckpt.data et .index). Cependant, je veux restaurer et l'utiliser pour les prévisions:

#encoding[0] => numpy ndarray (128,) # anyway a list with only one entry 
#unknowndata = np.array(encoding[0])[None] 
unknowndata = np.expand_dims(encoding[0], axis=0) 
print(unknowndata.shape) # Output (1, 128) 

# Restore pre-trained tf model 
with tf.Session() as sess: 
    #saver.restore(sess, "models/model_1.ckpt") 
    saver = tf.train.import_meta_graph('models/model_1.ckpt.meta') 
    saver.restore(sess,tf.train.latest_checkpoint('models/./')) 
    y = tf.get_collection('final tensor') # tf.nn.softmax(tf.matmul(y2, W3) + b3) 
    X = tf.get_collection('input') # tf.placeholder(tf.float32, [None, 128]) 

    # W1 = tf.get_collection('vars')[0] 
    # b1 = tf.get_collection('vars')[1] 
    # W2 = tf.get_collection('vars')[2] 
    # b2 = tf.get_collection('vars')[3] 
    # W3 = tf.get_collection('vars')[4] 
    # b3 = tf.get_collection('vars')[5] 

    # y1 = tf.nn.relu(tf.matmul(X, W1) + b1) 
    # y2 = tf.nn.relu(tf.matmul(y1, W2) + b2) 
    # yLog = tf.matmul(y2, W3) + b3 
    # y = tf.nn.softmax(yLog) 

    prediction = tf.argmax(y, 1) 

    print(sess.run(prediction, feed_dict={i: d for i,d in zip(X, unknowndata.T)})) 
    # also had sess.run(prediction, feed_dict={X: unknowndata.T}) and also not transposed, still errors 

# Output: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] # one should be 1 obviously with a specific percentage 

Là, je ne diffusez que des problèmes ...

ValueError: Cannot feed value of shape (1,) for Tensor 'x:0', which has shape '(?, 128)' Altough I print the shape of the 'unknowndata' and it matches the (1, 128). I also tried it with

sess.run(prediction, feed_dict={X: unknownData})) # with transposed etc. but nothing worked for me there I got the other error 

TypeError: unhashable type: 'list'

Je veux seulement quelques prédictions de ce beau Modèle entraîné par Tensorflow.

+1

Bonjour et bienvenue sur StackOverflow! Les gens ici sont occupés à aider autant qu'ils le peuvent, donc tout le monde n'aura pas le temps de lire un mur de code. Je vous recommande [edit] (https://stackoverflow.com/posts/46496213/edit) votre message d'inclure un [** Exemple minimal, complet et vérifiable **] (https://stackoverflow.com/help/mcve) de votre code. Cela vous aidera à obtenir des réponses qui vous aideront. – LW001

+1

qu'en est-il de 'sess.run (prediction, feed_dict = {X [0]: unknownData}))'? – lejlot

+0

c'est ce que j'ai essayé et cela a fonctionné mais là il ne prend qu'un seul échantillon des 128 données et pas tous d'accord? La sortie me donnerait aussi onze fois zéro (il n'y a personne au moins un devrait être là) – lenlehm

Répondre

0

Le tenseur prediction est obtenu par un argmax sur y. Au lieu de renvoyer uniquement prediction, vous pouvez ajouter y à votre flux de sortie lorsque vous exécutez sess.run.

output_feed = [prediction, y] 
preds, probs = sess.run(output_feed, print(sess.run(prediction, feed_dict={i: d for i,d in zip(X, unknowndata.T)})) 

preds auront les prédictions du modèle et probs auront les scores de probabilité.

1

J'ai compris le problème! D'abord j'ai besoin de restaurer toutes les valeurs (les poids et les biais et matmul eux séparément). Deuxièmement je dois créer la même entrée que dans le modèle formé, dans mon cas:

X = tf.placeholder(tf.float32, [None, 128]) 

puis juste appeler la prédiction:

sess.run(prediction, feed_dict={X: unknownData}) 

Mais je ne reçois pas de distribution en pourcentage, mais je Attendez-vous à ce que grâce à la fonction SoftMax. Est-ce que quelqu'un sait comment y accéder?