2017-03-01 1 views
0

Je veux utiliser CNN pour résoudre la tâche de désembrouillage, et j'ai des données d'apprentissage qui sont un répertoire d'images PNG et un fichier texte correspondant contenant le nom des fichiers. Comme les données sont trop volumineuses pour être ajoutées à la mémoire en une seule fois, et existe-t-il une API ou une méthode pour permettre de lire l'image blurissante en tant qu'entrée et sa vérité fondamentale comme résultat attendu pour former ?Comment est-ce que je pourrais lire l'image d'un répertoire comme entrée et sortie pendant que traing un modèle de CNN dans Tensorflow?

J'ai passé pas mal de temps à résoudre ce problème, mais j'ai été dérouté après avoir lu les API dans les introductions API en ligne.

+0

Vous cherchez ceci: http://stackoverflow.com/a/36947632/2505209? Utilisez l'image comme exemple ainsi que l'étiquette. – hars

Répondre

0

La méthode n'est pas si confuse. Le tensorflow fournit au fichier TFrecords un bon usage de la mémoire.

def create_cord(): 

    writer = tf.python_io.TFRecordWriter("train.tfrecords") 
    for index in xrange(66742): 
     blur_file_name = get_file_name(index, True) 
     orig_file_name = get_file_name(index, False) 
     blur_image_path = cwd + blur_file_name 
     orig_image_path = cwd + orig_file_name 

     blur_image = Image.open(blur_image_path) 
     orig_image = Image.open(orig_image_path) 

     blur_image = blur_image.resize((IMAGE_HEIGH, IMAGE_WIDTH)) 
     orig_image = orig_image.resize((IMAGE_HEIGH, IMAGE_WIDTH)) 

     blur_image_raw = blur_image.tobytes() 
     orig_image_raw = orig_image.tobytes() 
     example = tf.train.Example(features=tf.train.Features(feature={ 
     "blur_image_raw": tf.train.Feature(bytes_list=tf.train.BytesList(value=[blur_image_raw])), 
     'orig_image_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[orig_image_raw])) 
    })) 
    writer.write(example.SerializeToString()) 
    writer.close() 

pour lire l'ensemble de données:

def read_and_decode(filename): 
    filename_queue = tf.train.string_input_producer([filename]) 

    reader = tf.TFRecordReader() 
    _, serialized_example = reader.read(filename_queue) 
    features = tf.parse_single_example(serialized_example, 
            features={ 
             'blur_image_raw': tf.FixedLenFeature([], tf.string), 
             'orig_image_raw': tf.FixedLenFeature([], tf.string), 
            }) 

    blur_img = tf.decode_raw(features['blur_image_raw'], tf.uint8) 
    blur_img = tf.reshape(blur_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3]) 
    blur_img = tf.cast(blur_img, tf.float32) * (1./255) - 0.5 

    orig_img = tf.decode_raw(features['blur_image_raw'], tf.uint8) 
    orig_img = tf.reshape(orig_img, [IMAGE_WIDTH, IMAGE_HEIGH, 3]) 
    orig_img = tf.cast(orig_img, tf.float32) * (1./255) - 0.5 

    return blur_img, orig_img 


if __name__ == '__main__': 

    # create_cord() 

    blur, orig = read_and_decode("train.tfrecords") 
    blur_batch, orig_batch = tf.train.shuffle_batch([blur, orig], 
               batch_size=3, capacity=1000, 
               min_after_dequeue=100) 
    init = tf.global_variables_initializer() 
    with tf.Session() as sess: 
     sess.run(init) 
    # 启动队列 
     threads = tf.train.start_queue_runners(sess=sess) 
     for i in range(3): 
      v, l = sess.run([blur_batch, orig_batch]) 
      print(v.shape, l.shape)