2017-05-25 2 views
2

J'utilisais dynamic_rnn avec un LSTMCell, qui affichait un LSTMStateTuple contenant l'état interne. Appeler reshape sur cet objet (par erreur) donne un tenseur sans provoquer d'erreur lors de la création du graphe. Je n'ai pas non plus eu d'erreur lors de l'exécution lors de la saisie des données dans le graphique.Appeler reshape sur un LSTMStateTuple le transforme en un tenseur

code:

cell = tf.contrib.rnn.LSTMCell(size, state_is_tuple=True, ...) 
outputs, states = tf.nn.dynamic_rnn(cell, inputs, ...) 
print(states) # state is an LSTMStateTuple 
states = tf.reshape(states, [-1, size]) 
print(states) # state is a tensor of shape [?, size] 

Est-ce un bug (je demande parce qu'il est pas documenté nulle part)? Quel est le maintien du tenseur remodelé?

Répondre

0

J'ai mené une expérience similaire qui peut vous donne quelques conseils:

>>> s = tf.constant([[0, 0, 0, 1, 1, 1], 
        [2, 2, 2, 3, 3, 3]]) 
>>> t = tf.constant([[4, 4, 4, 5, 5, 5],                
        [6, 6, 6, 7, 7, 7]]) 
>>> g = tf.reshape((s, t), [-1, 3]) # <tf.Tensor 'Reshape_1:0' shape=(8, 3) dtype=int32> 
>>> sess.run(g) 
array([[0, 0, 0], 
     [1, 1, 1], 
     [2, 2, 2], 
     [3, 3, 3], 
     [4, 4, 4], 
     [5, 5, 5], 
     [6, 6, 6], 
     [7, 7, 7]], dtype=int32) 

Nous pouvons voir qu'il concaténer juste deux tenseurs dans la première dimension et effectue la remise en forme. Puisque le LSTMStateTuple est comme un namedtuple, il a le même effet que tuple et je pense que c'est aussi ce qui se passe dans votre cas.

Allons plus loin,

>>> st = tf.contrib.rnn.LSTMStateTuple(s, t) 
>>> gg = tf.reshape(st, [-1, 3]) 
>>> sess.run(gg) 
    array([[0, 0, 0], 
      [1, 1, 1], 
      [2, 2, 2], 
      [3, 3, 3], 
      [4, 4, 4], 
      [5, 5, 5], 
      [6, 6, 6], 
      [7, 7, 7]], dtype=int32) 

Nous pouvons voir que si nous créons un LSTMStateTuple, le résultat vérifie notre hypothèse.