2017-08-30 3 views
1

J'essaie de définir une classe appelée XGBExtended qui étend la classe xgboost.XGBClassifier, l'API scikit-learn pour xgboost. Je rencontre des problèmes avec la méthode get_params. Voici une session IPython illustrant le problème. Fondamentalement, get_params semble renvoyer uniquement les attributs que je définis au sein de XGBExtended.__init__, et les attributs définis pendant la méthode parent init (xgboost.XGBClassifier.__init__) sont ignorés. J'utilise IPython et exécute python 2.7. Spécification complète du système en bas.Extension de xgboost.XGBClassifier

In [182]: import xgboost as xgb 
    ...: 
    ...: class XGBExtended(xgb.XGBClassifier): 
    ...: def __init__(self, foo): 
    ...:  super(XGBExtended, self).__init__() 
    ...:  self.foo = foo 
    ...: 
    ...: clf = XGBExtended(foo = 1) 
    ...: 
    ...: clf.get_params() 
    ...: 
--------------------------------------------------------------------------- 
KeyError         Traceback (most recent call last) 
<ipython-input-182-431c4c3f334b> in <module>() 
     8 clf = XGBExtended(foo = 1) 
     9 
---> 10 clf.get_params() 

/Users/andrewhannigan/lib/xgboost/python-package/xgboost/sklearn.pyc in get_params(self, deep) 
    188   if isinstance(self.kwargs, dict): # if kwargs is a dict, update params accordingly 
    189    params.update(self.kwargs) 
--> 190   if params['missing'] is np.nan: 
    191    params['missing'] = None # sklearn doesn't handle nan. see #4725 
    192   if not params.get('eval_metric', True): 

KeyError: 'missing' 

J'ai frappé une erreur parce que « manquant » n'est pas une clé dans la params dict dans la méthode XGBClassifier.get_params. J'entre le débogueur pour fouiner:

In [183]: %debug 
> /Users/andrewhannigan/lib/xgboost/python-package/xgboost/sklearn.py(190)get_params() 
    188   if isinstance(self.kwargs, dict): # if kwargs is a dict, update params accordingly 
    189    params.update(self.kwargs) 
--> 190   if params['missing'] is np.nan: 
    191    params['missing'] = None # sklearn doesn't handle nan. see #4725 
    192   if not params.get('eval_metric', True): 

ipdb> params 
{'foo': 1} 
ipdb> self.__dict__ 
{'n_jobs': 1, 'seed': None, 'silent': True, 'missing': nan, 'nthread': None, 'min_child_weight': 1, 'random_state': 0, 'kwargs': {}, 'objective': 'binary:logistic', 'foo': 1, 'max_depth': 3, 'reg_alpha': 0, 'colsample_bylevel': 1, 'scale_pos_weight': 1, '_Booster': None, 'learning_rate': 0.1, 'max_delta_step': 0, 'base_score': 0.5, 'n_estimators': 100, 'booster': 'gbtree', 'colsample_bytree': 1, 'subsample': 1, 'reg_lambda': 1, 'gamma': 0} 
ipdb> 

Comme vous pouvez le voir, le params ne contient que la variable foo. Cependant, l'objet lui-même contient tous les paramètres définis par xgboost.XGBClassifier.__init__. Mais pour une raison quelconque, la méthode BaseEstimator.get_params appelée à partir de xgboost.XGBClassifier.get_params obtient uniquement les paramètres définis explicitement dans la méthode XGBExtended.__init__. Malheureusement, même si j'appeler explicitement get_params avec deep = True, cela ne fonctionne toujours pas correctement:

ipdb> super(XGBModel, self).get_params(deep=True) 
{'foo': 1} 
ipdb> 

Quelqu'un peut-il dire pourquoi cela se passe?

spécifications système:

In [186]: print IPython.sys_info() 
{'commit_hash': u'1149d1700', 
'commit_source': 'installation', 
'default_encoding': 'UTF-8', 
'ipython_path': '/Users/andrewhannigan/virtualenvironment/nimble_ai/lib/python2.7/site-packages/IPython', 
'ipython_version': '5.4.1', 
'os_name': 'posix', 
'platform': 'Darwin-14.5.0-x86_64-i386-64bit', 
'sys_executable': '/usr/local/Cellar/python/2.7.10/Frameworks/Python.framework/Versions/2.7/Resources/Python.app/Contents/MacOS/Python', 
'sys_platform': 'darwin', 
'sys_version': '2.7.10 (default, Jul 3 2015, 12:05:53) \n[GCC 4.2.1 Compatible Apple LLVM 6.1.0 (clang-602.0.53)]'} 

Répondre

1

Le problème ici est la déclaration incorrecte de la classe des enfants. Lorsque vous déclarez la méthode init en utilisant uniquement foo, vous remplacez la méthode d'origine. Il ne sera pas initialisé automatiquement, même si le constructeur de la classe de base est censé avoir des valeurs par défaut pour lui.

Vous devez utiliser les éléments suivants:

class XGBExtended(xgb.XGBClassifier): 
    def __init__(self, foo, max_depth=3, learning_rate=0.1, 
       n_estimators=100, silent=True, 
       objective="binary:logistic", 
       nthread=-1, gamma=0, min_child_weight=1, 
       max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, 
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, 
       base_score=0.5, seed=0, missing=None, **kwargs): 

     # Pass the required parameters to super class 
     super(XGBExtended, self).__init__(max_depth, learning_rate, 
              n_estimators, silent, objective, 
              nthread, gamma, min_child_weight, 
              max_delta_step, subsample, 
              colsample_bytree, colsample_bylevel, 
              reg_alpha, reg_lambda, 
scale_pos_weight, base_score, seed, missing, **kwargs) 

     # Use other custom parameters 
     self.foo = foo 

Après cela, vous ne recevrez aucune erreur.

clf = XGBExtended(foo = 1) 
print(clf.get_params(deep=True)) 

>>> {'reg_alpha': 0, 'colsample_bytree': 1, 'silent': True, 
    'colsample_bylevel': 1, 'scale_pos_weight': 1, 'learning_rate': 0.1, 
    'missing': None, 'max_delta_step': 0, 'nthread': -1, 'base_score': 0.5, 
    'n_estimators': 100, 'subsample': 1, 'reg_lambda': 1, 'seed': 0, 
    'min_child_weight': 1, 'objective': 'binary:logistic', 
    'foo': 1, 'max_depth': 3, 'gamma': 0} 
+0

@andrew J'ai changé ma réponse et je ne reçois aucune erreur maintenant. –

+0

merci! Il y a donc une certaine magie de scikit-learn avec 'get_params'. Est-ce que 'get_params' regarde la signature de' __init __() 'pour décider quels champs de l'objet sont des paramètres de modélisation réels? Un peu ennuyeux que vous ayez à reformater tous les arguments init de la classe que vous étendez. – andrew

+0

@andrew Non, je ne pense pas que cela soit dû à scikit. IMO est une chose python lorsque vous surchargez la méthode init(), il n'y a aucun moyen de savoir quand la superclasse init est appelée ou non. –