Je travaille avec un ensemble de données dans lequel les éléments de lot sont des textes représentés par des matrices avec la forme (max_sentences_per_text, max_tokens_per_sentence). Il passe par une couche d'intégration (devenant 3d) puis par un LSTM distribué dans le temps qui produit un vecteur pour chaque phrase (retour à 2d). Ensuite, une seconde couche LSTM lit tous les vecteurs de phrase et sort un vecteur final pour chaque élément de lot, qui peut traverser des couches denses normales.InvalidArgumentError utilisant des LSTM empilés dans keras
Ceci est illustré ci-dessous (généré avec keras.utils.plot_model
), avec 85 phrases par texte et 40 jetons par phrase:
Voici le code du modèle:
inputs = Input([num_sentences, max_sentence_size])
vocab_size, embedding_size = embeddings.shape
init = initializers.constant(embeddings)
emb_layer = Embedding(vocab_size, embedding_size, mask_zero=True,
embeddings_initializer=init)
emb_layer.trainable = False
embedded = emb_layer(inputs)
projection_layer = Dense(lstm1_units, activation=None, use_bias=False,
name='projection')
projected = projection_layer(embedded)
lstm1 = LSTM(lstm1_units, name='token_lstm')
sentence_vectors = TimeDistributed(lstm1)(projected)
lstm2 = LSTM(lstm2_units, name='sentence_lstm')
final_vector = lstm2(sentence_vectors)
hidden = Dense(hidden_units, activation='relu', name='hidden')(final_vector)
scores = Dense(num_scores, activation='sigmoid', name='scorer')(hidden)
model = keras.models.Model(inputs, scores)
Cela ressemble bien pour moi, sauf que j'ai l'erreur suivante:
Traceback (most recent call last):
File "src/network.py", line 43, in <module>
network.fit(x, y, validation_data=(xval, yval))
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/engine/training.py", line 1507, in fit
initial_epoch=initial_epoch)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/engine/training.py", line 1156, in _fit_loop
outs = f(ins_batch)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 2269, in __call__
**self.session_kwargs)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 767, in run
run_metadata_ptr)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 965, in _run
feed_dict_string, options, run_metadata)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1015, in _do_run
target_list, options, run_metadata)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 1035, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Inputs to operation sentence_lstm/while/Select_2 of type Select must have the same size and shape. Input 0: [32,4000] != input 1: [32,100]
[[Node: sentence_lstm/while/Select_2 = Select[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](sentence_lstm/while/Tile, sentence_lstm/while/add_5, sentence_lstm/while/Identity_3)]]
Caused by op u'sentence_lstm/while/Select_2', defined at:
File "src/network.py", line 37, in <module>
args.hidden_units)
File "src/model.py", line 51, in create_model
final_vector = lstm2(sentence_vectors)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/layers/recurrent.py", line 262, in __call__
return super(Recurrent, self).__call__(inputs, **kwargs)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/engine/topology.py", line 596, in __call__
output = self.call(inputs, **kwargs)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/layers/recurrent.py", line 341, in call
input_length=input_shape[1])
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 2538, in rnn
swap_memory=True)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2605, in while_loop
result = context.BuildLoop(cond, body, loop_vars, shape_invariants)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2438, in BuildLoop
pred, body, original_loop_vars, loop_vars, shape_invariants)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2388, in _BuildLoop
body_result = body(*packed_vars_for_body)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/keras/backend/tensorflow_backend.py", line 2509, in _step
new_states = [tf.where(tiled_mask_t, new_states[i], states[i]) for i in range(len(states))]
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 2301, in where
return gen_math_ops._select(condition=condition, t=x, e=y, name=name)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 2386, in _select
name=name)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
op_def=op_def)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2327, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/Users/erick/anaconda2/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1226, in __init__
self._traceback = _extract_stack()
InvalidArgumentError (see above for traceback): Inputs to operation sentence_lstm/while/Select_2 of type Select must have the same size and shape. Input 0: [32,4000] != input 1: [32,100]
[[Node: sentence_lstm/while/Select_2 = Select[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](sentence_lstm/while/Tile, sentence_lstm/while/add_5, sentence_lstm/while/Identity_3)]]
L'appel de formation est network.fit(x, y, validation_data=(xval, yval))
, avec les formes suivantes:
In [89]: x.shape
Out[89]: (1000, 85, 40)
In [90]: y.shape
Out[90]: (1000, 5)
In [91]: xval.shape
Out[91]: (500, 85, 40)
In [92]: yval.shape
Out[92]: (500, 5)
On dirait que c'est un problème avec la taille de votre entrée. Qu'est-ce que vous utilisez comme entrée pour model.fit? C'est là que l'erreur est lancée. – michetonu
Je ne pense pas que le problème est avec les entrées. J'ai édité la question avec plus d'informations, montrant que le problème a été causé dans l'appel à la deuxième couche LSTM. – erickrf