Implement SDSDOT with DSDOT and avoid allocating buffers in DSDOT.

This commit is contained in:
Chen-Pang He 2012-09-08 02:06:45 +08:00
parent b0b9b4d6b2
commit 1b61aadcbe
2 changed files with 10 additions and 32 deletions

View File

@ -19,25 +19,15 @@
#include "level2_real_impl.h" #include "level2_real_impl.h"
#include "level3_impl.h" #include "level3_impl.h"
// currently used by DSDOT only double BLASFUNC(dsdot)(int* n, float* x, int* incx, float* y, int* incy)
double* cast_vector_to_double(float* x, int n, int incx)
{ {
double* ret = new double[n]; if(*n<=0) return 0;
if(incx<0) vector(ret,n) = vector(x,n,-incx).reverse().cast<double>();
else vector(ret,n) = vector(x,n, incx).cast<double>(); if(*incx==1 && *incy==1) return (vector(x,*n).cast<double>().cwiseProduct(vector(y,*n).cast<double>())).sum();
return ret; else if(*incx>0 && *incy>0) return (vector(x,*n,*incx).cast<double>().cwiseProduct(vector(y,*n,*incy).cast<double>())).sum();
} else if(*incx<0 && *incy>0) return (vector(x,*n,-*incx).reverse().cast<double>().cwiseProduct(vector(y,*n,*incy).cast<double>())).sum();
else if(*incx>0 && *incy<0) return (vector(x,*n,*incx).cast<double>().cwiseProduct(vector(y,*n,-*incy).reverse().cast<double>())).sum();
double BLASFUNC(dsdot)(int* n, float* px, int* incx, float* py, int* incy) else if(*incx<0 && *incy<0) return (vector(x,*n,-*incx).reverse().cast<double>().cwiseProduct(vector(y,*n,-*incy).reverse().cast<double>())).sum();
{ else return 0;
if(*n <= 0) return 0;
double* x = cast_vector_to_double(px, *n, *incx);
double* y = cast_vector_to_double(py, *n, *incy);
double res = vector(x,*n).cwiseProduct(vector(y,*n)).sum();
delete[] x;
delete[] y;
return res;
} }

View File

@ -2,7 +2,6 @@
// for linear algebra. // for linear algebra.
// //
// Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr> // Copyright (C) 2009 Gael Guennebaud <gael.guennebaud@inria.fr>
// Copyright (C) 2012 Chen-Pang He <jdh8@ms63.hinet.net>
// //
// This Source Code Form is subject to the terms of the Mozilla // This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed // Public License v. 2.0. If a copy of the MPL was not distributed
@ -20,15 +19,4 @@
#include "level3_impl.h" #include "level3_impl.h"
float BLASFUNC(sdsdot)(int* n, float* alpha, float* x, int* incx, float* y, int* incy) float BLASFUNC(sdsdot)(int* n, float* alpha, float* x, int* incx, float* y, int* incy)
{ { return *alpha + BLASFUNC(dsdot)(n, x, incx, y, incy); }
float res = *alpha;
if(*n>0) {
if(*incx==1 && *incy==1) res += (vector(x,*n).cwiseProduct(vector(y,*n))).sum();
else if(*incx>0 && *incy>0) res += (vector(x,*n,*incx).cwiseProduct(vector(y,*n,*incy))).sum();
else if(*incx<0 && *incy>0) res += (vector(x,*n,-*incx).reverse().cwiseProduct(vector(y,*n,*incy))).sum();
else if(*incx>0 && *incy<0) res += (vector(x,*n,*incx).cwiseProduct(vector(y,*n,-*incy).reverse())).sum();
else if(*incx<0 && *incy<0) res += (vector(x,*n,-*incx).reverse().cwiseProduct(vector(y,*n,-*incy).reverse())).sum();
}
return res;
}