2017-10-17 2 views
0

J'essaie de créer une fonction eval_metric_op qui affichera la proportion de correspondances exactes à un seuil donné pour un problème de classification multi-étiquettes. La fonction suivante renvoie 0 (aucune correspondance exacte) ou 1 (correspondance exacte) en fonction du seuil donné.Comment créer une "correspondance exacte" eval_metric_op pour TensorFlow?

def exact_match(y_true, y_logits, threshold): 
    y_pred = np.round(y_logits-threshold+0.5) 
    return int(np.array_equal(y_true, y_pred)) 

y_true = np.array([1,1,0]) 
y_logits = np.array([0.67, 0.9, 0.55]) 

print(exact_match(y_true, y_logits, 0.5)) 
print(exact_match(y_true, y_logits, 0.6)) 

Un seuil de 0,5 fournit une prévision de [1,1,1] ce qui est incorrect si la fonction retourne 0. Un seuil de 0,6 fournit une prévision de [1,1,0] ce qui est correct si la fonction retourne 1.

Je voudrais transformer cette fonction en une métrique d'évaluation tensorflow op - quelqu'un peut-il conseiller la meilleure façon de le faire?

Je peux obtenir la même logique en utilisant ops tensorflow ci-dessous, mais je ne suis pas tout à fait sûr de savoir comment faire cela dans un eval_metric_op personnalisé:

import tensorflow as tf 

def exact_match_fn(y_true, y_logits, threshold): 
    #pred = tf.equal(tf.round(y_logits), tf.round(y_true)) 
    predictions = tf.to_float(tf.greater_equal(y_logits, threshold)) 
    pred_match = tf.equal(predictions, tf.round(y_true)) 
    exact_match = tf.reduce_min(tf.to_float(pred_match)) 
    return exact_match 

graph = tf.Graph() 
with graph.as_default(): 
    y_true = tf.constant([1,1,0], dtype=tf.float32) 
    y_logits = tf.constant([0.67,0.9,0.55], dtype=tf.float32) 
    exact_match_50 = exact_match_fn(y_true, y_logits, 0.5) 
    exact_match_60 = exact_match_fn(y_true, y_logits, 0.6) 

sess = tf.InteractiveSession(graph=graph) 
print(sess.run([exact_match_50, exact_match_60])) 

Le code ci-dessus entraînera exact_match_50 de 0 (à moins 1 prédiction incorrecte) et exact_match_60 de 1 (toutes les étiquettes correctes). Est-il suffisant d'utiliser simplement tf.contrib.metrics.streaming_mean() ou existe-t-il une meilleure alternative? J'implémenteriez cela comme:

tf.contrib.metrics.streaming_mean(exact_match(y_true, y_logits, threshold)) 

Répondre

1

La sortie de votre exact_match_fn est un op qui peut être utilisé pour l'évaluation. Si vous voulez la moyenne sur un lot, changez votre reduce_min pour réduire juste au-dessus de l'axe approprié.

E.g. si vous y_true/y_logits chacun ont une forme (batch_size, n)

def exact_match_fn(y_true, y_logits, threshold): 
    #pred = tf.equal(tf.round(y_logits), tf.round(y_true)) 
    predictions = tf.to_float(tf.greater_equal(y_logits, threshold)) 
    pred_match = tf.equal(predictions, tf.round(y_true)) 
    exact_match = tf.reduce_min(tf.to_float(pred_match), axis=1) 
    return exact_match 


def exact_match_prop_fn(*args): 
    return tf.reduce_mean(exact_match_fn(*args)) 

Cela vous donnera la moyenne sur un lot. Si vous voulez la moyenne sur un ensemble de données complet, je voudrais juste collecter les correspondances (ou correct et total compte) et évaluer en dehors de la session/tensorflow, mais streaming_mean fait probablement juste cela, pas sûr.

+0

Oui, c'est certainement correct d'ajouter l'axe = 1 - merci pour cette correction! Je pense que ça fait partie de ça. Je pense que d'après mes expériences, streaming_mean() gère la collecte des nombres totaux, ce qui serait la seule chose qui manque ici. – reese0106