J'essaie de prédire la classe pour chaque nombre dans le vecteur d'entrée. Il y a 3 classes. Classe 1, si la valeur d'entrée est passée de 0 à 1. Classe 2, si elle est passée de 1 à 0. Classe 0 sinon. Après la deuxième époque et la précision suivante est bloqué à 0,8824.RNN ne surajuste pas sur des données simples
Un nombre plus élevé d'années de formation ne change rien. J'ai essayé de commutation LSTM
à GRU
ou SimpleRNN
, cela ne change rien. J'ai également essayé de générer des vecteurs d'entrée plus longs et plusieurs lots, même sans succès. La seule chose qui m'a aidé est d'augmenter la taille des couches LSTM à 128, en ajoutant trois couches TimeDistributedDense(128, relu)
et BatchNormalization
après chaque couche, y compris LSTM. Mais on dirait que c'est trop pour un problème aussi simple et ne donne pas des résultats parfaits de toute façon. J'ai passé plus d'une journée à essayer de le faire fonctionner. Qu'est-ce qui pourrait être un problème? Merci!
# complete code for training
from keras.models import Sequential
from keras.layers import Dense, LSTM, TimeDistributed
from keras.utils.np_utils import to_categorical
import numpy as np
np.random.seed(1337)
X = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0])
Y = np.array([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0])
Y_cat = to_categorical(Y, 3).reshape((1, len(X), 3))
X_r = X.reshape((1, len(X), 1))
model = Sequential()
model.add(LSTM(32, input_dim=1, return_sequences=True))
model.add(LSTM(32, return_sequences=True))
model.add(LSTM(32, return_sequences=True))
model.add(TimeDistributed(Dense(3, activation='softmax')))
model.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
model.fit(X_r, Y_cat, nb_epoch=10)
model.predict_classes(X_r) # will print array filled with zeros
Merci! Bien que je n'aie aucun NaN à n'importe quelle époque, enlever deux couches de LSTM et m'entraîner pour 250 époques a aidé. – SSS
En outre, j'ai trouvé qu'avec l'optimiseur d'Adam (lr = 0.05) il s'entraîne complètement dans environ 30 époques. – SSS
@SSS À droite, je vois maintenant que j'avais Keras 1.2.0 au lieu de la dernière 1.2.1. Après la mise à jour, les NaN ne sont plus là, donc je suppose que le problème a été corrigé dans la nouvelle version. – jdehesa