J'ai des doutes sur la façon dont tf.train.string_input_producer
fonctionne. Supposons donc que j'ai fourni filename_list en tant que paramètre d'entrée au string_input_producer
. Ensuite, selon la documentation https://www.tensorflow.org/programmers_guide/reading_data, cela va créer un FIFOQueue
, où je peux définir le numéro d'époque, mélanger les noms de fichiers et ainsi de suite. Par conséquent, dans mon cas, j'ai 4 noms de fichiers ("db1.tfrecords", "db2.tfrecords" ...). Et j'ai utilisé tf.train.batch
pour alimenter le lot réseau d'images. De plus, chaque nom de fichier/base de données contient un ensemble d'images pour une personne. La deuxième base de données est pour la deuxième personne et ainsi de suite. Jusqu'à présent, j'ai le code suivant:Détermination du numéro d'époque avec tf.train.string_input_producer en tensorflow
tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"),
(common + "P21_db.tfrecords")]
filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue')
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
features={
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string),
'annotation_raw': tf.FixedLenFeature([], tf.string)
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
height = tf.cast(features['height'], tf.int32)
width = tf.cast(features['width'], tf.int32)
image = tf.reshape(image, [height, width, 3])
annotation = tf.cast(features['annotation_raw'], tf.string)
min_after_dequeue = 100
num_threads = 4
capacity = min_after_dequeue + num_threads * batch_size
label_batch, images_batch = tf.train.batch([annotation, image],
shapes=[[], [112, 112, 3]],
batch_size=batch_size,
capacity=capacity,
num_threads=num_threads)
Enfin, en essayant de voir l'image reconstruite à la sortie de la autoencoder, je me suis le premier des images de la 1ère base de données, puis je commencer à voir les images de la deuxième base de données et ainsi de suite.
Ma question: Comment puis-je savoir si je suis dans la même époque? Et si je suis dans la bonne époque, comment puis-je fusionner un lot d'images de tous les noms de fichiers que j'ai?
Enfin, j'ai essayé d'imprimer la valeur de l'époque en évaluant la variable locale dans le Session
comme suit:
epoch_var = tf.local_variables()[0]
Puis:
with tf.Session() as sess:
print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y.
Toute aide est très appréciée !!
Vous pouvez compter le nombre d'enregistrements en utilisant 'tf.python_io.tf_record_iterator' et étant donné la taille du lot, vous devriez obtenir le numéro d'époque actuel. Je n'ai pas compris votre deuxième question. –
@vijaym, ce n'est pas ce que je vous demande. J'ai 'tf.train.string_input_producer' et pas' tf.python_io.tf_record_iterator'. –