Pour la formation d'un modèle de LSTM dans tensorflow, j'ai structuré mes données dans un tf.train.SequenceExample le format et stocké dans un fichier TFRecord . Je voudrais maintenant utiliser la nouvelle API DataSet à générer des lots rembourrés pour la formation. Dans the documentation il y a un exemple pour utiliser padded_batch, mais pour mes données je ne peux pas comprendre quelle devrait être la valeur de padded_shapes.Comment créer des lots matelassés dans Tensorflow pour les données tf.train.SequenceExample à l'aide de l'API DataSet?
Pour lire le fichier TFrecord dans les lots que j'ai écrit le code Python suivant:
import math
import tensorflow as tf
import numpy as np
import struct
import sys
import array
if(len(sys.argv) != 2):
print "Usage: createbatches.py [RFRecord file]"
sys.exit(0)
vectorSize = 40
inFile = sys.argv[1]
def parse_function_dataset(example_proto):
sequence_features = {
'inputs': tf.FixedLenSequenceFeature(shape=[vectorSize],
dtype=tf.float32),
'labels': tf.FixedLenSequenceFeature(shape=[],
dtype=tf.int64)}
_, sequence = tf.parse_single_sequence_example(example_proto, sequence_features=sequence_features)
length = tf.shape(sequence['inputs'])[0]
return sequence['inputs'], sequence['labels']
sess = tf.InteractiveSession()
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.contrib.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_function_dataset)
# dataset = dataset.batch(1)
dataset = dataset.padded_batch(4, padded_shapes=[None])
iterator = dataset.make_initializable_iterator()
batch = iterator.get_next()
# Initialize `iterator` with training data.
training_filenames = [inFile]
sess.run(iterator.initializer, feed_dict={filenames: training_filenames})
print(sess.run(batch))
Le code fonctionne bien si j'utilise dataset = dataset.batch(1)
(pas de remplissage nécessaire dans ce cas), mais quand j'utilise le padded_batch
variante, je reçois l'erreur suivante:
TypeError: If shallow structure is a sequence, input must also be a sequence. Input has type: .
Pouvez-vous me aider à comprendre ce que je passerais pour le paramètre padded_shapes?
(Je sais qu'il ya beaucoup de code exemple en utilisant le filetage et les files d'attente pour cela, mais je préfère utiliser la nouvelle API DataSet pour ce projet)
Merci Marijn! Vos questions m'ont beaucoup aidé! –