2016-05-10 1 views
5

Si je tente d'importer une enregistrée définition TensorFlow graphique avecComment puis-je obtenir « la import_graph_def » pour revenir tenseurs de tensorflow

import tensorflow as tf 
from tensorflow.python.platform import gfile 

with gfile.FastGFile(FLAGS.model_save_dir.format(log_id) + '/graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read()) 
x, y, y_ = tf.import_graph_def(graph_def, 
           return_elements=['data/inputs', 
               'output/network_activation', 
               'data/correct_outputs'], 
           name='') 

les valeurs renvoyées ne sont pas Tensor s comme prévu, mais autre chose: à la place, par exemple , d'obtenir x comme

Tensor("data/inputs:0", shape=(?, 784), dtype=float32) 

Je reçois

name: "data/inputs_1" 
op: "Placeholder" 
attr { 
    key: "dtype" 
    value { 
    type: DT_FLOAT 
    } 
} 
attr { 
    key: "shape" 
    value { 
    shape { 
    } 
    } 
} 

C'est, au lieu d'obtenir le tenseur attendu x je reçois, x.op. Cela me confond parce que le documentation semble dire que je devrais obtenir un Tensor (bien qu'il y ait un tas de ou s là qui le rendent difficile à comprendre).

Comment puis-je obtenir tf.import_graph_def pour renvoyer des Tensor s spécifiques que je peux ensuite utiliser (par exemple pour alimenter le modèle chargé ou exécuter des analyses)?

+0

La deuxième ligne de code doit être 'partir tensorflow.python.platform importer gfile'. – tobe

Répondre

3

Les noms 'data/inputs', 'output/network_activation' et 'data/correct_outputs' sont en fait des noms d'opération. Pour obtenir tf.import_graph_def() revenir tf.Tensor objets, vous devez ajouter l'index de sortie au nom de l'opération, ce qui est généralement ':0' pour les opérations de sortie unique:

x, y, y_ = tf.import_graph_def(graph_def, 
           return_elements=['data/inputs:0', 
               'output/network_activation:0', 
               'data/correct_outputs:0'], 
           name='')