2016-04-23 2 views
5

J'utilise keras 1.0.1 J'essaie d'ajouter une couche d'attention au-dessus d'un LSTM. C'est ce que j'ai jusqu'ici, mais ça ne marche pas.Keras attention layer over LSTM

input_ = Input(shape=(input_length, input_dim)) 
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_) 
att = TimeDistributed(Dense(1)(lstm)) 
att = Reshape((-1, input_length))(att) 
att = Activation(activation="softmax")(att) 
att = RepeatVector(self.HID_DIM)(att) 
merge = Merge([att, lstm], "mul") 
hid = Merge("sum")(merge) 

last = Dense(self.HID_DIM, activation="relu")(hid) 

Le réseau doit appliquer un LSTM sur la séquence d'entrée. Ensuite, chaque état caché du LSTM doit être entré dans une couche entièrement connectée, sur laquelle une Softmax est appliquée. Le softmax est répliqué pour chaque dimension cachée et multiplié par les états cachés de LSTM par élément. Ensuite, le vecteur résultant devrait être moyenné.

EDIT: Cela compile, mais je ne suis pas sûr qu'il fasse ce que je pense qu'il devrait faire.

input_ = Input(shape=(input_length, input_dim)) 
lstm = GRU(self.HID_DIM, input_dim=input_dim, input_length = input_length, return_sequences=True)(input_) 
att = TimeDistributed(Dense(1))(lstm) 
att = Flatten()(att) 
att = Activation(activation="softmax")(att) 
att = RepeatVector(self.HID_DIM)(att) 
att = Permute((2,1))(att) 
mer = merge([att, lstm], "mul") 
hid = AveragePooling1D(pool_length=input_length)(mer) 
hid = Flatten()(hid) 
+0

Salut @siamii il y avait du succès avec le réseau d'attention? Actuellement j'essaie la même chose .. – Nacho

+0

Jetez un oeil à cette mise en œuvre de l'attention sur un LSTM: https://github.com/philipperemy/keras-attention-mechanism Cela devrait fonctionner sur votre exemple. –

Répondre

1

Here est une implémentation de LSTM de l'attention avec Keras, et un exemple de son instantiation. Je ne l'ai pas essayé moi-même, cependant.