Vous pouvez essayer d'utiliser le mode de restauration de slim - slim.assign_from_checkpoint
Il y a des documents connexes dans les sources minces: https://github.com/tensorflow/tensorflow/blob/129665119ea60640f7ed921f36db9b5c23455224/tensorflow/contrib/slim/python/slim/learning.py
partie correspondante:
*************************************************
* Fine-Tuning Part of a model from a checkpoint *
*************************************************
Rather than initializing all of the weights of a given model, we sometimes
only want to restore some of the weights from a checkpoint. To do this, one
need only filter those variables to initialize as follows:
...
# Create the train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)
checkpoint_path = '/path/to/old_model_checkpoint'
# Specify the variables to restore via a list of inclusion or exclusion
# patterns:
variables_to_restore = slim.get_variables_to_restore(
include=["conv"], exclude=["fc8", "fc9])
# or
variables_to_restore = slim.get_variables_to_restore(exclude=["conv"])
init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
checkpoint_path, variables_to_restore)
# Create an initial assignment function.
def InitAssignFn(sess):
sess.run(init_assign_op, init_feed_dict)
# Run training.
slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn)
Mise à jour
J'ai essayé les éléments suivants:
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
images = tf.placeholder(tf.float32, [None, 224, 224, 3])
predictions = nets.vgg.vgg_16(images)
print [v.name for v in slim.get_variables_to_restore(exclude=['fc8']) ]
Et a obtenu cette sortie (raccourci):
[u'vgg_16/conv1/conv1_1/weights:0',
u'vgg_16/conv1/conv1_1/biases:0',
…
u'vgg_16/fc6/weights:0',
u'vgg_16/fc6/biases:0',
u'vgg_16/fc7/weights:0',
u'vgg_16/fc7/biases:0',
u'vgg_16/fc8/weights:0',
u'vgg_16/fc8/biases:0']
il semble donc que vous devez préfixer portée avec vgg_16
:
print [v.name for v in slim.get_variables_to_restore(exclude=['vgg_16/fc8']) ]
DONNE (raccourcies):
[u'vgg_16/conv1/conv1_1/weights:0',
u'vgg_16/conv1/conv1_1/biases:0',
…
u'vgg_16/fc6/weights:0',
u'vgg_16/fc6/biases:0',
u'vgg_16/fc7/weights:0',
u'vgg_16/fc7/biases:0']
Update 2
Exemple complet qui s'exécute sans erreurs (sur mon système).
import tensorflow as tf
import tensorflow.contrib.slim as slim
import tensorflow.contrib.slim.nets as nets
s = tf.Session(config=tf.ConfigProto(gpu_options={'allow_growth':True}))
images = tf.placeholder(tf.float32, [None, 224, 224, 3])
predictions = nets.vgg.vgg_16(images, 200)
variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_16/fc8'])
init_assign_op, init_feed_dict = slim.assign_from_checkpoint('./vgg16.ckpt', variables_to_restore)
s.run(init_assign_op, init_feed_dict)
Dans l'exemple ci-dessus vgg16.ckpt
est un point de contrôle sauvé par tf.train.Saver
pour 1000 cours VGG16 modèle.
L'utilisation de ce point de contrôle avec toutes les variables du modèle 200 classes (y compris fc8) donne l'erreur suivante:
init_assign_op, init_feed_dict = slim.assign_from_checkpoint('./vgg16.ckpt', slim.get_variables_to_restore())
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
1 init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
----> 2 './vgg16.ckpt', slim.get_variables_to_restore())
/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/framework/python/ops/variables.pyc in assign_from_checkpoint(model_path, var_list)
527 assign_ops.append(var.assign(placeholder_value))
528
--> 529 feed_dict[placeholder_value] = var_value.reshape(var.get_shape())
530
531 assign_op = control_flow_ops.group(*assign_ops)
ValueError: total size of new array must be unchanged
J'ai essayé cette méthode déjà. Cela me donne toujours la même erreur: 'InvalidArgumentError (voir ci-dessus pour la traçabilité): Assign nécessite des formes des deux tenseurs pour correspondre.lhs shape = [1,1,4096,200] forme rhs = [1,1,4096,1000] \t [[Noeud: save_1/Assign_32 = Attribuer [T = DT_FLOAT, _class = ["loc: @ vgg_16/fc8/poids "], use_locking = true, validate_shape = true, _device ="/travail: hôte local/réplique: 0/tâche: 0/gpu: 0 "] (vgg_16/fc8/poids, save_1/restore_slice_32/_3)]]' – user1050648
S'il vous plaît voir une réponse mise à jour –
Salut, merci. Cela semble faire l'affaire, tant que le 'num_classes' est compatible avec VGG16. Si vous initialisez une instance de 'vgg_16' en utilisant dis 200, au lieu de 1000, classes, l'erreur apparaît toujours. – user1050648