2016-03-11 4 views
2

J'essaie d'implémenter un RNN avec état, mais il ne cesse de me demander un "input_shape complet (y compris la taille de lot)" . J'ai essayé différentes choses dans les arguments input_shape et input_batch_size, mais personne ne semble fonctionner. Quelqu'un peut-il faire la lumière?Keras/Python - Si un RNN est avec état, un input_shape complet doit être fourni (y compris la taille du lot)

code:

model=Sequential()  
model.add(SimpleRNN(init='uniform',output_dim=80,input_dim=len(pred_frame.columns),stateful=True,batch_input_shape=(len(pred_frame.index),len(pred_frame.columns)),input_shape=(len(pred_frame.index),len(pred_frame.columns)))) 
model.add(Dense(output_dim=200,input_dim=len(pred_frame.columns),init="glorot_uniform")) 
model.add(Dense(output_dim=1)) 
model.compile(loss="mse", class_mode='scalar', optimizer="sgd") 
model.fit(X=predictor_train, y=target_train, batch_size=len(pred_frame.index),show_accuracy=True) 

Traceback:

File "/Users/file.py", line 1483, in Pred 
model.add(SimpleRNN(init='uniform',output_dim=80,input_dim=len(pred_frame.columns),stateful=True,batch_input_shape=(len(pred_frame.index),len(pred_frame.columns)),input_shape=(len(pred_frame.index),len(pred_frame.columns)))) 
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 194, in __init__ 
super(SimpleRNN, self).__init__(**kwargs) 
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 97, in __init__ 
super(Recurrent, self).__init__(**kwargs) 
File "/Library/Python/2.7/site-packages/keras/layers/core.py", line 43, in __init__ 
self.set_input_shape((None,) + tuple(kwargs['input_shape'])) 
File "/Library/Python/2.7/site-packages/keras/layers/core.py", line 141, in set_input_shape 
self.build() 
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 199, in build 
self.reset_states() 
File "/Library/Python/2.7/site-packages/keras/layers/recurrent.py", line 221, in reset_states 
'(including batch size).') 
Exception: If a RNN is stateful, a complete input_shape must be provided (including batch size). 

Répondre

3

Vous devez fournir uniquement le batch_input_shape = paramètre, et non le paramètre input_shape. En outre, pour éviter les erreurs de forme en entrée, assurez-vous que la taille des données d'apprentissage est un multiple de batch_size. Et enfin, si vous utilisez des scissions de validation, vous devez être sûr que les deux scissions sont également des multiples de batch_size.

# ensure data size is a multiple of batch_size 
data_size=data_size-data_size%batch_size 
# ensure validation splits are multiples of batch_size 
increment=float(batch_size)/len(data_size) 
val_split=float(int(val_split/(increment))) * increment 
+0

Merci pour Erik, Cependant, je reçois toujours la même erreur. Alors je demande, quel devrait être le "paramètre"? J'utilise un tuple avec (nombre de lignes dans la trame de prédiction, nombre de colonnes dans la trame de prédiction) - est-ce exact? – abutremutante

0

Dans votre définition de SimpleRNN, enlever input_dim et input_shape, définissez batch_input_shape = (Number_Of_sequences, Size_Of_Each_Sequence, Shape_Of_Element_In_Each_Sequence). batch_input_shape devrait être un tuple de longueur au moins 3.

Si vous passe vos séquences un par un, ensemble Number_Of_sequences = 1

Si la taille de vos séquences n'est pas fixe, réglez Size_Of_Each_Sequence = None