2017-08-25 2 views
1

J'ai besoin de créer une variable epsilon_n qui change la définition (et la valeur) en fonction du step actuel. Depuis que j'ai plus de deux cas, il semble que je ne peux pas utiliser tf.cond. Je suis en train d'utiliser tf.case comme suit:Tesnorflow: Impossible d'utiliser tf.case avec l'argument d'entrée

import tensorflow as tf 

#### 
EPSILON_DELTA_PHASE1 = 33e-4 
EPSILON_DELTA_PHASE2 = 2.5 
#### 
step = tf.placeholder(dtype=tf.float32, shape=None) 


def fn1(step): 
    return tf.constant([1.]) 

def fn2(step): 
    return tf.constant([1.+step*EPSILON_DELTA_PHASE1]) 

def fn3(step): 
    return tf.constant([1.+step*EPSILON_DELTA_PHASE2]) 

epsilon_n = tf.case(
     pred_fn_pairs=[ 
      (tf.less(step, 3e4), lambda step: fn1(step)), 
      (tf.less(step, 6e4), lambda step: fn2(step)), 
      (tf.less(step, 1e5), lambda step: fn3(step))], 
      default=lambda: tf.constant([1e5]), 
     exclusive=False) 

Cependant, je continue à recevoir ce message d'erreur:

TypeError: <lambda>() missing 1 required positional argument: 'step' 

J'ai essayé les éléments suivants:

epsilon_n = tf.case(
     pred_fn_pairs=[ 
      (tf.less(step, 3e4), fn1), 
      (tf.less(step, 6e4), fn2), 
      (tf.less(step, 1e5), fn3)], 
      default=lambda: tf.constant([1e5]), 
     exclusive=False) 

Cependant, je voudrais la même erreur . Les exemples de la documentation de Tensorflow pèsent sur les cas où aucun argument d'entrée n'est transmis aux fonctions appelables. Je n'ai pas trouvé assez d'informations sur tf.case sur internet! S'il vous plaît toute aide?

Répondre

2

Voici quelques modifications à apporter. Par souci de cohérence, vous pouvez définir toutes les valeurs de retour comme variables.

# Since step is a scalar, scalar shape [() or [], not None] much be provided 
step = tf.placeholder(dtype=tf.float32, shape=()) 


def fn1(step): 
    return tf.constant([1.]) 

# Here you need to use Variable not constant, since you are modifying the value using placeholder 
def fn2(step): 
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE1]) 

def fn3(step): 
    return tf.Variable([1.+step*EPSILON_DELTA_PHASE2]) 

epsilon_n = tf.case(
    pred_fn_pairs=[ 
     (tf.less(step, 3e4), lambda : fn1(step)), 
     (tf.less(step, 6e4), lambda : fn2(step)), 
     (tf.less(step, 1e5), lambda : fn3(step))], 
     default=lambda: tf.constant([1e5]), 
    exclusive=False) 
+0

correction de petites fautes de frappe –