J'ai un exemple de réseau récurrent simple, avec les variables tf.Saver
et weight
, bias
et state
en cours d'enregistrement.Accès aux valeurs d'une variable Tensorflow restaurée
Lorsque l'exemple est exécuté sans option, il initialise le vecteur d'état pour contenir des zéros, mais je veux passer une option load_model
et d'utiliser les dernières valeurs du vecteur d'état comme un aliment pour l'invocation session.run
.
Toute la documentation que je vois insiste sur le fait que l'on doit appeler session.run
pour récupérer des valeurs stockées à partir de variables, mais dans ce cas je veux récupérer les valeurs pour pouvoir initialiser la variable d'état. Ai-je besoin de faire un graphique séparé juste pour récupérer les valeurs d'initialisation?
code exemple ci-dessous:
import tensorflow as tf
import math
import numpy as np
INPUTS = 10
HIDDEN_1 = 2
BATCH_SIZE = 3
def batch_vm2(m, x):
[input_size, output_size] = m.get_shape().as_list()
input_shape = tf.shape(x)
batch_rank = input_shape.get_shape()[0].value - 1
batch_shape = input_shape[:batch_rank]
output_shape = tf.concat(0, [batch_shape, [output_size]])
x = tf.reshape(x, [-1, input_size])
y = tf.matmul(x, m)
y = tf.reshape(y, output_shape)
return y
def get_weight_and_biases():
with tf.variable_scope(network_scope, reuse = True) as scope:
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
return weights, biases
def get_saver():
with tf.variable_scope('h1') as scope:
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
saver = tf.train.Saver([weights, biases, state])
return saver, scope
def load(sess, saver, checkpoint_dir = './'):
print("loading a session")
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
raise Exception("no checkpoint found")
return
iteration = None
def iterate_state(prev_state_tuple, input):
with tf.variable_scope(network_scope, reuse = True) as scope:
weights = tf.get_variable('W', shape=[INPUTS, HIDDEN_1], initializer=tf.truncated_normal_initializer(stddev=1.0/math.sqrt(float(INPUTS))))
biases = tf.get_variable('bias', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0))
state = tf.get_variable('state', shape=[HIDDEN_1], initializer=tf.constant_initializer(0.0), trainable=False)
print("input: ",input.get_shape())
matmuladd = batch_vm2(weights, input) + biases
matmulpri = tf.Print(matmuladd,[matmuladd, weights], message=" malmul -> %i, weights " % iteration)
print("prev state: ",prev_state_tuple.get_shape())
unpacked_state, unpacked_out = tf.split(0,2,prev_state_tuple)
prev_state = 0.99* unpacked_state
prev_state = tf.Print(prev_state, [unpacked_state, matmuladd], message=" -> prevstate, matmulpri ")
state = state.assign(prev_state + 0.01*matmulpri)
#output = tf.nn.relu(state)
output = tf.nn.tanh(state)
state = tf.Print(state, [state], message=" state -> ")
output = tf.Print(output, [output], message=" output -> ")
print(" state: ", state.get_shape())
print(" output: ", output.get_shape())
concat_result = tf.concat(0,[state, output])
print (" concat return: ", concat_result.get_shape())
return concat_result
def data_iter():
while True:
idxs = np.random.rand(BATCH_SIZE, INPUTS)
yield idxs
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_boolean('load_model', False, 'If true, uses model files '
'to restore.')
network_scope = None
with tf.Graph().as_default():
inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
iteration = -1
saver, network_scope = get_saver()
initial_state = tf.placeholder(tf.float32, shape=(HIDDEN_1))
initial_out = tf.zeros([HIDDEN_1],
name='initial_out')
concat_tensor = tf.concat(0,[initial_state, initial_out])
print(" init state: ",initial_state.get_shape())
print(" init out: ",initial_out.get_shape())
print(" concat: ",concat_tensor.get_shape())
scanout = tf.scan(iterate_state, inputs, initializer=concat_tensor, name='state_scan')
print ("scanout shape: ", scanout.get_shape())
state, output = tf.split(1,2,scanout, name='split_scan_output')
print(" end state: ",state.get_shape())
print(" end out: ",output.get_shape())
sess = tf.Session()
# Run the Op to initialize the variables.
sess.run(tf.initialize_all_variables())
tf.train.write_graph(sess.graph_def, './tenIrisSave/logsd','graph.pbtxt')
tf_weight, tf_bias = get_weight_and_biases()
tf.histogram_summary('weights', tf_weight)
tf.histogram_summary('bias', tf_bias)
tf.histogram_summary('state', state)
tf.histogram_summary('out', output)
summary_op = tf.merge_all_summaries()
summary_writer = tf.train.SummaryWriter('./tenIrisSave/summary',sess.graph_def)
if FLAGS.load_model:
load(sess, saver)
# HOW DO I LOAD restored state values??????
#st = state[BATCH_SIZE - 1,:]
#st = sess.run([state], feed_dict={})
print("LOADED last state vec: ", st)
else:
st = np.array([0.0 , 0.0])
iter_ = data_iter()
for i in xrange(0, 1):
print ("iteration: ",i)
iteration = i
input_data = iter_.next()
out,st,so,summary_str = sess.run([output,state,scanout,summary_op], feed_dict={ inputs: input_data, initial_state: st })
saver.save(sess, 'my-model', global_step=1+i)
summary_writer.add_summary(summary_str, i)
summary_writer.flush()
print("input vec: ", input_data)
print("state vec: ", st)
st = st[-1]
print("last state vec: ", st)
print("output vec: ", out)
print(" end state (runtime): ",st.shape)
print(" end out (runtime): ",out.shape)
print(" end scanout (runtime): ",so.shape)
note à lignes 124-126 les lignes commentées pour les moyens que j'ai essayé d'initialiser les valeurs de dictionnaire d'alimentation. Aucun d'entre eux ne travaille.
1) oui c'est ce qui se fait sur la clause 'else' à la ligne 129. Si vous exécutez le script sans option' load_model' vous verrez ce qu'il est Faire. Il n'y a qu'une variable 'state', il n'y a pas d'autre variable dans la même portée.Notez la ligne 75: 'concat_result = tf.concat (0, [état, sortie])' dans 'iterate_state'. – diffeomorphism
Pouvez-vous essayer le code fourni lors de mon dernier montage? Ce que je voulais dire, c'est que vous redéfinissez 'state', donc quand vous l'exécutez avec' sess.run (state) ', vous obtenez le mauvais tenseur –
cool, ça marche! – diffeomorphism