2016-04-14 5 views
5

J'essaie de définir une UserDefinedAggregateFunction (UDAF) dans Spark, qui compte le nombre d'occurrences pour chaque valeur unique dans une colonne d'un groupe.Pourquoi la carte mutable devient automatiquement immutable dans UserDefinedAggregateFunction (UDAF) dans Spark

Voici un exemple: Supposons que j'ai un dataframe df comme ça,

+----+----+ 
|col1|col2| 
+----+----+ 
| a| a1| 
| a| a1| 
| a| a2| 
| b| b1| 
| b| b2| 
| b| b3| 
| b| b1| 
| b| b1| 
+----+----+ 

Je vais avoir un DistinctValues ​​UDAF

val func = new DistinctValues 

Alors je l'applique à la dataframe df

val agg_value = df.groupBy("col1").agg(func(col("col2")).as("DV")) 

Je m'attends à avoir quelque chose de lik e ceci:

+----+--------------------------+ 
|col1|DV      | 
+----+--------------------------+ 
| a| Map(a1->2, a2->1)  | 
| b| Map(b1->3, b2->1, b3->1)| 
+----+--------------------------+ 

Alors je suis sorti avec un UDAF comme celui-ci,

import org.apache.spark.sql.expressions.MutableAggregationBuffer 
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction 
import org.apache.spark.sql.Row 
import org.apache.spark.sql.types.StructType 
import org.apache.spark.sql.types.StructField 
import org.apache.spark.sql.types.DataType 
import org.apache.spark.sql.types.ArrayType 
import org.apache.spark.sql.types.StringType 
import org.apache.spark.sql.types.MapType 
import org.apache.spark.sql.types.LongType 
import Array._ 

class DistinctValues extends UserDefinedAggregateFunction { 
    def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("value", StringType) :: Nil) 

    def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil) 

    def dataType: DataType = MapType(StringType, LongType) 
    def deterministic: Boolean = true 

    def initialize(buffer: MutableAggregationBuffer): Unit = { 
    buffer(0) = scala.collection.mutable.Map() 
    } 

    def update(buffer: MutableAggregationBuffer, input: Row) : Unit = { 
    val str = input.getAs[String](0) 
    var mp = buffer.getAs[scala.collection.mutable.Map[String, Long]](0) 
    var c:Long = mp.getOrElse(str, 0) 
    c = c + 1 
    mp.put(str, c) 
    buffer(0) = mp 
    } 

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = { 
    var mp1 = buffer1.getAs[scala.collection.mutable.Map[String, Long]](0) 
    var mp2 = buffer2.getAs[scala.collection.mutable.Map[String, Long]](0) 
    mp2 foreach { 
     case (k ,v) => { 
      var c:Long = mp1.getOrElse(k, 0) 
      c = c + v 
      mp1.put(k ,c) 
     } 
    } 
    buffer1(0) = mp1 
    } 

    def evaluate(buffer: Row): Any = { 
     buffer.getAs[scala.collection.mutable.Map[String, LongType]](0) 
    } 
} 

J'ai cette fonction sur mon dataframe,

val func = new DistinctValues 
val agg_values = df.groupBy("col1").agg(func(col("col2")).as("DV")) 

Il a donné une telle erreur,

func: DistinctValues = [email protected] 
org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 32.0 failed 4 times, most recent failure: Lost task 1.3 in stage 32.0 (TID 884, ip-172-31-22-166.ec2.internal): java.lang.ClassCastException: scala.collection.immutable.Map$EmptyMap$ cannot be cast to scala.collection.mutable.Map 
at $iwC$$iwC$DistinctValues.update(<console>:39) 
at org.apache.spark.sql.execution.aggregate.ScalaUDAF.update(udaf.scala:431) 
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:187) 
at org.apache.spark.sql.execution.aggregate.AggregationIterator$$anonfun$12.apply(AggregationIterator.scala:180) 
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.processCurrentSortedGroup(SortBasedAggregationIterator.scala:116) 
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:152) 
at org.apache.spark.sql.execution.aggregate.SortBasedAggregationIterator.next(SortBasedAggregationIterator.scala:29) 
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) 
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) 
at org.apache.spark.shuffle.sort.BypassMergeSortShuffleWriter.write(BypassMergeSortShuffleWriter.java:149) 
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:73) 
at org.apache.spark.scheduler.ShuffleMapTask.runTask(ShuffleMapTask.scala:41) 
at org.apache.spark.scheduler.Task.run(Task.scala:89) 
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:213) 
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1145) 
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:615) 
at java.lang.Thread.run(Thread.java:745) 

On dirait dans le update(buffer: MutableAggregationBuffer, input: Row) méthode, la buffer variable est un immutable.Map, le programme fatigué de le jeter aux mutable.Map,

Mais j'utilisé mutable.Map pour initialiser la variable buffer dans la méthode initialize(buffer: MutableAggregationBuffer, input:Row). Est-ce la même variable transmise à la méthode update? Et aussi buffer est mutableAggregationBuffer, donc il devrait être mutable, non?

Pourquoi mon mutable.Map est devenu immuable? Est-ce que quelqu'un sait ce qui est arrivé?

J'ai vraiment besoin d'une carte mutable dans cette fonction pour terminer la tâche. Je sais qu'il existe une solution pour créer une carte modifiable à partir de la carte immuable, puis la mettre à jour. Mais je veux vraiment savoir pourquoi le mutable se transforme automatiquement en programme immuable, cela n'a pas de sens pour moi.

Répondre

4

Croyez que c'est le MapType dans votre StructType. buffer détient donc un Map, ce qui serait immuable.

Vous pouvez le convertir, mais pourquoi ne pas vous laisser juste et immuable faire:

mp = mp + (k -> c) 

pour ajouter une entrée à la Map immuable?

Exemple de travail ci-dessous:

class DistinctValues extends UserDefinedAggregateFunction { 
    def inputSchema: org.apache.spark.sql.types.StructType = StructType(StructField("_2", IntegerType) :: Nil) 

    def bufferSchema: StructType = StructType(StructField("values", MapType(StringType, LongType))::Nil) 

    def dataType: DataType = MapType(StringType, LongType) 
    def deterministic: Boolean = true 

    def initialize(buffer: MutableAggregationBuffer): Unit = { 
    buffer(0) = Map() 
    } 

    def update(buffer: MutableAggregationBuffer, input: Row) : Unit = { 
    val str = input.getAs[String](0) 
    var mp = buffer.getAs[Map[String, Long]](0) 
    var c:Long = mp.getOrElse(str, 0) 
    c = c + 1 
    mp = mp + (str -> c) 
    buffer(0) = mp 
    } 

    def merge(buffer1: MutableAggregationBuffer, buffer2: Row) : Unit = { 
    var mp1 = buffer1.getAs[Map[String, Long]](0) 
    var mp2 = buffer2.getAs[Map[String, Long]](0) 
    mp2 foreach { 
     case (k ,v) => { 
      var c:Long = mp1.getOrElse(k, 0) 
      c = c + v 
      mp1 = mp1 + (k -> c) 
     } 
    } 
    buffer1(0) = mp1 
    } 

    def evaluate(buffer: Row): Any = { 
     buffer.getAs[Map[String, LongType]](0) 
    } 
} 
+0

Belle prise! Hmm, le 'MapyType' dans' StructType' peut être le cas. Mais il n'y a pas d'autre type de carte modifiable dans 'spark.sql.types', sauf si je définis le mien. –

+0

Comme je l'ai dit, ne pas - il suffit d'utiliser une «carte» immuable. 'mp = mp + (k -> c)' sur un 'Map' immuable vous donne la même fonctionnalité que' mp.mettre (k, c) 'sur une' Map' mutable –

+0

'mp = mp + (k -> c)' fonctionne! Je suis nouveau à Scala, je ne savais pas que vous pouviez manipuler un type de données immuable comme celui-ci. Merci beaucoup! –