J'ai essayé de trouver des exemples simples que j'ai créés en travaillant, parce que je trouve les exemples donnés avec de grands ensembles de données complexes difficiles à saisir intuitivement. Le programme ci-dessous prend une liste de poids [x_0 x_1 ... x_n]
et les utilise pour créer une diffusion aléatoire de points sur un plan avec un bruit aléatoire ajouté. Ensuite, j'entraîne les réseaux neuronaux simples sur ces données et vérifie les résultats. Quand je fais cela avec les modèles Graph tout fonctionne parfaitement, le score de perte descend à zéro de façon prévisible lorsque le modèle converge sur les poids donnés. Cependant, lorsque j'essaie d'utiliser un modèle séquentiel, rien ne se passe. Code ci-dessousEssayer d'obtenir un simple exemple de réseau neuronal Keras
Si vous voulez, je peux poster mon autre script qui utilise le graphique au lieu de séquentiel et montrer qu'il trouve les poids d'entrée parfaitement.
#!/usr/bin/env python
from keras.models import Sequential, Graph
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD
import numpy as np
import theano, sys
NUM_TRAIN = 100000
NUM_TEST = 10000
INDIM = 3
mn = 1
def myrand(a, b) :
return (b)*(np.random.random_sample()-0.5)+a
def get_data(count, ws, xno, bounds=100, rweight=0.0) :
xt = np.random.rand(count, len(ws))
xt = np.multiply(bounds, xt)
yt = np.random.rand(count, 1)
ws = np.array(ws, dtype=np.float)
xno = np.array([float(xno) + rweight*myrand(-mn, mn) for x in xt], dtype=np.float)
yt = np.dot(xt, ws)
yt = np.add(yt, xno)
return (xt, yt)
if __name__ == '__main__' :
if len(sys.argv) > 1 :
EPOCHS = int(sys.argv[1])
XNO = float(sys.argv[2])
WS = [float(x) for x in sys.argv[3:]]
mx = max([abs(x) for x in (WS+[XNO])])
mn = min([abs(x) for x in (WS+[XNO])])
mn = min(1, mn)
WS = [float(x)/mx for x in WS]
XNO = float(XNO)/mx
INDIM = len(WS)
else :
INDIM = 3
WS = [2.0, 1.0, 0.5]
XNO = 2.2
X_test, y_test = get_data(10000, WS, XNO, 10000, rweight=0.4)
X_train, y_train = get_data(100000, WS, XNO, 10000)
model = Sequential()
model.add(Dense(INDIM, input_dim=INDIM, init='uniform', activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(2, init='uniform', activation='tanh'))
model.add(Dropout(0.5))
model.add(Dense(1, init='uniform', activation='softmax'))
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='mean_squared_error', optimizer=sgd)
model.fit(X_train, y_train, shuffle="batch", show_accuracy=True, nb_epoch=EPOCHS)
score, acc = model.evaluate(X_test, y_test, batch_size=16, show_accuracy=True)
print score
print acc
predict_data = np.random.rand(100*100, INDIM)
predictions = model.predict(predict_data)
for x in range(len(predict_data)) :
print "%s --> %s" % (str(predict_data[x]), str(predictions[x]))
La sortie est comme suit
$ ./keras_hello.py 20 10 5 4 3 2 1
Using gpu device 0: GeForce GTX 970 (CNMeM is disabled)
Epoch 1/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 2/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 3/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 4/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 5/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 6/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 7/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 8/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 9/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 10/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 11/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 12/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 13/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 14/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 15/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 16/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 17/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 18/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 19/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
Epoch 20/20
100000/100000 [==============================] - 0s - loss: 60726734.3061 - acc: 1.0000
10000/10000 [==============================] - 0s
60247198.6661
1.0
[ 0.06698217 0.70033048 0.4317502 0.78504855 0.26173543] --> [ 1.]
[ 0.28940025 0.21746189 0.93097653 0.94885535 0.56790348] --> [ 1.]
[ 0.69430499 0.1622601 0.22802859 0.75709315 0.88948355] --> [ 1.]
[ 0.90714721 0.99918648 0.31404901 0.83920051 0.84081288] --> [ 1.]
[ 0.02214092 0.03132355 0.14417082 0.33901317 0.91491426] --> [ 1.]
[ 0.31426055 0.80830795 0.46686523 0.58353359 0.50000842] --> [ 1.]
[ 0.27649579 0.77914451 0.33572287 0.08703303 0.50865592] --> [ 1.]
[ 0.99280349 0.24028343 0.05556034 0.31411902 0.41912574] --> [ 1.]
[ 0.91897031 0.96840695 0.23561379 0.16005505 0.06567748] --> [ 1.]
[ 0.27392867 0.44021533 0.44129147 0.40658522 0.47582736] --> [ 1.]
[ 0.82063221 0.95182938 0.64210378 0.69578691 0.2946907 ] --> [ 1.]
[ 0.12672415 0.35700418 0.89303047 0.80726545 0.79870725] --> [ 1.]
[ 0.6662085 0.41358115 0.76637022 0.82093095 0.76973305] --> [ 1.]
[ 0.96201937 0.29706843 0.22856618 0.59924945 0.05653825] --> [ 1.]
[ 0.34120276 0.71866377 0.18758929 0.52424856 0.64061623] --> [ 1.]
[ 0.25471237 0.35001821 0.63248632 0.45442404 0.96967989] --> [ 1.]
[ 0.79390087 0.00100834 0.49645204 0.55574269 0.33487764] --> [ 1.]
[ 0.41330261 0.38061826 0.33766183 0.23133121 0.80999653] --> [ 1.]
[ 0.49603561 0.33414841 0.10180184 0.9227252 0.35073833] --> [ 1.]
[ 0.17960345 0.05259438 0.565135 0.40465603 0.91518233] --> [ 1.]
[ 0.36129943 0.903603 0.63047644 0.96553285 0.94006713] --> [ 1.]
[ 0.7150973 0.93945141 0.31802763 0.15849441 0.92902078] --> [ 1.]
[ 0.23730571 0.65360248 0.68776259 0.79697206 0.86814652] --> [ 1.]
[ 0.47414382 0.75421265 0.32531333 0.43218305 0.4680773 ] --> [ 1.]
[ 0.4887811 0.66130135 0.79913557 0.68948405 0.48376372] --> [ 1.]
....