0

J'ai une question concernant l'ensemble de données sur le diabète sur sklearn. Je suis en train de tracer la courbe d'apprentissage pour un type d'estimateur, mais de toute façon je l'avertissement:J'ai obtenu une très faible note au test de Sklearn sur le tracé de la courbe d'apprentissage avec l'arbre de décision

« D: \ Users \ XXXX \ Anaconda2 \ lib \ site-packages \ sklearn \ cross_validation.p ing : La classe la moins peuplée de y a seulement 1 membres, ce qui est trop nombre d'étiquettes pour une classe ne peut pas être inférieur à n_folds = 3. "

et le code trace un résultat bizarre. Le jeu de données d'entraînement a un score très élevé (toujours 1, ce qui est peut-être logique car c'est un arbre), mais le score du test est très mauvais (0.03125 à son meilleur)

Je l'ai essayé dans différents jeux de données (chiffres) et ça a bien fonctionné. Le code que j'ai est le suivant:

import numpy as np 
import matplotlib.pyplot as plt 
from sklearn.datasets import load_diabetes 
from sklearn.learning_curve import learning_curve 
from sklearn import tree 


diabetes = load_diabetes() 
X, y = diabetes.data, diabetes.target 


estimator = tree.DecisionTreeClassifier() 
estimator.fit(X, y) 

title = "Learning Curves Decision Tree" 
plt.figure(1) 
plt.title(title) 
plt.xlabel("Training examples") 
plt.ylabel("Score") 
train_sizes, train_scores, test_scores = learning_curve(estimator, X, y) 

print train_sizes 
print train_scores 
print test_scores 

plt.grid() 
plt.plot(train_sizes, train_scores, 'o-', color="r",label="Training score") 
plt.plot(train_sizes, test_scores, 'o-', color="g",label="Cross-validation score") 

plt.legend(loc="best") 

plt.show() 

Quelqu'un peut-il s'il vous plaît me donner une explication pourquoi cela se produit? Merci

Répondre

1

L'ensemble de données diabetes représente un regression problem rather than a classification problem et ne peut donc pas être testé sur DecisionTreeClassifier. Selon the docs:

L'ensemble de données de diabète se compose de 10 variables physiologiques (âge, sexe, poids, pression artérielle) mesure sur 442 patients, et une indication de la progression de la maladie après un an

Le target Le vecteur doit être traité comme une variable dépendante continue (ou au moins ordinale) dont vous avez besoin pour prédire la valeur de, plutôt qu'un ensemble de catégories.

Si vous traitez target comme un ensemble d'étiquettes de classe, vous vous retrouvez avec un total de 214 classes, dont la plupart n'ont qu'un seul membre (d'où le message d'avertissement). Dans cette situation, votre arbre de décision ajusté se comporte essentiellement comme une "table de correspondance" qui peut parfaitement mapper chaque valeur X de votre jeu d'entraînement à la valeur y correspondante, mais n'a aucune valeur prédictive pour les exemples non vus dans votre jeu de validation croisée. Ceci est un exemple particulièrement extrême de overfitting.