2017-07-09 3 views
0

Je souhaite utiliser une règle de mise à jour alternée avec keras. I.e. par lot, je voudrais appeler une étape régulière basée sur le gradient, et ensuite appeler une étape personnalisée.Règle de mise à jour alternée personnalisée avec keras

J'ai pensé à l'implémenter en héritant d'un optimiseur ou d'un rappel (et en utilisant les appels sur le lot). Cependant, ni l'un ni l'autre ne le ferait, car ils n'ont pas tous les deux les données de lot et les étiquettes de lot (et j'ai besoin des deux).

Une idée sur la façon d'implémenter une mise à jour personnalisée en alternance avec keras?

Si nécessaire, je ne me dérange pas d'appeler directement tensorflow méthodes spécifiques, tant que je peux continuer à utiliser le projet enveloppé dans le cadre keras (avec model.fit, model.predict ..)

Répondre

0

essayer de créer un rappel personnalisé

import keras.callbacks as callbacks 

class JSONMetrics(callbacks.Callback): 

_model  = None 
_each_epoch = None 
_metrics = None 
_epoch  = None 
_file_json = None 

def __init__(self,model,each_epoch,logger=None): 

    self._file_json = "file_log.json" 
    self._model  = model 
    self._each_epoch= each_epoch 
    self._epoch  = 0 
    self._metrics = {'loss':[], 'acc':[]} 

def on_epoch_begin(self, epoch, logs): 
    # print('Epoch {0} begin'.format(epoch)) 
    try: 
     with open(self._file_json, 'r') as f: 
      self._metrics = json.load(f) 

def on_epoch_end(self, epoch, logs): 
    self._logger.info('Nemesis: Epoch {0} end'.format(epoch)) 

    self._metrics['loss'].append(logs.get('loss')) 
    self._metrics['acc'].append(logs.get('acc')) 
    with open(self._file_json, 'w') as f: 
     data = json.dump(self._metrics, f) 

    if self._epoch % self._each_epoch == 0: 

     file_name = 'weights%08d.h5' % self._epoch 
     #print('Saving weights at {0} file'.format(file_name)) 
     self._model.save_weights(file_name) 

    self._epoch += 1 

Vous pouvez évoquer le self.model pour résoudre votre problème et enregistrez le acc et la perte par exemple.