2016-12-19 2 views
4

Je trouve ce qui suit sur la page d'accueil de la documentation tensorflow pour l'utilisation de l'opération de matmul lorsque rang> 2:opération tensorflow de matmul pour le rang> 2 ne fonctionne pas

https://www.tensorflow.org/api_docs/python/math_ops/matrix_math_functions#matmul

# 3-D tensor `a` 
a = tf.constant(np.arange(1,13), shape=[2, 2, 3]) => [[[ 1. 2. 3.] 
                [ 4. 5. 6.]], 
                [[ 7. 8. 9.] 
                [10. 11. 12.]]] 

# 3-D tensor `b` 
b = tf.constant(np.arange(13,25), shape=[2, 3, 2]) => [[[13. 14.] 
                [15. 16.] 
                [17. 18.]], 
                [[19. 20.] 
                [21. 22.] 
                [23. 24.]]] 
c = tf.matmul(a, b) => [[[ 94 100] 
        [229 244]], 
        [[508 532] 
        [697 730]]] 

Il est tout simplement ne fonctionne pas quand je le branche en Python. Je reçois

c = tf.matmul(a, b) 
ValueError: Shape must be rank 2 but is rank 3 

Quelqu'un sait ce qui ne va pas?

+0

Je suis pas familier avec tensorflow donc je ne sais pas combien un tableau de tensorflow est comparé à un tableau numpy, mais vous pouvez essayer 'numpy.dot' (mais l'ordre des dimensions pourrait être différent là-haut). Quoi qu'il en soit, il est possible que votre version de tensorflow soit plus ancienne que ce à quoi correspond la documentation (en supposant que la documentation soit correcte). –

Répondre

1

Votre TensorFlow est-il trop ancien? Voici ce que je reçois dans la version 0.12rc0

a = tf.constant(np.arange(1,13).astype(np.float32), shape=[2, 2, 3]) 
b = tf.constant(np.arange(13,25).astype(np.float32), shape=[2, 3, 2]) 
sess.run(tf.matmul(a, b)) => 

array([[[ 94., 100.], 
     [ 229., 244.]], 

     [[ 508., 532.], 
     [ 697., 730.]]], dtype=float32) 
+0

Merci, mais votre code me donne la même erreur ... Installé tensorflow il y a quelques semaines, puis-je avoir la mauvaise version? Savez-vous comment je vérifie quelle version je cours actuellement? Merci – user7318197

+0

@ user7318197 la plupart des modules supportent la méthode '__version__', alors essayez d'imprimer' tf .__ version__'. –

+0

D'accord, merci. Je reçois "0.11.0" comme réponse. On dirait qu'il y a une version 0.12.0, donc jusqu'à essayer de l'installer et de voir si ça aide. – user7318197