2017-10-12 2 views
3

J'utilise tensorflow dataset API. Et tester mon code avec un cas simple. Ci-dessous montre le code simple que j'ai utilisé. Le problème est que, lorsque la taille de l'ensemble de données est petite, il semble que la taille retournée par l'API de l'ensemble de données n'est pas cohérente. Je suis sûr qu'il y a une bonne façon d'y faire face. Mais même si j'ai lu toutes les fonctions dans cette page et tutoriel, je ne pouvais pas trouver cela.la taille retournée de l'ensemble de données de tensorflow API n'est pas constante

import numpy as np 
import tensorflow as tf 

data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel] 
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source) 
dataset = dataset.shuffle(buffer_size=100) 
dataset = dataset.batch(16) 
dataset = dataset.repeat() 

iterator = tf.contrib.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes) 
next_element = iterator.get_next() 
training_init_op = iterator.make_initializer(dataset) 

with tf.Session() as sess: 
    sess.run(training_init_op) 
    next_elem = next_element.eval() 
    print(np.shape(next_elem)) 
    next_elem = next_element.eval() 
    print(np.shape(next_elem)) 
    next_elem = next_element.eval() 
    print(np.shape(next_elem)) 
    next_elem = next_element.eval() 
    print(np.shape(next_elem)) 
    next_elem = next_element.eval() 
    print(np.shape(next_elem)) 
    next_elem = next_element.eval() 
    print(np.shape(next_elem)) 
    next_elem = next_element.eval() 
    print(np.shape(next_elem)) 

L'ensemble de données est une vidéo en niveaux de gris. Il y a une séquence de vidéo de 24 et la taille de pas est de 200. La taille de la trame est de 64 par 64 et un seul canal. Je mets la taille des lots comme 16 et taille du tampon 100. Mais le résultat du code est,

(16, 200, 64, 64, 1) 
(8, 200, 64, 64, 1) 
(16, 200, 64, 64, 1) 
(8, 200, 64, 64, 1) 
(16, 200, 64, 64, 1) 
(8, 200, 64, 64, 1) 
(16, 200, 64, 64, 1) 

La taille de retour de la vidéo est soit 16 ou 8. Je suppose que c'est parce que la taille des données d'origine est faible, 24, quand il atteint la fin des données, l'API retourne juste ce qui reste.

Mais je ne comprends pas. Je définis également la taille du tampon sur 100. Cela signifie que le tampon doit être rempli à l'avance avec un petit jeu de données. Et à partir de ce tampon, l'API doit sélectionner next_element dont la taille de lot est de 16.

Lorsque j'ai utilisé l'API de type queue dans tensorflow, je n'ai pas eu ce problème. Quelle que soit la taille des données d'origine, il existe un moment où l'itérateur atteint la fin de l'ensemble de données. Je me demande comment ce problème est résolu par d'autres personnes utilisant cette API.

Répondre

2

Essayez d'appeler repeat() avant batch():

data_source = tf.zeros([24, 200, 64, 64, 1]) #[number_of_video, steps, pixel_w, pixel_h, channel] 
dataset = tf.contrib.data.Dataset.from_tensor_slices(data_source) 
dataset = dataset.shuffle(buffer_size=100) 
dataset = dataset.repeat() 
dataset = dataset.batch(16) 

Le résultat que je reçois:

(16, 200, 64, 64, 1) 
(16, 200, 64, 64, 1) 
(16, 200, 64, 64, 1) 
(16, 200, 64, 64, 1) 
(16, 200, 64, 64, 1) 
(16, 200, 64, 64, 1) 
(16, 200, 64, 64, 1)