2017-10-07 6 views
0

Donc, j'essaie d'utiliser un rnn en tensorflow pour générer du texte. Cependant, une fois que je suis passé d'un static_rnn à un dynamic_rnn, je suis arrivé cette erreur:Tensorflow dynamic_rnn entrée erreur de rang

File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/tensor_shape.py", line 654, in with_rank_at_least 
    raise ValueError("Shape %s must have rank at least %d" % (self, rank)) 
ValueError: Shape (100, 5) must have rank at least 3 

Ceci est la partie du code qui génère l'erreur:

inputs_series = self.input_layer() 
with tf.variable_scope(constants.HIDDEN): 
    self.hidden_state_placeholder = tf.placeholder(
     dtype=tf.float32, 
     shape=[self.settings.train.batch_size, self.settings.rnn.hidden_size], 
     name="hidden_state_placeholder") 
    cell = tf.contrib.rnn.GRUCell(self.settings.rnn.hidden_size) 
    states_series, self.current_state = tf.nn.dynamic_rnn(
     cell=cell, 
     inputs=inputs_series, 
     initial_state=self.hidden_state_placeholder) 

La forme de inputs_series est: (10, 5, 100), correspondant à (longueur de texte tronquée, taille de lot, nombre de classes)

La forme du hidden_state_placeholder est (5, 100) pour (taille de lot, taille de l'état masqué), mais l'erreur persiste même lorsque je ne fournis pas d'état initial.

La version tensorflow est 1.3, si cela aide.

Toute idée serait appréciée!

Répondre

0

Par défaut, time_major == False dans tf.nn.dynamic_rnn, mais votre inputs_series est time_major == True. Alors peut-être changer la dernière déclaration à

states_series, self.current_state = tf.nn.dynamic_rnn(
    cell=cell, 
    inputs=inputs_series, 
    initial_state=self.hidden_state_placeholder, 
    time_major=True) 
+0

Désolé pour la réponse tardive - Je ai juste essayé de cela, et cela n'a pas résolu le problème - J'ai reçu le même message d'erreur. C'était ** un bug légitime dans mon code, merci d'avoir attrapé ça! – frankie

+0

'inputs_series = self.input_layer()' Je n'ai pas pu tester cette instruction, donc je la remplace par une variable de taille fixe (ainsi que d'autres paramètres inconnus). Veuillez imprimer la taille de 'inputs_series' (par exemple' print (inputs_series.get_shape()) '). Je suppose que cela pourrait avoir un problème. –

+0

J'ai fait les tirages, et je l'ai compris. Une opération 'tf.unstack()' s'est déroulée dans 'self.input_layer()', qui a converti le tenseur en une liste de 10 tenseurs. Apparemment, cela ne joue pas bien avec l'API dynamique rnn. Une fois que j'ai supprimé l'opération de désempilage, tout a fonctionné comme prévu. Merci pour toutes les suggestions! – frankie