Grid 0.7.0
BatchedBlas.h
Go to the documentation of this file.
1/*************************************************************************************
2
3 Grid physics library, www.github.com/paboyle/Grid
4
5 Source file: BatchedBlas.h
6
7 Copyright (C) 2023
8
9Author: Peter Boyle <pboyle@bnl.gov>
10
11 This program is free software; you can redistribute it and/or modify
12 it under the terms of the GNU General Public License as published by
13 the Free Software Foundation; either version 2 of the License, or
14 (at your option) any later version.
15
16 This program is distributed in the hope that it will be useful,
17 but WITHOUT ANY WARRANTY; without even the implied warranty of
18 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
19 GNU General Public License for more details.
20
21 You should have received a copy of the GNU General Public License along
22 with this program; if not, write to the Free Software Foundation, Inc.,
23 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
24
25 See the full license in the file "LICENSE" in the top level distribution directory
26*************************************************************************************/
27/* END LEGAL */
28#pragma once
29
30#ifdef GRID_HIP
31#include <hipblas/hipblas.h>
32#endif
33#ifdef GRID_CUDA
34#include <cublas_v2.h>
35#endif
36#ifdef GRID_SYCL
37#include <oneapi/mkl.hpp>
38#endif
39#if 0
40#define GRID_ONE_MKL
41#endif
42#ifdef GRID_ONE_MKL
43#include <oneapi/mkl.hpp>
44#endif
45
47// Need to rearrange lattice data to be in the right format for a
48// batched multiply. Might as well make these static, dense packed
51#ifdef GRID_HIP
52 typedef hipblasHandle_t gridblasHandle_t;
53#endif
54#ifdef GRID_CUDA
55 typedef cublasHandle_t gridblasHandle_t;
56#endif
57#ifdef GRID_SYCL
58 typedef sycl::queue *gridblasHandle_t;
59#endif
60#ifdef GRID_ONE_MKL
61 typedef sycl::queue *gridblasHandle_t;
62#endif
63#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP) && !defined(GRID_ONE_MKL)
64 typedef int32_t gridblasHandle_t;
65#endif
66
68
69class GridBLAS {
70public:
71
72
74 static int gridblasInit;
75
76 static void Init(void)
77 {
78 if ( ! gridblasInit ) {
79#ifdef GRID_CUDA
80 std::cout << "cublasCreate"<<std::endl;
81 cublasCreate(&gridblasHandle);
82 cublasSetPointerMode(gridblasHandle, CUBLAS_POINTER_MODE_DEVICE);
83#endif
84#ifdef GRID_HIP
85 std::cout << "hipblasCreate"<<std::endl;
86 hipblasCreate(&gridblasHandle);
87#endif
88#ifdef GRID_SYCL
89 gridblasHandle = theGridAccelerator;
90#endif
91#ifdef GRID_ONE_MKL
92 sycl::gpu_selector selector;
93 sycl::device selectedDevice { selector };
94 sycl::property_list q_prop{sycl::property::queue::in_order()};
95 gridblasHandle =new sycl::queue (selectedDevice,q_prop);
96#endif
98 }
99 }
100
101 // Force construct once
102 GridBLAS() { Init(); };
104
106 // BLAS GEMM conventions:
108 // - C = alpha A * B + beta C
109 // Dimensions:
110 // - C_m.n
111 // - A_m.k
112 // - B_k.n
113 // - Flops = 8 M N K
114 // - Bytes = 2*sizeof(word) * (MN+MK+KN)
115 // M=60, N=12
116 // Flop/Byte = 8 . 60.60.12 / (60.12+60.60+60.12)/16 = 4 so expect about 4 TF/s on a GCD
118 void synchronise(void)
119 {
120#ifdef GRID_HIP
121 auto err = hipDeviceSynchronize();
122 assert(err==hipSuccess);
123#endif
124#ifdef GRID_CUDA
125 auto err = cudaDeviceSynchronize();
126 assert(err==cudaSuccess);
127#endif
128#ifdef GRID_SYCL
130#endif
131#ifdef GRID_ONE_MKL
132 gridblasHandle->wait();
133#endif
134 }
135
136 void gemmBatched(int m,int n, int k,
137 ComplexD alpha,
138 deviceVector<ComplexD*> &Amk, // pointer list to matrices
140 ComplexD beta,
142 {
144 m,n,k,
145 alpha,
146 Amk,
147 Bkn,
148 beta,
149 Cmn);
150 }
151 void gemmBatched(int m,int n, int k,
152 ComplexF alpha,
153 deviceVector<ComplexF*> &Amk, // pointer list to matrices
155 ComplexF beta,
157 {
159 m,n,k,
160 alpha,
161 Amk,
162 Bkn,
163 beta,
164 Cmn);
165 }
166 void gemmBatched(int m,int n, int k,
167 RealD alpha,
168 deviceVector<RealD*> &Amk, // pointer list to matrices
170 RealD beta,
172 {
174 m,n,k,
175 alpha,
176 Amk,
177 Bkn,
178 beta,
179 Cmn);
180 }
181 void gemmBatched(int m,int n, int k,
182 RealF alpha,
183 deviceVector<RealF*> &Amk, // pointer list to matrices
185 RealF beta,
187 {
189 m,n,k,
190 alpha,
191 Amk,
192 Bkn,
193 beta,
194 Cmn);
195 }
196
199 int m,int n, int k,
200 ComplexD alpha,
201 deviceVector<ComplexD*> &Amk, // pointer list to matrices
203 ComplexD beta,
205 {
206 RealD t2=usecond();
207 int32_t batchCount = Amk.size();
208 assert(Bkn.size()==batchCount);
209 assert(Cmn.size()==batchCount);
210
211 //assert(OpA!=GridBLAS_OP_T); // Complex case expect no transpose
212 //assert(OpB!=GridBLAS_OP_T);
213
214 int lda = m; // m x k column major
215 int ldb = k; // k x n column major
216 int ldc = m; // m x b column major
217 if(OpA!=GridBLAS_OP_N)
218 lda = k;
219 if(OpB!=GridBLAS_OP_N)
220 ldb = n;
221
222 static deviceVector<ComplexD> alpha_p(1);
223 static deviceVector<ComplexD> beta_p(1);
224 // can prestore the 1 and the zero on device
225 acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(ComplexD));
226 acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(ComplexD));
227 RealD t0=usecond();
228 // std::cout << "ZgemmBatched mnk "<<m<<","<<n<<","<<k<<" count "<<batchCount<<std::endl;
229#ifdef GRID_HIP
230 hipblasOperation_t hOpA;
231 hipblasOperation_t hOpB;
232 if ( OpA == GridBLAS_OP_N ) hOpA = HIPBLAS_OP_N;
233 if ( OpA == GridBLAS_OP_T ) hOpA = HIPBLAS_OP_T;
234 if ( OpA == GridBLAS_OP_C ) hOpA = HIPBLAS_OP_C;
235 if ( OpB == GridBLAS_OP_N ) hOpB = HIPBLAS_OP_N;
236 if ( OpB == GridBLAS_OP_T ) hOpB = HIPBLAS_OP_T;
237 if ( OpB == GridBLAS_OP_C ) hOpB = HIPBLAS_OP_C;
238 auto err = hipblasZgemmBatched(gridblasHandle,
239 hOpA,
240 hOpB,
241 m,n,k,
242 (hipblasDoubleComplex *) &alpha_p[0],
243 (hipblasDoubleComplex **)&Amk[0], lda,
244 (hipblasDoubleComplex **)&Bkn[0], ldb,
245 (hipblasDoubleComplex *) &beta_p[0],
246 (hipblasDoubleComplex **)&Cmn[0], ldc,
247 batchCount);
248 // std::cout << " hipblas return code " <<(int)err<<std::endl;
249 assert(err==HIPBLAS_STATUS_SUCCESS);
250#endif
251#ifdef GRID_CUDA
252 cublasOperation_t hOpA;
253 cublasOperation_t hOpB;
254 if ( OpA == GridBLAS_OP_N ) hOpA = CUBLAS_OP_N;
255 if ( OpA == GridBLAS_OP_T ) hOpA = CUBLAS_OP_T;
256 if ( OpA == GridBLAS_OP_C ) hOpA = CUBLAS_OP_C;
257 if ( OpB == GridBLAS_OP_N ) hOpB = CUBLAS_OP_N;
258 if ( OpB == GridBLAS_OP_T ) hOpB = CUBLAS_OP_T;
259 if ( OpB == GridBLAS_OP_C ) hOpB = CUBLAS_OP_C;
260 auto err = cublasZgemmBatched(gridblasHandle,
261 hOpA,
262 hOpB,
263 m,n,k,
264 (cuDoubleComplex *) &alpha_p[0],
265 (cuDoubleComplex **)&Amk[0], lda,
266 (cuDoubleComplex **)&Bkn[0], ldb,
267 (cuDoubleComplex *) &beta_p[0],
268 (cuDoubleComplex **)&Cmn[0], ldc,
269 batchCount);
270 assert(err==CUBLAS_STATUS_SUCCESS);
271#endif
272#ifdef GRID_SYCL
273 int64_t m64=m;
274 int64_t n64=n;
275 int64_t k64=k;
276 int64_t lda64=lda;
277 int64_t ldb64=ldb;
278 int64_t ldc64=ldc;
279 int64_t batchCount64=batchCount;
280
281 oneapi::mkl::transpose iOpA;
282 oneapi::mkl::transpose iOpB;
283
284 if ( OpA == GridBLAS_OP_N ) iOpA = oneapi::mkl::transpose::N;
285 if ( OpA == GridBLAS_OP_T ) iOpA = oneapi::mkl::transpose::T;
286 if ( OpA == GridBLAS_OP_C ) iOpA = oneapi::mkl::transpose::C;
287 if ( OpB == GridBLAS_OP_N ) iOpB = oneapi::mkl::transpose::N;
288 if ( OpB == GridBLAS_OP_T ) iOpB = oneapi::mkl::transpose::T;
289 if ( OpB == GridBLAS_OP_C ) iOpB = oneapi::mkl::transpose::C;
290
291 oneapi::mkl::blas::column_major::gemm_batch(*gridblasHandle,
292 &iOpA,
293 &iOpB,
294 &m64,&n64,&k64,
295 (ComplexD *) &alpha_p[0],
296 (const ComplexD **)&Amk[0], (const int64_t *)&lda64,
297 (const ComplexD **)&Bkn[0], (const int64_t *)&ldb64,
298 (ComplexD *) &beta_p[0],
299 (ComplexD **)&Cmn[0], (const int64_t *)&ldc64,
300 (int64_t)1,&batchCount64,std::vector<sycl::event>());
301 synchronise();
302#if 0
303 // This code was used to check the mat mul on Sunspot/OneMKL
304 std::cerr << " Called SYCL batched ZGEMM OpA "<< OpA << " OpB "<<OpB <<std::endl;
305 std::vector<ComplexD> A(m*k); // pointer list to matrices
306 std::vector<ComplexD> B(k*n);
307 std::vector<ComplexD> C(m*n);
308 // int sda = lda*k;
309 // int sdb = ldb*k;
310 // int sdc = ldc*n;
311 std::cerr << " Checking the GEMM results "<<std::endl;
312 for (int p = 0; p < 1; ++p) {
313 ComplexD * Amk_p; // pointer list to matrices
314 ComplexD * Bkn_p; // pointer list to matrices
315 ComplexD * Cmn_p; // pointer list to matrices
316 acceleratorCopyFromDevice((void *)&Amk[p],(void *)&Amk_p,sizeof(ComplexD*));
317 acceleratorCopyFromDevice((void *)&Bkn[p],(void *)&Bkn_p,sizeof(ComplexD*));
318 acceleratorCopyFromDevice((void *)&Cmn[p],(void *)&Cmn_p,sizeof(ComplexD*));
319 std::cerr << " p " << p << " copied pointers "<<std::endl;
320 acceleratorCopyFromDevice((void *)Amk_p,(void *)&A[0],m*k*sizeof(ComplexD));
321 acceleratorCopyFromDevice((void *)Bkn_p,(void *)&B[0],k*n*sizeof(ComplexD));
322 acceleratorCopyFromDevice((void *)Cmn_p,(void *)&C[0],m*n*sizeof(ComplexD));
323 std::cerr << " p " << p << " copied matrices "<<std::endl;
324 std::cerr << " C[0] "<<C[0]<<std::endl;
325 std::cerr << " A[0] "<<A[0]<<std::endl;
326 std::cerr << " B[0] "<<B[0]<<std::endl;
327 std::cerr << " m "<<m<<std::endl;
328 std::cerr << " n "<<n<<std::endl;
329 std::cerr << " k "<<k<<std::endl;
330 for (int mm = 0; mm < m; ++mm) {
331 for (int nn = 0; nn < n; ++nn) {
332 ComplexD c_mn(0.0);
333 for (int kk = 0; kk < k; ++kk) {
334 int idx_a, idx_b;
335 // int lda = m; // m x k column major
336 // int ldb = k; // k x n column major
337 // int ldc = m; // m x b column major
338 if(OpA!=GridBLAS_OP_N) {
339 idx_a =kk + mm*lda;
340 } else {
341 idx_a =mm + kk*lda;
342 }
343 if(OpB!=GridBLAS_OP_N) {
344 idx_b =nn + kk*ldb;
345 } else {
346 idx_b =kk + nn*ldb;
347 }
348 // std::cerr << " idx_a "<<idx_a<<" idx_b "<<idx_b<<std::endl;
349
350 ComplexD Ac = A[idx_a];
351 ComplexD Bc = B[idx_b];
352 if(OpA==GridBLAS_OP_C) Ac = conjugate(Ac);
353 if(OpB==GridBLAS_OP_C) Bc = conjugate(Bc);
354
355 c_mn += Ac*Bc;
356 }
357 std::cerr << " beta "<<beta<<" alpha "<<alpha<<" C_"<<mm<<","<<nn<<" "<<c_mn<<" "<<C[mm + nn*ldc]<<std::endl;
358 }
359 }
360 }
361#endif
362#endif
363#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
364 // Need a default/reference implementation; use Eigen
365 if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_N) ) {
366 thread_for (p, batchCount, {
367 Eigen::Map<Eigen::MatrixXcd> eAmk(Amk[p],m,k);
368 Eigen::Map<Eigen::MatrixXcd> eBkn(Bkn[p],k,n);
369 Eigen::Map<Eigen::MatrixXcd> eCmn(Cmn[p],m,n);
370 if (std::abs(beta) != 0.0)
371 eCmn = beta * eCmn + alpha * eAmk * eBkn ;
372 else
373 eCmn = alpha * eAmk * eBkn ;
374 });
375 } else if ( (OpA == GridBLAS_OP_C ) && (OpB == GridBLAS_OP_N) ) {
376 thread_for (p, batchCount, {
377 Eigen::Map<Eigen::MatrixXcd> eAmk(Amk[p],k,m);
378 Eigen::Map<Eigen::MatrixXcd> eBkn(Bkn[p],k,n);
379 Eigen::Map<Eigen::MatrixXcd> eCmn(Cmn[p],m,n);
380 if (std::abs(beta) != 0.0)
381 eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn ;
382 else
383 eCmn = alpha * eAmk.adjoint() * eBkn ;
384 });
385 } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_N) ) {
386 thread_for (p, batchCount, {
387 Eigen::Map<Eigen::MatrixXcd> eAmk(Amk[p],k,m);
388 Eigen::Map<Eigen::MatrixXcd> eBkn(Bkn[p],k,n);
389 Eigen::Map<Eigen::MatrixXcd> eCmn(Cmn[p],m,n);
390 if (std::abs(beta) != 0.0)
391 eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn ;
392 else
393 eCmn = alpha * eAmk.transpose() * eBkn ;
394 });
395 } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_C) ) {
396 thread_for (p, batchCount, {
397 Eigen::Map<Eigen::MatrixXcd> eAmk(Amk[p],m,k);
398 Eigen::Map<Eigen::MatrixXcd> eBkn(Bkn[p],n,k);
399 Eigen::Map<Eigen::MatrixXcd> eCmn(Cmn[p],m,n);
400 if (std::abs(beta) != 0.0)
401 eCmn = beta * eCmn + alpha * eAmk * eBkn.adjoint() ;
402 else
403 eCmn = alpha * eAmk * eBkn.adjoint() ;
404 });
405 } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_T) ) {
406 thread_for (p, batchCount, {
407 Eigen::Map<Eigen::MatrixXcd> eAmk(Amk[p],m,k);
408 Eigen::Map<Eigen::MatrixXcd> eBkn(Bkn[p],n,k);
409 Eigen::Map<Eigen::MatrixXcd> eCmn(Cmn[p],m,n);
410 eCmn = beta * eCmn + alpha * eAmk * eBkn.transpose() ;
411 });
412 } else if ( (OpA == GridBLAS_OP_C ) && (OpB == GridBLAS_OP_C) ) {
413 thread_for (p, batchCount, {
414 Eigen::Map<Eigen::MatrixXcd> eAmk(Amk[p],k,m);
415 Eigen::Map<Eigen::MatrixXcd> eBkn(Bkn[p],n,k);
416 Eigen::Map<Eigen::MatrixXcd> eCmn(Cmn[p],m,n);
417 if (std::abs(beta) != 0.0)
418 eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn.adjoint() ;
419 else
420 eCmn = alpha * eAmk.adjoint() * eBkn.adjoint() ;
421 } );
422 } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_T) ) {
423 thread_for (p, batchCount, {
424 Eigen::Map<Eigen::MatrixXcd> eAmk(Amk[p],k,m);
425 Eigen::Map<Eigen::MatrixXcd> eBkn(Bkn[p],n,k);
426 Eigen::Map<Eigen::MatrixXcd> eCmn(Cmn[p],m,n);
427 if (std::abs(beta) != 0.0)
428 eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn.transpose() ;
429 else
430 eCmn = alpha * eAmk.transpose() * eBkn.transpose() ;
431 } );
432 } else {
433 assert(0);
434 }
435#endif
436 RealD t1=usecond();
437 RealD flops = 8.0*m*n*k*batchCount;
438 RealD bytes = 1.0*sizeof(ComplexD)*(m*k+k*n+m*n)*batchCount;
439 // std::cout <<GridLogMessage<< " batched Blas copy "<<(t0-t2)/1.e3 <<" ms "<<std::endl;
440 // std::cout <<GridLogMessage<< " batched Blas zGemm call "<<m<<","<<n<<","<<k<<" "<< flops/(t1-t0)/1.e3 <<" GF/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
441 // std::cout <<GridLogMessage<< " batched Blas zGemm call "<<m<<","<<n<<","<<k<<" "<< bytes/(t1-t0)/1.e3 <<" GB/s "<<(t1-t0)/1.e3<<" ms "<<std::endl;
442 }
443
446 int m,int n, int k,
447 ComplexF alpha,
448 deviceVector<ComplexF*> &Amk, // pointer list to matrices
450 ComplexF beta,
452 {
453 RealD t2=usecond();
454 int32_t batchCount = Amk.size();
455
456 //assert(OpA!=GridBLAS_OP_T); // Complex case expect no transpose
457 //assert(OpB!=GridBLAS_OP_T);
458
459 int lda = m; // m x k column major
460 int ldb = k; // k x n column major
461 int ldc = m; // m x b column major
462 if(OpA!=GridBLAS_OP_N)
463 lda = k;
464 if(OpB!=GridBLAS_OP_N)
465 ldb = n;
466 static deviceVector<ComplexF> alpha_p(1);
467 static deviceVector<ComplexF> beta_p(1);
468 // can prestore the 1 and the zero on device
469 acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(ComplexF));
470 acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(ComplexF));
471 RealD t0=usecond();
472
473 assert(Bkn.size()==batchCount);
474 assert(Cmn.size()==batchCount);
475#ifdef GRID_HIP
476 hipblasOperation_t hOpA;
477 hipblasOperation_t hOpB;
478 if ( OpA == GridBLAS_OP_N ) hOpA = HIPBLAS_OP_N;
479 if ( OpA == GridBLAS_OP_T ) hOpA = HIPBLAS_OP_T;
480 if ( OpA == GridBLAS_OP_C ) hOpA = HIPBLAS_OP_C;
481 if ( OpB == GridBLAS_OP_N ) hOpB = HIPBLAS_OP_N;
482 if ( OpB == GridBLAS_OP_T ) hOpB = HIPBLAS_OP_T;
483 if ( OpB == GridBLAS_OP_C ) hOpB = HIPBLAS_OP_C;
484 auto err = hipblasCgemmBatched(gridblasHandle,
485 hOpA,
486 hOpB,
487 m,n,k,
488 (hipblasComplex *) &alpha_p[0],
489 (hipblasComplex **)&Amk[0], lda,
490 (hipblasComplex **)&Bkn[0], ldb,
491 (hipblasComplex *) &beta_p[0],
492 (hipblasComplex **)&Cmn[0], ldc,
493 batchCount);
494
495 assert(err==HIPBLAS_STATUS_SUCCESS);
496#endif
497#ifdef GRID_CUDA
498 cublasOperation_t hOpA;
499 cublasOperation_t hOpB;
500 if ( OpA == GridBLAS_OP_N ) hOpA = CUBLAS_OP_N;
501 if ( OpA == GridBLAS_OP_T ) hOpA = CUBLAS_OP_T;
502 if ( OpA == GridBLAS_OP_C ) hOpA = CUBLAS_OP_C;
503 if ( OpB == GridBLAS_OP_N ) hOpB = CUBLAS_OP_N;
504 if ( OpB == GridBLAS_OP_T ) hOpB = CUBLAS_OP_T;
505 if ( OpB == GridBLAS_OP_C ) hOpB = CUBLAS_OP_C;
506 auto err = cublasCgemmBatched(gridblasHandle,
507 hOpA,
508 hOpB,
509 m,n,k,
510 (cuComplex *) &alpha_p[0],
511 (cuComplex **)&Amk[0], lda,
512 (cuComplex **)&Bkn[0], ldb,
513 (cuComplex *) &beta_p[0],
514 (cuComplex **)&Cmn[0], ldc,
515 batchCount);
516 assert(err==CUBLAS_STATUS_SUCCESS);
517#endif
518#ifdef GRID_SYCL
519 int64_t m64=m;
520 int64_t n64=n;
521 int64_t k64=k;
522 int64_t lda64=lda;
523 int64_t ldb64=ldb;
524 int64_t ldc64=ldc;
525 int64_t batchCount64=batchCount;
526
527 oneapi::mkl::transpose iOpA;
528 oneapi::mkl::transpose iOpB;
529
530 if ( OpA == GridBLAS_OP_N ) iOpA = oneapi::mkl::transpose::N;
531 if ( OpA == GridBLAS_OP_T ) iOpA = oneapi::mkl::transpose::T;
532 if ( OpA == GridBLAS_OP_C ) iOpA = oneapi::mkl::transpose::C;
533 if ( OpB == GridBLAS_OP_N ) iOpB = oneapi::mkl::transpose::N;
534 if ( OpB == GridBLAS_OP_T ) iOpB = oneapi::mkl::transpose::T;
535 if ( OpB == GridBLAS_OP_C ) iOpB = oneapi::mkl::transpose::C;
536
537 oneapi::mkl::blas::column_major::gemm_batch(*gridblasHandle,
538 &iOpA,
539 &iOpB,
540 &m64,&n64,&k64,
541 (ComplexF *) &alpha_p[0],
542 (const ComplexF **)&Amk[0], (const int64_t *)&lda64,
543 (const ComplexF **)&Bkn[0], (const int64_t *)&ldb64,
544 (ComplexF *) &beta_p[0],
545 (ComplexF **)&Cmn[0], (const int64_t *)&ldc64,
546 (int64_t)1,&batchCount64,std::vector<sycl::event>());
547 synchronise();
548#endif
549#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
550 // Need a default/reference implementation; use Eigen
551 if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_N) ) {
552 thread_for (p, batchCount, {
553 Eigen::Map<Eigen::MatrixXcf> eAmk(Amk[p],m,k);
554 Eigen::Map<Eigen::MatrixXcf> eBkn(Bkn[p],k,n);
555 Eigen::Map<Eigen::MatrixXcf> eCmn(Cmn[p],m,n);
556 if (std::abs(beta) != 0.0)
557 eCmn = beta * eCmn + alpha * eAmk * eBkn ;
558 else
559 eCmn = alpha * eAmk * eBkn ;
560 });
561 } else if ( (OpA == GridBLAS_OP_C ) && (OpB == GridBLAS_OP_N) ) {
562 thread_for (p, batchCount, {
563 Eigen::Map<Eigen::MatrixXcf> eAmk(Amk[p],k,m);
564 Eigen::Map<Eigen::MatrixXcf> eBkn(Bkn[p],k,n);
565 Eigen::Map<Eigen::MatrixXcf> eCmn(Cmn[p],m,n);
566 if (std::abs(beta) != 0.0)
567 eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn ;
568 else
569 eCmn = alpha * eAmk.adjoint() * eBkn ;
570 });
571 } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_N) ) {
572 thread_for (p, batchCount, {
573 Eigen::Map<Eigen::MatrixXcf> eAmk(Amk[p],k,m);
574 Eigen::Map<Eigen::MatrixXcf> eBkn(Bkn[p],k,n);
575 Eigen::Map<Eigen::MatrixXcf> eCmn(Cmn[p],m,n);
576 if (std::abs(beta) != 0.0)
577 eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn ;
578 else
579 eCmn = alpha * eAmk.transpose() * eBkn ;
580 });
581 } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_C) ) {
582 thread_for (p, batchCount, {
583 Eigen::Map<Eigen::MatrixXcf> eAmk(Amk[p],m,k);
584 Eigen::Map<Eigen::MatrixXcf> eBkn(Bkn[p],n,k);
585 Eigen::Map<Eigen::MatrixXcf> eCmn(Cmn[p],m,n);
586 if (std::abs(beta) != 0.0)
587 eCmn = beta * eCmn + alpha * eAmk * eBkn.adjoint() ;
588 else
589 eCmn = alpha * eAmk * eBkn.adjoint() ;
590 });
591 } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_T) ) {
592 thread_for (p, batchCount, {
593 Eigen::Map<Eigen::MatrixXcf> eAmk(Amk[p],m,k);
594 Eigen::Map<Eigen::MatrixXcf> eBkn(Bkn[p],n,k);
595 Eigen::Map<Eigen::MatrixXcf> eCmn(Cmn[p],m,n);
596 if (std::abs(beta) != 0.0)
597 eCmn = beta * eCmn + alpha * eAmk * eBkn.transpose() ;
598 else
599 eCmn = alpha * eAmk * eBkn.transpose() ;
600 });
601 } else if ( (OpA == GridBLAS_OP_C ) && (OpB == GridBLAS_OP_C) ) {
602 thread_for (p, batchCount, {
603 Eigen::Map<Eigen::MatrixXcf> eAmk(Amk[p],k,m);
604 Eigen::Map<Eigen::MatrixXcf> eBkn(Bkn[p],n,k);
605 Eigen::Map<Eigen::MatrixXcf> eCmn(Cmn[p],m,n);
606 if (std::abs(beta) != 0.0)
607 eCmn = beta * eCmn + alpha * eAmk.adjoint() * eBkn.adjoint() ;
608 else
609 eCmn = alpha * eAmk.adjoint() * eBkn.adjoint() ;
610 } );
611 } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_T) ) {
612 thread_for (p, batchCount, {
613 Eigen::Map<Eigen::MatrixXcf> eAmk(Amk[p],k,m);
614 Eigen::Map<Eigen::MatrixXcf> eBkn(Bkn[p],n,k);
615 Eigen::Map<Eigen::MatrixXcf> eCmn(Cmn[p],m,n);
616 if (std::abs(beta) != 0.0)
617 eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn.transpose() ;
618 else
619 eCmn = alpha * eAmk.transpose() * eBkn.transpose() ;
620 } );
621 } else {
622 assert(0);
623 }
624#endif
625 RealD t1=usecond();
626 RealD flops = 8.0*m*n*k*batchCount;
627 RealD bytes = 1.0*sizeof(ComplexF)*(m*k+k*n+m*n)*batchCount;
628 }
629
631 // Single precision real GEMM
633
636 int m,int n, int k,
637 RealF alpha,
638 deviceVector<RealF*> &Amk, // pointer list to matrices
640 RealF beta,
642 {
643 RealD t2=usecond();
644 int32_t batchCount = Amk.size();
645
646 assert(OpA!=GridBLAS_OP_C); // Real case no conjugate
647 assert(OpB!=GridBLAS_OP_C);
648
649 int lda = m; // m x k column major
650 int ldb = k; // k x n column major
651 int ldc = m; // m x b column major
652 if(OpA!=GridBLAS_OP_N)
653 lda = k;
654 if(OpB!=GridBLAS_OP_N)
655 ldb = n;
656 static deviceVector<RealF> alpha_p(1);
657 static deviceVector<RealF> beta_p(1);
658 // can prestore the 1 and the zero on device
659 acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(RealF));
660 acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(RealF));
661 RealD t0=usecond();
662
663 assert(Bkn.size()==batchCount);
664 assert(Cmn.size()==batchCount);
665#ifdef GRID_HIP
666 hipblasOperation_t hOpA;
667 hipblasOperation_t hOpB;
668 if ( OpA == GridBLAS_OP_N ) hOpA = HIPBLAS_OP_N;
669 if ( OpA == GridBLAS_OP_T ) hOpA = HIPBLAS_OP_T;
670 if ( OpA == GridBLAS_OP_C ) hOpA = HIPBLAS_OP_C;
671 if ( OpB == GridBLAS_OP_N ) hOpB = HIPBLAS_OP_N;
672 if ( OpB == GridBLAS_OP_T ) hOpB = HIPBLAS_OP_T;
673 if ( OpB == GridBLAS_OP_C ) hOpB = HIPBLAS_OP_C;
674 auto err = hipblasSgemmBatched(gridblasHandle,
675 hOpA,
676 hOpB,
677 m,n,k,
678 (float *) &alpha_p[0],
679 (float **)&Amk[0], lda,
680 (float **)&Bkn[0], ldb,
681 (float *) &beta_p[0],
682 (float **)&Cmn[0], ldc,
683 batchCount);
684 assert(err==HIPBLAS_STATUS_SUCCESS);
685#endif
686#ifdef GRID_CUDA
687 cublasOperation_t hOpA;
688 cublasOperation_t hOpB;
689 if ( OpA == GridBLAS_OP_N ) hOpA = CUBLAS_OP_N;
690 if ( OpA == GridBLAS_OP_T ) hOpA = CUBLAS_OP_T;
691 if ( OpA == GridBLAS_OP_C ) hOpA = CUBLAS_OP_C;
692 if ( OpB == GridBLAS_OP_N ) hOpB = CUBLAS_OP_N;
693 if ( OpB == GridBLAS_OP_T ) hOpB = CUBLAS_OP_T;
694 if ( OpB == GridBLAS_OP_C ) hOpB = CUBLAS_OP_C;
695 auto err = cublasSgemmBatched(gridblasHandle,
696 hOpA,
697 hOpB,
698 m,n,k,
699 (float *) &alpha_p[0],
700 (float **)&Amk[0], lda,
701 (float **)&Bkn[0], ldb,
702 (float *) &beta_p[0],
703 (float **)&Cmn[0], ldc,
704 batchCount);
705 assert(err==CUBLAS_STATUS_SUCCESS);
706#endif
707#ifdef GRID_SYCL
708 int64_t m64=m;
709 int64_t n64=n;
710 int64_t k64=k;
711 int64_t lda64=lda;
712 int64_t ldb64=ldb;
713 int64_t ldc64=ldc;
714 int64_t batchCount64=batchCount;
715
716 oneapi::mkl::transpose iOpA;
717 oneapi::mkl::transpose iOpB;
718
719 if ( OpA == GridBLAS_OP_N ) iOpA = oneapi::mkl::transpose::N;
720 if ( OpA == GridBLAS_OP_T ) iOpA = oneapi::mkl::transpose::T;
721 if ( OpA == GridBLAS_OP_C ) iOpA = oneapi::mkl::transpose::C;
722 if ( OpB == GridBLAS_OP_N ) iOpB = oneapi::mkl::transpose::N;
723 if ( OpB == GridBLAS_OP_T ) iOpB = oneapi::mkl::transpose::T;
724 if ( OpB == GridBLAS_OP_C ) iOpB = oneapi::mkl::transpose::C;
725
726 oneapi::mkl::blas::column_major::gemm_batch(*gridblasHandle,
727 &iOpA,
728 &iOpB,
729 &m64,&n64,&k64,
730 (float *) &alpha_p[0],
731 (const float **)&Amk[0], (const int64_t *)&lda64,
732 (const float **)&Bkn[0], (const int64_t *)&ldb64,
733 (float *) &beta_p[0],
734 (float **)&Cmn[0], (const int64_t *)&ldc64,
735 (int64_t)1,&batchCount64,std::vector<sycl::event>());
736 synchronise();
737#endif
738#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
739 // Need a default/reference implementation; use Eigen
740 if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_N) ) {
741 thread_for (p, batchCount, {
742 Eigen::Map<Eigen::MatrixXf> eAmk(Amk[p],m,k);
743 Eigen::Map<Eigen::MatrixXf> eBkn(Bkn[p],k,n);
744 Eigen::Map<Eigen::MatrixXf> eCmn(Cmn[p],m,n);
745 if (std::abs(beta) != 0.0)
746 eCmn = beta * eCmn + alpha * eAmk * eBkn ;
747 else
748 eCmn = alpha * eAmk * eBkn ;
749 });
750 } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_N) ) {
751 thread_for (p, batchCount, {
752 Eigen::Map<Eigen::MatrixXf> eAmk(Amk[p],k,m);
753 Eigen::Map<Eigen::MatrixXf> eBkn(Bkn[p],k,n);
754 Eigen::Map<Eigen::MatrixXf> eCmn(Cmn[p],m,n);
755 if (std::abs(beta) != 0.0)
756 eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn ;
757 else
758 eCmn = alpha * eAmk.transpose() * eBkn ;
759 });
760 } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_T) ) {
761 thread_for (p, batchCount, {
762 Eigen::Map<Eigen::MatrixXf> eAmk(Amk[p],m,k);
763 Eigen::Map<Eigen::MatrixXf> eBkn(Bkn[p],n,k);
764 Eigen::Map<Eigen::MatrixXf> eCmn(Cmn[p],m,n);
765 if (std::abs(beta) != 0.0)
766 eCmn = beta * eCmn + alpha * eAmk * eBkn.transpose() ;
767 else
768 eCmn = alpha * eAmk * eBkn.transpose() ;
769 });
770 } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_T) ) {
771 thread_for (p, batchCount, {
772 Eigen::Map<Eigen::MatrixXf> eAmk(Amk[p],k,m);
773 Eigen::Map<Eigen::MatrixXf> eBkn(Bkn[p],n,k);
774 Eigen::Map<Eigen::MatrixXf> eCmn(Cmn[p],m,n);
775 if (std::abs(beta) != 0.0)
776 eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn.transpose() ;
777 else
778 eCmn = alpha * eAmk.transpose() * eBkn.transpose() ;
779 });
780 } else {
781 assert(0);
782 }
783#endif
784 RealD t1=usecond();
785 RealD flops = 2.0*m*n*k*batchCount;
786 RealD bytes = 1.0*sizeof(RealF)*(m*k+k*n+m*n)*batchCount;
787 }
788
789
791 // Double precision real GEMM
795 int m,int n, int k,
796 RealD alpha,
797 deviceVector<RealD*> &Amk, // pointer list to matrices
799 RealD beta,
801 {
802 RealD t2=usecond();
803 int32_t batchCount = Amk.size();
804
805 assert(OpA!=GridBLAS_OP_C); // Real case no conjugate
806 assert(OpB!=GridBLAS_OP_C);
807
808 int lda = m; // m x k column major
809 int ldb = k; // k x n column major
810 int ldc = m; // m x b column major
811 if(OpA!=GridBLAS_OP_N)
812 lda = k;
813 if(OpB!=GridBLAS_OP_N)
814 ldb = n;
815
816 static deviceVector<RealD> alpha_p(1);
817 static deviceVector<RealD> beta_p(1);
818 // can prestore the 1 and the zero on device
819 acceleratorCopyToDevice((void *)&alpha,(void *)&alpha_p[0],sizeof(RealD));
820 acceleratorCopyToDevice((void *)&beta ,(void *)&beta_p[0],sizeof(RealD));
821 RealD t0=usecond();
822
823 assert(Bkn.size()==batchCount);
824 assert(Cmn.size()==batchCount);
825#ifdef GRID_HIP
826 hipblasOperation_t hOpA;
827 hipblasOperation_t hOpB;
828 if ( OpA == GridBLAS_OP_N ) hOpA = HIPBLAS_OP_N;
829 if ( OpA == GridBLAS_OP_T ) hOpA = HIPBLAS_OP_T;
830 if ( OpA == GridBLAS_OP_C ) hOpA = HIPBLAS_OP_C;
831 if ( OpB == GridBLAS_OP_N ) hOpB = HIPBLAS_OP_N;
832 if ( OpB == GridBLAS_OP_T ) hOpB = HIPBLAS_OP_T;
833 if ( OpB == GridBLAS_OP_C ) hOpB = HIPBLAS_OP_C;
834 auto err = hipblasDgemmBatched(gridblasHandle,
835 HIPBLAS_OP_N,
836 HIPBLAS_OP_N,
837 m,n,k,
838 (double *) &alpha_p[0],
839 (double **)&Amk[0], lda,
840 (double **)&Bkn[0], ldb,
841 (double *) &beta_p[0],
842 (double **)&Cmn[0], ldc,
843 batchCount);
844 assert(err==HIPBLAS_STATUS_SUCCESS);
845#endif
846#ifdef GRID_CUDA
847 cublasOperation_t hOpA;
848 cublasOperation_t hOpB;
849 if ( OpA == GridBLAS_OP_N ) hOpA = CUBLAS_OP_N;
850 if ( OpA == GridBLAS_OP_T ) hOpA = CUBLAS_OP_T;
851 if ( OpA == GridBLAS_OP_C ) hOpA = CUBLAS_OP_C;
852 if ( OpB == GridBLAS_OP_N ) hOpB = CUBLAS_OP_N;
853 if ( OpB == GridBLAS_OP_T ) hOpB = CUBLAS_OP_T;
854 if ( OpB == GridBLAS_OP_C ) hOpB = CUBLAS_OP_C;
855 auto err = cublasDgemmBatched(gridblasHandle,
856 hOpA,
857 hOpB,
858 m,n,k,
859 (double *) &alpha_p[0],
860 (double **)&Amk[0], lda,
861 (double **)&Bkn[0], ldb,
862 (double *) &beta_p[0],
863 (double **)&Cmn[0], ldc,
864 batchCount);
865 assert(err==CUBLAS_STATUS_SUCCESS);
866#endif
867#ifdef GRID_SYCL
868 int64_t m64=m;
869 int64_t n64=n;
870 int64_t k64=k;
871 int64_t lda64=lda;
872 int64_t ldb64=ldb;
873 int64_t ldc64=ldc;
874 int64_t batchCount64=batchCount;
875
876 oneapi::mkl::transpose iOpA;
877 oneapi::mkl::transpose iOpB;
878
879 if ( OpA == GridBLAS_OP_N ) iOpA = oneapi::mkl::transpose::N;
880 if ( OpA == GridBLAS_OP_T ) iOpA = oneapi::mkl::transpose::T;
881 if ( OpA == GridBLAS_OP_C ) iOpA = oneapi::mkl::transpose::C;
882 if ( OpB == GridBLAS_OP_N ) iOpB = oneapi::mkl::transpose::N;
883 if ( OpB == GridBLAS_OP_T ) iOpB = oneapi::mkl::transpose::T;
884 if ( OpB == GridBLAS_OP_C ) iOpB = oneapi::mkl::transpose::C;
885
886 oneapi::mkl::blas::column_major::gemm_batch(*gridblasHandle,
887 &iOpA,
888 &iOpB,
889 &m64,&n64,&k64,
890 (double *) &alpha_p[0],
891 (const double **)&Amk[0], (const int64_t *)&lda64,
892 (const double **)&Bkn[0], (const int64_t *)&ldb64,
893 (double *) &beta_p[0],
894 (double **)&Cmn[0], (const int64_t *)&ldc64,
895 (int64_t)1,&batchCount64,std::vector<sycl::event>());
896 synchronise();
897#endif
898#if !defined(GRID_SYCL) && !defined(GRID_CUDA) && !defined(GRID_HIP)
899 // Need a default/reference implementation; use Eigen
900 if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_N) ) {
901 thread_for (p, batchCount, {
902 Eigen::Map<Eigen::MatrixXd> eAmk(Amk[p],m,k);
903 Eigen::Map<Eigen::MatrixXd> eBkn(Bkn[p],k,n);
904 Eigen::Map<Eigen::MatrixXd> eCmn(Cmn[p],m,n);
905 if (std::abs(beta) != 0.0)
906 eCmn = beta * eCmn + alpha * eAmk * eBkn ;
907 else
908 eCmn = alpha * eAmk * eBkn ;
909 });
910 } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_N) ) {
911 thread_for (p, batchCount, {
912 Eigen::Map<Eigen::MatrixXd> eAmk(Amk[p],k,m);
913 Eigen::Map<Eigen::MatrixXd> eBkn(Bkn[p],k,n);
914 Eigen::Map<Eigen::MatrixXd> eCmn(Cmn[p],m,n);
915 if (std::abs(beta) != 0.0)
916 eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn ;
917 else
918 eCmn = alpha * eAmk.transpose() * eBkn ;
919 });
920 } else if ( (OpA == GridBLAS_OP_N ) && (OpB == GridBLAS_OP_T) ) {
921 thread_for (p, batchCount, {
922 Eigen::Map<Eigen::MatrixXd> eAmk(Amk[p],m,k);
923 Eigen::Map<Eigen::MatrixXd> eBkn(Bkn[p],n,k);
924 Eigen::Map<Eigen::MatrixXd> eCmn(Cmn[p],m,n);
925 if (std::abs(beta) != 0.0)
926 eCmn = beta * eCmn + alpha * eAmk * eBkn.transpose() ;
927 else
928 eCmn = alpha * eAmk * eBkn.transpose() ;
929 });
930 } else if ( (OpA == GridBLAS_OP_T ) && (OpB == GridBLAS_OP_T) ) {
931 thread_for (p, batchCount, {
932 Eigen::Map<Eigen::MatrixXd> eAmk(Amk[p],k,m);
933 Eigen::Map<Eigen::MatrixXd> eBkn(Bkn[p],n,k);
934 Eigen::Map<Eigen::MatrixXd> eCmn(Cmn[p],m,n);
935 if (std::abs(beta) != 0.0)
936 eCmn = beta * eCmn + alpha * eAmk.transpose() * eBkn.transpose() ;
937 else
938 eCmn = alpha * eAmk.transpose() * eBkn.transpose() ;
939 });
940 } else {
941 assert(0);
942 }
943#endif
944 RealD t1=usecond();
945 RealD flops = 2.0*m*n*k*batchCount;
946 RealD bytes = 1.0*sizeof(RealD)*(m*k+k*n+m*n)*batchCount;
947 }
948
949 template<class CComplex>
950 double benchmark(int M, int N, int K, int BATCH)
951 {
952 int32_t N_A = M*K*BATCH;
953 int32_t N_B = K*N*BATCH;
954 int32_t N_C = M*N*BATCH;
955 deviceVector<CComplex> A(N_A); acceleratorMemSet(&A[0],0,N_A*sizeof(CComplex));
956 deviceVector<CComplex> B(N_B); acceleratorMemSet(&B[0],0,N_B*sizeof(CComplex));
957 deviceVector<CComplex> C(N_C); acceleratorMemSet(&C[0],0,N_C*sizeof(CComplex));
958 CComplex alpha(1.0);
959 CComplex beta (1.0);
960 RealD flops = 8.0*M*N*K*BATCH;
961 int ncall=1000;
962 deviceVector<CComplex *> As(BATCH);
963 deviceVector<CComplex *> Bs(BATCH);
964 deviceVector<CComplex *> Cs(BATCH);
965 for(int b = 0 ; b < BATCH;b++) {
966 CComplex *ptr;
967 ptr = &A[b*M*K]; acceleratorPut(As[b],ptr);
968 ptr = &B[b*K*N]; acceleratorPut(Bs[b],ptr);
969 ptr = &C[b*M*N]; acceleratorPut(Cs[b],ptr);
970 }
971
972 // Warm up call
973 gemmBatched(M,N,K,
974 alpha,
975 As, // m x k
976 Bs, // k x n
977 beta,
978 Cs);
979 synchronise();
980
981 RealD t0 = usecond();
982 for(int i=0;i<ncall;i++){
983 gemmBatched(M,N,K,
984 alpha,
985 As, // m x k
986 Bs, // k x n
987 beta,
988 Cs);
989 synchronise();
990 }
991 RealD t1 = usecond();
992 RealD bytes = 1.0*sizeof(CComplex)*(M*N*2+N*K+M*K)*BATCH;
993 flops = 8.0*M*N*K*BATCH*ncall;
994 flops = flops/(t1-t0)/1.e3;
995 return flops; // Returns gigaflops
996 }
997
998};
999
void acceleratorPut(T &dev, const T &host)
void acceleratorCopyToDevice(void *from, void *to, size_t bytes)
void acceleratorMemSet(void *base, int value, size_t bytes)
void acceleratorCopyFromDevice(void *from, void *to, size_t bytes)
#define accelerator_barrier(dummy)
std::vector< T, devAllocator< T > > deviceVector
int32_t gridblasHandle_t
Definition BatchedBlas.h:64
GridBLASOperation_t
Definition BatchedBlas.h:67
@ GridBLAS_OP_T
Definition BatchedBlas.h:67
@ GridBLAS_OP_N
Definition BatchedBlas.h:67
@ GridBLAS_OP_C
Definition BatchedBlas.h:67
B
Lattice< vobj > conjugate(const Lattice< vobj > &lhs)
#define NAMESPACE_BEGIN(A)
Definition Namespace.h:35
#define NAMESPACE_END(A)
Definition Namespace.h:36
std::complex< RealF > ComplexF
Definition Simd.h:78
float RealF
Definition Simd.h:60
std::complex< RealD > ComplexD
Definition Simd.h:79
double RealD
Definition Simd.h:61
#define thread_for(i, num,...)
Definition Threads.h:60
double usecond(void)
Definition Timer.h:50
static INTERNAL_PRECISION K
Definition Zolotarev.cc:230
void gemmBatched(GridBLASOperation_t OpA, GridBLASOperation_t OpB, int m, int n, int k, ComplexD alpha, deviceVector< ComplexD * > &Amk, deviceVector< ComplexD * > &Bkn, ComplexD beta, deviceVector< ComplexD * > &Cmn)
void gemmBatched(int m, int n, int k, RealD alpha, deviceVector< RealD * > &Amk, deviceVector< RealD * > &Bkn, RealD beta, deviceVector< RealD * > &Cmn)
void gemmBatched(GridBLASOperation_t OpA, GridBLASOperation_t OpB, int m, int n, int k, RealF alpha, deviceVector< RealF * > &Amk, deviceVector< RealF * > &Bkn, RealF beta, deviceVector< RealF * > &Cmn)
void synchronise(void)
static gridblasHandle_t gridblasHandle
Definition BatchedBlas.h:73
void gemmBatched(GridBLASOperation_t OpA, GridBLASOperation_t OpB, int m, int n, int k, RealD alpha, deviceVector< RealD * > &Amk, deviceVector< RealD * > &Bkn, RealD beta, deviceVector< RealD * > &Cmn)
static int gridblasInit
Definition BatchedBlas.h:74
void gemmBatched(int m, int n, int k, ComplexF alpha, deviceVector< ComplexF * > &Amk, deviceVector< ComplexF * > &Bkn, ComplexF beta, deviceVector< ComplexF * > &Cmn)
void gemmBatched(int m, int n, int k, ComplexD alpha, deviceVector< ComplexD * > &Amk, deviceVector< ComplexD * > &Bkn, ComplexD beta, deviceVector< ComplexD * > &Cmn)
void gemmBatched(int m, int n, int k, RealF alpha, deviceVector< RealF * > &Amk, deviceVector< RealF * > &Bkn, RealF beta, deviceVector< RealF * > &Cmn)
double benchmark(int M, int N, int K, int BATCH)
void gemmBatched(GridBLASOperation_t OpA, GridBLASOperation_t OpB, int m, int n, int k, ComplexF alpha, deviceVector< ComplexF * > &Amk, deviceVector< ComplexF * > &Bkn, ComplexF beta, deviceVector< ComplexF * > &Cmn)
static void Init(void)
Definition BatchedBlas.h:76