J'essaie d'optimiser un pipeline et je voulais essayer de donner RandomizedSearchCV
un objet np.random.RandomState
. Je ne peux pas le faire mais je peux lui donner d'autres distributions.`RandomizedSearchCV` de sklearn ne fonctionne pas avec` np.random.RandomState`
Y at-il une syntaxe particulière que je peux utiliser pour donner un RandomSearchCV
np.random.RandomState(0).uniform(0.1,1.0)
?
from scipy import stats
import numpy as np
from sklearn.neighbors import KernelDensity
from sklearn.grid_search import RandomizedSearchCV
# Generate data
x = np.random.normal(5,1,size=int(1e3))
# Make model
model = KernelDensity()
# Gridsearch for best params
# This one works
search_params = RandomizedSearchCV(model, param_distributions={"bandwidth":stats.uniform(0.1, 1)}, n_iter=30, n_jobs=2)
search_params.fit(x[:, None])
# RandomizedSearchCV(cv=None, error_score='raise',
# estimator=KernelDensity(algorithm='auto', atol=0, bandwidth=1.0, breadth_first=True,
# kernel='gaussian', leaf_size=40, metric='euclidean',
# metric_params=None, rtol=0),
# fit_params={}, iid=True, n_iter=30, n_jobs=2,
# param_distributions={'bandwidth': <scipy.stats._distn_infrastructure.rv_frozen object at 0x106ab7da0>},
# pre_dispatch='2*n_jobs', random_state=None, refit=True,
# scoring=None, verbose=0)
# This one doesn't work :(
search_params = RandomizedSearchCV(model, param_distributions={"bandwidth":np.random.RandomState(0).uniform(0.1, 1)}, n_iter=30, n_jobs=2)
# TypeError: object of type 'float' has no len()
Merci! Y at-il une erreur dans 'RandomizedSearchCV (model, param_distributions = {" bande passante ": stats.uniform (0.1, 1)}, n_iter = 30, n_jobs = 2)'? Je me basais sur https://jakevdp.github.io/blog/2013/12/01/kernel-density-estimation/ –
@ O.rka La classe-uniforme n'implémente pas les arguments-constructeurs [comme je le vois ] (https://github.com/scipy/scipy/blob/v0.18.1/scipy/stats/_continuous_distns.py#L4883). La classe supérieure héritée n'utilise que des arguments nommés comme a et b pour la plage. Donc, je crains, que ce que vous faites des échantillons de la gamme par défaut de (0,1), mais je ne suis pas sûr à 100% à ce sujet. Mais cela devrait être facile à vérifier. – sascha
Est-ce qu'il fait quelque chose comme ça? 'stats.uniform (5,1) .rvs (3) .tolist() # [5.172340508345329, 5.137135749628878, 5.932595463037163]' ou est-ce différent dans le backend? –