2017-03-02 2 views
0

J'ai construit un CNN pour la classification d'image. Pendant l'entraînement, j'ai enregistré plusieurs points de contrôle. Les données sont transmises via un feed_dictionary dans le réseau.Tensorflow se plaint de manquer feed_dict pendant la restauration de graphique

Maintenant, je veux restaurer le modèle qui échoue et je ne comprends pas pourquoi. Les lignes importantes de code sont les suivantes:

with tf.Graph().as_default(): 

.... 

if checkpoint_dir is not None: 
    checkpoint_saver = tf.train.Saver() 
    session_hooks.append(tf.train.CheckpointSaverHook(checkpoint_dir, 
                 save_secs=flags.save_interval_secs, 
                 saver=checkpoint_saver)) 
.... 

with tf.train.MonitoredTrainingSession(
     save_summaries_steps=flags.save_summaries_steps, 
     hooks=session_hooks, 
     config=tf.ConfigProto(
      log_device_placement=flags.log_device_placement)) as mon_sess: 

    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) 
    if checkpoint and checkpoint.model_checkpoint_path: 

     # restoring from the checkpoint file 
     checkpoint_saver.restore(mon_sess, checkpoint.model_checkpoint_path) 

     global_step_restore = checkpoint.model_checkpoint_path.split('/')[-1].split('-')[-1] 
     print("Model restored from checkpoint: global_step = %s" % global_step_restore) 

La ligne "checkpoint_saver.restore" renvoie une erreur:

retraçage (appel le plus récent en dernier): fichier « C: \ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ paquets-de-site \ tensorflow \ python \ client \ session.py ", ligne 1022, dans _do_call return fn (* args) Fichier" C: \ Program Files \ Anaconda3 \ envs \ tensorflow \ lib \ package de site \ tensorflow \ python \ client \ session.py ", ligne 1004, dans _run_fn status, run_metadata) Fichier" C: \ Fichiers programme \ Anaconda3 \ envs \ tensorflow \ lib \ contextlib.py ", ligne 6 6, dans exit suivant (self.gen) Fichier "C: \ Program Files \ Anaconda3 \ envs \ tensorflow \ bibliothèque \ packages \ tensorflow \ python \ framework \ errors_impl.py", ligne 469, dans raise_exception_on_not_ok_status pywrap_tensorflow.TF_GetCode (état)) tensorflow.python.framework.errors_impl.InvalidArgumentError: Vous devez nourrir une valeur pour tenseur d'espace réservé '' avec des input_images DTYPE flotteur [[Noeud: input_images = Placeholderdtype = DT_FLOAT, forme = [], _device = "/ job: localhost/réplique: 0/tâche: 0/cpu: 0"]]

Une solution pour résoudre ce problème? Pourquoi ai-je besoin d'un feed_dictionary rempli juste pour restaurer le graphique?

Merci d'avance!

Mise à jour:

Voici le code de la méthode de restauration de l'objet économiseur:

def restore(self, sess, save_path): 
    """Restores previously saved variables. 

    This method runs the ops added by the constructor for restoring variables. 
    It requires a session in which the graph was launched. The variables to 
    restore do not have to have been initialized, as restoring is itself a way 
    to initialize variables. 

    The `save_path` argument is typically a value previously returned from a 
    `save()` call, or a call to `latest_checkpoint()`. 

    Args: 
     sess: A `Session` to use to restore the parameters. 
     save_path: Path where parameters were previously saved. 
    """ 
    if self._is_empty: 
     return 
    sess.run(self.saver_def.restore_op_name, 
      {self.saver_def.filename_tensor_name: save_path}) 

Ce que je ne comprends pas: Pourquoi le graphique exécuté immédiatement? Est-ce que j'utilise la mauvaise méthode? Je veux juste restaurer toutes les vars entraînables.

+0

Nommez toutes les variables et tous les espaces réservés. Est-ce utile? http://stackoverflow.com/questions/34793978/tensorflow-complaining-about-placeholder-after-model-restore – hars

+0

Toutes les variables sont nommées. Le flux d'entrée pour mon tenseur d'image est manquant. Je pense que le problème est dû à l'utilisation combinée de MonitoredTrainingSession et d'un feed_dict. MonitoredTrainingSession est destiné à être utilisé pour des configurations plus importantes et peut-être pas compatible avec les dictionnaires de flux?!?. J'essaye de construire un cas de test pour mon "cadre de formation" personnalisé.Donc, je veux garder l'exemple du modèle léger (utiliser un feed_dict plutôt qu'une file d'attente d'importation) – monchi

Répondre

1

Le problème est causé par un SessionRunHook pour l'enregistrement de processus:

crochet d'origine:

class _LoggerHook(tf.train.SessionRunHook): 
    """Logs loss and runtime.""" 

    def begin(self): 
    self._step = -1 

    def before_run(self, run_context): 
    self._step += 1 
    self._start_time = time.time() 
    return tf.train.SessionRunArgs(loss) # Asks for loss value. 

    def after_run(self, run_context, run_values): 
    duration = time.time() - self._start_time 
    loss_value = run_values.results 
    if self._step % 5 == 0: 
     num_examples_per_step = FLAGS.batch_size 
     examples_per_sec = num_examples_per_step/duration 
     sec_per_batch = float(duration) 

     format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 
        'sec/batch)') 
     print (format_str % (datetime.now(), self._step, loss_value, 
          examples_per_sec, sec_per_batch)) 

crochet modifié:

class _LoggerHook(tf.train.SessionRunHook): 
    """Logs loss and runtime.""" 

    def __init__(self, flags, loss_op): 
     self._flags = flags 
     self._loss_op = loss_op 
     self._start_time = time.time() 

    def begin(self): 
     self._step = 0 

    def before_run(self, run_context): 
     if self._step == 0: 
      run_args = None 
     else: 
      run_args = tf.train.SessionRunArgs(self._loss_op) 

     return run_args 

    def after_run(self, run_context, run_values): 

     if self._step > 0: 
      duration_n_steps = time.time() - self._start_time 
      loss_value = run_values.results 
      if self._step % self._flags.log_every_n_steps == 0: 
       num_examples_per_step = self._flags.batch_size 

       duration = duration_n_steps/self._flags.log_every_n_steps 
       examples_per_sec = num_examples_per_step/duration 
       sec_per_batch = float(duration) 

       format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' 
           'sec/batch)') 
       print(format_str % (datetime.now(), self._step, loss_value, 
            examples_per_sec, sec_per_batch)) 

       self._start_time = time.time() 
     self._step += 1 

Explication:

L'enregistrement est maintenant skiped pour la première itération. Ainsi, le fichier session.run, qui est exécuté par Saver.restore (..), ne nécessite plus de dictionnaire de flux rempli.