J'essaie de créer une fonction eval_metric_op qui affichera la proportion de correspondances exactes à un seuil donné pour un problème de classification multi-étiquettes. La fonction suivante renvoie 0 (aucune correspondance exacte) ou 1 (correspondance exacte) en fonction du seuil donné.Comment créer une "correspondance exacte" eval_metric_op pour TensorFlow?
def exact_match(y_true, y_logits, threshold):
y_pred = np.round(y_logits-threshold+0.5)
return int(np.array_equal(y_true, y_pred))
y_true = np.array([1,1,0])
y_logits = np.array([0.67, 0.9, 0.55])
print(exact_match(y_true, y_logits, 0.5))
print(exact_match(y_true, y_logits, 0.6))
Un seuil de 0,5 fournit une prévision de [1,1,1] ce qui est incorrect si la fonction retourne 0. Un seuil de 0,6 fournit une prévision de [1,1,0] ce qui est correct si la fonction retourne 1.
Je voudrais transformer cette fonction en une métrique d'évaluation tensorflow op - quelqu'un peut-il conseiller la meilleure façon de le faire?
Je peux obtenir la même logique en utilisant ops tensorflow ci-dessous, mais je ne suis pas tout à fait sûr de savoir comment faire cela dans un eval_metric_op personnalisé:
import tensorflow as tf
def exact_match_fn(y_true, y_logits, threshold):
#pred = tf.equal(tf.round(y_logits), tf.round(y_true))
predictions = tf.to_float(tf.greater_equal(y_logits, threshold))
pred_match = tf.equal(predictions, tf.round(y_true))
exact_match = tf.reduce_min(tf.to_float(pred_match))
return exact_match
graph = tf.Graph()
with graph.as_default():
y_true = tf.constant([1,1,0], dtype=tf.float32)
y_logits = tf.constant([0.67,0.9,0.55], dtype=tf.float32)
exact_match_50 = exact_match_fn(y_true, y_logits, 0.5)
exact_match_60 = exact_match_fn(y_true, y_logits, 0.6)
sess = tf.InteractiveSession(graph=graph)
print(sess.run([exact_match_50, exact_match_60]))
Le code ci-dessus entraînera exact_match_50 de 0 (à moins 1 prédiction incorrecte) et exact_match_60 de 1 (toutes les étiquettes correctes). Est-il suffisant d'utiliser simplement tf.contrib.metrics.streaming_mean()
ou existe-t-il une meilleure alternative? J'implémenteriez cela comme:
tf.contrib.metrics.streaming_mean(exact_match(y_true, y_logits, threshold))
Oui, c'est certainement correct d'ajouter l'axe = 1 - merci pour cette correction! Je pense que ça fait partie de ça. Je pense que d'après mes expériences, streaming_mean() gère la collecte des nombres totaux, ce qui serait la seule chose qui manque ici. – reese0106