2017-10-09 6 views
2

Cette question donne la solution pour trier l'axe y: Data order in seaborn heatmap from pivot Mais comment effectuer un tri personnalisé pour les axes x et y?Comment effectuer un tri personnalisé pour les axes x et y sur DataFrame indexé pour heatmap?

Sans tri personnalisé, nous voyons l'ordre:

  • axe x: téléphone, télévision
  • axe y: Apple, Google, Samsung

code :

lol = [['apple', 'phone', 10], ['samsung', 'tv', 20], ['apple', 'tv', 5], ['google', 'tv', 8], ['google', 'phone', 9], ['samsung', 'phone', 3]] 
df = pd.DataFrame(lol) 
df = df.rename(columns={0:'brand', 1:'product', 2:'count'}) 
df = df.pivot('brand', 'product', 'count') 
ax = sns.heatmap(df) 
plt.show() 

[out]:

enter image description here

Si je dois trier l'axe y pour montrer l'ordre samsung, apple, google, je pouvais faire:

lol = [['apple', 'phone', 10], ['samsung', 'tv', 20], ['apple', 'tv', 5], ['google', 'tv', 8], ['google', 'phone', 9], ['samsung', 'phone', 3]] 
df = pd.DataFrame(lol) 
df = df.rename(columns={0:'brand', 1:'product', 2:'count'}) 
df = df.pivot('brand', 'product', 'count') 

df.index = pd.CategoricalIndex(df.index, categories= ["samsung", "apple", "google"]) 
df.sortlevel(level=0, inplace=True) 
ax = sns.heatmap(df) 
plt.show() 

[out]:

enter image description here

Mais comment effectuer un tri personnalisé pour les axes x et y?, par ex.

  • axe y pour montrer l'ordre samsung, apple, google
  • axe x pour montrer l'ordre tv, phone (non seulement inverser l'ordre)

Répondre

2

Je pense que vous pouvez utiliser reindex:

a = ['samsung', 'apple', 'google'] 
b = ['tv','phone'] 

df = df.pivot('brand', 'product', 'count') 
df = df.reindex(index=a, columns=b) 
print (df) 
product tv phone 
brand    
samsung 20  3 
apple  5  10 
google 8  9 

ou ordered categorical:

df['brand'] = df['brand'].astype('category', categories=a, ordered=True) 
df['product'] = df['product'].astype('category', categories=b, ordered=True) 

df = df.pivot('brand', 'product', 'count') 
print (df) 
product tv phone 
brand    
samsung 20  3 
apple  5  10 
google 8  9 

ax = sns.heatmap(df) 
plt.show()