2010-02-04 4 views
11

Une coupe directe et coller de l'algorithme suivant:tri fusion de « programmation Scala » provoque un débordement de pile

def msort[T](less: (T, T) => Boolean) 
      (xs: List[T]): List[T] = { 
    def merge(xs: List[T], ys: List[T]): List[T] = 
    (xs, ys) match { 
     case (Nil, _) => ys 
     case (_, Nil) => xs 
     case (x :: xs1, y :: ys1) => 
     if (less(x, y)) x :: merge(xs1, ys) 
     else y :: merge(xs, ys1) 
    } 
    val n = xs.length/2 
    if (n == 0) xs 
    else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs)) 
    } 
} 

provoque une StackOverflowError sur 5000 listes longues.

Y a-t-il un moyen d'optimiser cela afin que cela ne se produise pas?

Répondre

17

Il le fait parce qu'il n'est pas récursif. Vous pouvez résoudre ce problème en utilisant une collection non stricte ou en la rendant récursive.

La dernière solution va comme ceci:

def msort[T](less: (T, T) => Boolean) 
      (xs: List[T]): List[T] = { 
    def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] = 
    (xs, ys) match { 
     case (Nil, _) => ys.reverse ::: acc 
     case (_, Nil) => xs.reverse ::: acc 
     case (x :: xs1, y :: ys1) => 
     if (less(x, y)) merge(xs1, ys, x :: acc) 
     else merge(xs, ys1, y :: acc) 
    } 
    val n = xs.length/2 
    if (n == 0) xs 
    else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs), Nil).reverse 
    } 
} 

L'utilisation non-implique soit des paramètres de rigueur en passant par nom, ou en utilisant des collections non strictes telles que Stream. Le code suivant utilise Stream juste pour éviter un débordement pile, et List ailleurs:

def msort[T](less: (T, T) => Boolean) 
      (xs: List[T]): List[T] = { 
    def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match { 
    case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right)) 
    case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys)) 
    case _ => if (left.isEmpty) right.toStream else left.toStream 
    } 
    val n = xs.length/2 
    if (n == 0) xs 
    else { 
    val (ys, zs) = xs splitAt n 
    merge(msort(less)(ys), msort(less)(zs)).toList 
    } 
} 
+0

J'ai pensé à essayer de rendre la queue récursive, puis j'ai vu pas mal d'informations affirmant que la JVM n'est pas très compréhensible et n'optimise pas toujours la récursion de la queue. Y a-t-il une sorte de ligne directrice pour quand cela réussit? – user44242

+0

La JVM ne fonctionne pas, donc le compilateur Scala le fera pour vous. Il ne le fait que sous certaines conditions: il doit être auto-récursif (c.-à-d. F appelant g, et g appelant f ne fonctionnera pas), il doit être _tail_ récursion, bien sûr (l'appel récursif _must_ toujours être la dernière chose sur ce chemin de code), sur les méthodes il doit être soit 'final' soit' private'. Dans l'exemple, parce que 'merge' est défini dans' msort', au lieu d'être défini sur une classe ou un objet, il est effectivement privé. –

+3

Je pense qu'il peut être utile de mentionner ici que msort lui-même n'est pas récursif, mais la fusion est. Pour quiconque n'est convaincu que par le compilateur, ajoutez @tailrec à la définition de fusion, et vous remarquerez qu'il est accepté comme une fonction récursive de queue, comme l'a souligné Daniel. –

3

Juste au cas où les solutions de Daniel n'a pas fait assez clair, le problème est la récursivité de cette fusion est aussi profond que la longueur de la liste , et ce n'est pas la récursion de queue, donc il ne peut pas être converti en itération.

Scala peut convertir la solution de fusion récursive queue de Daniel en quelque chose à peu près équivalent à ceci:

def merge(xs: List[T], ys: List[T]): List[T] = { 
    var acc:List[T] = Nil 
    var decx = xs 
    var decy = ys 
    while (!decx.isEmpty || !decy.isEmpty) { 
    (decx, decy) match { 
     case (Nil, _) => { acc = decy.reverse ::: acc ; decy = Nil } 
     case (_, Nil) => { acc = decx.reverse ::: acc ; decx = Nil } 
     case (x :: xs1, y :: ys1) => 
     if (less(x, y)) { acc = x :: acc ; decx = xs1 } 
     else { acc = y :: acc ; decy = ys1 } 
    } 
    } 
    acc.reverse 
} 

mais il garde une trace de toutes les variables pour vous.

(Une méthode récursive-queue est celle où la méthode seulement appels se pour obtenir une réponse complète de retraverser,. Il ne se dit et fait alors quelque chose avec le résultat avant de le passer en arrière aussi, la queue-récursion ne peut pas être utilisé si la méthode peut être polymorphe, donc il ne fonctionne généralement que dans les objets ou avec des classes marquées comme finales.)

+1

Est-ce que ce dernier acc devrait être inversé? Si vous utilisiez cela comme une fonction de fusion autonome, il devrait y en avoir, mais il y a peut-être quelque chose à propos de l'utilisation de msort que je ne comprends pas. – timday

+1

@timday - Droite. Fixé. –

6

Juste en jouant avec TailCalls de scala (support de trampoline), qui je soupçonne n'était pas là quand question a été posée à l'origine. Voici une version immuable récursive de la fusion dans Rex's answer.

import scala.util.control.TailCalls._ 

def merge[T <% Ordered[T]](x:List[T],y:List[T]):List[T] = { 

    def build(s:List[T],a:List[T],b:List[T]):TailRec[List[T]] = { 
    if (a.isEmpty) { 
     done(b.reverse ::: s) 
    } else if (b.isEmpty) { 
     done(a.reverse ::: s) 
    } else if (a.head<b.head) { 
     tailcall(build(a.head::s,a.tail,b)) 
    } else { 
     tailcall(build(b.head::s,a,b.tail)) 
    } 
    } 

    build(List(),x,y).result.reverse 
} 

Fonctionne aussi vite que la version mutable sur les grands List[Long] s sur Scala 2.9.1 sur 64bit OpenJDK (amd64 Debian/Squeeze sur un Core i7).

Questions connexes