2016-11-22 1 views
0

Je suis le blog wildml sur la classification de texte en utilisant tensorflow. J'ai changé le code pour sauvegarder le graphique def comme suit:Erreur Tensorflow lors de la restauration du graphique def du fichier .pb

tf.train.write_graph(sess.graph_def,'./DeepLearn/model/','train.pb', as_text=False) 

plus tard dans un fichier séparé que je restaure le graphique comme suit:

with tf.gfile.FastGFile(os.path.join('./DeepLearn/model/','train.pb'), 'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
    _ = tf.import_graph_def(graph_def, name='') 
with tf.Session() as sess: 
    t = sess.graph.get_tensor_by_name('embedding/W:0') 
    sess.run(t) 

Quand je tente d'exécuter le tenseur et obtenir sa valeur , j'obtiens l'erreur suivante:

tensorflow.python.framework.errors.FailedPreconditionError: Attempting to use uninitialized value embedding/W 

Quelle pourrait être la raison possible de cette erreur. Le tenseur doit avoir été initialisé car je le restaure à partir du graphe sauvegardé.

+0

'sess.run (tf.initialize_all_variables())'? – sygi

+0

Mais, je charge le tenseur du graphique précédemment sauvegardé, donc je ne pense pas que je dois l'initialiser en utilisant cette déclaration. – Nitin

+1

Vous devez toujours initialiser les variables, car la lecture de graphdef ne restaure que le graphique lui-même, pas les valeurs des variables. Si vous souhaitez restaurer les valeurs des variables, vous devez charger à partir d'un point de contrôle. –

Répondre

0

Merci Alexandre! Oui, je dois charger à la fois le graphique (à partir du fichier .pb) et les poids (à partir du fichier de points de contrôle.). Utilisé l'exemple de code suivant (tiré d'un blog) et cela a fonctionné pour moi.

with tf.Session() as persisted_sess: 
    print("load graph") 
    with gfile.FastGFile("/tmp/load/test.pb",'rb') as f: 
     graph_def = tf.GraphDef() 
     graph_def.ParseFromString(f.read()) 
     persisted_sess.graph.as_default() 
     tf.import_graph_def(graph_def, name='') 
    persisted_result = persisted_sess.graph.get_tensor_by_name("saved_result:0") 
    tf.add_to_collection(tf.GraphKeys.VARIABLES,persisted_result) 
    try: 
     saver = tf.train.Saver(tf.all_variables()) 
    except:pass 
     print("load data") 
    saver.restore(persisted_sess, "checkpoint.data") # now OK 
    print(persisted_result.eval()) 
    print("DONE")