2017-08-15 5 views
1

J'ai des doutes sur la façon dont tf.train.string_input_producer fonctionne. Supposons donc que j'ai fourni filename_list en tant que paramètre d'entrée au string_input_producer. Ensuite, selon la documentation https://www.tensorflow.org/programmers_guide/reading_data, cela va créer un FIFOQueue, où je peux définir le numéro d'époque, mélanger les noms de fichiers et ainsi de suite. Par conséquent, dans mon cas, j'ai 4 noms de fichiers ("db1.tfrecords", "db2.tfrecords" ...). Et j'ai utilisé tf.train.batch pour alimenter le lot réseau d'images. De plus, chaque nom de fichier/base de données contient un ensemble d'images pour une personne. La deuxième base de données est pour la deuxième personne et ainsi de suite. Jusqu'à présent, j'ai le code suivant:Détermination du numéro d'époque avec tf.train.string_input_producer en tensorflow

tfrecords_filename_seq = [(common + "P16_db.tfrecords"), (common + "P17_db.tfrecords"), (common + "P19_db.tfrecords"), 
          (common + "P21_db.tfrecords")] 

filename_queue = tf.train.string_input_producer(tfrecords_filename_seq, num_epochs=num_epoch, shuffle=False, name='queue') 
reader = tf.TFRecordReader() 

key, serialized_example = reader.read(filename_queue) 
features = tf.parse_single_example(
    serialized_example, 
    # Defaults are not specified since both keys are required. 
    features={ 
     'height': tf.FixedLenFeature([], tf.int64), 
     'width': tf.FixedLenFeature([], tf.int64), 
     'image_raw': tf.FixedLenFeature([], tf.string), 
     'annotation_raw': tf.FixedLenFeature([], tf.string) 
    }) 

image = tf.decode_raw(features['image_raw'], tf.uint8) 
height = tf.cast(features['height'], tf.int32) 
width = tf.cast(features['width'], tf.int32) 

image = tf.reshape(image, [height, width, 3]) 

annotation = tf.cast(features['annotation_raw'], tf.string) 

min_after_dequeue = 100 
num_threads = 4 
capacity = min_after_dequeue + num_threads * batch_size 
label_batch, images_batch = tf.train.batch([annotation, image], 
                 shapes=[[], [112, 112, 3]], 
                 batch_size=batch_size, 
                 capacity=capacity, 
                 num_threads=num_threads) 

Enfin, en essayant de voir l'image reconstruite à la sortie de la autoencoder, je me suis le premier des images de la 1ère base de données, puis je commencer à voir les images de la deuxième base de données et ainsi de suite.

Ma question: Comment puis-je savoir si je suis dans la même époque? Et si je suis dans la bonne époque, comment puis-je fusionner un lot d'images de tous les noms de fichiers que j'ai?

Enfin, j'ai essayé d'imprimer la valeur de l'époque en évaluant la variable locale dans le Session comme suit:

epoch_var = tf.local_variables()[0] 

Puis:

with tf.Session() as sess: 
    print(sess.run(epoch_var.eval())) # Here I got 9 as output. don't know y. 

Toute aide est très appréciée !!

+0

Vous pouvez compter le nombre d'enregistrements en utilisant 'tf.python_io.tf_record_iterator' et étant donné la taille du lot, vous devriez obtenir le numéro d'époque actuel. Je n'ai pas compris votre deuxième question. –

+0

@vijaym, ce n'est pas ce que je vous demande. J'ai 'tf.train.string_input_producer' et pas' tf.python_io.tf_record_iterator'. –

Répondre

0

Donc, ce que j'ai compris, c'est que l'utilisation de tf.train.shuffle_batch_join résout mon problème car il commence à mélanger des images provenant de différents ensembles de données. En d'autres termes, chaque lot contient maintenant des images de tous les jeux de données/noms_fichiers. Voici un exemple:

def read_my_file_format(filename_queue): 
    reader = tf.TFRecordReader() 
    key, serialized_example = reader.read(filename_queue) 
    features = tf.parse_single_example(
     serialized_example, 
     # Defaults are not specified since both keys are required. 
     features={ 
      'height': tf.FixedLenFeature([], tf.int64), 
      'width': tf.FixedLenFeature([], tf.int64), 
      'image_raw': tf.FixedLenFeature([], tf.string), 
      'annotation_raw': tf.FixedLenFeature([], tf.string) 
     }) 

    # This is how we create one example, that is, extract one example from the database. 
    image = tf.decode_raw(features['image_raw'], tf.uint8) 
    # The height and the weights are used to 
    height = tf.cast(features['height'], tf.int32) 
    width = tf.cast(features['width'], tf.int32) 

    # The image is reshaped since when stored as a binary format, it is flattened. Therefore, we need the 
    # height and the weight to restore the original image back. 
    image = tf.reshape(image, [height, width, 3]) 

    annotation = tf.cast(features['annotation_raw'], tf.string) 
    return annotation, image 

def input_pipeline(filenames, batch_size, num_threads, num_epochs=None): 
    filename_queue = tf.train.string_input_producer(filenames, num_epochs=num_epoch, shuffle=False, 
                name='queue') 
    # Therefore, Note that here we have created num_threads readers to read from the filename_queue. 
    example_list = [read_my_file_format(filename_queue=filename_queue) for _ in range(num_threads)] 
    min_after_dequeue = 100 
    capacity = min_after_dequeue + num_threads * batch_size 
    label_batch, images_batch = tf.train.shuffle_batch_join(example_list, 
                  shapes=[[], [112, 112, 3]], 
                  batch_size=batch_size, 
                  capacity=capacity, 
                  min_after_dequeue=min_after_dequeue) 
    return label_batch, images_batch, example_list 

label_batch, images_batch, input_ann_img = \ 
    input_pipeline(tfrecords_filename_seq, batch_size, num_threads, num_epochs=num_epoch) 

Et maintenant, cela va créer un certain nombre de lecteurs à lire à partir du FIFOQueue, et après chaque lecteur aura un décodeur différent. Enfin, après décodage des images, elles seront introduites dans un autre Queue créé après avoir appelé tf.train.shuffle_batch_join pour alimenter le réseau en lots d'images.