improve SYMV it is now faster and ready for use

This commit is contained in:
Gael Guennebaud 2009-07-23 14:20:45 +02:00
parent eee14846e3
commit 713c92140c
3 changed files with 126 additions and 63 deletions

View File

@ -184,34 +184,67 @@ struct ei_triangular_assignment_selector<Derived1, Derived2, SelfAdjoint, Dynami
* Wrapper to ei_product_selfadjoint_vector * Wrapper to ei_product_selfadjoint_vector
***************************************************************************/ ***************************************************************************/
template<typename Lhs, int LhsMode, typename Rhs, int RhsMode> template<typename Lhs, int LhsMode, typename Rhs>
struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true> struct ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,0,true>
: public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,RhsMode,true>, : public ReturnByValue<ei_selfadjoint_product_returntype<Lhs,LhsMode,false,Rhs,0,true>,
Matrix<typename ei_traits<Rhs>::Scalar, Matrix<typename ei_traits<Rhs>::Scalar,
Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> > Rhs::RowsAtCompileTime,Rhs::ColsAtCompileTime> >
{ {
typedef typename ei_cleantype<typename Rhs::Nested>::type RhsNested; typedef typename Lhs::Scalar Scalar;
typedef typename Lhs::Nested LhsNested;
typedef typename ei_cleantype<LhsNested>::type _LhsNested;
typedef ei_blas_traits<_LhsNested> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename ei_cleantype<ActualLhsType>::type _ActualLhsType;
typedef typename Rhs::Nested RhsNested;
typedef typename ei_cleantype<RhsNested>::type _RhsNested;
typedef ei_blas_traits<_RhsNested> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename ei_cleantype<ActualRhsType>::type _ActualRhsType;
enum {
LhsUpLo = LhsMode&(UpperTriangularBit|LowerTriangularBit)
};
ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs) ei_selfadjoint_product_returntype(const Lhs& lhs, const Rhs& rhs)
: m_lhs(lhs), m_rhs(rhs) : m_lhs(lhs), m_rhs(rhs)
{} {}
template<typename Dest> inline void _addTo(Dest& dst) const
{ evalTo(dst,1); }
template<typename Dest> inline void _subTo(Dest& dst) const
{ evalTo(dst,-1); }
template<typename Dest> void evalTo(Dest& dst) const template<typename Dest> void evalTo(Dest& dst) const
{ {
dst.resize(m_rhs.rows(), m_rhs.cols()); dst.resize(m_lhs.rows(), m_rhs.cols());
ei_product_selfadjoint_vector<typename Lhs::Scalar,ei_traits<Lhs>::Flags&RowMajorBit, dst.setZero();
LhsMode&(UpperTriangularBit|LowerTriangularBit)> evalTo(dst,1);
}
template<typename Dest> void evalTo(Dest& dst, Scalar alpha) const
{
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs);
ei_assert((&dst.coeff(1))-(&dst.coeff(0))==1 && "not implemented yet");
ei_product_selfadjoint_vector<Scalar, ei_traits<_ActualLhsType>::Flags&RowMajorBit, int(LhsUpLo), bool(LhsBlasTraits::NeedToConjugate), bool(RhsBlasTraits::NeedToConjugate)>
( (
m_lhs.rows(), // size lhs.rows(), // size
m_lhs.data(), // lhs &lhs.coeff(0,0), lhs.stride(), // lhs info
m_lhs.stride(), // lhsStride, &rhs.coeff(0), (&rhs.coeff(1))-(&rhs.coeff(0)), // rhs info
m_rhs.data(), // rhs &dst.coeffRef(0), // result info
// int rhsIncr, actualAlpha // scale factor
dst.data() // res
); );
} }
const typename Lhs::Nested m_lhs; const LhsNested m_lhs;
const typename Rhs::Nested m_rhs; const RhsNested m_rhs;
}; };
/*************************************************************************** /***************************************************************************

View File

@ -30,12 +30,12 @@
* the number of load/stores of the result by a factor 2 and to reduce * the number of load/stores of the result by a factor 2 and to reduce
* the instruction dependency. * the instruction dependency.
*/ */
template<typename Scalar, int StorageOrder, int UpLo> template<typename Scalar, int StorageOrder, int UpLo, bool ConjugateLhs, bool ConjugateRhs>
static EIGEN_DONT_INLINE void ei_product_selfadjoint_vector( static EIGEN_DONT_INLINE void ei_product_selfadjoint_vector(
int size, int size,
const Scalar* lhs, int lhsStride, const Scalar* lhs, int lhsStride,
const Scalar* rhs, //int rhsIncr, const Scalar* _rhs, int rhsIncr,
Scalar* res) Scalar* res, Scalar alpha)
{ {
typedef typename ei_packet_traits<Scalar>::type Packet; typedef typename ei_packet_traits<Scalar>::type Packet;
const int PacketSize = sizeof(Packet)/sizeof(Scalar); const int PacketSize = sizeof(Packet)/sizeof(Scalar);
@ -46,8 +46,22 @@ static EIGEN_DONT_INLINE void ei_product_selfadjoint_vector(
FirstTriangular = IsRowMajor == IsLower FirstTriangular = IsRowMajor == IsLower
}; };
ei_conj_if<NumTraits<Scalar>::IsComplex && IsRowMajor> conj0; ei_conj_helper<NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(ConjugateLhs, IsRowMajor), ConjugateRhs> cj0;
ei_conj_if<NumTraits<Scalar>::IsComplex && !IsRowMajor> conj1; ei_conj_helper<NumTraits<Scalar>::IsComplex && EIGEN_LOGICAL_XOR(ConjugateLhs, !IsRowMajor), ConjugateRhs> cj1;
Scalar cjAlpha = ConjugateRhs ? ei_conj(alpha) : alpha;
// if the rhs is not sequentially stored in memory we copy it to a temporary buffer,
// this is because we need to extract packets
const Scalar* EIGEN_RESTRICT rhs = _rhs;
if (rhsIncr!=1)
{
Scalar* r = ei_aligned_stack_new(Scalar, size);
const Scalar* it = _rhs;
for (int i=0; i<size; ++i, it+=rhsIncr)
r[i] = *it;
rhs = r;
}
for (int i=0;i<size;i++) for (int i=0;i<size;i++)
res[i] = 0; res[i] = 0;
@ -62,9 +76,9 @@ static EIGEN_DONT_INLINE void ei_product_selfadjoint_vector(
register const Scalar* EIGEN_RESTRICT A0 = lhs + j*lhsStride; register const Scalar* EIGEN_RESTRICT A0 = lhs + j*lhsStride;
register const Scalar* EIGEN_RESTRICT A1 = lhs + (j+1)*lhsStride; register const Scalar* EIGEN_RESTRICT A1 = lhs + (j+1)*lhsStride;
Scalar t0 = rhs[j]; Scalar t0 = cjAlpha * rhs[j];
Packet ptmp0 = ei_pset1(t0); Packet ptmp0 = ei_pset1(t0);
Scalar t1 = rhs[j+1]; Scalar t1 = cjAlpha * rhs[j+1];
Packet ptmp1 = ei_pset1(t1); Packet ptmp1 = ei_pset1(t1);
Scalar t2 = 0; Scalar t2 = 0;
@ -78,17 +92,17 @@ static EIGEN_DONT_INLINE void ei_product_selfadjoint_vector(
size_t alignedStart = (starti) + ei_alignmentOffset(&res[starti], endi-starti); size_t alignedStart = (starti) + ei_alignmentOffset(&res[starti], endi-starti);
alignedEnd = alignedStart + ((endi-alignedStart)/(PacketSize))*(PacketSize); alignedEnd = alignedStart + ((endi-alignedStart)/(PacketSize))*(PacketSize);
res[j] += t0 * conj0(A0[j]); res[j] += cj0.pmul(A0[j], t0);
if(FirstTriangular) if(FirstTriangular)
{ {
res[j+1] += t1 * conj0(A1[j+1]); res[j+1] += cj0.pmul(A1[j+1], t1);
res[j] += t1 * conj0(A1[j]); res[j] += cj0.pmul(A1[j], t1);
t3 += conj1(A1[j]) * rhs[j]; t3 += cj1.pmul(A1[j], rhs[j]);
} }
else else
{ {
res[j+1] += t0 * conj0(A0[j+1]) + t1 * conj0(A1[j+1]); res[j+1] += cj0.pmul(A0[j+1],t0) + cj0.pmul(A1[j+1],t1);
t2 += conj1(A0[j+1]) * rhs[j+1]; t2 += cj1.pmul(A0[j+1], rhs[j+1]);
} }
for (size_t i=starti; i<alignedStart; ++i) for (size_t i=starti; i<alignedStart; ++i)
@ -97,41 +111,50 @@ static EIGEN_DONT_INLINE void ei_product_selfadjoint_vector(
t2 += ei_conj(A0[i]) * rhs[i]; t2 += ei_conj(A0[i]) * rhs[i];
t3 += ei_conj(A1[i]) * rhs[i]; t3 += ei_conj(A1[i]) * rhs[i];
} }
// Yes this an optimization for gcc 4.3 and 4.4 (=> huge speed up)
// gcc 4.2 does this optimization automatically.
const Scalar* EIGEN_RESTRICT a0It = A0 + alignedStart;
const Scalar* EIGEN_RESTRICT a1It = A1 + alignedStart;
const Scalar* EIGEN_RESTRICT rhsIt = rhs + alignedStart;
Scalar* EIGEN_RESTRICT resIt = res + alignedStart;
for (size_t i=alignedStart; i<alignedEnd; i+=PacketSize) for (size_t i=alignedStart; i<alignedEnd; i+=PacketSize)
{ {
Packet A0i = ei_ploadu(&A0[i]); Packet A0i = ei_ploadu(a0It); a0It += PacketSize;
Packet A1i = ei_ploadu(&A1[i]); Packet A1i = ei_ploadu(a1It); a1It += PacketSize;
Packet Bi = ei_ploadu(&rhs[i]); // FIXME should be aligned in most cases Packet Bi = ei_ploadu(rhsIt); rhsIt += PacketSize; // FIXME should be aligned in most cases
Packet Xi = ei_pload(&res[i]); Packet Xi = ei_pload (resIt);
Xi = ei_padd(ei_padd(Xi, ei_pmul(ptmp0, conj0(A0i))), ei_pmul(ptmp1, conj0(A1i))); Xi = cj0.pmadd(A0i,ptmp0, cj0.pmadd(A1i,ptmp1,Xi));
ptmp2 = ei_padd(ptmp2, ei_pmul(conj1(A0i), Bi)); ptmp2 = cj1.pmadd(A0i, Bi, ptmp2);
ptmp3 = ei_padd(ptmp3, ei_pmul(conj1(A1i), Bi)); ptmp3 = cj1.pmadd(A1i, Bi, ptmp3);
ei_pstore(&res[i],Xi); ei_pstore(resIt,Xi); resIt += PacketSize;
} }
for (size_t i=alignedEnd; i<endi; i++) for (size_t i=alignedEnd; i<endi; i++)
{ {
res[i] += t0 * conj0(A0[i]) + t1 * conj0(A1[i]); res[i] += cj0.pmul(A0[i], t0) + cj0.pmul(A1[i],t1);
t2 += conj1(A0[i]) * rhs[i]; t2 += cj1.pmul(A0[i], rhs[i]);
t3 += conj1(A1[i]) * rhs[i]; t3 += cj1.pmul(A1[i], rhs[i]);
} }
res[j] += t2 + ei_predux(ptmp2); res[j] += alpha * (t2 + ei_predux(ptmp2));
res[j+1] += t3 + ei_predux(ptmp3); res[j+1] += alpha * (t3 + ei_predux(ptmp3));
} }
for (int j=FirstTriangular ? 0 : bound;j<(FirstTriangular ? bound : size);j++) for (int j=FirstTriangular ? 0 : bound;j<(FirstTriangular ? bound : size);j++)
{ {
register const Scalar* EIGEN_RESTRICT A0 = lhs + j*lhsStride; register const Scalar* EIGEN_RESTRICT A0 = lhs + j*lhsStride;
Scalar t1 = rhs[j]; Scalar t1 = cjAlpha * rhs[j];
Scalar t2 = 0; Scalar t2 = 0;
res[j] += t1 * conj0(A0[j]); res[j] += cj0.pmul(A0[j],t1);
for (int i=FirstTriangular ? 0 : j+1; i<(FirstTriangular ? j : size); i++) { for (int i=FirstTriangular ? 0 : j+1; i<(FirstTriangular ? j : size); i++) {
res[i] += t1 * conj0(A0[i]); res[i] += cj0.pmul(A0[i], t1);
t2 += conj1(A0[i]) * rhs[i]; t2 += cj1.pmul(A0[i], rhs[i]);
} }
res[j] += t2; res[j] += alpha * t2;
} }
if(rhsIncr!=1)
ei_aligned_stack_delete(Scalar, const_cast<Scalar*>(rhs), size);
} }

View File

@ -31,6 +31,8 @@ template<typename MatrixType> void product_selfadjoint(const MatrixType& m)
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> VectorType; typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, 1> VectorType;
typedef Matrix<Scalar, 1, MatrixType::RowsAtCompileTime> RowVectorType; typedef Matrix<Scalar, 1, MatrixType::RowsAtCompileTime> RowVectorType;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, Dynamic, RowMajor> RhsMatrixType;
int rows = m.rows(); int rows = m.rows();
int cols = m.cols(); int cols = m.cols();
@ -38,31 +40,36 @@ template<typename MatrixType> void product_selfadjoint(const MatrixType& m)
m2 = MatrixType::Random(rows, cols), m2 = MatrixType::Random(rows, cols),
m3; m3;
VectorType v1 = VectorType::Random(rows), VectorType v1 = VectorType::Random(rows),
v2 = VectorType::Random(rows); v2 = VectorType::Random(rows),
v3(rows);
RowVectorType r1 = RowVectorType::Random(rows), RowVectorType r1 = RowVectorType::Random(rows),
r2 = RowVectorType::Random(rows); r2 = RowVectorType::Random(rows);
RhsMatrixType m4 = RhsMatrixType::Random(rows,10);
Scalar s1 = ei_random<Scalar>(), Scalar s1 = ei_random<Scalar>(),
s2 = ei_random<Scalar>(), s2 = ei_random<Scalar>(),
s3 = ei_random<Scalar>(); s3 = ei_random<Scalar>();
m1 = m1.adjoint()*m1; m1 = (m1.adjoint() + m1).eval();
// lower // lower
m2.setZero(); m2 = m1.template triangularView<LowerTriangular>();
m2.template triangularView<LowerTriangular>() = m1; VERIFY_IS_APPROX(v3 = (s1*m2).template selfadjointView<LowerTriangular>() * (s2*v1), (s1*m1) * (s2*v1));
ei_product_selfadjoint_vector<Scalar,MatrixType::Flags&RowMajorBit,LowerTriangularBit> VERIFY_IS_APPROX(v3 = (s1*m2.conjugate()).template selfadjointView<LowerTriangular>() * (s2*v1), (s1*m1.conjugate()) * (s2*v1));
(cols,m2.data(),cols, v1.data(), v2.data()); VERIFY_IS_APPROX(v3 = (s1*m2).template selfadjointView<LowerTriangular>() * (s2*m4.col(1)), (s1*m1) * (s2*m4.col(1)));
VERIFY_IS_APPROX(v2, m1 * v1);
VERIFY_IS_APPROX((m2.template selfadjointView<LowerTriangular>() * v1).eval(), m1 * v1); VERIFY_IS_APPROX(v3 = (s1*m2).template selfadjointView<LowerTriangular>() * (s2*v1.conjugate()), (s1*m1) * (s2*v1.conjugate()));
VERIFY_IS_APPROX(v3 = (s1*m2.conjugate()).template selfadjointView<LowerTriangular>() * (s2*v1.conjugate()), (s1*m1.conjugate()) * (s2*v1.conjugate()));
// upper // upper
m2.setZero(); m2 = m1.template triangularView<UpperTriangular>();
m2.template triangularView<UpperTriangular>() = m1; VERIFY_IS_APPROX(v3 = (s1*m2).template selfadjointView<UpperTriangular>() * (s2*v1), (s1*m1) * (s2*v1));
ei_product_selfadjoint_vector<Scalar,MatrixType::Flags&RowMajorBit,UpperTriangularBit>(cols,m2.data(),cols, v1.data(), v2.data()); VERIFY_IS_APPROX(v3 = (s1*m2.conjugate()).template selfadjointView<UpperTriangular>() * (s2*v1), (s1*m1.conjugate()) * (s2*v1));
VERIFY_IS_APPROX(v2, m1 * v1); VERIFY_IS_APPROX(v3 = (s1*m2.adjoint()).template selfadjointView<LowerTriangular>() * (s2*v1), (s1*m1.adjoint()) * (s2*v1));
VERIFY_IS_APPROX((m2.template selfadjointView<UpperTriangular>() * v1).eval(), m1 * v1); VERIFY_IS_APPROX(v3 = (s1*m2.transpose()).template selfadjointView<LowerTriangular>() * (s2*v1), (s1*m1.transpose()) * (s2*v1));
VERIFY_IS_APPROX(v3 = (s1*m2).template selfadjointView<UpperTriangular>() * (s2*v1.conjugate()), (s1*m1) * (s2*v1.conjugate()));
VERIFY_IS_APPROX(v3 = (s1*m2.conjugate()).template selfadjointView<UpperTriangular>() * (s2*v1.conjugate()), (s1*m1.conjugate()) * (s2*v1.conjugate()));
// rank2 update // rank2 update
m2 = m1.template triangularView<LowerTriangular>(); m2 = m1.template triangularView<LowerTriangular>();