J'ai un problème dans la bibliothèque Machine Learning de Google - Tensorflow. Lorsque je veux initialiser ma session, il me dit que cela doit être une chaîne ou un tenseur. Je n'ai pas remarqué d'erreur.type invalide, doit être une chaîne ou un Tenseur [TensorFlow]
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
sess.run(tf.initialize_all_variables)
y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
for i in range(1000):
batch = mnist.train.next_batch(50)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print accuracy.eval(feed_dict={x: mnist.test.images, y_: mnist.test.labels})
Ceci est sortie du programme suivant dans le terminal:
(tensorflow) [email protected]:~/tensorflow$ python mnist_e.py
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Traceback (most recent call last):
File "mnist_e.py", line 13, in <module>
sess.run(tf.initialize_all_variables)
File "/home/juldou-box/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 372, in run
run_metadata_ptr)
File "/home/juldou-box/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 584, in _run
processed_fetches = self._process_fetches(fetches)
File "/home/juldou-box/tensorflow/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 540, in _process_fetches
% (subfetch, fetch, type(subfetch), str(e)))
TypeError: Fetch argument <function initialize_all_variables at 0x7fe4ca157c80> of <function initialize_all_variables at 0x7fe4ca157c80> has invalid type <type 'function'>, must be a string or Tensor. (Can not convert a function into a Tensor or Operation.)
c'est juste une erreur de syntaxe, voir ma réponse – Wolf