2017-10-12 1 views
0

Je suis le "module de couche" de Tensorflow à partir de ce lien de tutoriel https://www.tensorflow.org/tutorials/layers. Vous pourriez être en mesure de m'aider comment puis-je obtenir les résultats des prédictions et leurs probabilités respectives. J'ai besoin de le voir pour mieux comprendre le modèle. Et s'il y a un moyen de sauvegarder les résultats - prédictions et probabilités à csv.Tensorflow: extraire les prédictions de "Layer Module"

Merci beaucoup pour votre temps.

Répondre

0

J'ai trouvé un moyen de le faire et c'est très simple comme je le pensais. Je pensais que certains pourraient avoir une question similaire alors voilà. Tensorflow est un nouveau cadre pour travailler sur des modèles d'apprentissage automatique, mais je réalise finalement que c'est assez facile.

tensorflow a cette fonction * .predict (...) sur le modèle que vous avez créé avec tf et renvoie les prédictions variables que vous définissez votre modèle qui contient des « classes » et « probabilités »

Sur votre modèle pour un exemple:

some_classifier = tf.estimator.Estimator(model_fn=model_fn, 
    model_dir=...) 

vous pouvez le faire pour la prédiction (comme nous le savons)

prediction_input_fn = tf.estimator.inputs.numpy_input_fn(
    x={"x": test_feats}, #given you have this variable containing the test data 
    y=test_labels, #and the equivalent label for the test data 
    num_epochs=1, 
    shuffle=False) 
prediction_results = some_classifier.predict(input_fn=prediction_input_fn) 

puis la variable prediction_r ésultats contient les valeurs de classe prédite et ses probabilités que vous pouvez ensuite enregistrer (par exemple en utilisant Panda)

save = panda.DataFrame(list(prediction_results)) 
save.to_csv("file.csv") 

Le code ci-dessus snip fonctionne bien donné que vous avez déjà écrit un code dans votre modèle un conteneur pour les prévisions, comme ci-dessous :

predictions = { 
    "classes": tf.argmax(input=logits, axis=1), 
    "probabilities": tf.nn.softmax(logits, name="softmax_tensor") 

}