2016-05-28 1 views
3

J'ai besoin d'aide pour essayer de corriger ce code pour un autoencoder simple dans Keras. J'essayais d'ajouter un pré-traitement d'image pour le tutoriel d'autoencoder sur le blog de Keras. Voilà ce que je l'ai faitErreur ImageDataGenerator

input_image = Input(shape=(1,256,256,)) 
flattened = Flatten()(input_image) 
encoded = Dense(128,activation='relu',name='Dense1')(flattened) 
decoded = Dense(256*256, activation='sigmoid',name='Dense2')(encoded) 
output_image = Reshape((1,256,256,))(decoded) 
autoencoder = Model(input_image,output_image) 
autoencoder.compile(optimizer='adadelta', loss='binary_crossentropy') 

datagen = ImageDataGenerator(
    rotation_range=20, 
    width_shift_range=0.2, 
    height_shift_range=0.2, 
    horizontal_flip=True) 

autoencoder.fit_generator(datagen.flow(train_imgs, train_imgs, 
      batch_size=32), 
      samples_per_epoch=train_imgs.shape[0], 
      nb_epoch=50, 
      validation_data=(test_imgs,test_imgs)) 

train_imgs a une forme (1000,256,256) où 1000 est le nombre d'échantillons de formation. test_imgs a une forme (50,256,256).

C'est l'erreur que je suis

Exception: sortie du générateur doit être un tuple (x, y, sample_weight) ou (x, y). Trouvé: aucun

Ceci a été généré par la fonction fit_generator.

Répondre

-1

Je pense que vous avez oublié d'adapter le modèle Datagen. S'il vous plaît ajouter datagen.fit(train_imgs) avant autoencoder.fit_generator et essayer de former votre modèle.

+0

C'est pas. Obtenez la même erreur. – user2775878

4

J'ai inventé cette chose moi-même. Il s'avère que ImageDataGenerator suppose que l'entrée est dans la forme (number_of_samples, number_of_channels, width, height). Remodelage train_imgs et test_imgs fait l'affaire. J'ai modifié le code dans la question pour inclure cette dimension supplémentaire.

0

Vous devez changer class_mode pour 'entrée' comme ceci:

autoencoder.fit_generator(datagen.flow(train_imgs, train_imgs, 
     batch_size=32,class_mode='input'), 
     samples_per_epoch=train_imgs.shape[0], 
     nb_epoch=50, 
     validation_data=(test_imgs,test_imgs)) 

Vous pouvez en lire plus here