mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-21 17:19:36 +08:00
fix trmv regarding strided vectors and static allocation of temporaries
This commit is contained in:
parent
0fdd01fe24
commit
c60818fca8
@ -84,12 +84,14 @@ class ProductBase : public MatrixBase<Derived>
|
|||||||
typedef internal::blas_traits<_LhsNested> LhsBlasTraits;
|
typedef internal::blas_traits<_LhsNested> LhsBlasTraits;
|
||||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||||
typedef typename internal::remove_all<ActualLhsType>::type _ActualLhsType;
|
typedef typename internal::remove_all<ActualLhsType>::type _ActualLhsType;
|
||||||
|
typedef typename internal::traits<Lhs>::Scalar LhsScalar;
|
||||||
|
|
||||||
typedef typename Rhs::Nested RhsNested;
|
typedef typename Rhs::Nested RhsNested;
|
||||||
typedef typename internal::remove_all<RhsNested>::type _RhsNested;
|
typedef typename internal::remove_all<RhsNested>::type _RhsNested;
|
||||||
typedef internal::blas_traits<_RhsNested> RhsBlasTraits;
|
typedef internal::blas_traits<_RhsNested> RhsBlasTraits;
|
||||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhsType;
|
||||||
typedef typename internal::remove_all<ActualRhsType>::type _ActualRhsType;
|
typedef typename internal::remove_all<ActualRhsType>::type _ActualRhsType;
|
||||||
|
typedef typename internal::traits<Rhs>::Scalar RhsScalar;
|
||||||
|
|
||||||
// Diagonal of a product: no need to evaluate the arguments because they are going to be evaluated only once
|
// Diagonal of a product: no need to evaluate the arguments because they are going to be evaluated only once
|
||||||
typedef CoeffBasedProduct<LhsNested, RhsNested, 0> FullyLazyCoeffBaseProductType;
|
typedef CoeffBasedProduct<LhsNested, RhsNested, 0> FullyLazyCoeffBaseProductType;
|
||||||
|
@ -152,6 +152,10 @@ struct traits<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false> >
|
|||||||
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
|
: traits<ProductBase<TriangularProduct<Mode,LhsIsTriangular,Lhs,true,Rhs,false>, Lhs, Rhs> >
|
||||||
{};
|
{};
|
||||||
|
|
||||||
|
|
||||||
|
template<int StorageOrder>
|
||||||
|
struct trmv_selector;
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
|
|
||||||
template<int Mode, typename Lhs, typename Rhs>
|
template<int Mode, typename Lhs, typename Rhs>
|
||||||
@ -166,19 +170,7 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
|
|||||||
{
|
{
|
||||||
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||||
|
|
||||||
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
|
internal::trmv_selector<(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dst, alpha);
|
||||||
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
|
|
||||||
|
|
||||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
|
||||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
|
||||||
|
|
||||||
internal::product_triangular_matrix_vector
|
|
||||||
<Index,Mode,
|
|
||||||
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
|
|
||||||
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
|
|
||||||
(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>
|
|
||||||
::run(lhs.rows(),lhs.cols(),lhs.data(),lhs.outerStride(),rhs.data(),rhs.innerStride(),
|
|
||||||
dst.data(),dst.innerStride(),actualAlpha);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -192,23 +184,167 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
|
|||||||
|
|
||||||
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
|
template<typename Dest> void scaleAndAddTo(Dest& dst, Scalar alpha) const
|
||||||
{
|
{
|
||||||
|
|
||||||
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
|
||||||
|
|
||||||
const ActualLhsType lhs = LhsBlasTraits::extract(m_lhs);
|
typedef TriangularProduct<(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),true,Transpose<const Rhs>,false,Transpose<const Lhs>,true> TriangularProductTranspose;
|
||||||
const ActualRhsType rhs = RhsBlasTraits::extract(m_rhs);
|
Transpose<Dest> dstT(dst);
|
||||||
|
internal::trmv_selector<(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>::run(
|
||||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
|
TriangularProductTranspose(m_rhs.transpose(),m_lhs.transpose()), dstT, alpha);
|
||||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
|
||||||
|
|
||||||
internal::product_triangular_matrix_vector
|
|
||||||
<Index,(Mode & UnitDiag) | ((Mode & Lower) ? Upper : Lower),
|
|
||||||
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
|
|
||||||
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
|
|
||||||
(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
|
|
||||||
::run(rhs.rows(),rhs.cols(),rhs.data(),rhs.outerStride(),lhs.data(),lhs.innerStride(),
|
|
||||||
dst.data(),dst.innerStride(),actualAlpha);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
|
||||||
|
// TODO: find a way to factorize this piece of code with gemv_selector since the logic is exactly the same.
|
||||||
|
|
||||||
|
template<> struct trmv_selector<ColMajor>
|
||||||
|
{
|
||||||
|
template<int Mode, typename Lhs, typename Rhs, typename Dest>
|
||||||
|
static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
|
||||||
|
{
|
||||||
|
typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
|
||||||
|
typedef typename ProductType::Index Index;
|
||||||
|
typedef typename ProductType::LhsScalar LhsScalar;
|
||||||
|
typedef typename ProductType::RhsScalar RhsScalar;
|
||||||
|
typedef typename ProductType::Scalar ResScalar;
|
||||||
|
typedef typename ProductType::RealScalar RealScalar;
|
||||||
|
typedef typename ProductType::ActualLhsType ActualLhsType;
|
||||||
|
typedef typename ProductType::ActualRhsType ActualRhsType;
|
||||||
|
typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
|
||||||
|
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
|
||||||
|
typedef Map<Matrix<ResScalar,Dynamic,1>, Aligned> MappedDest;
|
||||||
|
|
||||||
|
const ActualLhsType actualLhs = LhsBlasTraits::extract(prod.lhs());
|
||||||
|
const ActualRhsType actualRhs = RhsBlasTraits::extract(prod.rhs());
|
||||||
|
|
||||||
|
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
|
||||||
|
* RhsBlasTraits::extractScalarFactor(prod.rhs());
|
||||||
|
|
||||||
|
enum {
|
||||||
|
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
|
||||||
|
// on, the other hand it is good for the cache to pack the vector anyways...
|
||||||
|
EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
|
||||||
|
ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
|
||||||
|
MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
|
||||||
|
};
|
||||||
|
|
||||||
|
gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
|
||||||
|
|
||||||
|
bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
|
||||||
|
bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
|
||||||
|
|
||||||
|
RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
|
||||||
|
|
||||||
|
ResScalar* actualDestPtr;
|
||||||
|
bool freeDestPtr = false;
|
||||||
|
if (evalToDest)
|
||||||
|
{
|
||||||
|
actualDestPtr = dest.data();
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
int size = dest.size();
|
||||||
|
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
#endif
|
||||||
|
if((actualDestPtr = static_dest.data())==0)
|
||||||
|
{
|
||||||
|
freeDestPtr = true;
|
||||||
|
actualDestPtr = ei_aligned_stack_new(ResScalar,dest.size());
|
||||||
|
}
|
||||||
|
if(!alphaIsCompatible)
|
||||||
|
{
|
||||||
|
MappedDest(actualDestPtr, dest.size()).setZero();
|
||||||
|
compatibleAlpha = RhsScalar(1);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
MappedDest(actualDestPtr, dest.size()) = dest;
|
||||||
|
}
|
||||||
|
|
||||||
|
internal::product_triangular_matrix_vector
|
||||||
|
<Index,Mode,
|
||||||
|
LhsScalar, LhsBlasTraits::NeedToConjugate,
|
||||||
|
RhsScalar, RhsBlasTraits::NeedToConjugate,
|
||||||
|
ColMajor>
|
||||||
|
::run(actualLhs.rows(),actualLhs.cols(),
|
||||||
|
actualLhs.data(),actualLhs.outerStride(),
|
||||||
|
actualRhs.data(),actualRhs.innerStride(),
|
||||||
|
actualDestPtr,1,compatibleAlpha);
|
||||||
|
|
||||||
|
if (!evalToDest)
|
||||||
|
{
|
||||||
|
if(!alphaIsCompatible)
|
||||||
|
dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
|
||||||
|
else
|
||||||
|
dest = MappedDest(actualDestPtr, dest.size());
|
||||||
|
if(freeDestPtr) ei_aligned_stack_delete(ResScalar, actualDestPtr, dest.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<> struct trmv_selector<RowMajor>
|
||||||
|
{
|
||||||
|
template<int Mode, typename Lhs, typename Rhs, typename Dest>
|
||||||
|
static void run(const TriangularProduct<Mode,true,Lhs,false,Rhs,true>& prod, Dest& dest, typename TriangularProduct<Mode,true,Lhs,false,Rhs,true>::Scalar alpha)
|
||||||
|
{
|
||||||
|
typedef TriangularProduct<Mode,true,Lhs,false,Rhs,true> ProductType;
|
||||||
|
typedef typename ProductType::LhsScalar LhsScalar;
|
||||||
|
typedef typename ProductType::RhsScalar RhsScalar;
|
||||||
|
typedef typename ProductType::Scalar ResScalar;
|
||||||
|
typedef typename ProductType::Index Index;
|
||||||
|
typedef typename ProductType::ActualLhsType ActualLhsType;
|
||||||
|
typedef typename ProductType::ActualRhsType ActualRhsType;
|
||||||
|
typedef typename ProductType::_ActualRhsType _ActualRhsType;
|
||||||
|
typedef typename ProductType::LhsBlasTraits LhsBlasTraits;
|
||||||
|
typedef typename ProductType::RhsBlasTraits RhsBlasTraits;
|
||||||
|
|
||||||
|
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
|
||||||
|
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
|
||||||
|
|
||||||
|
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs())
|
||||||
|
* RhsBlasTraits::extractScalarFactor(prod.rhs());
|
||||||
|
|
||||||
|
enum {
|
||||||
|
DirectlyUseRhs = _ActualRhsType::InnerStrideAtCompileTime==1
|
||||||
|
};
|
||||||
|
|
||||||
|
gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
|
||||||
|
|
||||||
|
RhsScalar* actualRhsPtr;
|
||||||
|
bool freeRhsPtr = false;
|
||||||
|
if (DirectlyUseRhs)
|
||||||
|
{
|
||||||
|
actualRhsPtr = const_cast<RhsScalar*>(actualRhs.data());
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
int size = actualRhs.size();
|
||||||
|
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
#endif
|
||||||
|
if((actualRhsPtr = static_rhs.data())==0)
|
||||||
|
{
|
||||||
|
freeRhsPtr = true;
|
||||||
|
actualRhsPtr = ei_aligned_stack_new(RhsScalar, actualRhs.size());
|
||||||
|
}
|
||||||
|
Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
internal::product_triangular_matrix_vector
|
||||||
|
<Index,Mode,
|
||||||
|
LhsScalar, LhsBlasTraits::NeedToConjugate,
|
||||||
|
RhsScalar, RhsBlasTraits::NeedToConjugate,
|
||||||
|
RowMajor>
|
||||||
|
::run(actualLhs.rows(),actualLhs.cols(),
|
||||||
|
actualLhs.data(),actualLhs.outerStride(),
|
||||||
|
actualRhsPtr,1,
|
||||||
|
dest.data(),dest.innerStride(),
|
||||||
|
actualAlpha);
|
||||||
|
|
||||||
|
if((!DirectlyUseRhs) && freeRhsPtr) ei_aligned_stack_delete(RhsScalar, actualRhsPtr, prod.rhs().size());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
|
#endif // EIGEN_TRIANGULARMATRIXVECTOR_H
|
||||||
|
@ -81,6 +81,36 @@ template<typename MatrixType> void nomalloc(const MatrixType& m)
|
|||||||
m2.row(0).noalias() -= m1.row(0) * m1.adjoint();
|
m2.row(0).noalias() -= m1.row(0) * m1.adjoint();
|
||||||
m2.row(0).noalias() -= m1.col(0).adjoint() * m1;
|
m2.row(0).noalias() -= m1.col(0).adjoint() * m1;
|
||||||
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint();
|
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint();
|
||||||
|
VERIFY_IS_APPROX(m2,m2);
|
||||||
|
|
||||||
|
m2.col(0).noalias() = m1.template triangularView<Upper>() * m1.col(0);
|
||||||
|
m2.col(0).noalias() -= m1.adjoint().template triangularView<Upper>() * m1.col(0);
|
||||||
|
m2.col(0).noalias() -= m1.template triangularView<Upper>() * m1.row(0).adjoint();
|
||||||
|
m2.col(0).noalias() -= m1.adjoint().template triangularView<Upper>() * m1.row(0).adjoint();
|
||||||
|
|
||||||
|
m2.row(0).noalias() = m1.row(0) * m1.template triangularView<Upper>();
|
||||||
|
m2.row(0).noalias() -= m1.row(0) * m1.adjoint().template triangularView<Upper>();
|
||||||
|
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.template triangularView<Upper>();
|
||||||
|
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint().template triangularView<Upper>();
|
||||||
|
VERIFY_IS_APPROX(m2,m2);
|
||||||
|
|
||||||
|
m2.col(0).noalias() = m1.template selfadjointView<Upper>() * m1.col(0);
|
||||||
|
m2.col(0).noalias() -= m1.adjoint().template selfadjointView<Upper>() * m1.col(0);
|
||||||
|
m2.col(0).noalias() -= m1.template selfadjointView<Upper>() * m1.row(0).adjoint();
|
||||||
|
m2.col(0).noalias() -= m1.adjoint().template selfadjointView<Upper>() * m1.row(0).adjoint();
|
||||||
|
|
||||||
|
m2.row(0).noalias() = m1.row(0) * m1.template selfadjointView<Upper>();
|
||||||
|
m2.row(0).noalias() -= m1.row(0) * m1.adjoint().template selfadjointView<Upper>();
|
||||||
|
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.template selfadjointView<Upper>();
|
||||||
|
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint().template selfadjointView<Upper>();
|
||||||
|
VERIFY_IS_APPROX(m2,m2);
|
||||||
|
|
||||||
|
// The following fancy matrix-matrix products are not safe yet regarding static allocation
|
||||||
|
// m1 += m1.template triangularView<Upper>() * m2.col(;
|
||||||
|
// m1.template selfadjointView<Lower>().rankUpdate(m2);
|
||||||
|
// m1 += m1.template triangularView<Upper>() * m2;
|
||||||
|
// m1 += m1.template selfadjointView<Lower>() * m2;
|
||||||
|
// VERIFY_IS_APPROX(m1,m1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Scalar>
|
template<typename Scalar>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user