2017-10-12 7 views
0

Je veux effectuer une validation croisée de mes réseaux de neurones Keras avec la fonction cross_val_score() de scikit-learn.Comment puis-je exécuter la fonction après chaque fold dans cross_val_score() de scikit-learn?

Le problème est qu'après chaque pli non seulement on se souvient du résultat, mais aussi du modèle Keras entier. Donc, je voudrais effacer ce modèle en utilisant K.clear_session() après chaque pli. Mais ce ne sont que des détails pour le contexte.

Ma question principale est: Comment puis-je exécuter une fonction personnalisée après chaque fold avec cross_val_score() de scikit-learn? En d'autres termes: Il est possible d'exécuter un rappel qui doit être exécuté après chaque pli? Ou existe-t-il d'autres solutions de contournement?

Répondre

0

Vous pouvez probablement créer un rappel personnalisé et réécrire la méthode on_train_end (self, logs = {}) de ce rappel. Cette nouvelle méthode fera des choses à la fin de chaque étape de formation. Quelque chose comme ça:

class CustomCall(Callback): 

    def __init__(self): 
     super(CustomCall, self).__init__() 

    def on_epoch_begin(self, epoch, logs={}): 
     return 

    def on_epoch_end(self, epoch, logs={}): 
     return 

    def on_batch_begin(self, batch, logs={}): 
     return 

    def on_train_end(self, logs={}): 
     # Stuff here 
     print('\n Delete previous trained model : ') 
     K.clear_session() 
     return 
+0

Malheureusement, le problème est que K.clear_session() doit être appelé après l'évaluation du modèle, pas après s'être entraîné dans cross_val_score(). Je dois donc appeler K.clear_session() à la fin du cross-fold, pas à la fin de l'entraînement de Keras. –