2016-04-16 1 views
0

Je travaille sur un projet Spark avec scala. Je veux former un modèle qui peut être k_means, gaussian_mixture, la régression logistique, naive_bayes etc. Mais je ne peux pas définir un modèle générique comme type de retour. Puisque les types de ces algorithmes sont différents, comme GaussianMixtureModel, KMeansModel, etc. Je ne trouve aucun moyen logique de retourner ce modèle entraîné.Type multiple pour une variable en étincelle en utilisant scala

Voici une paix de code du projet:

model.model_algorithm match { 

     case "k_means" => 

     val model_k_means = k_means(data, parameters) 

     case "gaussian_mixture" => 

     val model_gaussian_mixture = gaussian_mixture(data, parameters) 

     case "logistic_regression" => 

     val model_logistic_regression = logistic_regression(data, parameters) 

} 

Ainsi est-il un moyen de retourner ce modèle formé ou de définir un modèle générique qui accepte tous les types?

+0

qu'est-ce que vous voulez _do_ avec le modèle formé? Ces classes étendent toutes les fonctions 'org.apache.spark.mllib.util.Saveable',' AntRef' et 'Any', donc votre méthode peut retourner n'importe lequel de ces types, mais cela ne vous aidera pas forcément. Si vous voulez effectuer l'action X sur ces résultats plus tard, vous pouvez créer un trait 'ModelResult' avec la méthode X, rendre ce pattern-matching' ModelResult', et avoir trois implémentations de ce trait, chacune gérant un modèle différent. –

+0

J'ai essayé de les faire de type Any, mais la méthode predict() ne peut pas être utilisée dans ce cas. Pouvez-vous expliquer comment puis-je implémenter le pattern-matching dans ce cas. Merci pour votre réponse. –

+0

Vous avez donc lancé trois modèles et l'appariement de modèles pour savoir lequel fonctionne. Si tel est le cas, c'est une mauvaise pratique. – eliasah

Répondre

1

Vous pouvez créer une interface commune pour envelopper toute votre logique interne de formation et de prédiction et simplement exposer une interface simple à réutiliser.

trait AlgorithmInterface extends Serializable { 
    def train(data: RDD[LabeledPoint]) 
    def predict(record: Vector) 
} 

Et ont des algorithmes mis en œuvre dans les classes comme

class LogisticRegressionAlgorithm extends AlgorithmInterface { 
    var model:LogisticRegressionModel = null 
    override def train(data: RDD[LabeledPoint]): Unit = { 
    model = new LogisticRegressionWithLBFGS() 
     .setNumClasses(10) 
     .run(data) 
    } 
    override def predict(record:Vector): Double = model.predict(record) 
} 

class GaussianMixtureAlgorithm extends AlgorithmInterface { 
    var model: GaussianMixtureModel = null 
    override def train(data: RDD[LabeledPoint]): Unit = { 
    model = new GaussianMixture().setK(2).run(data.map(_.features)) 
    } 
    override def predict(record: Vector) = model.predict(record) 
} 

Sa mise en œuvre

// Assigning the models to an Array[AlgorithmInterface] 
    val models: Array[AlgorithmInterface] = Array(
     new LogisticRegressionAlgorithm(), 
     new GaussianMixtureAlgorithm() 
    ) 
    // Training the Models using the Interfaces Train Function 
    models.foreach(_.train(data)) 
    //Predicting the Value 
    models.foreach(model=> println(model.predict(vectorData)))