2016-08-04 2 views
1

J'ai un problème dans la bibliothèque Machine Learning de Google - Tensorflow. Lorsque je veux initialiser ma session, il me dit que cela doit être une chaîne ou un tenseur. Je n'ai pas remarqué d'erreur.type invalide, doit être une chaîne ou un Tenseur [TensorFlow]

import tensorflow as tf 
from tensorflow.examples.tutorials.mnist import input_data 
mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 

sess = tf.InteractiveSession() 

x = tf.placeholder(tf.float32, shape=[None, 784]) 
y_ = tf.placeholder(tf.float32, shape=[None, 10]) 

W = tf.Variable(tf.zeros([784,10])) 
b = tf.Variable(tf.zeros([10])) 

sess.run(tf.initialize_all_variables) 

y = tf.nn.softmax(tf.matmul(x,W) + b) 

cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) 

for i in range(1000): 
    batch = mnist.train.next_batch(50) 
    train_step.run(feed_dict={x: batch[0], y_: batch[1]}) 

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) 
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels}) 

Ceci est sortie du programme suivant dans le terminal:

(tensorflow) [email protected]:~/tensorflow$ python mnist_e.py 
Extracting MNIST_data/train-images-idx3-ubyte.gz 
Extracting MNIST_data/train-labels-idx1-ubyte.gz 
Extracting MNIST_data/t10k-images-idx3-ubyte.gz 
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz 
Traceback (most recent call last): 
    File "mnist_e.py", line 13, in <module> 
    sess.run(tf.initialize_all_variables) 
    File "/home/juldou-box/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 372, in run 
    run_metadata_ptr) 
    File "/home/juldou-box/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 584, in _run 
    processed_fetches = self._process_fetches(fetches) 
    File "/home/juldou-box/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 540, in _process_fetches 
    % (subfetch, fetch, type(subfetch), str(e))) 
TypeError: Fetch argument <function initialize_all_variables at 0x7fe4ca157c80> of <function initialize_all_variables at 0x7fe4ca157c80> has invalid type <type 'function'>, must be a string or Tensor. (Can not convert a function into a Tensor or Operation.) 
+1

c'est juste une erreur de syntaxe, voir ma réponse – Wolf

Répondre

4

Je pense que vous venez de manquer une paire de parenthèses () après tf.initialize_all_variables;)

Comme le dit python, il est dans la ligne 13 , prendre soin

sess.run(tf.initialize_all_variables)

+0

Oui! Je sais où j'ai fait une erreur, j'ai devant 'sess.run (tf.initialize_all_variables))' et j'ai effacé une parenthèse à la fin. Je vous remercie! –