J'ai mis un groupe d'entités de longueur fixe et de longueur variable dans un tf.train.SequenceExample.L'ensemble de données tensorflow API ne fonctionne pas de manière stable lorsque la taille du lot est supérieure à 1
context_features
length, scalar, tf.int64
site_code_raw, scalar, tf.string
Date_Local_raw, scalar, tf.string
Time_Local_raw, scalar, tf.string
Sequence_features
Orig_RefPts, [#batch, #RefPoints, 4] tf.float32
tgt_location, [#batch, 3] tf.float32
tgt_val [#batch, 1] tf.float32
La valeur de #RefPoints
est variable pour les différents exemples de séquences. Je stocke sa valeur dans la fonctionnalité length
dans le context_features
. Les autres caractéristiques ont des tailles fixes.
Voici le code que je utilise pour lire & analyser les données:
def read_batch_DatasetAPI(
filenames,
batch_size = 20,
num_epochs = None,
buffer_size = 5000):
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(_parse_SeqExample1)
if (buffer_size is not None):
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
# next_element contains a tuple of following tensors
# length, scalar, tf.int64
# site_code_raw, scalar, tf.string
# Date_Local_raw, scalar, tf.string
# Time_Local_raw, scalar, tf.string
# Orig_RefPts, [#batch, #RefPoints, 4] tf.float32
# tgt_location, [#batch, 3] tf.float32
# tgt_val [#batch, 1] tf.float32
return iterator, next_element
def _parse_SeqExample1(in_SeqEx_proto):
# Define how to parse the example
context_features = {
'length': tf.FixedLenFeature([], dtype=tf.int64),
'site_code': tf.FixedLenFeature([], dtype=tf.string),
'Date_Local': tf.FixedLenFeature([], dtype=tf.string),
'Time_Local': tf.FixedLenFeature([], dtype=tf.string) #,
}
sequence_features = {
"input_features": tf.VarLenFeature(dtype=tf.float32),
'tgt_location_features': tf.FixedLenSequenceFeature([3], dtype=tf.float32),
'tgt_val_feature': tf.FixedLenSequenceFeature([1], dtype=tf.float32)
}
context, sequence = tf.parse_single_sequence_example(
in_SeqEx_proto,
context_features=context_features,
sequence_features=sequence_features)
# distribute the fetched context and sequence features into tensors
length = context['length']
site_code_raw = context['site_code']
Date_Local_raw = context['Date_Local']
Time_Local_raw = context['Time_Local']
# reshape the tensors according to the dimension definition above
Orig_RefPts = sequence['input_features'].values
Orig_RefPts = tf.reshape(Orig_RefPts, [-1, 4])
tgt_location = sequence['tgt_location_features']
tgt_location = tf.reshape(tgt_location, [-1])
tgt_val = sequence['tgt_val_feature']
tgt_val = tf.reshape(tgt_val, [-1])
return length, site_code_raw, Date_Local_raw, Time_Local_raw, \
Orig_RefPts, tgt_location, tgt_val
Quand j'appelle read_batch_DatasetAPI
avec batch_size = 1
(voir le code ci-dessous), il peut traiter tous (environ 200 000) Exemples de séquence de un by-one sans aucun problème. Mais si je change le batch_size
en un nombre supérieur à 1, il s'arrête simplement après l'extraction de 320 à 700 exemples de séquence sans message d'erreur. Je ne sais pas comment résoudre ce problème. Toute aide est appréciée!
# the iterator to get the next_element for one sample (in sequence)
iterator, next_element = read_batch_DatasetAPI(
in_tf_FWN, # the file name of the tfrecords containing ~200,000 Sequence Examples
batch_size = 1, # works when it is 1, doesn't work if > 1
num_epochs = 1,
buffer_size = None)
# tf session initialization
sess = tf.Session()
sess.run(tf.global_variables_initializer())
## reset the iterator to the beginning
sess.run(iterator.initializer)
try:
step = 0
while (True):
# get the next batch data
length, site_code_raw, Date_Local_raw, Time_Local_raw, \
Orig_RefPts, tgt_location, tgt_vale = sess.run(next_element)
step = step + 1
except tf.errors.OutOfRangeError:
# Task Done (all SeqExs have been visited)
print("closing ", in_tf_FWN)
except ValueError as err:
print("Error: {}".format(err.args))
except Exception as err:
print("Error: {}".format(err.args))