2017-10-14 1 views
1

Lors de la création d'un client Keras Optimizer, la fonction Workhorse est Optimizer.get_updates(). J'ai été capable de créer un optimiseur à pas fixe, mais je ne suis pas sûr de savoir comment faire des choses telles que des moyennes en cours où je dois utiliser des valeurs calculées à partir d'appels précédents de la fonction.Création de votre propre Optimiseur Keras

Par exemple, consider RMSprop. L'accumulateur n'est-il pas réinitialisé à chaque appel de la fonction?

accumulators = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] 
self.weights = accumulators 

Comment est RMSProp faisant la moyenne en cours d'exécution lorsque l'accumulateur est en cours de réinitialisation au début de chaque appel de mise à jour?

+1

Cela a été source de confusion pour moi aussi quand j'ai vu le code, mais cette fonction n'est pas appelée à chaque mise à jour, elle est appelée une fois pour construire le graphique, voir https://github.com/fchollet/keras/issues/5125 –

+1

Aussi la même question que l'issue https://stackoverflow.com/questions/41787873/how-adagrad-wroks-in-keras-what-does-self-weights-mean-in-keras-optimizer?rq=1 –

Répondre

1

Vous avez raison de dire que l'accumulateur est réglé sur zéro à chaque appel get_updates. Mais cette fonction n'est appelée qu'une seule fois, tandis que le graphe de calcul est construit.

Ce qui prête à confusion, c'est l'utilisation de fonctions symboliques. Comme Keras utilise des représentations symboliques, ce qui se passe dans get_updates est qu'une symbolique mise à jour est générée, dans la ligne 237-238:

new_a = self.rho * a + (1. - self.rho) * K.square(g) 
self.updates.append(K.update(a, new_a)) 

Ces mises à jour sont ensuite utilisées tout en effectuant une descente de gradient. Symboliquement, il est indiqué que lorsque vous appelez des mises à jour, comme dans les mises à jour d'une variable partagée, a est définie sur la valeur new_a qui prend en compte la valeur précédente de a. Cette partie fait l'accumulateur de moyenne courante.

Notez que plusieurs mises à jour sont créées, une pour chaque paramètre, puis ces mises à jour symboliques sont collectées dans une liste renvoyée à l'appelant.

+0

Tout s'explique maintenant. Merci beaucoup! –