2015-08-29 2 views
0

J'essaye de multiplier une matrice avec sa transposition, mais je n'ai pas réussi à faire un appel de sgemm correct. Sgemm prend de nombreux paramètres. Certains de ceux comme lda, ldb sont confus pour moi. Si j'appelle la fonction ci-dessous avec une matrice carrée, cela fonctionne sinon cela ne fonctionne pas.Multiplier une matrice avec sa transposition en utilisant cuBlas

/*param inMatrix: contains the matrix data in major order like [1 2 3 1 2 3] 
    param rowNum: Number of rows in a matrix eg if matrix is 
       |1 1| 
       |2 2| 
       |3 3| than rowNum should be 3*/ 
void matrixtTransposeMult(std::vector<float>& inMatrix, int rowNum) 
{ 
    cublasHandle_t handle; 
    cublasCreate(&handle); 

    int colNum = (int)inMatrix.size()/rowNum; 
    thrust::device_vector<float> d_InMatrix(inMatrix); 
    thrust::device_vector<float> d_outputMatrix(rowNum*rowNum); 
    float alpha = 1.0f; 
    float beta = 0.0f; 

    cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, rowNum, rowNum, colNum, &alpha, 
     thrust::raw_pointer_cast(d_InMatrix.data()), colNum, thrust::raw_pointer_cast(d_InMatrix.data()), colNum, &beta, 
     thrust::raw_pointer_cast(d_outputMatrix.data()), rowNum); 

    thrust::host_vector<float> result = d_outputMatrix; 
    for (auto elem : result) 
     std::cout << elem << ","; 
    std::cout << std::endl; 

    cublasDestroy(handle); 
} 

Que manque-t-il? Comment faire un appel de sgemm correct pour matrix * matrixTranspose?

Répondre

1

Ci-dessous les paramètres travaillés pour moi, si quelque chose me manque, veuillez me prévenir. J'espère que ce sera utile pour quelqu'un

void matrixtTransposeMult(std::vector<float>& inMatrix, int rowNum) 
{ 
    cublasHandle_t handle; 
    cublasCreate(&handle); 

    int colNum = (int)inMatrix.size()/rowNum; 
    thrust::device_vector<float> d_InMatrix(inMatrix); 
    thrust::device_vector<float> d_outputMatrix(rowNum*rowNum); 
    float alpha = 1.0f; 
    float beta = 0.0f; 

    cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_T, rowNum, rowNum, colNum, &alpha, 
     thrust::raw_pointer_cast(d_InMatrix.data()), rowNum, thrust::raw_pointer_cast(d_InMatrix.data()), rowNum, &beta, 
     thrust::raw_pointer_cast(d_outputMatrix.data()), rowNum); 

    thrust::host_vector<float> result = d_outputMatrix; 
    for (auto elem : result) 
     std::cout << elem << ","; 
    std::cout << std::endl; 

    cublasDestroy(handle); 
}