mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 03:09:01 +08:00
make trmv uses direct access
This commit is contained in:
parent
437dff80ee
commit
d72a8f1e50
@ -27,43 +27,39 @@
|
|||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<bool LhsIsTriangular, typename Lhs, typename Rhs, typename Result,
|
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs, int StorageOrder>
|
||||||
int Mode, bool ConjLhs, bool ConjRhs, int StorageOrder>
|
|
||||||
struct product_triangular_vector_selector;
|
struct product_triangular_vector_selector;
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, typename Result, int Mode, bool ConjLhs, bool ConjRhs, int StorageOrder>
|
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs>
|
||||||
struct product_triangular_vector_selector<false,Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,StorageOrder>
|
struct product_triangular_vector_selector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,ColMajor>
|
||||||
{
|
{
|
||||||
static EIGEN_DONT_INLINE void run(const Lhs& lhs, const Rhs& rhs, Result& res, typename traits<Lhs>::Scalar alpha)
|
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||||
{
|
|
||||||
typedef Transpose<Rhs> TrRhs; TrRhs trRhs(rhs);
|
|
||||||
typedef Transpose<Lhs> TrLhs; TrLhs trLhs(lhs);
|
|
||||||
typedef Transpose<Result> TrRes; TrRes trRes(res);
|
|
||||||
product_triangular_vector_selector<true,TrRhs,TrLhs,TrRes,
|
|
||||||
(Mode & UnitDiag) | (Mode & Lower) ? Upper : Lower, ConjRhs, ConjLhs, StorageOrder==RowMajor ? ColMajor : RowMajor>
|
|
||||||
::run(trRhs,trLhs,trRes,alpha);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, typename Result, int Mode, bool ConjLhs, bool ConjRhs>
|
|
||||||
struct product_triangular_vector_selector<true,Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,ColMajor>
|
|
||||||
{
|
|
||||||
typedef typename Rhs::Scalar Scalar;
|
|
||||||
typedef typename Rhs::Index Index;
|
|
||||||
enum {
|
enum {
|
||||||
IsLower = ((Mode&Lower)==Lower),
|
IsLower = ((Mode&Lower)==Lower),
|
||||||
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
|
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
|
||||||
};
|
};
|
||||||
static EIGEN_DONT_INLINE void run(const Lhs& lhs, const Rhs& rhs, Result& res, typename traits<Lhs>::Scalar alpha)
|
static EIGEN_DONT_INLINE void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
|
||||||
|
const RhsScalar* _rhs, Index rhsIncr, const ResScalar* _res, Index resIncr, ResScalar alpha)
|
||||||
{
|
{
|
||||||
|
EIGEN_UNUSED_VARIABLE(resIncr);
|
||||||
|
eigen_assert(resIncr==1);
|
||||||
|
|
||||||
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
|
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
|
||||||
typename conj_expr_if<ConjLhs,Lhs>::type cjLhs(lhs);
|
|
||||||
typename conj_expr_if<ConjRhs,Rhs>::type cjRhs(rhs);
|
|
||||||
|
|
||||||
Index size = lhs.cols();
|
typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > LhsMap;
|
||||||
for (Index pi=0; pi<size; pi+=PanelWidth)
|
const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
|
||||||
|
typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
|
||||||
|
|
||||||
|
typedef Map<Matrix<RhsScalar,Dynamic,1>, 0, InnerStride<> > RhsMap;
|
||||||
|
const RhsMap rhs(_rhs,cols,InnerStride<>(rhsIncr));
|
||||||
|
typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
|
||||||
|
|
||||||
|
typedef Map<Matrix<ResScalar,Dynamic,1> > ResMap;
|
||||||
|
ResMap res(_res,rows);
|
||||||
|
|
||||||
|
for (Index pi=0; pi<cols; pi+=PanelWidth)
|
||||||
{
|
{
|
||||||
Index actualPanelWidth = std::min(PanelWidth, size-pi);
|
Index actualPanelWidth = std::min(PanelWidth, cols-pi);
|
||||||
for (Index k=0; k<actualPanelWidth; ++k)
|
for (Index k=0; k<actualPanelWidth; ++k)
|
||||||
{
|
{
|
||||||
Index i = pi + k;
|
Index i = pi + k;
|
||||||
@ -74,38 +70,50 @@ struct product_triangular_vector_selector<true,Lhs,Rhs,Result,Mode,ConjLhs,ConjR
|
|||||||
if (HasUnitDiag)
|
if (HasUnitDiag)
|
||||||
res.coeffRef(i) += alpha * cjRhs.coeff(i);
|
res.coeffRef(i) += alpha * cjRhs.coeff(i);
|
||||||
}
|
}
|
||||||
Index r = IsLower ? size - pi - actualPanelWidth : pi;
|
Index r = IsLower ? cols - pi - actualPanelWidth : pi;
|
||||||
if (r>0)
|
if (r>0)
|
||||||
{
|
{
|
||||||
Index s = IsLower ? pi+actualPanelWidth : 0;
|
Index s = IsLower ? pi+actualPanelWidth : 0;
|
||||||
general_matrix_vector_product<Index,Scalar,ColMajor,ConjLhs,Scalar,ConjRhs>::run(
|
general_matrix_vector_product<Index,LhsScalar,ColMajor,ConjLhs,RhsScalar,ConjRhs>::run(
|
||||||
r, actualPanelWidth,
|
r, actualPanelWidth,
|
||||||
&(lhs.const_cast_derived().coeffRef(s,pi)), lhs.outerStride(),
|
&lhs.coeff(s,pi), lhsStride,
|
||||||
&rhs.coeff(pi), rhs.innerStride(),
|
&rhs.coeff(pi), rhsIncr,
|
||||||
&res.coeffRef(s), res.innerStride(), alpha);
|
&res.coeffRef(s), resIncr, alpha);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename Lhs, typename Rhs, typename Result, int Mode, bool ConjLhs, bool ConjRhs>
|
template<typename Index, int Mode, typename LhsScalar, bool ConjLhs, typename RhsScalar, bool ConjRhs>
|
||||||
struct product_triangular_vector_selector<true,Lhs,Rhs,Result,Mode,ConjLhs,ConjRhs,RowMajor>
|
struct product_triangular_vector_selector<Index,Mode,LhsScalar,ConjLhs,RhsScalar,ConjRhs,RowMajor>
|
||||||
{
|
{
|
||||||
typedef typename Rhs::Scalar Scalar;
|
typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
|
||||||
typedef typename Rhs::Index Index;
|
|
||||||
enum {
|
enum {
|
||||||
IsLower = ((Mode&Lower)==Lower),
|
IsLower = ((Mode&Lower)==Lower),
|
||||||
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
|
HasUnitDiag = (Mode & UnitDiag)==UnitDiag
|
||||||
};
|
};
|
||||||
static void run(const Lhs& lhs, const Rhs& rhs, Result& res, typename traits<Lhs>::Scalar alpha)
|
static void run(Index rows, Index cols, const LhsScalar* _lhs, Index lhsStride,
|
||||||
|
const RhsScalar* _rhs, Index rhsIncr, const ResScalar* _res, Index resIncr, ResScalar alpha)
|
||||||
{
|
{
|
||||||
|
eigen_assert(rhsIncr==1);
|
||||||
|
EIGEN_UNUSED_VARIABLE(rhsIncr);
|
||||||
|
|
||||||
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
|
static const Index PanelWidth = EIGEN_TUNE_TRIANGULAR_PANEL_WIDTH;
|
||||||
typename conj_expr_if<ConjLhs,Lhs>::type cjLhs(lhs);
|
|
||||||
typename conj_expr_if<ConjRhs,Rhs>::type cjRhs(rhs);
|
typedef Map<Matrix<LhsScalar,Dynamic,Dynamic,RowMajor>, 0, OuterStride<> > LhsMap;
|
||||||
Index size = lhs.cols();
|
const LhsMap lhs(_lhs,rows,cols,OuterStride<>(lhsStride));
|
||||||
for (Index pi=0; pi<size; pi+=PanelWidth)
|
typename conj_expr_if<ConjLhs,LhsMap>::type cjLhs(lhs);
|
||||||
|
|
||||||
|
typedef Map<Matrix<RhsScalar,Dynamic,1> > RhsMap;
|
||||||
|
const RhsMap rhs(_rhs,cols);
|
||||||
|
typename conj_expr_if<ConjRhs,RhsMap>::type cjRhs(rhs);
|
||||||
|
|
||||||
|
typedef Map<Matrix<ResScalar,Dynamic,1>, 0, InnerStride<> > ResMap;
|
||||||
|
ResMap res(_res,rows,InnerStride<>(resIncr));
|
||||||
|
|
||||||
|
for (Index pi=0; pi<cols; pi+=PanelWidth)
|
||||||
{
|
{
|
||||||
Index actualPanelWidth = std::min(PanelWidth, size-pi);
|
Index actualPanelWidth = std::min(PanelWidth, cols-pi);
|
||||||
for (Index k=0; k<actualPanelWidth; ++k)
|
for (Index k=0; k<actualPanelWidth; ++k)
|
||||||
{
|
{
|
||||||
Index i = pi + k;
|
Index i = pi + k;
|
||||||
@ -116,15 +124,15 @@ struct product_triangular_vector_selector<true,Lhs,Rhs,Result,Mode,ConjLhs,ConjR
|
|||||||
if (HasUnitDiag)
|
if (HasUnitDiag)
|
||||||
res.coeffRef(i) += alpha * cjRhs.coeff(i);
|
res.coeffRef(i) += alpha * cjRhs.coeff(i);
|
||||||
}
|
}
|
||||||
Index r = IsLower ? pi : size - pi - actualPanelWidth;
|
Index r = IsLower ? pi : cols - pi - actualPanelWidth;
|
||||||
if (r>0)
|
if (r>0)
|
||||||
{
|
{
|
||||||
Index s = IsLower ? 0 : pi + actualPanelWidth;
|
Index s = IsLower ? 0 : pi + actualPanelWidth;
|
||||||
general_matrix_vector_product<Index,Scalar,RowMajor,ConjLhs,Scalar,ConjRhs>::run(
|
general_matrix_vector_product<Index,LhsScalar,RowMajor,ConjLhs,RhsScalar,ConjRhs>::run(
|
||||||
actualPanelWidth, r,
|
actualPanelWidth, r,
|
||||||
&(lhs.const_cast_derived().coeffRef(pi,s)), lhs.outerStride(),
|
&(lhs.coeff(pi,s)), lhsStride,
|
||||||
&(rhs.const_cast_derived().coeffRef(s)), 1,
|
&(rhs.coeff(s)), rhsIncr,
|
||||||
&res.coeffRef(pi,0), res.innerStride(), alpha);
|
&res.coeffRef(pi), resIncr, alpha);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -165,12 +173,11 @@ struct TriangularProduct<Mode,true,Lhs,false,Rhs,true>
|
|||||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||||
|
|
||||||
internal::product_triangular_vector_selector
|
internal::product_triangular_vector_selector
|
||||||
<true,_ActualLhsType,_ActualRhsType,Dest,
|
<Index,Mode,
|
||||||
Mode,
|
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
|
||||||
LhsBlasTraits::NeedToConjugate,
|
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
|
||||||
RhsBlasTraits::NeedToConjugate,
|
|
||||||
(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>
|
(int(internal::traits<Lhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>
|
||||||
::run(lhs,rhs,dst,actualAlpha);
|
::run(lhs.rows(),lhs.cols(),lhs.data(),lhs.outerStride(),rhs.data(),rhs.innerStride(),dst.data(),dst.innerStride(),actualAlpha);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -194,12 +201,12 @@ struct TriangularProduct<Mode,false,Lhs,true,Rhs,false>
|
|||||||
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
* RhsBlasTraits::extractScalarFactor(m_rhs);
|
||||||
|
|
||||||
internal::product_triangular_vector_selector
|
internal::product_triangular_vector_selector
|
||||||
<false,_ActualLhsType,_ActualRhsType,Dest,
|
<Index,(Mode & UnitDiag) | (Mode & Lower) ? Upper : Lower,
|
||||||
Mode,
|
typename _ActualRhsType::Scalar, RhsBlasTraits::NeedToConjugate,
|
||||||
LhsBlasTraits::NeedToConjugate,
|
typename _ActualLhsType::Scalar, LhsBlasTraits::NeedToConjugate,
|
||||||
RhsBlasTraits::NeedToConjugate,
|
(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? ColMajor : RowMajor>
|
||||||
(int(internal::traits<Rhs>::Flags)&RowMajorBit) ? RowMajor : ColMajor>
|
::run(rhs.rows(),rhs.cols(),rhs.data(),rhs.outerStride(),lhs.data(),lhs.innerStride(),
|
||||||
::run(lhs,rhs,dst,actualAlpha);
|
dst.data(),dst.innerStride(),actualAlpha);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user