2017-10-16 1 views
2

Je l'exécution de code tutoriel de text classificationtensorflow: dépanner tf.estimator.inputs.numpy_input_fn fonction

je peux exécuter les scripts et cela a fonctionné, mais quand j'ai essayé de l'exécuter ligne par ligne en essayant de comprendre ce que chaque pas fait, je me suis un peu confus à cette étape:

test_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={WORDS_FEATURE: x_test}, 
    y=y_test, 
    num_epochs=1, 
    shuffle=False) 
classifier.train(input_fn=train_input_fn, steps=100) 

Je sais que sur le plan conceptuel est train_input_fn intégrer des données à la fonction de formation, mais comment je peux appeler manuellement cette fn pour inspecter ce qui est en elle?

J'ai retracé le code et a découvert la fonction train_input_fn flux de données aux 2 variables suivantes:

features 
Out[15]: {'words': <tf.Tensor 'random_shuffle_queue_DequeueMany:1' shape=(560, 10) dtype=int64>} 

labels 
Out[16]: <tf.Tensor 'random_shuffle_queue_DequeueMany:2' shape=(560,) dtype=int32> 

Quand j'ai essayé d'évaluer les caractéristiques variables en faisant une sess.run (caractéristiques), mon le terminal semble être bloqué et cesse de répondre.

Quelle est la bonne façon d'inspecter le contenu de variables comme celles-ci?

Merci!

Répondre

1

Basé sur le numpy_input_fn documentation et le comportement (suspendu) J'imagine que l'implémentation sous-jacente dépend d'un coureur de file d'attente. La suspension se produit lorsque les coureurs de file d'attente ne sont pas démarrés. Essayez de modifier l'exécution de votre session de script quelque chose comme ce qui suit, basé sur this guide:

with tf.Session() as sess: 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 
    try: 
     for step in xrange(1000000): 
      if coord.should_stop(): 
       break 
      features_data = sess.run(features) 
      print(features_data) 

    except Exception, e: 
     # Report exceptions to the coordinator. 
     coord.request_stop(e) 
    finally: 
     # Terminate as usual. It is safe to call `coord.request_stop()` twice. 
     coord.request_stop() 
     coord.join(threads) 

Sinon, je vous encourage à vérifier l'interface tf.data.Dataset (possible tf.contrib.data.Dataset dans tensorflow 1.3 ou avant). Vous pouvez obtenir des tenseurs d'entrée/étiquettes similaires sans utiliser de files d'attente avec Dataset.from_tensor_slices. La création est un peu plus complexe, mais l'interface est beaucoup plus flexible et l'implémentation n'utilise pas de file d'attente, ce qui signifie que la session est beaucoup plus simple.

import tensorflow as tf 
import numpy as np 

x_data = np.random.random((100000, 2)) 
y_data = np.random.random((100000,)) 

batch_size = 2 
buff = 100 


def input_fn(): 
    # possible tf.contrib.data.Dataset.from... in tf 1.3 or earlier 
    dataset = tf.data.Dataset.from_tensor_slices((x_data, y_data)) 
    dataset = dataset.repeat().shuffle(buff).batch(batch_size) 
    x, y = dataset.make_one_shot_iterator().get_next() 
    return x, y 


x, y = input_fn() 
with tf.Session() as sess: 
    print(sess.run([x, y])) 
+0

Merci DomJack. Le code a fonctionné avec quelques changements mineurs sur Python 3. Je ne pensais pas que ce soit si compliqué d'imprimer la valeur d'un tenseur dans Tensorflow. – Allen

+0

Ceci est dû à l'implémentation du programme d'exécution de file d'attente. J'ai modifié ma réponse pour inclure un exemple de jeu de données que vous pourriez trouver utile. Les jeux de données sont relativement nouveaux mais une fois que vous avez passé un peu de passe-partout, je les ai trouvés incroyablement simples, puissants et rapides. – DomJack

+0

Merci @DomJack. Je vais certainement vérifier. J'ai trouvé très contre-intuitif de déboguer le code Tensorflow parfois. – Allen