20#ifndef VOTCA_XTP_CUDAPIPELINE_H
21#define VOTCA_XTP_CUDAPIPELINE_H
26#error Cuda not enabled
68 template <
class M1,
class M2,
class M3>
69 void gemm(M1 &&A, M2 &&B, M3 &&C,
double beta = 0.0)
const;
89template <
class M1,
class M2,
class M3>
92 using m1 = std::decay_t<M1>;
93 using m2 = std::decay_t<M2>;
94 using m3 = std::decay_t<M3>;
95 static_assert(!m3::transposed(),
"C in gemm cannot be transposed atm");
98 const double *palpha = α
99 const double *pbeta = β
100 cublasOperation_t transA = CUBLAS_OP_N;
101 int k = int(A.cols());
102 if (m1::transposed()) {
103 transA = CUBLAS_OP_T;
106 cublasOperation_t transB = CUBLAS_OP_N;
107 int k2 = int(B.rows());
108 if (m2::transposed()) {
109 transB = CUBLAS_OP_T;
114 throw std::runtime_error(
115 "Shape mismatch in cuda gemm " + std::to_string(k) +
":" +
121 cublasStatus_t status =
122 cublasDgemm(
handle_, transA, transB,
int(C.rows()),
int(C.cols()), k,
123 palpha, A.data(),
int(A.ld()), B.data(),
int(B.ld()), pbeta,
124 C.data(),
int(C.ld()));
125 if (status != CUBLAS_STATUS_SUCCESS) {
126 throw std::runtime_error(
"dgemm failed on gpu " +
136 if (b.
cols() != 1 && b.
rows() != 1) {
137 throw std::runtime_error(
"B Matrix in Cublas diag_gemm must be a vector");
140 cublasSideMode_t mode = CUBLAS_SIDE_RIGHT;
141 Index Adim = A.cols();
142 if (M::transposed()) {
143 mode = CUBLAS_SIDE_LEFT;
147 if (Adim != b.
size()) {
148 throw std::runtime_error(
"Shape mismatch in cuda diag_gemm: A" +
153 cublasStatus_t status =
154 cublasDdgmm(
handle_, mode,
int(A.rows()),
int(A.cols()), A.data(),
155 int(A.ld()), b.
data(), 1, C.
data(),
int(C.
ld()));
157 if (status != CUBLAS_STATUS_SUCCESS) {
158 throw std::runtime_error(
"diag_gemm failed on gpu " +
CudaPipeline(int deviceID)
CudaPipeline & operator=(const CudaPipeline &)=delete
void gemm(M1 &&A, M2 &&B, M3 &&C, double beta=0.0) const
void axpy(const CudaMatrix &A, CudaMatrix &B, double alpha=1.0) const
const cudaStream_t & get_stream() const
void diag_gemm(const M &A, const CudaMatrix &b, CudaMatrix &C) const
CudaPipeline(const CudaPipeline &)=delete
std::string cudaGetErrorEnum(cublasStatus_t error)
std::string OutputDimension(const M &mat)
void checkCuda(cudaError_t result)
base class for all analysis tools