mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-21 17:19:36 +08:00
fix trmv regarding strided vectors and static allocation of temporaries
This commit is contained in:
parent
0fdd01fe24
commit
c60818fca8
@ -84,12 +84,14 @@ class ProductBase : public MatrixBase<Derived>
|
||||
typedef internal::blas_traits<_LhsNested> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
typedef typename internal::remove_all<ActualLhsType>::type _ActualLhsType;
|
||||
typedef typename internal::traits<Lhs>::Scalar LhsScalar;
|
||||
|
||||
typedef typename Rhs::Nested RhsNested;
|
||||
typedef typename internal::remove_all<RhsNested>::type _RhsNested;
|
||||
typedef internal::blas_traits<_RhsNested> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||
typedef typename internal::remove_all<ActualRhsType>::type _ActualRhsType;
|
||||
typedef typename internal::traits<Rhs>::Scalar RhsScalar;
|
||||
|
||||
// Diagonal of a product: no need to evaluate the arguments because they are going to be evaluated only once
|
||||
typedef CoeffBasedProduct<LhsNested, RhsNested, 0> FullyLazyCoeffBaseProductType;
|
||||
|
@ -152,6 +152,10 @@ struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
|
||||
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
|
||||
{};
|
||||
|
||||
|
||||
template<int StorageOrder>
|
||||
struct trmv_selector;
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
template<int Mode, typename Lhs, typename Rhs>
|
||||
@ -166,19 +170,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
|
||||
{
|
||||
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||
|
||||
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
|
||||
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||
|
||||
internal::product_triangular_matrix_vector
|
||||
<Index,Mode,
|
||||
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
|
||||
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
|
||||
(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>
|
||||
::run(lhs.rows(),lhs.cols(),lhs.data(),lhs.outerStride(),rhs.data(),rhs.innerStride(),
|
||||
dst.data(),dst.innerStride(),actualAlpha);
|
||||
internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
@ -192,23 +184,167 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
|
||||
|
||||
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
|
||||
{
|
||||
|
||||
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||
|
||||
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
|
||||
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||
|
||||
internal::product_triangular_matrix_vector
|
||||
<Index,(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),
|
||||
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
|
||||
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
|
||||
(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
|
||||
::run(rhs.rows(),rhs.cols(),rhs.data(),rhs.outerStride(),lhs.data(),lhs.innerStride(),
|
||||
dst.data(),dst.innerStride(),actualAlpha);
|
||||
typedef TriangularProduct<(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
|
||||
Transpose<Dest> dstT(dst);
|
||||
internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
|
||||
TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
namespace internal {
|
||||
|
||||
// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
|
||||
|
||||
template<> struct trmv_selector<ColMajor>
|
||||
{
|
||||
template<int Mode, typename Lhs, typename Rhs, typename Dest>
|
||||
static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
|
||||
{
|
||||
typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
|
||||
typedef typename ProductType::Index Index;
|
||||
typedef typename ProductType::LhsScalar LhsScalar;
|
||||
typedef typename ProductType::RhsScalar RhsScalar;
|
||||
typedef typename ProductType::Scalar ResScalar;
|
||||
typedef typename ProductType::RealScalar RealScalar;
|
||||
typedef typename ProductType::ActualLhsType ActualLhsType;
|
||||
typedef typename ProductType::ActualRhsType ActualRhsType;
|
||||
typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
|
||||
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
|
||||
typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
|
||||
|
||||
const ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs());
|
||||
const ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs());
|
||||
|
||||
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
|
||||
* RhsBlasTraits::extractScalarFactor(prod.rhs());
|
||||
|
||||
enum {
|
||||
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
|
||||
// on, the other hand it is good for the cache to pack the vector anyways...
|
||||
EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
|
||||
ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
|
||||
MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
|
||||
};
|
||||
|
||||
gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
|
||||
|
||||
bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
|
||||
bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
|
||||
|
||||
RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
|
||||
|
||||
ResScalar* actualDestPtr;
|
||||
bool freeDestPtr = false;
|
||||
if (evalToDest)
|
||||
{
|
||||
actualDestPtr = dest.data();
|
||||
}
|
||||
else
|
||||
{
|
||||
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
int size = dest.size();
|
||||
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
#endif
|
||||
if((actualDestPtr = static_dest.data())==0)
|
||||
{
|
||||
freeDestPtr = true;
|
||||
actualDestPtr = ei_aligned_stack_new(ResScalar,dest.size());
|
||||
}
|
||||
if(!alphaIsCompatible)
|
||||
{
|
||||
MappedDest(actualDestPtr, dest.size()).setZero();
|
||||
compatibleAlpha = RhsScalar(1);
|
||||
}
|
||||
else
|
||||
MappedDest(actualDestPtr, dest.size()) = dest;
|
||||
}
|
||||
|
||||
internal::product_triangular_matrix_vector
|
||||
<Index,Mode,
|
||||
LhsScalar, LhsBlasTraits::NeedToConjugate,
|
||||
RhsScalar, RhsBlasTraits::NeedToConjugate,
|
||||
ColMajor>
|
||||
::run(actualLhs.rows(),actualLhs.cols(),
|
||||
actualLhs.data(),actualLhs.outerStride(),
|
||||
actualRhs.data(),actualRhs.innerStride(),
|
||||
actualDestPtr,1,compatibleAlpha);
|
||||
|
||||
if (!evalToDest)
|
||||
{
|
||||
if(!alphaIsCompatible)
|
||||
dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
|
||||
else
|
||||
dest = MappedDest(actualDestPtr, dest.size());
|
||||
if(freeDestPtr) ei_aligned_stack_delete(ResScalar, actualDestPtr, dest.size());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template<> struct trmv_selector<RowMajor>
|
||||
{
|
||||
template<int Mode, typename Lhs, typename Rhs, typename Dest>
|
||||
static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
|
||||
{
|
||||
typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
|
||||
typedef typename ProductType::LhsScalar LhsScalar;
|
||||
typedef typename ProductType::RhsScalar RhsScalar;
|
||||
typedef typename ProductType::Scalar ResScalar;
|
||||
typedef typename ProductType::Index Index;
|
||||
typedef typename ProductType::ActualLhsType ActualLhsType;
|
||||
typedef typename ProductType::ActualRhsType ActualRhsType;
|
||||
typedef typename ProductType::_ActualRhsType _ActualRhsType;
|
||||
typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
|
||||
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
|
||||
|
||||
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
|
||||
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
|
||||
|
||||
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
|
||||
* RhsBlasTraits::extractScalarFactor(prod.rhs());
|
||||
|
||||
enum {
|
||||
DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
|
||||
};
|
||||
|
||||
gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
|
||||
|
||||
RhsScalar* actualRhsPtr;
|
||||
bool freeRhsPtr = false;
|
||||
if (DirectlyUseRhs)
|
||||
{
|
||||
actualRhsPtr = const_cast<RhsScalar*>(actualRhs.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
int size = actualRhs.size();
|
||||
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||
#endif
|
||||
if((actualRhsPtr = static_rhs.data())==0)
|
||||
{
|
||||
freeRhsPtr = true;
|
||||
actualRhsPtr = ei_aligned_stack_new(RhsScalar, actualRhs.size());
|
||||
}
|
||||
Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
|
||||
}
|
||||
|
||||
internal::product_triangular_matrix_vector
|
||||
<Index,Mode,
|
||||
LhsScalar, LhsBlasTraits::NeedToConjugate,
|
||||
RhsScalar, RhsBlasTraits::NeedToConjugate,
|
||||
RowMajor>
|
||||
::run(actualLhs.rows(),actualLhs.cols(),
|
||||
actualLhs.data(),actualLhs.outerStride(),
|
||||
actualRhsPtr,1,
|
||||
dest.data(),dest.innerStride(),
|
||||
actualAlpha);
|
||||
|
||||
if((!DirectlyUseRhs) && freeRhsPtr) ei_aligned_stack_delete(RhsScalar, actualRhsPtr, prod.rhs().size());
|
||||
}
|
||||
};
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
|
||||
|
@ -81,6 +81,36 @@ template<typename MatrixType> void nomalloc(const MatrixType& m)
|
||||
m2.row(0).noalias() -= m1.row(0) * m1.adjoint();
|
||||
m2.row(0).noalias() -= m1.col(0).adjoint() * m1;
|
||||
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint();
|
||||
VERIFY_IS_APPROX(m2,m2);
|
||||
|
||||
m2.col(0).noalias() = m1.template triangularView<Upper>() * m1.col(0);
|
||||
m2.col(0).noalias() -= m1.adjoint().template triangularView<Upper>() * m1.col(0);
|
||||
m2.col(0).noalias() -= m1.template triangularView<Upper>() * m1.row(0).adjoint();
|
||||
m2.col(0).noalias() -= m1.adjoint().template triangularView<Upper>() * m1.row(0).adjoint();
|
||||
|
||||
m2.row(0).noalias() = m1.row(0) * m1.template triangularView<Upper>();
|
||||
m2.row(0).noalias() -= m1.row(0) * m1.adjoint().template triangularView<Upper>();
|
||||
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.template triangularView<Upper>();
|
||||
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint().template triangularView<Upper>();
|
||||
VERIFY_IS_APPROX(m2,m2);
|
||||
|
||||
m2.col(0).noalias() = m1.template selfadjointView<Upper>() * m1.col(0);
|
||||
m2.col(0).noalias() -= m1.adjoint().template selfadjointView<Upper>() * m1.col(0);
|
||||
m2.col(0).noalias() -= m1.template selfadjointView<Upper>() * m1.row(0).adjoint();
|
||||
m2.col(0).noalias() -= m1.adjoint().template selfadjointView<Upper>() * m1.row(0).adjoint();
|
||||
|
||||
m2.row(0).noalias() = m1.row(0) * m1.template selfadjointView<Upper>();
|
||||
m2.row(0).noalias() -= m1.row(0) * m1.adjoint().template selfadjointView<Upper>();
|
||||
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.template selfadjointView<Upper>();
|
||||
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint().template selfadjointView<Upper>();
|
||||
VERIFY_IS_APPROX(m2,m2);
|
||||
|
||||
// The following fancy matrix-matrix products are not safe yet regarding static allocation
|
||||
// m1 += m1.template triangularView<Upper>() * m2.col(;
|
||||
// m1.template selfadjointView<Lower>().rankUpdate(m2);
|
||||
// m1 += m1.template triangularView<Upper>() * m2;
|
||||
// m1 += m1.template selfadjointView<Lower>() * m2;
|
||||
// VERIFY_IS_APPROX(m1,m1);
|
||||
}
|
||||
|
||||
template<typename Scalar>
|
||||
|
Loading…
x
Reference in New Issue
Block a user