0

Je suis en train de trouver les meilleurs paramètres pour le modèle de régression en utilisant NN GridSearchCV avec le code suivant:Comment obtenir des prédictions pour chaque ensemble de paramètres en utilisant GridSearchCV?

param_grid = dict(optimizer=optimizer, epochs=epochs, batch_size=batches, init=init 
grid = GridSearchCV(estimator=model, param_grid=param_grid, scoring='neg_mean_squared_error') 
grid_result = grid.fit(input_train, target_train) 

pred = grid.predict(input_test) 

Si je comprends bien, grid.predict(input_test) utilise les meilleurs paramètres pour prédire l'entrée ensemble donné. Existe-t-il un moyen d'évaluer GridSearchCV pour chaque ensemble de paramètres en utilisant l'ensemble de test?

En fait, mon ensemble de test comprend des enregistrements spéciaux et je veux tester la généralité du modèle ainsi que la précision. Je vous remercie.

Répondre

0

Vous pouvez remplacer le paramètre cv standard à 3 plis de GridSearchCV par un itérateur personnalisé, qui fournit les indices de train et de test des trains de données et des trains de données concaténés. En conséquence, alors que 1 fois la validation croisée you'l former votre modèle sur input_train objets et tester votre modèle ajusté sur input_test objets:

def modified_cv(input_train_len, input_test_len): 
    yield (np.array(range(input_train_len)), 
      np.array(range(input_train_len, input_train_len + input_test_len))) 

input_train_len = len(input_train) 
input_test_len = len(input_test) 
data = np.concatenate((input_train, input_test), axis=0) 
target = np.concatenate((target_train, target_test), axis=0) 
grid = GridSearchCV(estimator=model, 
        param_grid=param_grid, 
        cv=modified_cv(input_train_len, input_test_len), 
        scoring='neg_mean_squared_error') 
grid_result = grid.fit(data, target) 

En accédant grid_result.cv_results_ dictionnaire, you'l voir votre valeur sur les paramètres set test toute la grille des paramètres du modèle spécifiés.

+0

Merci @ eduard-ilyasov. Cela fonctionne parfaitement. – saleh