fix trmv regarding strided vectors and static allocation of temporaries

This commit is contained in:
Gael Guennebaud 2011-02-01 11:38:46 +01:00
parent 0fdd01fe24
commit c60818fca8
3 changed files with 198 additions and 30 deletions

View File

@ -84,12 +84,14 @@ class ProductBase : public MatrixBase<Derived>
typedef internal::blas_traits<_LhsNested> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
typedef typename internal::remove_all<ActualLhsType>::type _ActualLhsType;
typedef typename internal::traits<Lhs>::Scalar LhsScalar;
typedef typename Rhs::Nested RhsNested;
typedef typename internal::remove_all<RhsNested>::type _RhsNested;
typedef internal::blas_traits<_RhsNested> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
typedef typename internal::remove_all<ActualRhsType>::type _ActualRhsType;
typedef typename internal::traits<Rhs>::Scalar RhsScalar;
// Diagonal of a product: no need to evaluate the arguments because they are going to be evaluated only once
typedef CoeffBasedProduct<LhsNested, RhsNested, 0> FullyLazyCoeffBaseProductType;

View File

@ -152,6 +152,10 @@ struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
{};
template<int StorageOrder>
struct trmv_selector;
} // end namespace internal
template<int Mode, typename Lhs, typename Rhs>
@ -166,19 +170,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
{
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs);
internal::product_triangular_matrix_vector
<Index,Mode,
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>
::run(lhs.rows(),lhs.cols(),lhs.data(),lhs.outerStride(),rhs.data(),rhs.innerStride(),
dst.data(),dst.innerStride(),actualAlpha);
internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
}
};
@ -192,23 +184,167 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
{
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
* RhsBlasTraits::extractScalarFactor(m_rhs);
internal::product_triangular_matrix_vector
<Index,(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
::run(rhs.rows(),rhs.cols(),rhs.data(),rhs.outerStride(),lhs.data(),lhs.innerStride(),
dst.data(),dst.innerStride(),actualAlpha);
typedef TriangularProduct<(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
Transpose<Dest> dstT(dst);
internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
}
};
namespace internal {
// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
template<> struct trmv_selector<ColMajor>
{
template<int Mode, typename Lhs, typename Rhs, typename Dest>
static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
{
typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
typedef typename ProductType::Index Index;
typedef typename ProductType::LhsScalar LhsScalar;
typedef typename ProductType::RhsScalar RhsScalar;
typedef typename ProductType::Scalar ResScalar;
typedef typename ProductType::RealScalar RealScalar;
typedef typename ProductType::ActualLhsType ActualLhsType;
typedef typename ProductType::ActualRhsType ActualRhsType;
typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
const ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs());
const ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs());
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
* RhsBlasTraits::extractScalarFactor(prod.rhs());
enum {
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
// on, the other hand it is good for the cache to pack the vector anyways...
EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
};
gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
ResScalar* actualDestPtr;
bool freeDestPtr = false;
if (evalToDest)
{
actualDestPtr = dest.data();
}
else
{
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
int size = dest.size();
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
#endif
if((actualDestPtr = static_dest.data())==0)
{
freeDestPtr = true;
actualDestPtr = ei_aligned_stack_new(ResScalar,dest.size());
}
if(!alphaIsCompatible)
{
MappedDest(actualDestPtr, dest.size()).setZero();
compatibleAlpha = RhsScalar(1);
}
else
MappedDest(actualDestPtr, dest.size()) = dest;
}
internal::product_triangular_matrix_vector
<Index,Mode,
LhsScalar, LhsBlasTraits::NeedToConjugate,
RhsScalar, RhsBlasTraits::NeedToConjugate,
ColMajor>
::run(actualLhs.rows(),actualLhs.cols(),
actualLhs.data(),actualLhs.outerStride(),
actualRhs.data(),actualRhs.innerStride(),
actualDestPtr,1,compatibleAlpha);
if (!evalToDest)
{
if(!alphaIsCompatible)
dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
else
dest = MappedDest(actualDestPtr, dest.size());
if(freeDestPtr) ei_aligned_stack_delete(ResScalar, actualDestPtr, dest.size());
}
}
};
template<> struct trmv_selector<RowMajor>
{
template<int Mode, typename Lhs, typename Rhs, typename Dest>
static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
{
typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
typedef typename ProductType::LhsScalar LhsScalar;
typedef typename ProductType::RhsScalar RhsScalar;
typedef typename ProductType::Scalar ResScalar;
typedef typename ProductType::Index Index;
typedef typename ProductType::ActualLhsType ActualLhsType;
typedef typename ProductType::ActualRhsType ActualRhsType;
typedef typename ProductType::_ActualRhsType _ActualRhsType;
typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
* RhsBlasTraits::extractScalarFactor(prod.rhs());
enum {
DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
};
gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
RhsScalar* actualRhsPtr;
bool freeRhsPtr = false;
if (DirectlyUseRhs)
{
actualRhsPtr = const_cast<RhsScalar*>(actualRhs.data());
}
else
{
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
int size = actualRhs.size();
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
#endif
if((actualRhsPtr = static_rhs.data())==0)
{
freeRhsPtr = true;
actualRhsPtr = ei_aligned_stack_new(RhsScalar, actualRhs.size());
}
Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
}
internal::product_triangular_matrix_vector
<Index,Mode,
LhsScalar, LhsBlasTraits::NeedToConjugate,
RhsScalar, RhsBlasTraits::NeedToConjugate,
RowMajor>
::run(actualLhs.rows(),actualLhs.cols(),
actualLhs.data(),actualLhs.outerStride(),
actualRhsPtr,1,
dest.data(),dest.innerStride(),
actualAlpha);
if((!DirectlyUseRhs) && freeRhsPtr) ei_aligned_stack_delete(RhsScalar, actualRhsPtr, prod.rhs().size());
}
};
} // end namespace internal
#endif // EIGEN_TRIANGULARMATRIXVECTOR_H

View File

@ -81,6 +81,36 @@ template<typename MatrixType> void nomalloc(const MatrixType& m)
m2.row(0).noalias() -= m1.row(0) * m1.adjoint();
m2.row(0).noalias() -= m1.col(0).adjoint() * m1;
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint();
VERIFY_IS_APPROX(m2,m2);
m2.col(0).noalias() = m1.template triangularView<Upper>() * m1.col(0);
m2.col(0).noalias() -= m1.adjoint().template triangularView<Upper>() * m1.col(0);
m2.col(0).noalias() -= m1.template triangularView<Upper>() * m1.row(0).adjoint();
m2.col(0).noalias() -= m1.adjoint().template triangularView<Upper>() * m1.row(0).adjoint();
m2.row(0).noalias() = m1.row(0) * m1.template triangularView<Upper>();
m2.row(0).noalias() -= m1.row(0) * m1.adjoint().template triangularView<Upper>();
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.template triangularView<Upper>();
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint().template triangularView<Upper>();
VERIFY_IS_APPROX(m2,m2);
m2.col(0).noalias() = m1.template selfadjointView<Upper>() * m1.col(0);
m2.col(0).noalias() -= m1.adjoint().template selfadjointView<Upper>() * m1.col(0);
m2.col(0).noalias() -= m1.template selfadjointView<Upper>() * m1.row(0).adjoint();
m2.col(0).noalias() -= m1.adjoint().template selfadjointView<Upper>() * m1.row(0).adjoint();
m2.row(0).noalias() = m1.row(0) * m1.template selfadjointView<Upper>();
m2.row(0).noalias() -= m1.row(0) * m1.adjoint().template selfadjointView<Upper>();
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.template selfadjointView<Upper>();
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint().template selfadjointView<Upper>();
VERIFY_IS_APPROX(m2,m2);
// The following fancy matrix-matrix products are not safe yet regarding static allocation
// m1 += m1.template triangularView<Upper>() * m2.col(;
// m1.template selfadjointView<Lower>().rankUpdate(m2);
// m1 += m1.template triangularView<Upper>() * m2;
// m1 += m1.template selfadjointView<Lower>() * m2;
// VERIFY_IS_APPROX(m1,m1);
}
template<typename Scalar>