2017-08-24 2 views
1

Excuses si ce n'est pas le bon endroit pour soulever mon problème (s'il vous plaît aidez-moi avec le meilleur endroit pour le soulever si c'est le cas). Je suis un novice avec Keras et Python, donc les réponses d'espoir ont cela à l'esprit.Comment former un CNN par lots avec Keras fit_generator?

Je suis en train de former un modèle de direction CNN qui prend des images en entrée. C'est un assez grand ensemble de données, j'ai donc créé un générateur de données pour fonctionner avec fit_generator(). Il n'est pas clair pour moi comment faire pour que cette méthode s'entraîne sur des lots, donc j'ai supposé que le générateur doit renvoyer des lots à fit_generator(). Le générateur ressemble à ceci:

def gen(file_name, batchsz = 64): 
    csvfile = open(file_name) 
    reader = csv.reader(csvfile) 
    batchCount = 0 
    while True: 
     for line in reader: 
      inputs = [] 
      targets = [] 
      temp_image = cv2.imread(line[1]) # line[1] is path to image 
      measurement = line[3] # steering angle 
      inputs.append(temp_image) 
      targets.append(measurement) 
      batchCount += 1 
      if batchCount >= batchsz: 
       batchCount = 0 
       X = np.array(inputs) 
       y = np.array(targets) 
       yield X, y 
     csvfile.seek(0) 

Il lit un fichier csv contenant des données de télémétrie (angle de direction, etc.) et les chemins d'échantillons d'image et renvoie des tableaux de taille: BATCHSZ L'appel à fit_generator() ressemble à ceci:

tgen = gen('h:/Datasets/dataset14-no.zero.speed.trn.csv', batchsz = 128) # Train data generator 
vgen = gen('h:/Datasets/dataset14-no.zero.speed.val.csv', batchsz = 128) # Validation data generator 
try: 
    model.fit_generator(
     tgen, 
     samples_per_epoch=113526, 
     nb_epoch=6, 
     validation_data=vgen, 
     nb_val_samples=20001 
    ) 

l'ensemble de données contient 113526 points d'échantillonnage encore la sortie de mise à jour de formation de modèle lit comme celui-ci (par exemple):

1020/113526 [..............................] - ETA: 27737s - loss: 0.0080 
    1021/113526 [..............................] - ETA: 27723s - loss: 0.0080 
    1022/113526 [..............................] - ETA: 27709s - loss: 0.0080 
    1023/113526 [..............................] - ETA: 27696s - loss: 0.0080 

qui semble s'entraîner échantillon par échantillon (stochastique?). Le modèle résultant est inutile. Auparavant, je m'étais entraîné sur un ensemble de données beaucoup plus petit en utilisant .fit() avec l'ensemble de données chargé en mémoire, et cela produisait un modèle qui fonctionnait au moins même si mal. Il est clair que quelque chose ne va pas avec mon approche fit_generator(). Serai très reconnaissant pour de l'aide avec cela.

+1

'samples_per_epoch' doit être égale à' total_samples/batch_size' comme suggéré dans [documentation keras] (https://keras.io/models/sequential/). 'samples_per_epoch' spécifie le nombre de fois que le générateur est appelé avant qu'une date soit considérée comme terminée, il ne sait pas ce que' batch_size' vous utilisez – gionni

+0

Merci @gionni. Mise à jour de Keras 1.0.2 au plus tard. Les params fit-generator() ont plus de sens avec cette version. – tinyMind

Répondre

2

Ce:

for line in reader: 
    inputs = [] 
    targets = [] 

... est remise à zéro votre lot pour chaque ligne dans les fichiers csv. Vous n'êtes pas la formation avec vos données entières, mais avec un seul échantillon 128.

Suggestion:

for line in reader: 

    if batchCount == 0: 
     inputs = [] 
     targets = [] 
    .... 
    .... 

Comme quelqu'un a fait remarquer, le générateur en forme, samples_per_epoch doit être égale à total_samples/batchsz

Même si, je pense que votre perte devrait être en baisse de toute façon. Si ce n'est pas le cas, il se peut qu'il y ait encore un autre problème dans le code, peut-être dans la façon dont vous chargez les données, ou dans l'initialisation ou la structure du modèle.

Essayez de tracer vos images et imprimer les données dans le générateur:

for X,y in tgen: #careful, this is an infinite loop, make it stop 

    print(X.shape[0]) # is this really the number of batches you expect? 

    for image in X: 
     ...some method to plot X so you can see it, or just print  

    print(y) 

Vérifiez si les valeurs obtenues sont ok avec ce que vous attendez d'être.

+0

"... réinitialise votre lot pour chaque ligne dans les fichiers CSV." Doh! aurait dû repérer celui-là. Wierd parce que j'ai un code de test pour imprimer les tableaux yeilded, et ils sont des lots de la bonne taille et la séquence. – tinyMind

+0

A propos de la perte, récemment, j'ai eu un problème avec une perte "gelée". J'ai décidé de former un seul échantillon encore et encore pour de nombreuses époques et soudainement la perte a trouvé son chemin vers le bas.Ensuite, j'ai présenté d'autres exemples progressivement et il a commencé à s'entraîner correctement. Je suppose que le modèle était trop complexe ou que je n'avais pas initialisé correctement mes poids, il a donc fallu plus de temps pour montrer des résultats intéressants. –

+0

Merci Daniel. Semble être entrainé ok maintenant. La charge GPU est assez faible, comme si GPU attendait sur le script. – tinyMind