I understand that there are advantages (en particulier lorsque j'élargis la portée des modèles que je construis et la taille des jeux de données sur lesquels ils travaillent) pour utiliser le nouveau Dataset
de TensorFlow comme idiome pour mon pipeline d'alimentation de données. Cependant, je rencontre des difficultés pour mapper mon code existant feed_dict
sur ce nouveau modèle. Un problème auquel je suis confronté est que je n'arrive pas à déterminer comment le batching et les epochs interagissent, ou comment ils interfèrent avec la journalisation et la validation que je fais souvent. Par exemple, comment quelque chose comme la carte suivante à l'aide Dataset
?Comment puis-je convertir mon code TensorFlow basé sur le flux de base pour utiliser 'Dataset'?
# Load and process data into tensors of dimension (N, C_i) for input and (N, C_o) for output
# where N is the number of examples and C_ is the number of chanels, and the values are activations
train_x, train_y, valid_x, valid_y = load_data(file, [segments], ...)
train_size = len(train_x)
train_stats_feed = {input_activation: train_x, correct_output: train_y, is_train: False}
valid_stats_feed = {input_activation: valid_x, correct_output: valid_y, is_train: False}
with tf.Session(config=tf.ConfigProto(...)) as sess:
sess.run(tf.initialize_all_variables())
# Some analysis; not always done but the code needs to support it
train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), 0)
test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), 0)
test_writer.add_summary(sess.run(gs_summary), 0)
print(log_fmt.format(0, float(sess.run(accuracy, feed_dict=valid_stats_feed)),
float(sess.run(loss, feed_dict=valid_stats_feed))))
for ep in range(epochs):
# Slice the training data into random batches
batch_indices = np.array_split(np.random.permutation(train_size), int(train_size/mb_size))
for mini_batch_indices in batch_indices:
sess.run(train_step, feed_dict={input_activation: train_x[mini_batch_indices],
correct_output: train_y[mini_batch_indices], is_train: True})
gs = int(sess.run(global_step))
if gs % log_steps == 0:
test_writer.add_summary(sess.run(merged, feed_dict=valid_stats_feed), gs)
train_writer.add_summary(sess.run(merged, feed_dict=train_stats_feed), gs)
acc = float(sess.run(accuracy, feed_dict=valid_stats_feed))
sess.run(validation_accuracy.assign(acc))
print(log_fmt.format(gs, acc, float(sess.run(loss, feed_dict=valid_stats_feed))))
print(ep_fmt.format(ep + 2))
test_writer.add_summary(sess.run(gs_summary), ep + 1)
Certaines des définitions moins évidentes pour ce qui précède, le cas échéant:
# Preliminaries
# Some basic preliminaries, the details of which are not important to the question
# Mostly pretty standard; obvious things omitted from MWE for brevity
global_step = tf.Variable(0, trainable=False, name='global_step')
validation_accuracy = tf.Variable(0.0, trainable=False, name='validation_accuracy', dtype=tf.float32)
is_train = tf.placeholder(tf.bool, [], name='is_train')
input_activation = tf.placeholder(tf.float32, shape=[None, in_nodes], name='inputs')
correct_output = tf.placeholder(tf.float32, shape=[None, out_nodes], name='correct_outputs')
network_output = tf.identity(out_activations)
correct_predictions = correct_fn(correct_output, network_output)
accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
error = cost_fn(correct_output, network_output)
loss = error + FLAGS.regularization_weight * sum(tf.nn.l2_loss(w) for w in layer_weights)
train_step = tf.train.MomentumOptimizer(learning_rate, momentum=momentum).minimize(loss, global_step=global_step)
# Logging
train_writer = tf.summary.FileWriter(trainlogfile, tf.get_default_graph())
test_writer = tf.summary.FileWriter(testlogfile, tf.get_default_graph())
gs_summary = tf.summary.scalar('global_step_at_epoch', global_step)
merged = tf.summary.merge_all()
Il est pas clair pour moi comment cela correspond avec ce qui précède (par exemple, comment la formation et la validation entrelacer, où le chargement d'un fichier se produit, etc.) – orome
Cela génère également toutes sortes d'erreurs . – orome
Je pense que vous devez avoir une compréhension de base avant de poser des questions. –