2017-05-15 3 views
0

Je suis en utilisant DNN de tflearn, et je veux changer mes caractéristiques et les étiquettes pour être catégoriques et non numériques.TFlearn à catégorique

ici est mon filet:

x = tf.placeholder(dtype= tf.float32, shape=[None, 6], name='x') 
# Build neural network 
input_layer = tflearn.input_data(shape=[None, 6]) 
net = input_layer 
net = tflearn.fully_connected(net, 128, activation='relu') 
net = tflearn.fully_connected(net, 64, activation='relu') 
net = tflearn.fully_connected(net, 16, activation='relu') 
net = tflearn.fully_connected(net, 2, activation='sigmoid') 
net = tflearn.regression(net, optimizer='adam', loss='mean_square', metric='R2') 

w = tf.Variable(tf.truncated_normal([2, 2], stddev=0.1)) 
b = tf.Variable(tf.constant(1.0, shape=[2])) 
y = tf.nn.softmax(tf.matmul(net, w) + b, name='y') 

model = tflearn.DNN(net, tensorboard_verbose=3) 
return model 

Je sais tflearn.data_utils.to_categorical mais je ne sais pas comment injecter cette méthode. grâce

EDIT: J'ai essayé quelques petites choses, comme:

train_goal = tflearn.data_utils.to_categorical(train_goal, nb_classes=2) 
      test_goal = tflearn.data_utils.to_categorical(test_goal, nb_classes=2) 

et changer aussi la perte:

net = tflearn.regression(net, optimizer='adadelta', loss='categorical_crossentropy', metric= self.accuracy) 

mais je suis une perte de plus de 1:

Training Step: 35 | total loss: 1.64734 | time: 1.322s 
| AdaDelta | epoch: 001 | loss: 1.64734 - acc: 1.0000 | val_loss: 1.64313 - val_acc: 1.0000 -- iter: 2204/2204 
-- 
Training Step: 70 | total loss: 1.61961 | time: 0.216s 
| AdaDelta | epoch: 002 | loss: 1.61961 - acc: 1.0000 | val_loss: 0.00000 - val_acc: 0.0000 -- iter: 2204/2204 
-- 
Training Step: 105 | total loss: 1.58511 | time: 1.188s 
| AdaDelta | epoch: 003 | loss: 1.58511 - acc: 1.0000 | val_loss: 1.57300 - val_acc: 1.0000 -- iter: 2204/2204 

où est le problème?

Répondre

0

J'ai une erreur similaire, aussi une perte très élevée. Essayez d'utiliser train_goal.T [0] au lieu de train_goal. Assurez-vous que l'entrée y de to_categorical a la forme suivante: (n,)