EDIT: J'ai amélioré les tests pour donner des temps plus précis. J'ai aussi optimisé la version déroulée qui est maintenant bien meilleure que ce que j'avais initialement, la multiplication matricielle est encore plus rapide au fur et à mesure que vous augmentez la taille. EDIT2: Pour m'assurer que le compilateur JIT travaille sur les fonctions déroulées, j'ai modifié le code pour écrire les fonctions générées sous forme de fichiers M. Aussi la comparaison peut maintenant être considérée comme équitable que les deux méthodes s'évaluées en passant TimeIt la poignée de fonction: timeit(@myfunc)
Je ne suis pas convaincu que votre approche est plus rapide que la multiplication de matrices pour des tailles raisonnables. Comparons donc les deux méthodes.
J'utilise la Soroban pour me aider à la « déroula » forme de l'équation de x'*A*x
(essayez multiplier manuellement une matrice 20x20 et un vecteur 20x1!):
function f = buildUnrolledFunction(N)
% avoid regenerating files, CCODE below can be really slow!
fname = sprintf('f%d',N);
if exist([fname '.m'], 'file')
f = str2func(fname);
return
end
% construct symbolic vector/matrix of the specified size
x = sym('x', [N 1]);
A = sym('A', [N N]);
% work out the expanded form of the matrix-multiplication
% and convert it to a string
s = ccode(expand(x.'*A*x)); % instead of char(.) to avoid x^2
% a bit of RegExp to fix the notation of the variable names
% also convert indexing into linear indices: A(3,3) into A(9)
s = regexprep(regexprep(s, '^.*=\s+', ''), ';$', '');
s = regexprep(regexprep(s, 'x(\d+)', 'x($1)'), 'A(\d+)_(\d+)', ...
'A(${ int2str(sub2ind([N N],str2num($1),str2num($2))) })');
% build an M-function from the string, and write it to file
fid = fopen([fname '.m'], 'wt');
fprintf(fid, 'function v = %s(A,x)\nv = %s;\nend\n', fname, s);
fclose(fid);
% rehash path and return a function handle
rehash
clear(fname)
f = str2func(fname);
end
J'ai essayé de optimiser la fonction générée en évitant l'exponentiation (nous préférons x*x
à x^2
). J'ai également converti les indices en indices linéaires (A(9)
au lieu de A(3,3)
). Par conséquent, pour n=3
nous obtenons la même équation que vous aviez:
>> s
s =
A(1)*(x(1)*x(1)) + A(5)*(x(2)*x(2)) + A(9)*(x(3)*x(3)) +
A(4)*x(1)*x(2) + A(7)*x(1)*x(3) + A(2)*x(1)*x(2) +
A(8)*x(2)*x(3) + A(3)*x(1)*x(3) + A(6)*x(2)*x(3)
Compte tenu de la méthode ci-dessus pour construire M-fonctions, nous évaluons maintenant pour différentes tailles et la comparer à la forme matrix-multiplication (je l'ai mis dans une fonction distincte tenir compte des frais généraux d'appel de fonction). J'utilise la fonction TIMEIT au lieu de tic/toc
pour obtenir des horaires plus précis. De même, pour avoir une comparaison équitable, chaque méthode est implémentée comme une fonction de fichier M qui passe toutes les variables nécessaires comme arguments d'entrée.
function results = testMatrixMultVsUnrolled()
% vector/matrix size
N_vec = 2:50;
results = zeros(numel(N_vec),3);
for ii = 1:numel(N_vec);
% some random data
N = N_vec(ii);
x = rand(N,1); A = rand(N,N);
% matrix multiplication
f = @matMult;
results(ii,1) = timeit(@() feval(f, A,x));
% unrolled equation
f = buildUnrolledFunction(N);
results(ii,2) = timeit(@() feval(f, A,x));
% check result
results(ii,3) = norm(matMult(A,x) - f(A,x));
end
% display results
fprintf('N = %2d: mtimes = %.6f ms, unroll = %.6f ms [error = %g]\n', ...
[N_vec(:) results(:,1:2)*1e3 results(:,3)]')
plot(N_vec, results(:,1:2)*1e3, 'LineWidth',2)
xlabel('size (N)'), ylabel('timing [msec]'), grid on
legend({'mtimes','unrolled'})
title('Matrix multiplication: $$x^\mathsf{T}Ax$$', ...
'Interpreter','latex', 'FontSize',14)
end
function v = matMult(A,x)
v = x.' * A * x;
end
Les résultats:
N = 2: mtimes = 0.008816 ms, unroll = 0.006793 ms [error = 0]
N = 3: mtimes = 0.008957 ms, unroll = 0.007554 ms [error = 0]
N = 4: mtimes = 0.009025 ms, unroll = 0.008261 ms [error = 4.44089e-16]
N = 5: mtimes = 0.009075 ms, unroll = 0.008658 ms [error = 0]
N = 6: mtimes = 0.009003 ms, unroll = 0.008689 ms [error = 8.88178e-16]
N = 7: mtimes = 0.009234 ms, unroll = 0.009087 ms [error = 1.77636e-15]
N = 8: mtimes = 0.008575 ms, unroll = 0.009744 ms [error = 8.88178e-16]
N = 9: mtimes = 0.008601 ms, unroll = 0.011948 ms [error = 0]
N = 10: mtimes = 0.009077 ms, unroll = 0.014052 ms [error = 0]
N = 11: mtimes = 0.009339 ms, unroll = 0.015358 ms [error = 3.55271e-15]
N = 12: mtimes = 0.009271 ms, unroll = 0.018494 ms [error = 3.55271e-15]
N = 13: mtimes = 0.009166 ms, unroll = 0.020238 ms [error = 0]
N = 14: mtimes = 0.009204 ms, unroll = 0.023326 ms [error = 7.10543e-15]
N = 15: mtimes = 0.009396 ms, unroll = 0.024767 ms [error = 3.55271e-15]
N = 16: mtimes = 0.009193 ms, unroll = 0.027294 ms [error = 2.4869e-14]
N = 17: mtimes = 0.009182 ms, unroll = 0.029698 ms [error = 2.13163e-14]
N = 18: mtimes = 0.009330 ms, unroll = 0.033295 ms [error = 7.10543e-15]
N = 19: mtimes = 0.009411 ms, unroll = 0.152308 ms [error = 7.10543e-15]
N = 20: mtimes = 0.009366 ms, unroll = 0.167336 ms [error = 7.10543e-15]
N = 21: mtimes = 0.009335 ms, unroll = 0.183371 ms [error = 0]
N = 22: mtimes = 0.009349 ms, unroll = 0.200859 ms [error = 7.10543e-14]
N = 23: mtimes = 0.009411 ms, unroll = 0.218477 ms [error = 8.52651e-14]
N = 24: mtimes = 0.009307 ms, unroll = 0.235668 ms [error = 4.26326e-14]
N = 25: mtimes = 0.009425 ms, unroll = 0.256491 ms [error = 1.13687e-13]
N = 26: mtimes = 0.009392 ms, unroll = 0.274879 ms [error = 7.10543e-15]
N = 27: mtimes = 0.009515 ms, unroll = 0.296795 ms [error = 2.84217e-14]
N = 28: mtimes = 0.009567 ms, unroll = 0.319032 ms [error = 5.68434e-14]
N = 29: mtimes = 0.009548 ms, unroll = 0.339517 ms [error = 3.12639e-13]
N = 30: mtimes = 0.009617 ms, unroll = 0.361897 ms [error = 1.7053e-13]
N = 31: mtimes = 0.009672 ms, unroll = 0.387270 ms [error = 0]
N = 32: mtimes = 0.009629 ms, unroll = 0.410932 ms [error = 1.42109e-13]
N = 33: mtimes = 0.009605 ms, unroll = 0.434452 ms [error = 1.42109e-13]
N = 34: mtimes = 0.009534 ms, unroll = 0.462961 ms [error = 0]
N = 35: mtimes = 0.009696 ms, unroll = 0.489474 ms [error = 5.68434e-14]
N = 36: mtimes = 0.009691 ms, unroll = 0.512198 ms [error = 8.52651e-14]
N = 37: mtimes = 0.009671 ms, unroll = 0.544485 ms [error = 5.68434e-14]
N = 38: mtimes = 0.009710 ms, unroll = 0.573564 ms [error = 8.52651e-14]
N = 39: mtimes = 0.009946 ms, unroll = 0.604567 ms [error = 3.41061e-13]
N = 40: mtimes = 0.009735 ms, unroll = 0.636640 ms [error = 3.12639e-13]
N = 41: mtimes = 0.009858 ms, unroll = 0.665719 ms [error = 5.40012e-13]
N = 42: mtimes = 0.009876 ms, unroll = 0.697364 ms [error = 0]
N = 43: mtimes = 0.009956 ms, unroll = 0.730506 ms [error = 2.55795e-13]
N = 44: mtimes = 0.009897 ms, unroll = 0.765358 ms [error = 4.26326e-13]
N = 45: mtimes = 0.009991 ms, unroll = 0.800424 ms [error = 0]
N = 46: mtimes = 0.009956 ms, unroll = 0.829717 ms [error = 2.27374e-13]
N = 47: mtimes = 0.010210 ms, unroll = 0.865424 ms [error = 2.84217e-13]
N = 48: mtimes = 0.010022 ms, unroll = 0.907974 ms [error = 3.97904e-13]
N = 49: mtimes = 0.010098 ms, unroll = 0.944536 ms [error = 5.68434e-13]
N = 50: mtimes = 0.010153 ms, unroll = 0.984486 ms [error = 4.54747e-13]
Dans les petites tailles les deux méthodes effectuent un peu. De la même façon Bien que pour N<7
la version étendue bat mtimes
, mais la différence est à peine significative.Une fois que nous avons dépassé les tailles minuscules, la multiplication de la matrice est plus rapide.
Ce n'est pas vraiment surprenant; avec seulement N=20
le formula est effrayant et implique l'ajout de 400 termes. Comme la langue MatLab interprété, je doute que ce soit très efficace ..
Maintenant, je suis d'accord qu'il ya une surcharge pour appeler une fonction externe en fonction de l'intégration directe du code en ligne, mais comment est pratique une telle approche. Même pour une petite taille comme N=20
, la ligne générée est plus de 7000 caractères! J'ai aussi remarqué que l'éditeur MATLAB devenait lent à cause des longues lignes :)
En outre, l'avantage disparaît rapidement après environ N>10
. J'ai comparé la multiplication par code intégré/écrit explicitement par rapport à la matrice, similaire à ce que @DennisJaheruddin a suggéré. Le results:
N=3:
Elapsed time is 0.062295 seconds. % unroll
Elapsed time is 1.117962 seconds. % mtimes
N=12:
Elapsed time is 1.024837 seconds. % unroll
Elapsed time is 1.126147 seconds. % mtimes
N=19:
Elapsed time is 140.915138 seconds. % unroll
Elapsed time is 1.305382 seconds. % mtimes
... et il ne fait qu'empirer pour la version déroulée. Comme je l'ai déjà dit, MATLAB est interprété de sorte que le coût de l'analyse du code commence à apparaître dans des fichiers aussi volumineux. Comme je le vois, après avoir fait un million d'itérations, nous avons seulement gagné 1 seconde au mieux, ce qui, je pense, ne justifie pas tous les problèmes et les hacks, sur l'utilisation de la plus lisible et succincte v=x'*A*x
. Donc, il ya d'autres endroits dans le code que l'on peut améliorer, plutôt que de se concentrer sur une opération déjà optimisée telle que la multiplication matricielle.
Matrix multiplication dans MATLAB isseriouslyfast (c'est ce que MATLAB est le meilleur!). Il brille vraiment une fois que vous atteignez des données assez grandes (comme multithreading) entre en jeu:
>> N=5000; x=rand(N,1); A=rand(N,N);
>> tic, for i=1e4, v=x.'*A*x; end, toc
Elapsed time is 0.021959 seconds.
Pouvez-vous s'il vous plaît poster le code avant et après la modification? – higuaro
Avant (version plus rapide) val = x (1) * A (1,1) * x (1) + ... x (1) * A (1,2) * x (2) + ... x (1) * A (1,3) * x (3) + ... x (2) * A (2,1) * x (1) + ... x (2) * A (2,2) * x (2) + ... x (2) * A (2,3) * x (3) + ... x (3) * A (3,1) * x (1) + ... x (3) * A (3,2) * x (2) + ... x (3) * A (3,3) * x (3); vs val = x * A * x '; – user888379
parfois dérouler les équations de cette façon est plus rapide, même si ce n'est pas possible pour les grandes tailles ...Aussi, vous ne serez pas en profitant de la mise en œuvre de BLAS optimisée pour la multiplication de matrices (Intel MKL) – Amro