J'essaie de créer une classe pour initialiser rapidement et former un autoencoder pour le prototypage rapide. Une chose que j'aimerais pouvoir faire est d'ajuster rapidement le nombre d'époques pour lesquelles je m'entraîne. Cependant, il semble que peu importe ce que je fais, le modèle entraîne chaque couche pour 100 époques! J'utilise le backend tensorflow.Keras: Mauvais nombre d'épisodes de formation
Voici le code des deux méthodes incriminées.
def pretrain(self, X_train, nb_epoch = 10):
data = X_train
for ae in self.pretrains:
ae.fit(data, data, nb_epoch = nb_epoch)
ae.layers[0].output_reconstruction = False
ae.compile(optimizer='sgd', loss='mse')
data = ae.predict(data)
.........
def fine_train(self, X_train, nb_epoch):
weights = [ae.layers[0].get_weights() for ae in self.pretrains]
dims = self.dims
encoder = containers.Sequential()
decoder = containers.Sequential()
## add special input encoder
encoder.add(Dense(output_dim = dims[1], input_dim = dims[0],
weights = weights[0][0:2], activation = 'linear'))
## add the rest of the encoders
for i in range(1, len(dims) - 1):
encoder.add(Dense(output_dim = dims[i+1],
weights = weights[i][0:2], activation = self.act))
## add the decoders from the end
decoder.add(Dense(output_dim = dims[len(dims) - 2], input_dim = dims[len(dims) - 1],
weights = weights[len(dims) - 2][2:4], activation = self.act))
for i in range(len(dims) - 2, 1, -1):
decoder.add(Dense(output_dim = dims[i - 1],
weights = weights[i-1][2:4], activation = self.act))
## add the output layer decoder
decoder.add(Dense(output_dim = dims[0],
weights = weights[0][2:4], activation = 'linear'))
masterAE = AutoEncoder(encoder = encoder, decoder = decoder)
masterModel = models.Sequential()
masterModel.add(masterAE)
masterModel.compile(optimizer = 'sgd', loss = 'mse')
masterModel.fit(X_train, X_train, nb_epoch = nb_epoch)
self.model = masterModel
Des suggestions sur la façon de résoudre le problème seraient appréciées. Mon soupçon original était que c'était quelque chose à voir avec tensorflow, donc j'ai essayé de courir avec le backend theano mais j'ai rencontré le même problème.
Here est un lien vers le programme complet.