2016-10-31 2 views
5

(1) J'essaie d'affiner un réseau VGG-16 en utilisant TFSlim en chargeant des poids pré-chargés dans tous les calques sauf le calque fc8. J'ai réalisé cela en utilisant la fonction TF-SLiM comme suit:TFSlim - Problèmes de chargement du point de contrôle sauvegardé pour VGG16

import tensorflow as tf 
import tensorflow.contrib.slim as slim 
import tensorflow.contrib.slim.nets as nets 

vgg = nets.vgg 

# Specify where the Model, trained on ImageNet, was saved. 
model_path = 'path/to/vgg_16.ckpt' 

# Specify where the new model will live: 
log_dir = 'path/to/log/' 

images = tf.placeholder(tf.float32, [None, 224, 224, 3]) 
predictions = vgg.vgg_16(images) 

variables_to_restore = slim.get_variables_to_restore(exclude=['fc8']) 
restorer = tf.train.Saver(variables_to_restore) 




init = tf.initialize_all_variables() 

with tf.Session() as sess: 
    sess.run(init) 
    restorer.restore(sess,model_path) 
    print "model restored" 

Cela fonctionne très bien tant que je ne change pas le num_classes pour le modèle VGG16. Ce que je voudrais faire est de changer le num_classes de 1000 à 200. J'avais l'impression que si je faisais cette modification en définissant une nouvelle classe vgg16-modified qui remplace le fc8 pour produire 200 sorties, (avec un variables_to_restore = slim.get_variables_to_restore(exclude=['fc8']) que tout sera .? bien et dandy Cependant, tensorflow se plaint d'un décalage de dimensions:

InvalidArgumentError (see above for traceback): Assign requires shapes of both tensors to match. lhs shape= [1,1,4096,200] rhs shape= [1,1,4096,1000] 

alors, comment fait-on vraiment faire cela la documentation pour TFSlim est vraiment inégale et il y a plusieurs versions dispersées sur Github - donc pas Obtenir beaucoup d'aide là

Répondre

7

Vous pouvez essayer d'utiliser le mode de restauration de slim - slim.assign_from_checkpoint

Il y a des documents connexes dans les sources minces: https://github.com/tensorflow/tensorflow/blob/129665119ea60640f7ed921f36db9b5c23455224/tensorflow/contrib/slim/python/slim/learning.py

partie correspondante:

************************************************* 
* Fine-Tuning Part of a model from a checkpoint * 
************************************************* 
Rather than initializing all of the weights of a given model, we sometimes 
only want to restore some of the weights from a checkpoint. To do this, one 
need only filter those variables to initialize as follows: 
    ... 
    # Create the train_op 
    train_op = slim.learning.create_train_op(total_loss, optimizer) 
    checkpoint_path = '/path/to/old_model_checkpoint' 
    # Specify the variables to restore via a list of inclusion or exclusion 
    # patterns: 
    variables_to_restore = slim.get_variables_to_restore(
     include=["conv"], exclude=["fc8", "fc9]) 
    # or 
    variables_to_restore = slim.get_variables_to_restore(exclude=["conv"]) 
    init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
     checkpoint_path, variables_to_restore) 
    # Create an initial assignment function. 
    def InitAssignFn(sess): 
     sess.run(init_assign_op, init_feed_dict) 
    # Run training. 
    slim.learning.train(train_op, my_log_dir, init_fn=InitAssignFn) 

Mise à jour

J'ai essayé les éléments suivants:

import tensorflow as tf 
import tensorflow.contrib.slim as slim 
import tensorflow.contrib.slim.nets as nets 
images = tf.placeholder(tf.float32, [None, 224, 224, 3]) 
predictions = nets.vgg.vgg_16(images) 
print [v.name for v in slim.get_variables_to_restore(exclude=['fc8']) ] 

Et a obtenu cette sortie (raccourci):

[u'vgg_16/conv1/conv1_1/weights:0', 
u'vgg_16/conv1/conv1_1/biases:0', 
… 
u'vgg_16/fc6/weights:0', 
u'vgg_16/fc6/biases:0', 
u'vgg_16/fc7/weights:0', 
u'vgg_16/fc7/biases:0', 
u'vgg_16/fc8/weights:0', 
u'vgg_16/fc8/biases:0'] 

il semble donc que vous devez préfixer portée avec vgg_16:

print [v.name for v in slim.get_variables_to_restore(exclude=['vgg_16/fc8']) ] 

DONNE (raccourcies):

[u'vgg_16/conv1/conv1_1/weights:0', 
u'vgg_16/conv1/conv1_1/biases:0', 
… 
u'vgg_16/fc6/weights:0', 
u'vgg_16/fc6/biases:0', 
u'vgg_16/fc7/weights:0', 
u'vgg_16/fc7/biases:0'] 

Update 2

Exemple complet qui s'exécute sans erreurs (sur mon système).

import tensorflow as tf 
import tensorflow.contrib.slim as slim 
import tensorflow.contrib.slim.nets as nets 

s = tf.Session(config=tf.ConfigProto(gpu_options={'allow_growth':True})) 

images = tf.placeholder(tf.float32, [None, 224, 224, 3]) 
predictions = nets.vgg.vgg_16(images, 200) 
variables_to_restore = slim.get_variables_to_restore(exclude=['vgg_16/fc8']) 
init_assign_op, init_feed_dict = slim.assign_from_checkpoint('./vgg16.ckpt', variables_to_restore) 
s.run(init_assign_op, init_feed_dict) 

Dans l'exemple ci-dessus vgg16.ckpt est un point de contrôle sauvé par tf.train.Saver pour 1000 cours VGG16 modèle.

L'utilisation de ce point de contrôle avec toutes les variables du modèle 200 classes (y compris fc8) donne l'erreur suivante:

init_assign_op, init_feed_dict = slim.assign_from_checkpoint('./vgg16.ckpt', slim.get_variables_to_restore()) 
--------------------------------------------------------------------------- 
ValueError        Traceback (most recent call last) 
     1 init_assign_op, init_feed_dict = slim.assign_from_checkpoint(
----> 2  './vgg16.ckpt', slim.get_variables_to_restore()) 

/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/framework/python/ops/variables.pyc in assign_from_checkpoint(model_path, var_list) 
    527  assign_ops.append(var.assign(placeholder_value)) 
    528 
--> 529  feed_dict[placeholder_value] = var_value.reshape(var.get_shape()) 
    530 
    531 assign_op = control_flow_ops.group(*assign_ops) 

ValueError: total size of new array must be unchanged 
+0

J'ai essayé cette méthode déjà. Cela me donne toujours la même erreur: 'InvalidArgumentError (voir ci-dessus pour la traçabilité): Assign nécessite des formes des deux tenseurs pour correspondre.lhs shape = [1,1,4096,200] forme rhs = [1,1,4096,1000] \t [[Noeud: save_1/Assign_32 = Attribuer [T = DT_FLOAT, _class = ["loc: @ vgg_16/fc8/poids "], use_locking = true, validate_shape = true, _device ="/travail: hôte local/réplique: 0/tâche: 0/gpu: 0 "] (vgg_16/fc8/poids, save_1/restore_slice_32/_3)]]' – user1050648

+0

S'il vous plaît voir une réponse mise à jour –

+0

Salut, merci. Cela semble faire l'affaire, tant que le 'num_classes' est compatible avec VGG16. Si vous initialisez une instance de 'vgg_16' en utilisant dis 200, au lieu de 1000, classes, l'erreur apparaît toujours. – user1050648