2017-05-14 1 views
2

pyplot.scatter permet de passer à c= un tableau qui correspond à des groupes, qui colorera alors les points basés sur ces groupes. Cependant, cela ne semble pas soutenir la génération d'une légende sans spécifiquement traçage de chaque groupe séparément.nuage de points avec légende colorée par groupe sans appels multiples à plt.scatter

Ainsi, par exemple, un diagramme de dispersion avec des groupes de couleur peuvent être générés par itérer sur les groupes et tracer chacun séparément:

import matplotlib.pyplot as plt 
from sklearn.datasets import load_iris 
feats = load_iris()['data'] 
target = load_iris()['target'] 

f, ax = plt.subplots(1) 
for i in np.unique(target): 
    mask = target == i 
    plt.scatter(feats[mask, 0], feats[mask, 1], label=i) 
ax.legend() 

qui génère:

enter image description here

je peux obtenir une parcelle semblable sans itération sur chaque groupe:

f, ax = plt.subplots(1) 
ax.scatter(feats[:, 0], feats[:, 1], c=np.array(['C0', 'C1', 'C2'])[target]) 

Mais je n'arrive pas à trouver un moyen de générer une légende correspondante avec cette seconde stratégie. Tous les exemples que j'ai rencontrés se rapportent aux groupes, ce qui semble ... moins qu'idéal. Je sais que je peux générer manuellement une légende, mais encore une fois cela semble trop lourd.

Répondre

0

L'exemple de dispersion matplotlib qui résout ce problème utilise une boucle, de sorte que est probablement l'usage prévu aussi: https://matplotlib.org/examples/lines_bars_and_markers/scatter_with_legend.html

Si votre objectif plus large est de faire tout le traçage et le marquage des données catégoriques plus simple, vous devriez envisager Seaborn . Ceci est une question similaire à Scatter plots in Pandas/Pyplot: How to plot by category

Une façon d'atteindre votre objectif est d'utiliser des pandas avec des colonnes étiquetées. Une fois que vous avez des données dans une base de données Pandas, vous pouvez utiliser Seaborn pairplot pour créer ce type de tracé. (Seaborn a également l'ensemble de données de l'iris disponible sous forme de trame de données étiquetée)

import seaborn as sns 
iris = sns.load_dataset("iris") 
sns.pairplot(iris, hue="species") 

enter image description here

Si vous voulez juste les deux premières fonctionnalités, vous pouvez utiliser

sns.pairplot(x_vars=['sepal_length'], y_vars=['sepal_width'], data=iris, hue="species", size=5) 

enter image description here

Si vous voulez vraiment utiliser le dict de données sklearn, vous pouvez tirer cela dans un dataframe comme suit:

import pandas as pd 
from sklearn.datasets import load_iris 
import numpy as np 

feats = load_iris()['data'].astype('O') 
target = load_iris()['target'] 
feat_names = load_iris()['feature_names'] 
target_names = load_iris()['target_names'].astype('O') 

sk_df = pd.DataFrame(
    np.hstack([feats,target_names[target][:,np.newaxis]]), 
    columns=feat_names+['target',]) 
sns.pairplot(sk_df, vars=feat_names, hue="target") 
+0

Je suis conscient que vous pouvez le faire dans seaborn simplement, mais mon cas d'utilisation réelle (où je suis en train de tracer des parcelles de dispersion 3D) seaborn ne supporte pas. sous le capot seaborn utilise matplotlib pour faire le tracé - je suppose que je pourrais passer à travers et voir comment seaborn génère les nuages ​​de points et les légendes de figures associées dans le coupleplot (ou regplot). Je suppose que c'est en boucle sur les groupes comme dans mon premier exemple de code. – user3014097