Dans l'exemple de code ci-dessous, l'entrée a la forme [2,3,4,5]
et la forme résultante est [2,3,4]
.
Les idées principales sont:
- Il est facile d'obtenir une ligne au lieu d'une colonne à l'aide
gather_nd
, donc je suis passé les deux dernières dimensions avec tf.transpose
.
- Nous devons convertir les indices que nous obtenons de
tf.argmax
(indices
ci-dessous) en quelque chose de vraiment utilisable (voir final_idx
ci-dessous) en tf.gather_nd
. La conversion se fait par empilement de trois composants:
[0 0 0 1 1 1]
[0 1 2 0 1 2]
[3 3 3 0 0 0]
On pourrait donc aller de [3, 0]
à
[[[0 0 3]
[0 1 3]
[0 2 3]]
[[1 0 0]
[1 1 0]
[1 2 0]]].
Batch,Y,X = 2, 3, 4
tf.reset_default_graph()
data = np.arange(Batch*Y*X*5)
np.random.shuffle(data)
Params = tf.constant(np.reshape(data, [Batch, Y, X, 5]), dtype=tf.int32)
indices = tf.argmax(tf.reduce_sum(Params, [1,2]), 1)
indices = tf.cast(tf.reshape(tf.tile(tf.reshape(indices, [-1,1]),
[1,Y]), [-1]), tf.int32)
idx = tf.reshape(tf.range(batch_size), [-1,1])
idx = tf.reshape(tf.tile(idx, [1, y]), [-1])
inc = tf.reshape(tf.tile(tf.range(Y), [Batch]), [-1])
final_idx = tf.reshape(tf.stack([idx, inc, indices], 1), [Batch, Y, -1])
transposed = tf.transpose(Params, [0, 1, 3, 2])
slice = tf.gather_nd(transposed, final_idx)
with tf.Session() as sess:
print sess.run(Params)
print sess.run(idx)
print sess.run(inc)
print sess.run(indices)
print sess.run(final_idx)
print sess.run(slice)
[[[[ 22 38 68 49 119]
[ 47 74 111 117 90]
[ 14 32 31 12 75]
[ 93 34 57 3 56]]
[[ 69 21 4 94 39]
[ 83 96 62 102 80]
[ 55 113 48 98 29]
[107 81 67 76 28]]
[[ 53 51 77 66 63]
[ 92 115 118 116 13]
[ 43 78 15 1 0]
[ 99 50 27 60 73]]]
[[[ 97 88 91 64 86]
[ 72 110 26 87 33]
[ 70 30 41 114 5]
[ 95 82 46 16 61]]
[[109 71 45 8 40]
[101 9 23 59 10]
[ 37 65 44 11 19]
[ 42 104 106 105 18]]
[[112 58 7 17 89]
[ 25 79 103 85 20]
[ 35 6 108 100 36]
[ 24 52 2 54 84]]]]
[0 0 0 1 1 1]
[0 1 2 0 1 2]
[3 3 3 0 0 0]
[[[0 0 3]
[0 1 3]
[0 2 3]]
[[1 0 0]
[1 1 0]
[1 2 0]]]
[[[ 49 117 12 3]
[ 94 102 98 76]
[ 66 116 1 60]]
[[ 97 72 70 95]
[109 101 37 42]
[112 25 35 24]]]
Probablement Je dois préciser: Les valeurs des indices-Tensor spécifient qui tranche à utiliser. Le résultat doit être obtenu pendant l'exécution de la session, car d'autres calculs en dépendent. J'ai édité mon post original, j'espère que c'est plus clair maintenant. – Benjamin
mis à jour. Je ne sais pas si c'est ce que vous voulez. Je vais ajouter quelques explications plus tard. – greeness
Merveilleux. Maintenant, je comprends aussi pourquoi d'autres exemples n'ont pas fonctionné. Hm, la transposition est connue pour être une opération lente, je suppose que l'on pourrait contourner cela en générant aussi des indices pour la troisième dimension, et en concaténant ceux-ci aussi bien? Peut-être aussi cher. D'autre part - pouvons-nous nous débarrasser d'une des opérations de pavage si nous la transposons immédiatement à [Lot, N, Y, X]? Alors seulement énumérer Batch, et empilez-le sur l'index? – Benjamin