2013-08-30 4 views
4

On m'a demandé de faire tourner du code MATLAB plus rapidement, et j'ai été confronté à quelque chose qui me semble étrange.matlab matrix operation speed

Dans l'une des fonctions il y a une boucle où l'on multiplie un vecteur 3x1 (appelons-le x) - une matrice 3x3 (appelons-le A) - et la transposition de x, ce qui donne un scalaire. Le code a l'ensemble des multiplications et des additions élément par élément, et il est assez lourd:

val = x(1)*A(1,1)*x(1) + x(1)*A(1,2)*x(2) + x(1)*A(1,3)*x(3) + ... 
     x(2)*A(2,1)*x(1) + x(2)*A(2,2)*x(2) + x(2)*A(2,3)*x(3) + ... 
     x(3)*A(3,1)*x(1) + x(3)*A(3,2)*x(2) + x(3)*A(3,3)*x(3); 

Je pensais que je venais de le remplacer par tout:

val = x*A*x'; 

À ma grande surprise, il a fonctionné beaucoup plus lentement (comme dans 4-5 fois plus lent). Est-ce juste que le vecteur et la matrice sont si petits que les optimisations de MATLAB ne s'appliquent pas?

+0

Pouvez-vous s'il vous plaît poster le code avant et après la modification? – higuaro

+0

Avant (version plus rapide) val = x (1) * A (1,1) * x (1) + ... x (1) * A (1,2) * x (2) + ... x (1) * A (1,3) * x (3) + ... x (2) * A (2,1) * x (1) + ... x (2) * A (2,2) * x (2) + ... x (2) * A (2,3) * x (3) + ... x (3) * A (3,1) * x (1) + ... x (3) * A (3,2) * x (2) + ... x (3) * A (3,3) * x (3); vs val = x * A * x '; – user888379

+0

parfois dérouler les équations de cette façon est plus rapide, même si ce n'est pas possible pour les grandes tailles ...Aussi, vous ne serez pas en profitant de la mise en œuvre de BLAS optimisée pour la multiplication de matrices (Intel MKL) – Amro

Répondre

8

EDIT: J'ai amélioré les tests pour donner des temps plus précis. J'ai aussi optimisé la version déroulée qui est maintenant bien meilleure que ce que j'avais initialement, la multiplication matricielle est encore plus rapide au fur et à mesure que vous augmentez la taille. EDIT2: Pour m'assurer que le compilateur JIT travaille sur les fonctions déroulées, j'ai modifié le code pour écrire les fonctions générées sous forme de fichiers M. Aussi la comparaison peut maintenant être considérée comme équitable que les deux méthodes s'évaluées en passant TimeIt la poignée de fonction: timeit(@myfunc)


Je ne suis pas convaincu que votre approche est plus rapide que la multiplication de matrices pour des tailles raisonnables. Comparons donc les deux méthodes.

J'utilise la Soroban pour me aider à la « déroula » forme de l'équation de x'*A*x (essayez multiplier manuellement une matrice 20x20 et un vecteur 20x1!):

function f = buildUnrolledFunction(N) 
    % avoid regenerating files, CCODE below can be really slow! 
    fname = sprintf('f%d',N); 
    if exist([fname '.m'], 'file') 
     f = str2func(fname); 
     return 
    end 

    % construct symbolic vector/matrix of the specified size 
    x = sym('x', [N 1]); 
    A = sym('A', [N N]); 

    % work out the expanded form of the matrix-multiplication 
    % and convert it to a string 
    s = ccode(expand(x.'*A*x)); % instead of char(.) to avoid x^2 

    % a bit of RegExp to fix the notation of the variable names 
    % also convert indexing into linear indices: A(3,3) into A(9) 
    s = regexprep(regexprep(s, '^.*=\s+', ''), ';$', ''); 
    s = regexprep(regexprep(s, 'x(\d+)', 'x($1)'), 'A(\d+)_(\d+)', ... 
     'A(${ int2str(sub2ind([N N],str2num($1),str2num($2))) })'); 

    % build an M-function from the string, and write it to file 
    fid = fopen([fname '.m'], 'wt'); 
    fprintf(fid, 'function v = %s(A,x)\nv = %s;\nend\n', fname, s); 
    fclose(fid); 

    % rehash path and return a function handle 
    rehash 
    clear(fname) 
    f = str2func(fname); 
end 

J'ai essayé de optimiser la fonction générée en évitant l'exponentiation (nous préférons x*x à x^2). J'ai également converti les indices en indices linéaires (A(9) au lieu de A(3,3)). Par conséquent, pour n=3 nous obtenons la même équation que vous aviez:

>> s 
s = 
A(1)*(x(1)*x(1)) + A(5)*(x(2)*x(2)) + A(9)*(x(3)*x(3)) + 
A(4)*x(1)*x(2) + A(7)*x(1)*x(3) + A(2)*x(1)*x(2) + 
A(8)*x(2)*x(3) + A(3)*x(1)*x(3) + A(6)*x(2)*x(3) 

Compte tenu de la méthode ci-dessus pour construire M-fonctions, nous évaluons maintenant pour différentes tailles et la comparer à la forme matrix-multiplication (je l'ai mis dans une fonction distincte tenir compte des frais généraux d'appel de fonction). J'utilise la fonction TIMEIT au lieu de tic/toc pour obtenir des horaires plus précis. De même, pour avoir une comparaison équitable, chaque méthode est implémentée comme une fonction de fichier M qui passe toutes les variables nécessaires comme arguments d'entrée.

function results = testMatrixMultVsUnrolled() 
    % vector/matrix size 
    N_vec = 2:50; 
    results = zeros(numel(N_vec),3); 
    for ii = 1:numel(N_vec); 
     % some random data 
     N = N_vec(ii); 
     x = rand(N,1); A = rand(N,N); 

     % matrix multiplication 
     f = @matMult; 
     results(ii,1) = timeit(@() feval(f, A,x)); 

     % unrolled equation 
     f = buildUnrolledFunction(N); 
     results(ii,2) = timeit(@() feval(f, A,x)); 

     % check result 
     results(ii,3) = norm(matMult(A,x) - f(A,x)); 
    end 

    % display results 
    fprintf('N = %2d: mtimes = %.6f ms, unroll = %.6f ms [error = %g]\n', ... 
     [N_vec(:) results(:,1:2)*1e3 results(:,3)]') 
    plot(N_vec, results(:,1:2)*1e3, 'LineWidth',2) 
    xlabel('size (N)'), ylabel('timing [msec]'), grid on 
    legend({'mtimes','unrolled'}) 
    title('Matrix multiplication: $$x^\mathsf{T}Ax$$', ... 
     'Interpreter','latex', 'FontSize',14) 
end 

function v = matMult(A,x) 
    v = x.' * A * x; 
end 

Les résultats:

timing timing_closeup

N = 2: mtimes = 0.008816 ms, unroll = 0.006793 ms [error = 0] 
N = 3: mtimes = 0.008957 ms, unroll = 0.007554 ms [error = 0] 
N = 4: mtimes = 0.009025 ms, unroll = 0.008261 ms [error = 4.44089e-16] 
N = 5: mtimes = 0.009075 ms, unroll = 0.008658 ms [error = 0] 
N = 6: mtimes = 0.009003 ms, unroll = 0.008689 ms [error = 8.88178e-16] 
N = 7: mtimes = 0.009234 ms, unroll = 0.009087 ms [error = 1.77636e-15] 
N = 8: mtimes = 0.008575 ms, unroll = 0.009744 ms [error = 8.88178e-16] 
N = 9: mtimes = 0.008601 ms, unroll = 0.011948 ms [error = 0] 
N = 10: mtimes = 0.009077 ms, unroll = 0.014052 ms [error = 0] 
N = 11: mtimes = 0.009339 ms, unroll = 0.015358 ms [error = 3.55271e-15] 
N = 12: mtimes = 0.009271 ms, unroll = 0.018494 ms [error = 3.55271e-15] 
N = 13: mtimes = 0.009166 ms, unroll = 0.020238 ms [error = 0] 
N = 14: mtimes = 0.009204 ms, unroll = 0.023326 ms [error = 7.10543e-15] 
N = 15: mtimes = 0.009396 ms, unroll = 0.024767 ms [error = 3.55271e-15] 
N = 16: mtimes = 0.009193 ms, unroll = 0.027294 ms [error = 2.4869e-14] 
N = 17: mtimes = 0.009182 ms, unroll = 0.029698 ms [error = 2.13163e-14] 
N = 18: mtimes = 0.009330 ms, unroll = 0.033295 ms [error = 7.10543e-15] 
N = 19: mtimes = 0.009411 ms, unroll = 0.152308 ms [error = 7.10543e-15] 
N = 20: mtimes = 0.009366 ms, unroll = 0.167336 ms [error = 7.10543e-15] 
N = 21: mtimes = 0.009335 ms, unroll = 0.183371 ms [error = 0] 
N = 22: mtimes = 0.009349 ms, unroll = 0.200859 ms [error = 7.10543e-14] 
N = 23: mtimes = 0.009411 ms, unroll = 0.218477 ms [error = 8.52651e-14] 
N = 24: mtimes = 0.009307 ms, unroll = 0.235668 ms [error = 4.26326e-14] 
N = 25: mtimes = 0.009425 ms, unroll = 0.256491 ms [error = 1.13687e-13] 
N = 26: mtimes = 0.009392 ms, unroll = 0.274879 ms [error = 7.10543e-15] 
N = 27: mtimes = 0.009515 ms, unroll = 0.296795 ms [error = 2.84217e-14] 
N = 28: mtimes = 0.009567 ms, unroll = 0.319032 ms [error = 5.68434e-14] 
N = 29: mtimes = 0.009548 ms, unroll = 0.339517 ms [error = 3.12639e-13] 
N = 30: mtimes = 0.009617 ms, unroll = 0.361897 ms [error = 1.7053e-13] 
N = 31: mtimes = 0.009672 ms, unroll = 0.387270 ms [error = 0] 
N = 32: mtimes = 0.009629 ms, unroll = 0.410932 ms [error = 1.42109e-13] 
N = 33: mtimes = 0.009605 ms, unroll = 0.434452 ms [error = 1.42109e-13] 
N = 34: mtimes = 0.009534 ms, unroll = 0.462961 ms [error = 0] 
N = 35: mtimes = 0.009696 ms, unroll = 0.489474 ms [error = 5.68434e-14] 
N = 36: mtimes = 0.009691 ms, unroll = 0.512198 ms [error = 8.52651e-14] 
N = 37: mtimes = 0.009671 ms, unroll = 0.544485 ms [error = 5.68434e-14] 
N = 38: mtimes = 0.009710 ms, unroll = 0.573564 ms [error = 8.52651e-14] 
N = 39: mtimes = 0.009946 ms, unroll = 0.604567 ms [error = 3.41061e-13] 
N = 40: mtimes = 0.009735 ms, unroll = 0.636640 ms [error = 3.12639e-13] 
N = 41: mtimes = 0.009858 ms, unroll = 0.665719 ms [error = 5.40012e-13] 
N = 42: mtimes = 0.009876 ms, unroll = 0.697364 ms [error = 0] 
N = 43: mtimes = 0.009956 ms, unroll = 0.730506 ms [error = 2.55795e-13] 
N = 44: mtimes = 0.009897 ms, unroll = 0.765358 ms [error = 4.26326e-13] 
N = 45: mtimes = 0.009991 ms, unroll = 0.800424 ms [error = 0] 
N = 46: mtimes = 0.009956 ms, unroll = 0.829717 ms [error = 2.27374e-13] 
N = 47: mtimes = 0.010210 ms, unroll = 0.865424 ms [error = 2.84217e-13] 
N = 48: mtimes = 0.010022 ms, unroll = 0.907974 ms [error = 3.97904e-13] 
N = 49: mtimes = 0.010098 ms, unroll = 0.944536 ms [error = 5.68434e-13] 
N = 50: mtimes = 0.010153 ms, unroll = 0.984486 ms [error = 4.54747e-13] 

Dans les petites tailles les deux méthodes effectuent un peu. De la même façon Bien que pour N<7 la version étendue bat mtimes, mais la différence est à peine significative.Une fois que nous avons dépassé les tailles minuscules, la multiplication de la matrice est plus rapide.

Ce n'est pas vraiment surprenant; avec seulement N=20 le formula est effrayant et implique l'ajout de 400 termes. Comme la langue MatLab interprété, je doute que ce soit très efficace ..


Maintenant, je suis d'accord qu'il ya une surcharge pour appeler une fonction externe en fonction de l'intégration directe du code en ligne, mais comment est pratique une telle approche. Même pour une petite taille comme N=20, la ligne générée est plus de 7000 caractères! J'ai aussi remarqué que l'éditeur MATLAB devenait lent à cause des longues lignes :)

En outre, l'avantage disparaît rapidement après environ N>10. J'ai comparé la multiplication par code intégré/écrit explicitement par rapport à la matrice, similaire à ce que @DennisJaheruddin a suggéré. Le results:

N=3: 
    Elapsed time is 0.062295 seconds. % unroll 
    Elapsed time is 1.117962 seconds. % mtimes 

N=12: 
    Elapsed time is 1.024837 seconds. % unroll 
    Elapsed time is 1.126147 seconds. % mtimes 

N=19: 
    Elapsed time is 140.915138 seconds. % unroll 
    Elapsed time is 1.305382 seconds. % mtimes 

... et il ne fait qu'empirer pour la version déroulée. Comme je l'ai déjà dit, MATLAB est interprété de sorte que le coût de l'analyse du code commence à apparaître dans des fichiers aussi volumineux. Comme je le vois, après avoir fait un million d'itérations, nous avons seulement gagné 1 seconde au mieux, ce qui, je pense, ne justifie pas tous les problèmes et les hacks, sur l'utilisation de la plus lisible et succincte v=x'*A*x. Donc, il ya d'autres endroits dans le code que l'on peut améliorer, plutôt que de se concentrer sur une opération déjà optimisée telle que la multiplication matricielle.

Matrix multiplication dans MATLAB isseriouslyfast (c'est ce que MATLAB est le meilleur!). Il brille vraiment une fois que vous atteignez des données assez grandes (comme multithreading) entre en jeu:

>> N=5000; x=rand(N,1); A=rand(N,N); 
>> tic, for i=1e4, v=x.'*A*x; end, toc 
Elapsed time is 0.021959 seconds. 
+0

peut-être un point important qui pourrait expliquer les temps plus rapides rapportés par l'OP, qui pourrait ne pas être couverts par vos tests: fait le travail JIT de Matlab pour les fonctions anonymes? Si ce n'est pas trop de travail, vous pouvez essayer de sauvegarder la fonction générée dans un fichier m, au lieu de créer une fonction à la volée et voir si cela fait une différence. –

+0

@BasSwinckels: terminé, voir la modification récente. Il n'y a pas eu un impact énorme, bien qu'il soit devenu un peu plus rapide pour les petites tailles (vous pouvez toujours le tester vous-même!). Je veux juste vous rappeler que cette gamme de 'N = 2: 50' est des cacahuètes pour' mtimes' qui peuvent facilement gérer à l'échelle de milliers. Je ne voudrais même pas essayer d'écrire la version élargie pour une telle taille :) – Amro

2

@Amro a donné en réponse exstensive, et je suis d'accord qu'en général, vous ne devriez pas la peine d'écrire les calculs explicites et il suffit d'utiliser la matrice multiplication partout dans votre code. Cependant, si votre matrice est assez petite et que vous avez vraiment besoin de calculer quelques milliards de fois, la forme écrite peut être beaucoup plus rapide (moins de frais généraux). Cependant, l'astuce consiste à ne pas placer votre code dans une fonction séparée, car le temps d'appel sera beaucoup plus long que le temps de calcul.

Voici un exemple smalle:

x = 1:3; 
A = rand(3); 
v=0; 

unroll = @(x) A(1)*(x(1)*x(1)) + A(5)*(x(2)*x(2)) + A(9)*(x(3)*x(3)) + A(4)*x(1)*x(2) + A(7)*x(1)*x(3) + A(2)*x(1)*x(2) + A(8)*x(2)*x(3) + A(3)*x(1)*x(3) + A(6)*x(2)*x(3); 
regular = @(x) x*A*x'; 

%Written out, no function call 
tic 
for t = 1:1e6 
    v = A(1)*(x(1)*x(1)) + A(5)*(x(2)*x(2)) + A(9)*(x(3)*x(3)) + A(4)*x(1)*x(2) + A(7)*x(1)*x(3) + A(2)*x(1)*x(2) + A(8)*x(2)*x(3) + A(3)*x(1)*x(3) + A(6)*x(2)*x(3);; 
end 
t1=toc; 

%Matrix form, no function call 
tic 
for t = 1:1e6 
    v = x*A*x'; 
end 
t2=toc; 

%Written out, function call 
tic 
for t = 1:1e6 
    v = unroll(x); 
end 
t3=toc; 

%Matrix form, function call 
tic 
for t = 1:1e6 
    v = regular(x); 
end 
t4=toc; 

[t1;t2;t3;t4] 

Ce qui vous donnera ces résultats:

0.0767 
1.6988 

6.1975 
7.9353 

Donc, si vous l'appelez via une fonction (anonyme), il ne sera pas intéressant d'utiliser la forme écrite, cependant si vous voulez vraiment obtenir la meilleure vitesse en utilisant simplement la forme écrite directement, vous pouvez obtenir une grande accélération pour les matrices minuscules.

+0

Je ne vois toujours pas l'avantage. Même avec 1e6 itérations de telles multiplications matricielles, nous avons à peine vu une amélioration de 1 seconde. Dans l'ensemble, ce n'est pas une grande amélioration et je préférerais rester avec la mise en œuvre simple, d'autant plus que cela ne vaut pas pour les grandes tailles :) – Amro

+0

L'observation sur la surcharge de l'appel de fonction est très pertinente. Un collègue et moi avons fait une expérience en comparant: (a) la forme écrite en ligne, (b) x '* A * x, et (c) la version d'appel de fonction du formulaire écrit. b s'est avéré être deux fois plus rapide que c. – user888379