mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 03:39:01 +08:00
add support for complex
This commit is contained in:
parent
a2415388ef
commit
0f2d480af0
@ -10,7 +10,8 @@ using namespace std;
|
||||
using namespace Eigen;
|
||||
|
||||
#ifndef SCALAR
|
||||
#define SCALAR std::complex<float>
|
||||
#define SCALAR std::complex<double>
|
||||
// #define SCALAR double
|
||||
#endif
|
||||
|
||||
typedef SCALAR Scalar;
|
||||
@ -28,6 +29,8 @@ static double done = 1;
|
||||
static double szero = 0;
|
||||
static std::complex<float> cfone = 1;
|
||||
static std::complex<float> cfzero = 0;
|
||||
static std::complex<double> cdone = 1;
|
||||
static std::complex<double> cdzero = 0;
|
||||
static char notrans = 'N';
|
||||
static char trans = 'T';
|
||||
static char nonunit = 'N';
|
||||
@ -57,6 +60,17 @@ void blas_gemm(const MatrixXcf& a, const MatrixXcf& b, MatrixXcf& c)
|
||||
(float*)c.data(),&ldc);
|
||||
}
|
||||
|
||||
void blas_gemm(const MatrixXcd& a, const MatrixXcd& b, MatrixXcd& c)
|
||||
{
|
||||
int M = c.rows(); int N = c.cols(); int K = a.cols();
|
||||
int lda = a.rows(); int ldb = b.rows(); int ldc = c.rows();
|
||||
|
||||
zgemm_(¬rans,¬rans,&M,&N,&K,(double*)&cdone,
|
||||
const_cast<double*>((const double*)a.data()),&lda,
|
||||
const_cast<double*>((const double*)b.data()),&ldb,(double*)&cdone,
|
||||
(double*)c.data(),&ldc);
|
||||
}
|
||||
|
||||
void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c)
|
||||
{
|
||||
int M = c.rows(); int N = c.cols(); int K = a.cols();
|
||||
@ -71,7 +85,7 @@ void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c)
|
||||
#endif
|
||||
|
||||
template<typename M>
|
||||
void gemm(const M& a, const M& b, M& c)
|
||||
EIGEN_DONT_INLINE void gemm(const M& a, const M& b, M& c)
|
||||
{
|
||||
c.noalias() += a * b;
|
||||
}
|
||||
@ -80,8 +94,10 @@ int main(int argc, char ** argv)
|
||||
{
|
||||
std::ptrdiff_t l1 = ei_queryL1CacheSize();
|
||||
std::ptrdiff_t l2 = ei_queryTopLevelCacheSize();
|
||||
std::cout << "L1 cache size = " << (l1>0 ? l1/1024 : -1) << " KB\n";
|
||||
std::cout << "L2/L3 cache size = " << (l2>0 ? l2/1024 : -1) << " KB\n";
|
||||
std::cout << "L1 cache size = " << (l1>0 ? l1/1024 : -1) << " KB\n";
|
||||
std::cout << "L2/L3 cache size = " << (l2>0 ? l2/1024 : -1) << " KB\n";
|
||||
typedef ei_product_blocking_traits<Scalar> Blocking;
|
||||
std::cout << "Register blocking = " << Blocking::mr << " x " << Blocking::nr << "\n";
|
||||
|
||||
int rep = 1; // number of repetitions per try
|
||||
int tries = 2; // number of tries, we keep the best
|
||||
|
Loading…
x
Reference in New Issue
Block a user