2016-11-30 1 views
1

je charge le modèle initial pré-entraîné:Comment continuer le train modèle de création de point de contrôle dans tensorflow

if FLAGS.pretrained_model_checkpoint_path: assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path) variables_to_restore = tf.get_collection( slim.variables.VARIABLES_TO_RESTORE) restorer = tf.train.Saver(variables_to_restore) restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path) print('%s: Pre-trained model restored from %s' % (datetime.now(), FLAGS.pretrained_model_checkpoint_path)) Et le modèle formé sur mes données, en utilisant flowers_train.py

Après le train terminée, la perte était d'environ 1,0 et le modèle a été enregistré dans le répertoire spécifié.

Maintenant, je veux poursuivre la formation, Alors, je ReStor modèle:

if FLAGS.checkpoint_dir is not None: # restoring from the checkpoint file ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) tf.train.Saver().restore(sess, ckpt.model_checkpoint_path)

Et continuer le modèle de train, mais la perte lors de la première étape est d'environ 6,5, ce qui signifie en fait, ce modèle n » était pas t initialisé du tout.

Voici le contenu de inception_train.py, qui ont été modifiés de cette inception_train.py

Le premier train que je commence par:

bazel-bin/inception/flowers_train --train_dir="{$TRAIN_DIR}" --data_dir="{$DATA_DIR}" --fine_tune=True --initial_learning_rate=0.001 --input_queue_memory_factor=1 --batch_size=64 --max_steps=100 --pretrained_model_checkpoint_path="/home/tensorflow/inception-v3/model.ckpt-157585"

J'ai essayé de poursuivre la formation par cette commande:

bazel-bin/inception/flowers_train --train_dir="{$TRAIN_NEW_DIR}" --data_dir="{$DATA_DIR}" --fine_tune=False --initial_learning_rate=0.001 --input_queue_memory_factor=1 --batch_size=64 --max_steps=2000 --checkpoint_dir="{$TRAIN_DIR}"

S'il vous plaît, Quelqu'un peut-il m'expliquer, ce que je fais mal lors de l'initialisation du modèle formé?

+0

avez-vous réussi à résoudre ce problème ou de trouver ce qui se passait mal? – Pinocchio

Répondre

0

Je l'ai résolu en utilisant le droit arg_scope comme suit:

with slim.arg_scope(inception_v3.inception_v3_arg_scope()): logits, _ = inception_v3.inception_v3(eval_inputs, num_classes=1001, is_training=False)