mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
fix dynamic allocation for fixed size objects in matrix-vector product
This commit is contained in:
parent
5ca407de54
commit
f46ace61d3
@ -358,10 +358,31 @@ struct gemv_selector<OnTheLeft,StorageOrder,BlasCompatible>
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename Scalar,int Size,int MaxSize,bool Cond> struct gemv_static_vector_if;
|
||||||
|
|
||||||
|
template<typename Scalar,int Size,int MaxSize>
|
||||||
|
struct gemv_static_vector_if<Scalar,Size,MaxSize,false>
|
||||||
|
{
|
||||||
|
EIGEN_STRONG_INLINE Scalar* data() { eigen_internal_assert(false && "should never be called"); return 0; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Scalar,int Size>
|
||||||
|
struct gemv_static_vector_if<Scalar,Size,Dynamic,true>
|
||||||
|
{
|
||||||
|
EIGEN_STRONG_INLINE Scalar* data() { return 0; }
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Scalar,int Size,int MaxSize>
|
||||||
|
struct gemv_static_vector_if<Scalar,Size,MaxSize,true>
|
||||||
|
{
|
||||||
|
Scalar m_data[EIGEN_SIZE_MIN_PREFER_FIXED(Size,MaxSize)];
|
||||||
|
EIGEN_STRONG_INLINE Scalar* data() { return m_data; }
|
||||||
|
};
|
||||||
|
|
||||||
template<> struct gemv_selector<OnTheRight,ColMajor,true>
|
template<> struct gemv_selector<OnTheRight,ColMajor,true>
|
||||||
{
|
{
|
||||||
template<typename ProductType, typename Dest>
|
template<typename ProductType, typename Dest>
|
||||||
static void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha)
|
static inline void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha)
|
||||||
{
|
{
|
||||||
typedef typename ProductType::Index Index;
|
typedef typename ProductType::Index Index;
|
||||||
typedef typename ProductType::LhsScalar LhsScalar;
|
typedef typename ProductType::LhsScalar LhsScalar;
|
||||||
@ -382,30 +403,43 @@ template<> struct gemv_selector<OnTheRight,ColMajor,true>
|
|||||||
|
|
||||||
enum {
|
enum {
|
||||||
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
|
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
|
||||||
|
// on, the other hand it is good for the cache to pack the vector anyways...
|
||||||
EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
|
EvalToDestAtCompileTime = Dest::InnerStrideAtCompileTime==1,
|
||||||
ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex)
|
ComplexByReal = (NumTraits<LhsScalar>::IsComplex) && (!NumTraits<RhsScalar>::IsComplex),
|
||||||
|
MightCannotUseDest = (Dest::InnerStrideAtCompileTime!=1) || ComplexByReal
|
||||||
};
|
};
|
||||||
|
|
||||||
|
gemv_static_vector_if<ResScalar,Dest::SizeAtCompileTime,Dest::MaxSizeAtCompileTime,MightCannotUseDest> static_dest;
|
||||||
|
|
||||||
bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
|
bool alphaIsCompatible = (!ComplexByReal) || (imag(actualAlpha)==RealScalar(0));
|
||||||
bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
|
bool evalToDest = EvalToDestAtCompileTime && alphaIsCompatible;
|
||||||
|
|
||||||
RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
|
RhsScalar compatibleAlpha = get_factor<ResScalar,RhsScalar>::run(actualAlpha);
|
||||||
|
|
||||||
ResScalar* actualDest;
|
ResScalar* actualDestPtr;
|
||||||
|
bool freeDestPtr = false;
|
||||||
if (evalToDest)
|
if (evalToDest)
|
||||||
{
|
{
|
||||||
actualDest = &dest.coeffRef(0);
|
actualDestPtr = &dest.coeffRef(0);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
actualDest = ei_aligned_stack_new(ResScalar,dest.size());
|
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
int size = dest.size();
|
||||||
|
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
#endif
|
||||||
|
if((actualDestPtr = static_dest.data())==0)
|
||||||
|
{
|
||||||
|
freeDestPtr = true;
|
||||||
|
actualDestPtr = ei_aligned_stack_new(ResScalar,dest.size());
|
||||||
|
}
|
||||||
if(!alphaIsCompatible)
|
if(!alphaIsCompatible)
|
||||||
{
|
{
|
||||||
MappedDest(actualDest, dest.size()).setZero();
|
MappedDest(actualDestPtr, dest.size()).setZero();
|
||||||
compatibleAlpha = RhsScalar(1);
|
compatibleAlpha = RhsScalar(1);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
MappedDest(actualDest, dest.size()) = dest;
|
MappedDest(actualDestPtr, dest.size()) = dest;
|
||||||
}
|
}
|
||||||
|
|
||||||
general_matrix_vector_product
|
general_matrix_vector_product
|
||||||
@ -413,16 +447,16 @@ template<> struct gemv_selector<OnTheRight,ColMajor,true>
|
|||||||
actualLhs.rows(), actualLhs.cols(),
|
actualLhs.rows(), actualLhs.cols(),
|
||||||
&actualLhs.coeffRef(0,0), actualLhs.outerStride(),
|
&actualLhs.coeffRef(0,0), actualLhs.outerStride(),
|
||||||
actualRhs.data(), actualRhs.innerStride(),
|
actualRhs.data(), actualRhs.innerStride(),
|
||||||
actualDest, 1,
|
actualDestPtr, 1,
|
||||||
compatibleAlpha);
|
compatibleAlpha);
|
||||||
|
|
||||||
if (!evalToDest)
|
if (!evalToDest)
|
||||||
{
|
{
|
||||||
if(!alphaIsCompatible)
|
if(!alphaIsCompatible)
|
||||||
dest += actualAlpha * MappedDest(actualDest, dest.size());
|
dest += actualAlpha * MappedDest(actualDestPtr, dest.size());
|
||||||
else
|
else
|
||||||
dest = MappedDest(actualDest, dest.size());
|
dest = MappedDest(actualDestPtr, dest.size());
|
||||||
ei_aligned_stack_delete(ResScalar, actualDest, dest.size());
|
if(freeDestPtr) ei_aligned_stack_delete(ResScalar, actualDestPtr, dest.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -455,24 +489,37 @@ template<> struct gemv_selector<OnTheRight,RowMajor,true>
|
|||||||
&& (!(_ActualRhsType::Flags & RowMajorBit))
|
&& (!(_ActualRhsType::Flags & RowMajorBit))
|
||||||
};
|
};
|
||||||
|
|
||||||
RhsScalar* rhs_data;
|
gemv_static_vector_if<RhsScalar,_ActualRhsType::SizeAtCompileTime,_ActualRhsType::MaxSizeAtCompileTime,!DirectlyUseRhs> static_rhs;
|
||||||
|
|
||||||
|
RhsScalar* actualRhsPtr;
|
||||||
|
bool freeRhsPtr = false;
|
||||||
if (DirectlyUseRhs)
|
if (DirectlyUseRhs)
|
||||||
rhs_data = const_cast<RhsScalar*>(&actualRhs.coeffRef(0));
|
{
|
||||||
|
actualRhsPtr = const_cast<RhsScalar*>(&actualRhs.coeffRef(0));
|
||||||
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
rhs_data = ei_aligned_stack_new(RhsScalar, actualRhs.size());
|
#ifdef EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
Map<typename _ActualRhsType::PlainObject>(rhs_data, actualRhs.size()) = actualRhs;
|
int size = actualRhs.size();
|
||||||
|
EIGEN_DENSE_STORAGE_CTOR_PLUGIN
|
||||||
|
#endif
|
||||||
|
if((actualRhsPtr = static_rhs.data())==0)
|
||||||
|
{
|
||||||
|
freeRhsPtr = true;
|
||||||
|
actualRhsPtr = ei_aligned_stack_new(RhsScalar, actualRhs.size());
|
||||||
|
}
|
||||||
|
Map<typename _ActualRhsType::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
|
||||||
}
|
}
|
||||||
|
|
||||||
general_matrix_vector_product
|
general_matrix_vector_product
|
||||||
<Index,LhsScalar,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
|
<Index,LhsScalar,RowMajor,LhsBlasTraits::NeedToConjugate,RhsScalar,RhsBlasTraits::NeedToConjugate>::run(
|
||||||
actualLhs.rows(), actualLhs.cols(),
|
actualLhs.rows(), actualLhs.cols(),
|
||||||
&actualLhs.coeffRef(0,0), actualLhs.outerStride(),
|
&actualLhs.coeffRef(0,0), actualLhs.outerStride(),
|
||||||
rhs_data, 1,
|
actualRhsPtr, 1,
|
||||||
&dest.coeffRef(0,0), dest.innerStride(),
|
&dest.coeffRef(0,0), dest.innerStride(),
|
||||||
actualAlpha);
|
actualAlpha);
|
||||||
|
|
||||||
if (!DirectlyUseRhs) ei_aligned_stack_delete(RhsScalar, rhs_data, prod.rhs().size());
|
if((!DirectlyUseRhs) && freeRhsPtr) ei_aligned_stack_delete(RhsScalar, actualRhsPtr, prod.rhs().size());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -71,6 +71,16 @@ template<typename MatrixType> void nomalloc(const MatrixType& m)
|
|||||||
VERIFY_IS_APPROX((m1+m2)(r,c), (m1(r,c))+(m2(r,c)));
|
VERIFY_IS_APPROX((m1+m2)(r,c), (m1(r,c))+(m2(r,c)));
|
||||||
VERIFY_IS_APPROX(m1.cwiseProduct(m1.block(0,0,rows,cols)), (m1.array()*m1.array()).matrix());
|
VERIFY_IS_APPROX(m1.cwiseProduct(m1.block(0,0,rows,cols)), (m1.array()*m1.array()).matrix());
|
||||||
VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2));
|
VERIFY_IS_APPROX((m1*m1.transpose())*m2, m1*(m1.transpose()*m2));
|
||||||
|
|
||||||
|
m2.col(0).noalias() = m1 * m1.col(0);
|
||||||
|
m2.col(0).noalias() -= m1.adjoint() * m1.col(0);
|
||||||
|
m2.col(0).noalias() -= m1 * m1.row(0).adjoint();
|
||||||
|
m2.col(0).noalias() -= m1.adjoint() * m1.row(0).adjoint();
|
||||||
|
|
||||||
|
m2.row(0).noalias() = m1.row(0) * m1;
|
||||||
|
m2.row(0).noalias() -= m1.row(0) * m1.adjoint();
|
||||||
|
m2.row(0).noalias() -= m1.col(0).adjoint() * m1;
|
||||||
|
m2.row(0).noalias() -= m1.col(0).adjoint() * m1.adjoint();
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Scalar>
|
template<typename Scalar>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user