2017-09-30 2 views
0

J'ai mis un groupe d'entités de longueur fixe et de longueur variable dans un tf.train.SequenceExample.L'ensemble de données tensorflow API ne fonctionne pas de manière stable lorsque la taille du lot est supérieure à 1

context_features 
    length,   scalar,     tf.int64 
    site_code_raw,  scalar,     tf.string 
    Date_Local_raw, scalar,     tf.string 
    Time_Local_raw, scalar,     tf.string 
Sequence_features 
    Orig_RefPts,  [#batch, #RefPoints, 4] tf.float32 
    tgt_location,  [#batch, 3]    tf.float32 
    tgt_val   [#batch, 1]    tf.float32 

La valeur de #RefPoints est variable pour les différents exemples de séquences. Je stocke sa valeur dans la fonctionnalité length dans le context_features. Les autres caractéristiques ont des tailles fixes.

Voici le code que je utilise pour lire & analyser les données:

def read_batch_DatasetAPI(
    filenames, 
    batch_size = 20, 
    num_epochs = None, 
    buffer_size = 5000): 

    dataset = tf.contrib.data.TFRecordDataset(filenames) 
    dataset = dataset.map(_parse_SeqExample1) 
    if (buffer_size is not None): 
     dataset = dataset.shuffle(buffer_size=buffer_size) 
    dataset = dataset.repeat(num_epochs) 
    dataset = dataset.batch(batch_size) 
    iterator = dataset.make_initializable_iterator() 
    next_element = iterator.get_next() 

    # next_element contains a tuple of following tensors 
    # length,   scalar,     tf.int64 
    # site_code_raw,  scalar,     tf.string 
    # Date_Local_raw, scalar,     tf.string 
    # Time_Local_raw, scalar,     tf.string 
    # Orig_RefPts,  [#batch, #RefPoints, 4] tf.float32 
    # tgt_location,  [#batch, 3]    tf.float32 
    # tgt_val   [#batch, 1]    tf.float32 

    return iterator, next_element 

def _parse_SeqExample1(in_SeqEx_proto): 

    # Define how to parse the example 
    context_features = { 
     'length': tf.FixedLenFeature([], dtype=tf.int64), 
     'site_code': tf.FixedLenFeature([], dtype=tf.string), 
     'Date_Local': tf.FixedLenFeature([], dtype=tf.string), 
     'Time_Local': tf.FixedLenFeature([], dtype=tf.string) #, 
    } 

    sequence_features = { 
     "input_features": tf.VarLenFeature(dtype=tf.float32), 
     'tgt_location_features': tf.FixedLenSequenceFeature([3], dtype=tf.float32), 
     'tgt_val_feature': tf.FixedLenSequenceFeature([1], dtype=tf.float32) 
    }               

    context, sequence = tf.parse_single_sequence_example(
     in_SeqEx_proto, 
     context_features=context_features, 
     sequence_features=sequence_features) 

    # distribute the fetched context and sequence features into tensors 
    length = context['length'] 
    site_code_raw = context['site_code'] 
    Date_Local_raw = context['Date_Local'] 
    Time_Local_raw = context['Time_Local'] 

    # reshape the tensors according to the dimension definition above 
    Orig_RefPts = sequence['input_features'].values 
    Orig_RefPts = tf.reshape(Orig_RefPts, [-1, 4]) 
    tgt_location = sequence['tgt_location_features'] 
    tgt_location = tf.reshape(tgt_location, [-1]) 
    tgt_val = sequence['tgt_val_feature'] 
    tgt_val = tf.reshape(tgt_val, [-1]) 

    return length, site_code_raw, Date_Local_raw, Time_Local_raw, \ 
     Orig_RefPts, tgt_location, tgt_val 

Quand j'appelle read_batch_DatasetAPI avec batch_size = 1 (voir le code ci-dessous), il peut traiter tous (environ 200 000) Exemples de séquence de un by-one sans aucun problème. Mais si je change le batch_size en un nombre supérieur à 1, il s'arrête simplement après l'extraction de 320 à 700 exemples de séquence sans message d'erreur. Je ne sais pas comment résoudre ce problème. Toute aide est appréciée!

# the iterator to get the next_element for one sample (in sequence) 
iterator, next_element = read_batch_DatasetAPI(
    in_tf_FWN, # the file name of the tfrecords containing ~200,000 Sequence Examples 
    batch_size = 1, # works when it is 1, doesn't work if > 1 
    num_epochs = 1, 
    buffer_size = None) 

# tf session initialization 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 

## reset the iterator to the beginning 
sess.run(iterator.initializer) 

try: 
    step = 0 

    while (True): 

     # get the next batch data 
     length, site_code_raw, Date_Local_raw, Time_Local_raw, \ 
     Orig_RefPts, tgt_location, tgt_vale = sess.run(next_element) 

     step = step + 1 

except tf.errors.OutOfRangeError: 
    # Task Done (all SeqExs have been visited) 
    print("closing ", in_tf_FWN) 

except ValueError as err: 
    print("Error: {}".format(err.args)) 

except Exception as err: 
    print("Error: {}".format(err.args)) 

Répondre

0

J'ai vu certains postes (Example 1 et Example 2) mentionnant la nouvelle fonction datasetfrom_generator (https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/data/Dataset#from_generator). Je ne sais pas comment l'utiliser pour résoudre mon problème pour le moment. Tout le monde sait comment le faire, s'il vous plaît le poster comme une nouvelle réponse. Je vous remercie!

Voici mon diagnostic actuel et la solution à ma question:

La variation de la longueur de la séquence (#RefPoints) a causé le problème. Le dataset.map(_parse_SeqExample1) ne fonctionne que si les #RefPoints sont identiques dans le lot. C'est pourquoi si le batch_size était 1, cela fonctionnait toujours, mais s'il était supérieur à 1, il a échoué à un moment donné.

J'ai trouvé que dataset a la fonction padded_batch qui peut combler la longueur variable à la longueur maximale dans le lot.Quelques modifications ont été apportées pour résoudre temporairement mon problème (je suppose que from_generator sera la vraie solution à mon cas):

  1. Dans la fonction _parse_SeqExample1, la déclaration de retour a été changé pour

    return tf.tuple([length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts, tgt_location, tgt_val])

  2. Dans la fonction read_batch_DatasetAPI, la déclaration

    dataset = dataset.batch(batch_size)

    a été changé pour

    dataset = dataset.padded_batch(batch_size, padded_shapes=( tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([None, 4]), tf.TensorShape([3]), tf.TensorShape([1]) ) )

  3. Enfin, changer la déclaration de chercher

    length, site_code_raw, Date_Local_raw, Time_Local_raw, \ Orig_RefPts, tgt_location, tgt_vale = sess.run(next_element)

    à

Note: Je ne sais pas pourquoi, cela ne fonctionne que sur la version actuelle tf-nightly-gpu pas le tensorflow-gpu v1.3.