2017-09-27 3 views
-1

J'ai implémenté un MLP simple dans tensorflow. La structure est une classe NeuralNet:Tensoflow: l'appel de la prédiction d'une fonction renvoie 'RuntimeError: tentative d'utilisation d'une session fermée'

class NeuralNet: 
    def __init__(self, **options): 
     self.type = options.get('net_type') # MLP, CNN, RNN 
     self.n_class = options.get('classes') 
     self.alpha = options.get('alpha') 
     self.batch_size = options.get('batch_size') 
     self.epoch = options.get('epochs') 
     self.model = {} 

Il a 3 fonctions différentes:

  • Fit:

    def fit (self, features, labels): 
        if self.type == 'MLP': 
        input_size = len(features[0]) 
        n_nodes_hl1 = input_size//5 
        batch_size = 50 
    
        sess = tf.InteractiveSession() 
    
        x = tf.placeholder(tf.float32, [None, input_size]) 
        y = tf.placeholder(tf.float32, [None, self.n_class]) 
        labels = self.labels_to_onehot(labels) 
    
        weights = {'hidden_1': tf.Variable(tf.random_normal([input_size, n_nodes_hl1])), 
            'output': tf.Variable(tf.random_normal([n_nodes_hl1, self.n_class]))} 
    
        biases = {'hidden_1': tf.Variable(tf.random_normal([n_nodes_hl1])), 
            'output': tf.Variable(tf.random_normal([self.n_class]))} 
    
        def neural_network_model(data, weight, bias): 
    
          l1 = tf.add(tf.matmul(data, weight['hidden_1']), bias['hidden_1']) 
          l1 = tf.nn.relu(l1) 
    
          output = tf.matmul(l1, weight['output']) + bias['output'] 
    
          return output 
    
        sess.run(tf.global_variables_initializer()) 
    
        prediction = neural_network_model(x, weights, biases) 
        l2 = self.alpha * tf.nn.l2_loss(weights['hidden_1']) + self.alpha * tf.nn.l2_loss(weights['output']) 
        cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=prediction)+l2) 
        train_step = tf.train.AdamOptimizer(0.005).minimize(cross_entropy) 
    
        sess=tf.Session() 
        sess.run(tf.global_variables_initializer()) 
    
        for epoch in range(self.epoch): 
          epoch_loss = 0 
          i = 0 
          while i < len(features): 
           start = i 
           end = i + batch_size 
           batch_x = np.array(features[start:end]) 
           batch_y = np.array(labels[start:end]) 
    
           _, c = sess.run([train_step, cross_entropy], feed_dict={x: batch_x, 
                       y: batch_y}) 
           epoch_loss += c 
           i += batch_size 
    
    
    
        self.model['session'] = sess 
        self.model['y'] = y 
        self.model['x'] = x 
        self.model['prediction'] = prediction 
    
  • Test (précision test):

    def test(self, test_features, test_labels): 
        with self.model['session']: 
         test_labels = np.eye(self.n_class)[[int(int(i)/2) for i in test_labels]] 
         correct = tf.equal(tf.argmax(self.model['prediction'], 1), tf.argmax(self.model['y'], 1)) 
         accuracy = tf.reduce_mean(tf.cast(correct, 'float')) 
         accuracy = accuracy.eval({self.model['x']: test_features, self.model['y']: test_labels}) 
         print('Accuracy:', accuracy) 
         return accuracy 
    
  • Pronostiez

    def predict(self, test_features): 
        with self.model['session']: 
         pred = self.model['prediction'] 
         predicted = pred.eval({self.model['x']: test_features}) 
         return predicted 
    

Lors de l'exécution de la méthode prédire, il renvoie une RuntimeError: ('Attempted to use a closed Session.')

Ma question est:

Pourquoi la méthode bon déroulement des test, tout en appelant la session de la même manière la méthode predict échoue?

Aurais-je besoin de créer un objet tf et de l'évaluer? Si oui, quel objet devrait-il être?

Répondre

0

Je ne pouvais pas exécuter votre code, mais j'ai une hypothèse. peut-être que vous exécutez la fonction de test après l'installation est terminée. Dans la fonction de test, vous gérez la session à l'aide du gestionnaire de contextes ('with block'). donc, mon hypothèse est, votre session est automatiquement fermée après que le bloc du gestionnaire de contexte de session est terminé.

def test(self, test_features, test_labels): 
    with self.model['session']: 
     test_labels = np.eye(self.n_class)[[int(int(i)/2) for i in test_labels]] 
     correct = tf.equal(tf.argmax(self.model['prediction'], 1), tf.argmax(self.model['y'], 1)) 
     accuracy = tf.reduce_mean(tf.cast(correct, 'float')) 
     accuracy = accuracy.eval({self.model['x']: test_features, self.model['y']: test_labels}) 
     print('Accuracy:', accuracy) 
     return accuracy 
    ## at this point, your session is maybe closed. 

Si mon hypothèse est juste, vous pouvez juste passer le sess, et exécuter le graphique en utilisant la méthode sess.run (~) et fermer manuellement utiliser sess.close().

P.S. Pourquoi avez-vous d'abord assigner interactiveSession() à sess, et lancez tf.global_variables_initializer()? Je pense que vous pouvez juste utiliser tf.global_variables_initializer() une fois, entre la construction du graphe de point est terminée et l'entraînement de point de départ.

MISE À JOUR Dans mon hypothèse, la fonction que vous exécutez à la première fois n'a pas d'importance, parce que les deux l'utilisation de la fonction « avec » bloc. toute fonction qui s'exécute en premier fermera la session à la fin de block.

Je suggère de passer le sess, cela signifie que,

def fit(self, ~): 
    # construct graph 
    self.sess = tf.Session() 

def test(self, ~): 
    # codes will be here 
    acc_val = self.sess.run([accuracy], feed_dict={~}) 
    return acc_val 

def predict(self, ~): 
    # codes will be here 
    predicted = self.sess.run([pred], feed_dict={~}) 
    return predicted 

J'espère que ce code vous donne une intuition.

+0

Salut, merci pour votre réponse. Cela m'a donné une bonne compréhension du problème, mais pas une solution réelle. En fait, il n'y a pas de problème quand je lance un test ou une prédiction en premier, mais je ne peux pas utiliser l'un d'entre eux une seconde fois. Mais je ne peux pas comprendre comment utiliser ce que vous avez mentionné. Pourriez-vous s'il vous plaît mettre à jour votre code? – ylnor