2016-05-02 2 views
13

J'ai construit un générateur simple qui donne tuple(inputs, targets) avec un seul élément dans les listes inputs et targets - essentiellement l'analyse de l'ensemble de données, un élément à la fois.Dans la méthode Keras model.fit_generator(), à quoi sert le paramètre contrôlé par la file d'attente du générateur "max_q_size"?

Je passe ce générateur en:

model.fit_generator(my_generator(), 
         nb_epoch=10, 
         samples_per_epoch=1, 
         max_q_size=1 # defaults to 10 
        ) 

que je reçois:

  • nb_epoch est le nombre de fois que le lot de formation sera exécuté
  • samples_per_epoch est le nombre d'échantillons formés avec par époque

Mais qu'est-ce que max_q_size pour et pourquoi serait-il par défaut à 10? Je pensais que le but de l'utilisation d'un générateur était de regrouper les ensembles de données en morceaux raisonnables, alors pourquoi la file d'attente supplémentaire?

Répondre

23

Ceci définit simplement la taille maximale de la file d'attente d'apprentissage interne qui est utilisée pour "précacher" vos échantillons du générateur. Il est utilisé lors de la génération des files d'attente

def generator_queue(generator, max_q_size=10, 
        wait_time=0.05, nb_worker=1): 
    '''Builds a threading queue out of a data generator. 
    Used in `fit_generator`, `evaluate_generator`, `predict_generator`. 
    ''' 
    q = queue.Queue() 
    _stop = threading.Event() 

    def data_generator_task(): 
     while not _stop.is_set(): 
      try: 
       if q.qsize() < max_q_size: 
        try: 
         generator_output = next(generator) 
        except ValueError: 
         continue 
        q.put(generator_output) 
       else: 
        time.sleep(wait_time) 
      except Exception: 
       _stop.set() 
       raise 

    generator_threads = [threading.Thread(target=data_generator_task) 
         for _ in range(nb_worker)] 

    for thread in generator_threads: 
     thread.daemon = True 
     thread.start() 

    return q, _stop 

En d'autres termes, vous avez un fil remplissant la file d'attente jusqu'à donné, la capacité maximale directement à partir de votre générateur, alors que (par exemple) routine d'entraînement consomme ses éléments (et parfois attend la fin)

while samples_seen < samples_per_epoch: 
    generator_output = None 
    while not _stop.is_set(): 
     if not data_gen_queue.empty(): 
      generator_output = data_gen_queue.get() 
      break 
     else: 
      time.sleep(wait_time) 

et pourquoi défaut de 10? Aucune raison particulière, comme la plupart des défauts - cela a tout simplement du sens, mais vous pouvez aussi utiliser des valeurs différentes.

Une telle construction suggère que les auteurs ont pensé à des générateurs de données onéreux, ce qui pourrait prendre du temps à s'exécuter. Par exemple envisager de télécharger des données sur un réseau en appel générateur - alors il est logique de précacher certains lots suivants, et télécharger les suivants en parallèle par souci d'efficacité et d'être robuste aux erreurs de réseau, etc

+2

Ah, je vois, Donc idéalement, vous n'arrêtez jamais de vous entraîner à attendre que le générateur génère des résultats - vous avez un thread qui remplit la file d'attente silencieusement à l'arrière pendant que le modèle s'entraîne sur les échantillons récupérés précédemment. – Ray

+1

Oui, c'est un scénario parfait. Ce qui dépend évidemment de la taille de la file d'attente et de la conception globale du système. – lejlot