2017-10-04 2 views
1

J'ai chargé un modèle pré-entraîné (Model 1) en utilisant le code suivant:Comment ne pas réinitialiser le modèle pré-chargé dans Tensorflow?

def load_seq2seq_model(sess): 


    with open(os.path.join(seq2seq_config_dir_path, 'config.pkl'), 'rb') as f: 
     saved_args = pickle.load(f) 

    # Initialize the model with saved args 
    model = Model1(saved_args) 

    #Inititalize Tensorflow saver 
    saver = tf.train.Saver() 

    # Checkpoint 
    ckpt = tf.train.get_checkpoint_state(seq2seq_config_dir_path) 
    print('Loading model: ', ckpt.model_checkpoint_path) 

    # Restore the model at the checkpoint 
    saver.restore(sess, ckpt.model_checkpoint_path) 
    return model 

Maintenant, je veux former un autre modèle (Model 2) à partir de zéro qui prendra la sortie du Model 1. Mais pour cela j'ai besoin de définir une session et charger le modèle pré-formé et initialiser le modèle tf.initialize_all_variables(). Ainsi, le modèle pré-entraîné sera également initialisé.

Quelqu'un peut-il s'il vous plaît dites-moi comment former le Model 2 après avoir obtenu la sortie du modèle pré-formé Model 1 correctement?

Ce que je suis en train est donné ci-dessous -

with tf.Session() as sess: 
    # Initialize all the variables of the graph 
    seq2seq_model = load_seq2seq_model(sess) 
    sess.run(tf.initialize_all_variables()) 
    .... Rest of the training code goes here.... 
+0

Avez-vous essayé d'initialiser AVANT d'importer le modèle 1? – Pop

+0

Je ne connais pas la procédure exacte. J'ai essayé ça. Cela fonctionne aussi. Mais si quelqu'un pouvait me dire la procédure correcte, je peux être certain. –

Répondre

0

Toutes les variables qui sont restaurées à l'aide d'un économiseur ne pas besoin d'être initialisé. Par conséquent, au lieu d'utiliser tf.initialize_all_variables(), vous pouvez utiliser tf.variables_initializer(var_list) pour initialiser uniquement les poids du second réseau.

Pour obtenir une liste de tous les poids du second réseau, vous pouvez créer le réseau Model 2 dans un périmètre variable:

with tf.variable_scope("model2"): 
    model2 = Model2(...) 

Ensuite, utilisez

model_2_variables_list = tf.get_collection(
    tf.GraphKeys.GLOBAL_VARIABLES, 
    scope="model2" 
) 

pour obtenir la liste des variables de la Model 2 réseau. Enfin, vous pouvez créer l'initialisier pour le deuxième réseau:

init2 = tf.variables_initializer(model_2_variables_list) 

with tf.Session() as sess: 
    # Initialize all the variables of the graph 
    seq2seq_model = load_seq2seq_model(sess) 
    sess.run(init2) 
    .... Rest of the training code goes here.... 
+0

NotFoundError arrive bientôt. Semble, Il essaie de charger le model2 aussi. –

+0

Changer l'ordre de chargement résolu le problème, je suppose! –

+0

@AvijitDasgupta Le NotFoundError apparaît-il lors de l'exécution de 'saver.restore (sess, ckpt.model_checkpoint_path)'? Si oui, assurez-vous d'initialiser l'économiseur ('saver = tf.train.Saver()') * avant * d'initialiser le 2ème réseau (donc après 'model = Model1 (...)' mais avant 'model = Model2 (...) '). Sinon, l'économiseur essaiera de charger les deux réseaux, ce qui entraînera NotFoundError. – BlueSun