2017-08-10 1 views
0

Alors que quelqu'un a déjà posé des questions sur le calcul d'un Weighted Average in Spark, dans cette question, je pose la question à propos de l'utilisation de jeux de données/DataFrames au lieu de RDD.Moyenne pondérée avec les datasets Spark sans UDF

Comment puis-je calculer une moyenne pondérée dans Spark? J'ai deux colonnes: le nombre et moyennes précédentes:

case class Stat(name:String, count: Int, average: Double) 
val statset = spark.createDataset(Seq(Stat("NY", 1,5.0), 
          Stat("NY",2,1.5), 
          Stat("LA",12,1.0), 
          Stat("LA",15,3.0))) 

Je voudrais être en mesure de calculer une moyenne pondérée comme ceci:

display(statset.groupBy($"name").agg(sum($"count").as("count"), 
        weightedAverage($"count",$"average").as("average"))) 

On peut utiliser une UDF pour se rapprocher:

val weightedAverage = udf(
    (row:Row)=>{ 
    val counts = row.getAs[WrappedArray[Int]](0) 
    val averages = row.getAs[WrappedArray[Double]](1) 
    val (count,total) = (counts zip averages).foldLeft((0,0.0)){ 
     case((cumcount:Int,cumtotal:Double),(newcount:Int,newaverage:Double))=>(cumcount+newcount,cumtotal+newcount*newaverage)} 
    (total/count) // Tested by returning count here and then extracting. Got same result as sum. 
    } 
) 

display(statset.groupBy($"name").agg(sum($"count").as("count"), 
        weightedAverage(struct(collect_list($"count"), 
            collect_list($"average"))).as("average"))) 

(Merci aux réponses à Passing a list of tuples as a parameter to a spark udf in scala pour aider à la rédaction de ce)

Newb ies: utilisez ces importations:

import org.apache.spark.sql._ 
import org.apache.spark.sql.functions._ 
import org.apache.spark.sql.types._ 
import scala.collection.mutable.WrappedArray 

Existe-t-il un moyen d'y parvenir avec des fonctions de colonne intégrées au lieu de fonctions UDF? L'UDF se sent mal à l'aise et si les nombres deviennent importants, vous devez convertir les Int à Long.

Répondre

1

On dirait que vous pourriez le faire en deux passes:

val totalCount = statset.select(sum($"count")).collect.head.getLong(0) 

statset.select(lit(totalCount) as "count", sum($"average" * $"count"/lit(totalCount)) as "average").show 

Ou, y compris le groupBy vous venez d'ajouter:

display(statset.groupBy($"name").agg(sum($"count").as("count"), 
        sum($"count"*$"average").as("total")) 
       .select($"name",$"count",($"total"/$"count"))) 
+0

Dans mon code actuel, j'ai groupBy ... encore, ce pourrait fonctionner ... –

+0

Je voudrais ajouter le nombre total comme une autre colonne dans la deuxième agrégation et ensuite faire la division à la fin. Le deuxième passage aurait besoin de beaucoup plus de données. –

+0

@MichelLemay: Merci! C'est juste ce dont j'avais besoin pour faire du jogging. J'ai suggéré une modification à votre réponse qui fonctionne également avec groupBy. –