2016-04-05 2 views
1

J'effectue une recherche de grille pour identifier les meilleurs paramètres SVM. J'utilise ipython et sklearn. Le code est lent et ne fonctionne que sur un seul cœur. Comment cela peut-il être semé et utiliser plusieurs cœurs? MerciAccélérer la recherche de grille dans sklearn

random_state = np.random.RandomState(10) 
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=.2,random_state=random_state) 

model_to_set = OneVsRestClassifier(svm.SVC(kernel="linear")) 

parameters = { 
    "estimator__C": [1, 2, 4, 8, 16, 32], 
    "estimator__kernel": ["linear", "rbf"], 
    "estimator__gamma":[1, 0.1, 1e-2, 1e-3, 1e-4], 
} 

model_tuning = GridSearchCV(model_to_set, param_grid=parameters) 
model_tuning.fit(X_train, y_train) 

print model_tuning.best_score_ 
print model_tuning.best_params_ 
print "Time passed: ", "{0:.1f}".format(time.time()-t), "sec" 

Répondre

4

Il y a un paramètre n_job dans GridSearchCV

n_jobs: int, par défaut = 1

Nombre d'emplois pour fonctionner en parallèle. Modifié dans version 0.17: mis à niveau vers joblib 0.9.3.

+0

parfait, mon pote merci. –

2

Par défaut, GridSearchCV utilise 1 tâche pour rechercher des valeurs de paramètre spécifiées pour un estimateur.

Donc, vous devez définir explicitement le nombre d'emplois parallèles que vous désirez par chaning la ligne suivante:

model_tuning = GridSearchCV(model_to_set, param_grid=parameters) 

dans ce qui suit pour permettre l'emploi en cours d'exécution en parallèle:

model_tuning = GridSearchCV(model_to_set, param_grid=parameters, n_jobs=4) 
0

Essayez d'utiliser sparkcontext pour paralléliser la recherche de grille sur plusieurs machines. En utilisant la bibliothèque étincelle sklearn, vous serez en mesure de faire fonctionner le programme plus rapide avec seulement le changement d'une seule ligne

from spark_sklearn import GridSearchCV 

Assurez-vous également que vous initialisez contexte d'allumage (« sc ») correctement. Ensuite, tout ce que vous devez faire est

gs = GridSearchCV(sc, classifier, param_grid) 
gs.fit(X, y) 

DataSet standards chiffres dans scikit-learn, grille de recherche a pris 318.92 secondes pour 864 paramètres candidats. Le même n'a pris que quelques secondes en utilisant 122,03 4 machines

Vous pouvez trouver plus d'informations ci-dessus here