2017-10-04 1 views
2

Je suis un peu confus comment utiliser fit_generator en keras.Comment: fit_generator dans keras

Par exemple permet de dire:

  • nous avons 10000 points de données
  • nous voulons courir pour 10 époques
  • avec la taille du lot de 512

En utilisant fit nous venons:

x, y = load_data() 
model.fit(x=x, y=y, batch_size=512, epochs=10) 

load_data charge toutes les données.

Maintenant, comment faire la même chose avec fit_generator.

Ce n'est pas clair pour moi comment il est traité lors de l'utilisation fit_generator. Si je le générateur suivant:

def data_generator(): 
    for x, y in load_data_per_line(): 
     yield x, y 

Dans le générateur ci-dessus chaque fois qu'il yields un point de données. Et:

def data_generator_2(): 
    x_output = [] 
    y_output = [] 
    i = 0 
    for x, y in load_data_per_line(): 
     x_output[i] = x 
     y_output[i] = y 
     i = i + 1 
     if i == batch_size: 
      yield x_output, y_output 
      i = 0 
      x_output = [] 
      y_output = [] 

Dans le générateur ci-dessus chaque fois qu'il yields points de données de taille de lot (512 dans ce cas).

Pour obtenir la même chose que fit mais en utilisant fit_generator:

model.fit_generator(data_generator(), steps_per_epoch=10000/512, epochs=10) 

ou

model.fit_generator(data_generator_2(), steps_per_epoch=10000/512, epochs=10) 

ou les deux sont faux (fit_generator et data_generator s)? Si l'un d'entre eux est correct, est-ce que cela garantit que tous les points de données seront traités et traités de manière séquentielle?

Toute idée est utile

Répondre

2

générateur 2 est presque correct, mais il devrait mieux renvoyer des tableaux numpy:

yield np.asarray(x_output),np.asarray(y_output) 

En outre, il devrait être infini:

while True: 

    #the code inside to loop infinitely 

Le premier ne retournera pas les lots et échouera.

Vous aurez probablement un problème dans steps_per_epoch, car 10000 n'est pas un multiple de 512. Vous avez besoin d'étapes entières. Vous pouvez à l'intérieur du générateur vérifier if i == 10000: et passer un plus petit lot que le dernier lot. Puis, vous avez (10000 //512) + (10000 % 512) étapes ou lots.

Tous les lots seront lus en séquence, mais les keras mélangent automatiquement le contenu de ces lots, utilisez suffle=False. Si vous utilisez le multithreading (pas le cas), vous devez créer des générateurs thread-safe ou utiliser un keras Sequence.

+0

Juste une curiosité, alors dans ce cas, le dernier lot n'aura pas la taille de 512 et c'est très bien non? – titipata

+1

C'est bien, tant que vous ne laissez pas votre générateur essayer de lire plus que ce qui est autorisé. –

+0

merci pour la réponse détaillée. comme le but du générateur d'ajustement est de former un modèle avec beaucoup de données pourquoi il suppose que je dois connaître le nombre de points de données? que se passe-t-il si pour une raison quelconque je ne connais pas le nombre exact de points de données? comment définir les étapes dans ce cas? –