2017-08-15 4 views
0

Je suis nouveau avec Spark et je veux l'utiliser pour le classificateur de forêt aléatoire. J'utilise des données Iris au format libsvm pour construire un modèle.spark classificateur de forêt aléatoire - obtenir des étiquettes en tant que chaîne

Ma question est - comment puis-je obtenir des étiquettes sous forme de chaîne? (Dans ce cas - les étiquettes sont des types de fleurs d'iris).

Lorsque les données sont converties au format libsvm, chaque étiquette reçoit un entier qui la représente, mais je ne sais pas comment revenir à l'étiquette de chaîne.

Est-ce possible avec libsvm? Ou devrais-je utiliser un autre format?

Voici mon code:

public PipelineModel runRandomForestAlgorithm(String dataPath) { 

System.setProperty("hadoop.home.dir", "C:/hadoop"); 
SparkSession spark = 
    SparkSession.builder().appName("JavaRandomForestClassifierExample").master("local[*]").getOrCreate(); 

/* Load and parse the data file, converting it to a DataFrame. */ 
DataFrameReader dataFrameReader = spark.read().format("libsvm"); 
Dataset<Row> data = dataFrameReader.load(dataPath); 

/* Index labels, adding metadata to the label column. 
    Fit on whole dataset to include all labels in index. */ 
StringIndexerModel labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(data); 

/* Automatically identify categorical features, and index them. 
    Set maxCategories so features with > 4 distinct values are treated as continuous. */ 
VectorIndexerModel featureIndexer = 
    new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(4).fit(data); 

/* Split the data into training and test sets (30% held out for testing) */ 
Dataset<Row>[] splits = data.randomSplit(new double[]{0.9, 0.1}); 
Dataset<Row> trainingData = splits[0]; 
testData = splits[1]; 

/* Train a RandomForest model. */ 
RandomForestClassifier rf = 
    new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10); 

/* Convert indexed labels back to original labels. */ 
IndexToString labelConverter = 
    new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels()); 

/* Chain indexers and forest in a Pipeline */ 
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{labelIndexer, featureIndexer, rf, labelConverter}); 

/* Train model. This also runs the indexers. */ 
PipelineModel model = pipeline.fit(trainingData); 

/* Make predictions. */ 
Dataset<Row> predictions = model.transform(testData); 

/* Select example rows to display. */ 
List<Row> predictionAsRows = 
    predictions.select("predictedLabel", "label", "features", "rawPrediction", "probability").collectAsList(); 

predictionAsRows.forEach(row -> { 
    System.out.println("predictedLabel: " + row.get(0) + " , " + "label: " + row.get(1) + " , " + "features: " + row.get(2) + " , " + 
     "predictions: " + row.get(3) + " , " + "probabilities: " + row.get(4)); 
}); 

Et voici la sortie:

predictedLabel: 1.0 , label: 1.0 , features: (4,[0,1,2,3], 
    [-0.833333,0.333333,-1.0,-0.916667]) , predictions: [10.0,0.0,0.0] , 
    probabilities: [1.0,0.0,0.0] 
    predictedLabel: 1.0 , label: 1.0 , features: (4,[0,1,2,3],         
    [-0.555556,0.166667,-0.830508,-0.916667]) , predictions: [10.0,0.0,0.0] 
    , probabilities: [1.0,0.0,0.0] 
    predictedLabel: 2.0 , label: 2.0 , features: (4,[0,1,2,3], 
    [-0.333333,-0.75,0.0169491,-4.03573E-8]) , predictions: [0.0,0.0,10.0] , 
    probabilities: [0.0,0.0,1.0] 
    predictedLabel: 2.0 , label: 2.0 , features: (4,[0,1,2,3], 
    [-0.166667,-0.416667,-0.0169491,-0.0833333]) , predictions: 
    [0.0,0.0,10.0] , probabilities: [0.0,0.0,1.0] 
    predictedLabel: 2.0 , label: 2.0 , features: (4,[0,1,2,3], 
    [0.166667,-0.25,0.118644,-4.03573E-8]) , predictions: [0.0,0.0,10.0] , 
    probabilities: [0.0,0.0,1.0] 
    predictedLabel: 2.0 , label: 2.0 , features: (4,[0,1,2,3], 
    [0.277778,-0.166667,0.152542,0.0833333]) , predictions: [0.0,0.0,10.0] , 
    probabilities: [0.0,0.0,1.0] 
    predictedLabel: 2.0 , label: 2.0 , features: (4,[0,2,3], 
    [0.5,0.254237,0.0833333]) , predictions: [0.0,0.0,10.0] , probabilities: 
    [0.0,0.0,1.0] 
    predictedLabel: 3.0 , label: 3.0 , features: (4,[0,1,2,3], 
    [-0.166667,-0.416667,0.38983,0.5]) , predictions: [0.0,9.875,0.125] ,   
    probabilities: [0.0,0.9875,0.0125] 
    predictedLabel: 3.0 , label: 3.0 , features: (4,[0,1,2,3], 
    [0.555555,-0.166667,0.661017,0.666667]) , predictions: [0.0,10.0,0.0] , 
    probabilities: [0.0,1.0,0.0] 
    predictedLabel: 3.0 , label: 3.0 , features: (4,[0,1,2,3], 
    [0.833333,-0.166667,0.898305,0.666667]) , predictions: [0.0,10.0,0.0] , 
    probabilities: [0.0,1.0,0.0] 
    predictedLabel: 3.0 , label: 3.0 , features: (4,[0,2,3], 
    [0.222222,0.38983,0.583333]) , predictions: [0.0,10.0,0.0] , 
    probabilities: [0.0,1.0,0.0] 
    predictedLabel: 3.0 , label: 3.0 , features: (4,[0,2,3], 
    [0.388889,0.661017,0.833333]) , predictions: [0.0,10.0,0.0] , probabilities: [0.0,1.0,0.0] 

Répondre

0

L'utilisation du format libsvm vous ne pouvez obtenir un nombre entier pour chaque classe, de sorte que vous ne pouvez pas obtenir une chaîne étiquette de classe à partir de là. Vous pouvez utiliser le convertisseur IndexToString() en utilisant la méthode setLabels(). Entrez simplement le tableau des étiquettes que vous avez. Pour que cela fonctionne, vous devriez probablement supprimer le StringIndexerModel() (inutile de toute façon car les classes sont des nombres, pas des chaînes). Exemple:

String[] labels = {"Setosa", "Versicolor", "Virginica"}; 
IndexToString labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("pred‌​ictedLabel").setLabe‌​ls(labels); 

En option, vous pouvez créer un Map séparé où vous associez les entiers aux étiquettes de chaîne. Pour l'ensemble de données Iris, il pourrait ressembler à ceci:

Map labels = new HashMap(); 
labels.put(1, "Setosa"); 
labels.put(2, "Versicolour"); 
labels.put(3, "Virginica"); 

Ensuite, vous pouvez utiliser cette Map pour obtenir les étiquettes de chaîne après toutes les transformations Spark sont faites.

Espérons que ça aide.

+0

La carte peut être très utile mais je ne sais pas comment joindre cette carte aux objets Spark. J'ai ajouté ces lignes et cela m'a beaucoup aidé: 'String [] labels = new String [] {"Iris-Setosa", "Iris-versicolor", "Iris-virginica"}; IndexToString stringConverter = new IndexToString(). SetLabels (étiquettes); /* Convertit les étiquettes indexées en étiquettes d'origine. */ IndexToString labelConverter = new IndexToString(). SetInputCol ("prédiction"). SetOutputCol ("predictedLabel"). SetLabels (stringConverter.getLabels()); ' – Shimrit

+0

@Shimrit La 'Map' ne peut être utilisée séparément qu'après que toutes les transformations aient été effectuées, donc' IndexToString() 'est à préférer. Je vais mettre à jour la réponse pour refléter cela. S'il vous plaît envisager d'accepter la réponse en cliquant sur la coche si elle vous a aidé. :) – Shaido