2017-04-09 3 views
6

Je cherche quelque chose de similaire à numpy.random.choice(range(3),replacement=False,size=2,p=[0.1,0.2,0.7])
dans TensorFlow.Échantillonnage sans remplacement à partir d'une distribution non uniforme donnée dans TensorFlow

Le Op le plus proche semble être tf.multinomial(tf.log(p)) qui prend les logits en entrée mais ne peut pas échantillonner sans remplacement. Existe-t-il un autre moyen d'échantillonner à partir d'une distribution non uniforme dans TensorFlow?

Merci.

Répondre

2

Vous pouvez simplement utiliser tf.py_func pour envelopper numpy.random.choice et de le rendre disponible en op tensorflow:

a = tf.placeholder(tf.float32) 
size = tf.placeholder(tf.int32) 
replace = tf.placeholder(tf.bool) 
p = tf.placeholder(tf.float32) 

y = tf.py_func(np.random.choice, [a, size, replace, p], tf.float32) 

with tf.Session() as sess: 
    print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 

Vous pouvez spécifier la graine numpy comme d'habitude:

np.random.seed(1) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
np.random.seed(1) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 
np.random.seed(1) 
print(sess.run(y, {a: range(3), size: 2, replace:False, p:[0.1,0.2,0.7]})) 

imprimerait:

[ 2. 0.] 
[ 2. 1.] 
[ 0. 1.] 
[ 2. 0.] 
[ 2. 1.] 
[ 0. 1.] 
[ 2. 0.]