2017-10-11 2 views
0

J'essaie d'extraire tous les poids/biais à partir d'un modèle enregistré output_graph.pb.tf.GraphKeys.TRAINABLE_VARIABLES sur output_graph.pb entraînant la liste vide

Je lis le modèle:

def create_graph(modelFullPath): 
    """Creates a graph from saved GraphDef file and returns a saver.""" 
    # Creates graph from saved graph_def.pb. 
    with tf.gfile.FastGFile(modelFullPath, 'rb') as f: 
     graph_def = tf.GraphDef() 
     graph_def.ParseFromString(f.read()) 
     tf.import_graph_def(graph_def, name='') 

GRAPH_DIR = r'C:\tmp\output_graph.pb' 
create_graph(GRAPH_DIR) 

Et en espérant que je tenté cette serais en mesure d'extraire tous les poids/biaise dans chaque couche.

with tf.Session() as sess: 
    all_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) 
    print (len(all_vars)) 

Cependant, j'obtiens une valeur de 0 en tant que len.

L'objectif final est d'extraire les poids et les biais et de l'enregistrer dans un fichier texte/np.arrays.

Répondre

1

La fonction tf.import_graph_def() ne dispose pas d'informations suffisantes pour reconstruire la collection tf.GraphKeys.TRAINABLE_VARIABLES (pour cela, vous devez utiliser un MetaGraphDef). Toutefois, si output.pb contient un "gelé" GraphDef, alors tous les poids seront stockés dans tf.constant() nœuds dans le graphique. Pour les extraire, vous pouvez faire quelque chose comme ce qui suit:

create_graph(GRAPH_DIR) 

constant_values = {} 

with tf.Session() as sess: 
    constant_ops = [op for op in sess.graph.get_operations() if op.type == "Const"] 
    for constant_op in constant_ops: 
    constant_values[constant_op.name] = sess.run(constant_op.outputs[0]) 

Notez que constant_values contiendra probablement plus de valeurs que seulement les poids, vous devrez peut-être filtrer davantage par op.name ou un autre critère.