mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-12 17:33:15 +08:00
Fix bug #483: optimize outer-products to skip setZero and a scalar multiple when not needed.
This commit is contained in:
parent
96ad13abba
commit
5a0c5c0393
@ -244,35 +244,61 @@ class GeneralProduct<Lhs, Rhs, OuterProduct>
|
|||||||
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
|
YOU_MIXED_DIFFERENT_NUMERIC_TYPES__YOU_NEED_TO_USE_THE_CAST_METHOD_OF_MATRIXBASE_TO_CAST_NUMERIC_TYPES_EXPLICITLY)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct set { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() = src; } };
|
||||||
|
struct add { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() += src; } };
|
||||||
|
struct sub { template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const { dst.const_cast_derived() -= src; } };
|
||||||
|
struct adds {
|
||||||
|
Scalar m_scale;
|
||||||
|
adds(const Scalar& s) : m_scale(s) {}
|
||||||
|
template<typename Dst, typename Src> void operator()(const Dst& dst, const Src& src) const {
|
||||||
|
dst.const_cast_derived() += m_scale * src;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Dest>
|
||||||
|
inline void evalTo(Dest& dest) const {
|
||||||
|
internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, set());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Dest>
|
||||||
|
inline void addTo(Dest& dest) const {
|
||||||
|
internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, add());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename Dest>
|
||||||
|
inline void subTo(Dest& dest) const {
|
||||||
|
internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, sub());
|
||||||
|
}
|
||||||
|
|
||||||
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
|
template<typename Dest> void scaleAndAddTo(Dest& dest, Scalar alpha) const
|
||||||
{
|
{
|
||||||
internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, alpha);
|
internal::outer_product_selector<(int(Dest::Flags)&RowMajorBit) ? RowMajor : ColMajor>::run(*this, dest, adds(alpha));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
template<> struct outer_product_selector<ColMajor> {
|
template<> struct outer_product_selector<ColMajor> {
|
||||||
template<typename ProductType, typename Dest>
|
template<typename ProductType, typename Dest, typename Func>
|
||||||
static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) {
|
static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, const Func& func) {
|
||||||
typedef typename Dest::Index Index;
|
typedef typename Dest::Index Index;
|
||||||
// FIXME make sure lhs is sequentially stored
|
// FIXME make sure lhs is sequentially stored
|
||||||
// FIXME not very good if rhs is real and lhs complex while alpha is real too
|
// FIXME not very good if rhs is real and lhs complex while alpha is real too
|
||||||
const Index cols = dest.cols();
|
const Index cols = dest.cols();
|
||||||
for (Index j=0; j<cols; ++j)
|
for (Index j=0; j<cols; ++j)
|
||||||
dest.col(j) += (alpha * prod.rhs().coeff(j)) * prod.lhs();
|
func(dest.col(j), prod.rhs().coeff(j) * prod.lhs());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template<> struct outer_product_selector<RowMajor> {
|
template<> struct outer_product_selector<RowMajor> {
|
||||||
template<typename ProductType, typename Dest>
|
template<typename ProductType, typename Dest, typename Func>
|
||||||
static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, typename ProductType::Scalar alpha) {
|
static EIGEN_DONT_INLINE void run(const ProductType& prod, Dest& dest, const Func& func) {
|
||||||
typedef typename Dest::Index Index;
|
typedef typename Dest::Index Index;
|
||||||
// FIXME make sure rhs is sequentially stored
|
// FIXME make sure rhs is sequentially stored
|
||||||
// FIXME not very good if lhs is real and rhs complex while alpha is real too
|
// FIXME not very good if lhs is real and rhs complex while alpha is real too
|
||||||
const Index rows = dest.rows();
|
const Index rows = dest.rows();
|
||||||
for (Index i=0; i<rows; ++i)
|
for (Index i=0; i<rows; ++i)
|
||||||
dest.row(i) += (alpha * prod.lhs().coeff(i)) * prod.rhs();
|
func(dest.row(i), prod.lhs().coeff(i) * prod.rhs());
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user