2016-04-26 11 views
4

J'utilise Keras pour prédire une série chronologique. En standard j'utilise 20 époques. Je veux savoir ce que mon réseau de neurones a prévu pour chacune des 20 époques. En utilisant model.predict, j'obtiens la dernière prédiction. Cependant je veux toutes les prédictions, ou au moins les 10 dernières (qui ont des niveaux d'erreur acceptables).Python/Keras - accès au rappel ModelCheckpoint

Pour y accéder j'essaie la fonction ModelCheckpoint de Keras, mais j'ai du mal à y accéder par la suite. J'utilise le code suivant:

model=Sequential() 

model.add(GRU(input_dim=col,init='uniform',output_dim=20)) 
model.add(Dense(10)) 
model.add(Dense(5)) 
model.add(Activation("softmax")) 
model.add(Dense(1)) 

model.compile(loss="mae", optimizer="RMSprop") 

checkpoint=ModelCheckpoint(filepath='/Users/Alex/checkpoint.hdf5') 

model.fit(X=predictor_train, y=target_train, nb_epoch=20, batch_size=batch,validation_split=0.1) #best validation split at 0.1 
model.evaluate(X=predictor_train, y=target_train,batch_size=batch,show_accuracy=True) 

print checkpoint 

Objectivement, mes questions sont les suivantes:

  • Je m'y attendais après avoir exécuté le code que je trouverais un fichier nommé checkpoint.hdf5 dans le dossier/Users/Alex, cependant je ne l'ai pas fait. Qu'est-ce que je rate?

  • Lorsque j'imprime checkpoint sur ce que j'obtiens est un keras.callbacks.ModelCheckpoint object at 0x117471290. Existe-t-il un moyen d'imprimer ce que je veux? A quoi ressemblerait le code?

Votre aide est très appréciée :)

Répondre

8

Il y a deux problèmes dans ce code:

  • Vous n'êtes pas passer le rappel à la méthode ajustement du modèle. Ceci est fait avec l'argument mot-clé "callbacks".
  • Le filepath doit contenir des espaces réservés (comme "{Epoch: 02d} - {val_loss: .2f}". Qui sont utilisés avec str.format par Keras afin de sauver chaque époque à un autre fichier

Ainsi, la version correcte doit être quelque chose comme:.

checkpoint = ModelCheckpoint(filepath='/Users/Alex/checkpoint-{epoch:02d}-{val_loss:.2f}.hdf5') 

model.fit(X=predictor_train, y=target_train, nb_epoch=20, 
     batch_size=batch,validation_split=0.1, callbacks=[checkpoint]) 

Vous pouvez également ajouter d'autres types de callbacks dans la liste qui est attribué à ce mot-clé

Malheureusement, l'objet de rappel ne stocke pas les informations d'historique afin il ne peut pas être récupéré de i t.

+0

Y at-il un moyen d'avoir ce fichier en csv ou txt? hdf5 est assez difficile à utiliser ... – abutremutante

+0

@abutremutante Non, et HDF5 est assez facile à utiliser avec h5py, mais pourquoi avez-vous besoin de travailler avec? Vous pouvez charger les poids sur votre modèle avec load_weights –