2016-06-19 1 views
0

J'ai un exemple de réseau récurrent simple, avec les variables tf.Saver et weight, bias et state en cours d'enregistrement.Accès aux valeurs d'une variable Tensorflow restaurée

Lorsque l'exemple est exécuté sans option, il initialise le vecteur d'état pour contenir des zéros, mais je veux passer une option load_model et d'utiliser les dernières valeurs du vecteur d'état comme un aliment pour l'invocation session.run .

Toute la documentation que je vois insiste sur le fait que l'on doit appeler session.run pour récupérer des valeurs stockées à partir de variables, mais dans ce cas je veux récupérer les valeurs pour pouvoir initialiser la variable d'état. Ai-je besoin de faire un graphique séparé juste pour récupérer les valeurs d'initialisation?

code exemple ci-dessous:

import tensorflow as tf 
import math 
import numpy as np 

INPUTS = 10 
HIDDEN_1 = 2 
BATCH_SIZE = 3 

def batch_vm2(m, x): 
    [input_size, output_size] = m.get_shape().as_list() 

    input_shape = tf.shape(x) 
    batch_rank = input_shape.get_shape()[0].value - 1 
    batch_shape = input_shape[:batch_rank] 
    output_shape = tf.concat(0, [batch_shape, [output_size]]) 

    x = tf.reshape(x, [-1, input_size]) 
    y = tf.matmul(x, m) 

    y = tf.reshape(y, output_shape) 

    return y 

def get_weight_and_biases(): 
    with tf.variable_scope(network_scope, reuse = True) as scope: 
     weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS)))) 
     biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0)) 
    return weights, biases 

def get_saver(): 
    with tf.variable_scope('h1') as scope: 
     weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS)))) 
     biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0)) 
     state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False) 
     saver = tf.train.Saver([weights, biases, state]) 
    return saver, scope 


def load(sess, saver, checkpoint_dir = './'): 

     print("loading a session") 
     ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
     if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
     else: 
      raise Exception("no checkpoint found") 
     return 

iteration = None 

def iterate_state(prev_state_tuple, input): 
    with tf.variable_scope(network_scope, reuse = True) as scope: 
     weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS)))) 
     biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0)) 
     state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False) 
     print("input: ",input.get_shape()) 
     matmuladd = batch_vm2(weights, input) + biases 
     matmulpri = tf.Print(matmuladd,[matmuladd, weights], message=" malmul -> %i, weights " % iteration) 
     print("prev state: ",prev_state_tuple.get_shape()) 
     unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple) 
     prev_state = 0.99* unpacked_state 
     prev_state = tf.Print(prev_state, [unpacked_state, matmuladd], message=" -> prevstate, matmulpri ") 
     state = state.assign(prev_state + 0.01*matmulpri) 
     #output = tf.nn.relu(state) 
     output = tf.nn.tanh(state) 
     state = tf.Print(state, [state], message=" state -> ") 
     output = tf.Print(output, [output], message=" output -> ") 
     print(" state: ", state.get_shape()) 
     print(" output: ", output.get_shape()) 
     concat_result = tf.concat(0,[state, output]) 
     print (" concat return: ", concat_result.get_shape()) 
     return concat_result 

def data_iter(): 
    while True: 
     idxs = np.random.rand(BATCH_SIZE, INPUTS) 
     yield idxs 

flags = tf.app.flags 
FLAGS = flags.FLAGS 
flags.DEFINE_boolean('load_model', False, 'If true, uses model files ' 
        'to restore.') 


network_scope = None 

with tf.Graph().as_default(): 
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS)) 
    iteration = -1 
    saver, network_scope = get_saver() 
    initial_state = tf.placeholder(tf.float32, shape=(HIDDEN_1)) 
    initial_out = tf.zeros([HIDDEN_1], 
          name='initial_out') 
    concat_tensor = tf.concat(0,[initial_state, initial_out]) 
    print(" init state: ",initial_state.get_shape()) 
    print(" init out: ",initial_out.get_shape()) 
    print(" concat: ",concat_tensor.get_shape()) 
    scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan') 
    print ("scanout shape: ", scanout.get_shape()) 
    state, output = tf.split(1,2,scanout, name='split_scan_output') 
    print(" end state: ",state.get_shape()) 
    print(" end out: ",output.get_shape()) 


    sess = tf.Session() 
    # Run the Op to initialize the variables. 

    sess.run(tf.initialize_all_variables()) 
    tf.train.write_graph(sess.graph_def, './tenIrisSave/logsd','graph.pbtxt') 
    tf_weight, tf_bias = get_weight_and_biases() 
    tf.histogram_summary('weights', tf_weight) 
    tf.histogram_summary('bias', tf_bias) 
    tf.histogram_summary('state', state) 
    tf.histogram_summary('out', output) 
    summary_op = tf.merge_all_summaries() 
    summary_writer = tf.train.SummaryWriter('./tenIrisSave/summary',sess.graph_def) 
    if FLAGS.load_model: 
     load(sess, saver) 
     # HOW DO I LOAD restored state values?????? 
     #st = state[BATCH_SIZE - 1,:] 
     #st = sess.run([state], feed_dict={}) 
     print("LOADED last state vec: ", st) 
    else: 
     st = np.array([0.0 , 0.0]) 
    iter_ = data_iter() 
    for i in xrange(0, 1): 
     print ("iteration: ",i) 
     iteration = i 
     input_data = iter_.next() 
     out,st,so,summary_str = sess.run([output,state,scanout,summary_op], feed_dict={ inputs: input_data, initial_state: st }) 
     saver.save(sess, 'my-model', global_step=1+i) 
     summary_writer.add_summary(summary_str, i) 
     summary_writer.flush() 
     print("input vec: ", input_data) 
     print("state vec: ", st) 
     st = st[-1] 
     print("last state vec: ", st) 
     print("output vec: ", out) 
     print(" end state (runtime): ",st.shape) 
     print(" end out (runtime): ",out.shape) 
     print(" end scanout (runtime): ",so.shape) 

note à lignes 124-126 les lignes commentées pour les moyens que j'ai essayé d'initialiser les valeurs de dictionnaire d'alimentation. Aucun d'entre eux ne travaille.

Répondre

1

Vous avez deux espaces réservés:

  • inputs
  • initial_state

D'après ce que je comprends que vous voulez soit (selon FLAGS.load_model):

  1. Utiliser une première état plein de zéros

    • est simple, vous nourrissez juste un tableau numpy plein de zéros
  2. Utilisez la dernière ligne state, qui est un Tensor dans le graphique en fonction de les deux espaces réservés.

    • vous voulez juste charger la valeur à partir d'un point de contrôle précédent

Cette analyse fait, ma première hypothèse est que l'erreur vient juste du fait que vous utilisez un autre tenseur nommé state dans la ligne:

state, output = tf.split(1,2,scanout, name='split_scan_output') 

Alors tensorflow va essayer de récupérer ce state, qui dépend des deux espaces réservés, au lieu de récupérer la valeur de la variable state que vous voulez. Renommez simplement le second et cela pourrait fonctionner.

Vous pouvez essayer:

if FLAGS.load_model: 
    load(sess, saver) 
    with tf.variable_scope('h1', reuse=True) 
     state_saved = tf.get_variable('state') 
    st = sess.run(state_saved) 
+0

1) oui c'est ce qui se fait sur la clause 'else' à la ligne 129. Si vous exécutez le script sans option' load_model' vous verrez ce qu'il est Faire. Il n'y a qu'une variable 'state', il n'y a pas d'autre variable dans la même portée.Notez la ligne 75: 'concat_result = tf.concat (0, [état, sortie])' dans 'iterate_state'. – diffeomorphism

+0

Pouvez-vous essayer le code fourni lors de mon dernier montage? Ce que je voulais dire, c'est que vous redéfinissez 'state', donc quand vous l'exécutez avec' sess.run (state) ', vous obtenez le mauvais tenseur –

+0

cool, ça marche! – diffeomorphism