2016-06-17 5 views
0

Je veux transformer un tenseur (appeler logits) de la formeMNIST Tensorflow: Comment manipuler un tenseur de la forme [i] vers un tenseur d'une forme [... 0,0,0,1,0,0 ...] où 1 est à la position i?

int32 - [batch_size] 

à un tenseur (appeler étiquettes) de la forme

[batch_size, 10] 

par exemple pour batch_size = 3

logits=[1,6,9] 
labels=[[0,1,0,0,0,0,0,0,0,0], 
     [0,0,0,0,0,0,1,0,0,0], 
     [0,0,0,0,0,0,0,0,0,1]] 

Cette question est venue parce que je veux changer la fonction de coût à une fonction quadratique dans l'exemple tensistflow mnist (https://github.com/tensorflow/tensorflow/tree/r0.9/tensorflow/examples/tutorials/mnist) J'utilise fully_connected_feed.py et dans mnist.py. Dans mnist.py Je souhaite modifier:

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='xentropy') 
    loss = tf.reduce_mean(cross_entropy, name='xentropy_mean') 

à

loss= tf.reduce_sum(tf.squared_difference(logits,labels)) 

Mais le problème ist, que:

Logits tensor, float - [batch_size, 10]; 
Labels tensor, int64 - [batch_size]. 

donc je dois "vectoriser" les étiquettes !? Est-ce que quelqu'un a une idée de comment faire cela?

Répondre

1

L'étiquette "vectorisation" est appelée un codage à chaud.

Vous recherchez une fonction tf.one_hot.

Cette fonction prend:

  1. Une liste d'indices (votre logits vecteur)
  2. Un paramètre depth: c'est la profondeur du vecteur un chaud (la longueur de l'étiquette codée d'un chaud)
  3. on_value & off_value que vous pouvez changer si vous voulez (mais la valeur par défaut de 1 et 0 sont ce que vous cherchez).
  4. dtype est le type de sortie du tenseur.

Ainsi, vous pouvez encoder un chaud avec vos étiquettes:

one_hot_labels = tf.one_hot(logits, 10, dtype=tf.uint8) 

one_hot_labels est un objet tf.Tensor.

Si vous avez besoin d'accéder à partir de python à son contenu, n'oubliez pas de l'évaluer (ou l'exécuter).

Voici un exemple de jouet:

import tensorflow as tf. 
tf.InteractiveSession() 
logits=[1,6,9] 
one_hot_labels = tf.one_hot(logits, 10, dtype=tf.uint8) 
print(one_hot_labels.eval()) 

Sorties:

[[0 1 0 0 0 0 0 0 0 0] 
[0 0 0 0 0 0 1 0 0 0] 
[0 0 0 0 0 0 0 0 0 1]] 
+1

Thak vous nessuno, qui était exactement ce que je cherchais. Mais si j'écris vectorized_labels = tf.one_hot (labels, 10) je reçois toujours TypeError: one_hot() prend au moins 4 arguments (2 donnés) ?? –