add support for complex

This commit is contained in:
Gael Guennebaud 2010-07-07 16:41:29 +02:00
parent a2415388ef
commit 0f2d480af0

View File

@ -10,7 +10,8 @@ using namespace std;
using namespace Eigen;
#ifndef SCALAR
#define SCALAR std::complex<float>
#define SCALAR std::complex<double>
// #define SCALAR double
#endif
typedef SCALAR Scalar;
@ -28,6 +29,8 @@ static double done = 1;
static double szero = 0;
static std::complex<float> cfone = 1;
static std::complex<float> cfzero = 0;
static std::complex<double> cdone = 1;
static std::complex<double> cdzero = 0;
static char notrans = 'N';
static char trans = 'T';
static char nonunit = 'N';
@ -57,6 +60,17 @@ void blas_gemm(const MatrixXcf& a, const MatrixXcf& b, MatrixXcf& c)
(float*)c.data(),&ldc);
}
void blas_gemm(const MatrixXcd& a, const MatrixXcd& b, MatrixXcd& c)
{
int M = c.rows(); int N = c.cols(); int K = a.cols();
int lda = a.rows(); int ldb = b.rows(); int ldc = c.rows();
zgemm_(&notrans,&notrans,&M,&N,&K,(double*)&cdone,
const_cast<double*>((const double*)a.data()),&lda,
const_cast<double*>((const double*)b.data()),&ldb,(double*)&cdone,
(double*)c.data(),&ldc);
}
void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c)
{
int M = c.rows(); int N = c.cols(); int K = a.cols();
@ -71,7 +85,7 @@ void blas_gemm(const MatrixXd& a, const MatrixXd& b, MatrixXd& c)
#endif
template<typename M>
void gemm(const M& a, const M& b, M& c)
EIGEN_DONT_INLINE void gemm(const M& a, const M& b, M& c)
{
c.noalias() += a * b;
}
@ -80,8 +94,10 @@ int main(int argc, char ** argv)
{
std::ptrdiff_t l1 = ei_queryL1CacheSize();
std::ptrdiff_t l2 = ei_queryTopLevelCacheSize();
std::cout << "L1 cache size = " << (l1>0 ? l1/1024 : -1) << " KB\n";
std::cout << "L2/L3 cache size = " << (l2>0 ? l2/1024 : -1) << " KB\n";
std::cout << "L1 cache size = " << (l1>0 ? l1/1024 : -1) << " KB\n";
std::cout << "L2/L3 cache size = " << (l2>0 ? l2/1024 : -1) << " KB\n";
typedef ei_product_blocking_traits<Scalar> Blocking;
std::cout << "Register blocking = " << Blocking::mr << " x " << Blocking::nr << "\n";
int rep = 1; // number of repetitions per try
int tries = 2; // number of tries, we keep the best