2016-08-10 2 views
2

Je m'entraînais en LSTM bidirectionnel de type RNN pendant près de 24 heures, et en raison de l'oscillation dans l'erreur j'ai décidé de diminuer l'apprentissage avant de lui permettre de continuer l'entraînement. Depuis que le modèle est enregistré en utilisant Saver.save (sess, fichier) à chaque époque, j'ai terminé l'entraînement avec la perte de CTC ayant réduit à environ 115.TensorFlow - Saver.restore ne restaure pas tous les paramètres

Maintenant, après la restauration du modèle, le taux d'erreur initial que je reçois est quelque part autour de 162, ce qui est incompatible avec le flux d'erreur que je recevais à la 7e époque, et c'est aussi ce que j'ai eu à la première époque. Donc, j'ai l'impression que la fonction "restore" ne fonctionne pas ou si elle fonctionne, alors il doit y avoir quelque chose d'autre qui ne lui permet pas de prendre effet.

Voici mon code:

graph = tf.Graph() 
    with graph.as_default(): 
     # Graph creation 
     graph_start = time.time() 
     seq_inputs = tf.placeholder(tf.float32, shape=  [None,batch_size,frame_length], name="sequence_inputs") 
     seq_lens = tf.placeholder(shape=[batch_size],dtype=tf.int32) 
     seq_inputs = seq_bn(seq_inputs,seq_lens) 

     initializer = tf.truncated_normal_initializer(mean=0,stddev=0.1) 
     forward = tf.nn.rnn_cell.LSTMCell(num_units=num_units, 
              num_proj = hidden_size, 
              use_peepholes=use_peephole, 
              initializer=initializer, 
              state_is_tuple=True) 

     forward = tf.nn.rnn_cell.MultiRNNCell([forward] * n_layers, state_is_tuple=True) 

     backward = tf.nn.rnn_cell.LSTMCell(num_units=num_units, 
              num_proj= hidden_size, 
              use_peepholes=use_peephole, 
              initializer=initializer, 
              state_is_tuple=True) 

     backward = tf.nn.rnn_cell.MultiRNNCell([backward] * n_layers, state_is_tuple=True) 

     [fw_out,bw_out], _ = tf.nn.bidirectional_dynamic_rnn(cell_fw=forward, cell_bw=backward, inputs=seq_inputs,time_major=True, dtype=tf.float32,            sequence_length=tf.cast(seq_lens,tf.int64)) 


     # Batch normalize forward output 
     mew,var_ = tf.nn.moments(fw_out,axes=[0]) 
     fw_out = tf.nn.batch_normalization(fw_out,mew,var_,0.1,1,1e-6) 
     # fw_out = seq_bn(fw_out,seq_lens) 

     # Batch normalize backward output 
     mew,var_ = tf.nn.moments(bw_out,axes=[0]) 
     bw_out = tf.nn.batch_normalization(bw_out,mew,var_,0.1,1,1e-6) 
     # bw_out = seq_bn(bw_out,seq_lens) 

     # Reshaping forward, and backward outputs for affine transformation 
     fw_out = tf.reshape(fw_out,[-1,hidden_size]) 
     bw_out = tf.reshape(bw_out,[-1,hidden_size]) 

     # Linear Layer params 
     W_fw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0/(hidden_size)))) 
     W_bw = tf.Variable(tf.truncated_normal(shape=[hidden_size,n_chars],stddev=np.sqrt(2.0/(hidden_size)))) 
     b_out = tf.constant(0.1,shape=[n_chars]) 

     # Perform an affine transformation 
     logits = tf.add(tf.add(tf.matmul(fw_out,W_fw),tf.matmul(bw_out,W_bw)),b_out) 
     logits = tf.reshape(logits,[-1,batch_size,n_chars]) 

     # Use CTC Beam Search Decoder to decode pred string from the prob map 
     decoded, log_prob = tf.nn.ctc_beam_search_decoder(logits, seq_lens) 

     # Target params 
     indices = tf.placeholder(dtype=tf.int64, shape=[None,2]) 
     values = tf.placeholder(dtype=tf.int32, shape=[None]) 
     shape = tf.placeholder(dtype=tf.int64,shape=[2]) 
     # Make targets 
     targets = tf.SparseTensor(indices,values,shape) 

     # Compute Loss 
     loss = tf.reduce_mean(tf.nn.ctc_loss(logits, targets, seq_lens)) 
     # Compute error rate based on edit distance 
     predicted = tf.to_int32(decoded[0]) 
     error_rate = tf.reduce_sum(tf.edit_distance(predicted,targets,normalize=False))/ \ 
     tf.to_float(tf.size(targets.values))  

     tvars = tf.trainable_variables() 
     grad, _ = tf.clip_by_global_norm(tf.gradients(loss,tvars),max_grad_norm) 
     optimizer = tf.train.MomentumOptimizer(learning_rate=lr,momentum=momentum) 
     train_step = optimizer.apply_gradients(zip(grad,tvars)) 
     graph_end = time.time() 
     print("Time elapsed for creating graph: %.3f"%(round(graph_end-graph_start,3))) 
     # steps per epoch 
     start_time = 0 
     steps = int(np.ceil(len(data_train.files)/batch_size)) 

     loss_tr = [] 
     log_tr = [] 
     loss_vl = [] 
     log_vl = [] 
     err_tr = [] 
     err_vl = [] 
     saver = tf.train.Saver() 
     with tf.Session(config=config) as sess: 
      #sess.run(tf.initialize_all_variables()) 
      checkpt_path = tf.train.latest_checkpoint(checkpoint_dir) 
      print(saver.restore(sess,checkpt_path)) 
      print("Model restore from 7th epoch 188th step") 
      feed = None 
      epoch = None 
      step = None 
      try: 
       for epoch in range(7,epochs+1): 
        if epoch==7: 
         initial_step = 189 
        else: 
         initial_step = 0 
        transcript = [] 
        loss_val = 0 
        l_pr = 0 
        start_time = time.time() 
        for step in range(initial_step,steps): 
         train_data, transcript, \ 
         targ_indices, targ_values, \ 
         targ_shape, n_frames = data_train.next_batch() 
         n_frames = np.reshape(n_frames,[-1]) 
         feed = {seq_inputs: train_data, indices:targ_indices, values:targ_values, shape:targ_shape, seq_lens:n_frames} 
         del train_data,targ_indices,targ_values,targ_shape,n_frames 

         # Evaluate loss value, decoded transcript, and log probability 
         _,loss_val,deco,l_pr,err_rt_tr = sess.run([train_step,loss,decoded,log_prob,error_rate], 
                  feed_dict=feed) 
         del feed 
         loss_tr.append(loss_val) 
         log_tr.append(l_pr) 
         err_tr.append(err_rt_tr) 

         # On validation set 
         val_data, val_transcript, \ 
         targ_indices, targ_values, \ 
         targ_shape, n_frames = data_val.next_batch() 
         n_frames = np.reshape(n_frames, [-1]) 
         feed = {seq_inputs: val_data, indices: targ_indices,values: targ_values, shape: targ_shape, seq_lens: n_frames} 
         del val_data, val_transcript,targ_indices,targ_values,targ_shape,n_frames 
        vl_loss, l_val_pr, err_rt_vl = sess.run([loss, log_prob, error_rate], feed_dict=feed) 
         del feed 
         loss_vl.append(vl_loss) 
         log_vl.append(l_val_pr) 
         err_vl.append(err_rt_vl) 
         print("epoch %d, step: %d, tr_loss: %.2f, vl_loss: %.2f, tr_err: %.2f, vl_err: %.2f" 
          % (epoch, step, np.mean(loss_tr), np.mean(loss_vl), err_rt_tr, err_rt_vl)) 

        end_time = time.time() 
        elapsed = round(end_time - start_time, 3) 

        # On training set 
        # Select a random index within batch_size 
        sample_index = np.random.randint(0, batch_size) 

        # Fetch the target transcript 
        actual_str = [data_train.reverse_map[i] for i in transcript[sample_index]] 

        # Fetch the decoded path from probability map 
        pred_sparse = tf.SparseTensor(deco[0].indices, deco[0].values, deco[0].shape) 
        pred_dense = tf.sparse_tensor_to_dense(pred_sparse) 
        ans = pred_dense.eval() 
        #pred = [data_train.reverse_map[i] for i in ans[sample_index, :]] 
        pred = [] 
        for i in ans[sample_index,:]: 
         if i == n_chars-1: 
          pred.append(data_train.reverse_map[0]) 
         else: 
          pred.append(data_train.reverse_map[i]) 
        print("time_elapsed for 200 steps: %.3f, " % (elapsed)) 
        if epoch%2 == 0: 
         print("Sample mini-batch results: \n" \ 
           "predicted string: ", np.array(pred)) 
         print("actual string: ", np.array(actual_str)) 
        print("On training set, the loss: %.2f, log_pr: %.3f, error rate %.3f:"% (loss_val, np.mean(l_pr), err_rt_tr)) 
        print("On validation set, the loss: %.2f, log_pr: %.3f, error rate: %.3f" % (vl_loss, np.mean(l_val_pr), err_rt_vl)) 

        # Save the trainable parameters after the end of an epoch 
        if epoch > 7: 
         path = saver.save(sess, 'model_%d' % epoch) 
        print("Session saved at: %s" % path) 
        np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object)) 
      except (KeyboardInterrupt, SystemExit, Exception), e: 
       print("Error/Interruption: %s" % str(e)) 
       exc_type, exc_obj, exc_tb = sys.exc_info() 
       print("Line no: %d" % exc_tb.tb_lineno) 
       if epoch > 7: 
        print("Saving model: %s" % saver.save(sess, 'Last.cpkt')) 
       print("Current batch: %d" % data_train.b_id) 
       print("Current epoch: %d" % epoch) 
       print("Current step: %d"%step) 
       np.save(results_fn, np.array([loss_tr, log_tr, loss_vl, log_vl, err_tr, err_vl], dtype=np.object)) 
       print("Clossing TF Session...") 
       sess.close() 
       print("Terminating Program...") 
       sys.exit(0) 
+0

si vous n'êtes pas en cours d'exécution 'initialize_all_variables', puis restaurer doit être obtenir toutes les variables de postes de contrôle (ou vous obtiendrez une erreur variable non initialisée) –

+0

BTW, un modèle commun de détecter rapidement ce genre de problèmes de point de contrôle faire l'évaluation en parallèle dans un processus différent, en même temps que le programme principal –

+0

@YaroslavBulatov J'ai d'abord eu l'appel à restaurer après l'initialisation des variables, mais dans un blog, j'ai lu qu'il n'est pas nécessaire que les variables doivent être initialisées lors de la restauration du point de contrôle, d'où je l'ai commenté. Je ne reçois aucune erreur, le programme fonctionne bien. Mon inquiétude est que ce n'est probablement PAS en restaurant les paramètres du modèle dans l'état où il est sauvegardé dans le fichier puisque je reçois le taux d'erreur d'entraînement que j'ai obtenu à la première époque. –

Répondre

0

Je pense que vous devez réinitialiser vos accumulateurs pour chaque époque.

Donc ceux-ci doivent être placés au début, à l'intérieur de la boucle.

loss_tr = [] 
log_tr = [] 
loss_vl = [] 
log_vl = [] 
err_tr = [] 
err_vl = []