2017-08-29 2 views
7

Il semble que si un MonitoredTrainingSession faire certaines opérations (exploitation forestière?) Avant le premier appel à .RUN (..), ce qui signifie que lorsque je fais:tf.train.MonitoredTrainingSession et reinitializable iterator du dataset

train_data = reader.traindata() # returns a tf.contrib.data.Dataset 
it = tf.contrib.data.Iterator.from_structure(train_data.output_types, train_data.output_shapes) 
init_train = it.make_initializer(train_data) 
ne = it.get_next() 
ts = tf.train.MonitoredTrainingSession(checkpoint_dir=save_path) 

... no calls to ts.run ... 

ts.run(init_train) 

Cela donne l'erreur:

FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element 

il coutures comme si le MonitoredTrainingSession fait certaines opérations avant de lancer l'opération, je le nourrir, ce qui rend impossible l'utilisation togeather avec un iterator reinitializable de Dataset.

Je suis sûr que je manque quelque chose et je serais ravi d'entendre ce que :-)

+0

Pour répondre en partie à moi-même, j'ai réussi à contourner le problème en utilisant: .ts._coordinated_creator.tf_sess.run (init_train Mais cela ressemble beaucoup à un hack et pas une approche recommandée? –

Répondre

5

On dirait qu'il n'y a pas de solution directe encore tensorflow. Oui, c'est bizarre qu'ils ne supportent pas complètement l'API Dataset.

La raison est que la session surveillée ignore init_op lors du chargement à partir du point de contrôle. Par conséquent, l'initialiseur Iterator doit être une variable locale.

Les suggestions de travail autour de courant est donnée dans ce numéro - https://github.com/tensorflow/tensorflow/issues/12859

scaffold = tf.train.Scaffold(local_init_op=tf.group(tf.local_variables_initializer(), 
            init_train)) 
with tf.train.MonitoredTrainingSession(scaffold=scaffold, 
             checkpoint_dir=checkpoint_dir) as sess: 
    while not sess.should_stop(): 
     sess.run(train_op)