mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Implement compound assignments using evaluator of SelfCwiseBinaryOp.
This commit is contained in:
parent
3b60d2dbc4
commit
06fb7cf470
@ -624,6 +624,59 @@ void swap_using_evaluator(const DstXprType& dst, const SrcXprType& src)
|
|||||||
copy_using_evaluator(SwapWrapper<DstXprType>(const_cast<DstXprType&>(dst)), src);
|
copy_using_evaluator(SwapWrapper<DstXprType>(const_cast<DstXprType&>(dst)), src);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Based on MatrixBase::operator+= (in CwiseBinaryOp.h)
|
||||||
|
template<typename DstXprType, typename SrcXprType>
|
||||||
|
void add_assign_using_evaluator(const MatrixBase<DstXprType>& dst, const MatrixBase<SrcXprType>& src)
|
||||||
|
{
|
||||||
|
typedef typename DstXprType::Scalar Scalar;
|
||||||
|
SelfCwiseBinaryOp<internal::scalar_sum_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
|
||||||
|
copy_using_evaluator(tmp, src.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Based on ArrayBase::operator+=
|
||||||
|
template<typename DstXprType, typename SrcXprType>
|
||||||
|
void add_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src)
|
||||||
|
{
|
||||||
|
typedef typename DstXprType::Scalar Scalar;
|
||||||
|
SelfCwiseBinaryOp<internal::scalar_sum_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
|
||||||
|
copy_using_evaluator(tmp, src.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Add add_assign_using_evaluator for EigenBase ?
|
||||||
|
|
||||||
|
template<typename DstXprType, typename SrcXprType>
|
||||||
|
void subtract_assign_using_evaluator(const MatrixBase<DstXprType>& dst, const MatrixBase<SrcXprType>& src)
|
||||||
|
{
|
||||||
|
typedef typename DstXprType::Scalar Scalar;
|
||||||
|
SelfCwiseBinaryOp<internal::scalar_difference_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
|
||||||
|
copy_using_evaluator(tmp, src.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename DstXprType, typename SrcXprType>
|
||||||
|
void subtract_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src)
|
||||||
|
{
|
||||||
|
typedef typename DstXprType::Scalar Scalar;
|
||||||
|
SelfCwiseBinaryOp<internal::scalar_difference_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
|
||||||
|
copy_using_evaluator(tmp, src.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename DstXprType, typename SrcXprType>
|
||||||
|
void multiply_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src)
|
||||||
|
{
|
||||||
|
typedef typename DstXprType::Scalar Scalar;
|
||||||
|
SelfCwiseBinaryOp<internal::scalar_product_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
|
||||||
|
copy_using_evaluator(tmp, src.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename DstXprType, typename SrcXprType>
|
||||||
|
void divide_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src)
|
||||||
|
{
|
||||||
|
typedef typename DstXprType::Scalar Scalar;
|
||||||
|
SelfCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
|
||||||
|
copy_using_evaluator(tmp, src.derived());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
||||||
#endif // EIGEN_ASSIGN_EVALUATOR_H
|
#endif // EIGEN_ASSIGN_EVALUATOR_H
|
||||||
|
@ -1032,6 +1032,8 @@ struct evaluator_impl<SwapWrapper<ArgType> >
|
|||||||
typedef typename XprType::Scalar Scalar;
|
typedef typename XprType::Scalar Scalar;
|
||||||
typedef typename XprType::Packet Packet;
|
typedef typename XprType::Packet Packet;
|
||||||
|
|
||||||
|
// This function and the next one are needed by assign to correctly align loads/stores
|
||||||
|
// TODO make Assign use .data()
|
||||||
Scalar& coeffRef(Index row, Index col)
|
Scalar& coeffRef(Index row, Index col)
|
||||||
{
|
{
|
||||||
return m_argImpl.coeffRef(row, col);
|
return m_argImpl.coeffRef(row, col);
|
||||||
@ -1085,6 +1087,71 @@ protected:
|
|||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// ---------- SelfCwiseBinaryOp ----------
|
||||||
|
|
||||||
|
template<typename BinaryOp, typename LhsXpr, typename RhsXpr>
|
||||||
|
struct evaluator_impl<SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> >
|
||||||
|
: evaluator_impl_base<SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> >
|
||||||
|
{
|
||||||
|
typedef SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> XprType;
|
||||||
|
|
||||||
|
evaluator_impl(const XprType& selfCwiseBinaryOp)
|
||||||
|
: m_argImpl(selfCwiseBinaryOp.expression()),
|
||||||
|
m_functor(selfCwiseBinaryOp.functor())
|
||||||
|
{ }
|
||||||
|
|
||||||
|
typedef typename XprType::Index Index;
|
||||||
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
typedef typename XprType::Packet Packet;
|
||||||
|
|
||||||
|
// This function and the next one are needed by assign to correctly align loads/stores
|
||||||
|
// TODO make Assign use .data()
|
||||||
|
Scalar& coeffRef(Index row, Index col)
|
||||||
|
{
|
||||||
|
return m_argImpl.coeffRef(row, col);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline Scalar& coeffRef(Index index)
|
||||||
|
{
|
||||||
|
return m_argImpl.coeffRef(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherEvaluatorType>
|
||||||
|
void copyCoeff(Index row, Index col, const OtherEvaluatorType& other)
|
||||||
|
{
|
||||||
|
Scalar& tmp = m_argImpl.coeffRef(row, col);
|
||||||
|
tmp = m_functor(tmp, other.coeff(row, col));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename OtherEvaluatorType>
|
||||||
|
void copyCoeff(Index index, const OtherEvaluatorType& other)
|
||||||
|
{
|
||||||
|
Scalar& tmp = m_argImpl.coeffRef(index);
|
||||||
|
tmp = m_functor(tmp, other.coeff(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int StoreMode, int LoadMode, typename OtherEvaluatorType>
|
||||||
|
void copyPacket(Index row, Index col, const OtherEvaluatorType& other)
|
||||||
|
{
|
||||||
|
const Packet res = m_functor.packetOp(m_argImpl.template packet<StoreMode>(row, col),
|
||||||
|
other.template packet<LoadMode>(row, col));
|
||||||
|
m_argImpl.template writePacket<StoreMode>(row, col, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int StoreMode, int LoadMode, typename OtherEvaluatorType>
|
||||||
|
void copyPacket(Index index, const OtherEvaluatorType& other)
|
||||||
|
{
|
||||||
|
const Packet res = m_functor.packetOp(m_argImpl.template packet<StoreMode>(index),
|
||||||
|
other.template packet<LoadMode>(index));
|
||||||
|
m_argImpl.template writePacket<StoreMode>(index, res);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
typename evaluator<LhsXpr>::type m_argImpl;
|
||||||
|
const BinaryOp& m_functor;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
|
|
||||||
#endif // EIGEN_COREEVALUATORS_H
|
#endif // EIGEN_COREEVALUATORS_H
|
||||||
|
@ -163,6 +163,16 @@ template<typename BinaryOp, typename Lhs, typename Rhs> class SelfCwiseBinaryOp
|
|||||||
return Base::operator=(rhs);
|
return Base::operator=(rhs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Lhs& expression() const
|
||||||
|
{
|
||||||
|
return m_matrix;
|
||||||
|
}
|
||||||
|
|
||||||
|
const BinaryOp& functor() const
|
||||||
|
{
|
||||||
|
return m_functor;
|
||||||
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
Lhs& m_matrix;
|
Lhs& m_matrix;
|
||||||
const BinaryOp& m_functor;
|
const BinaryOp& m_functor;
|
||||||
|
@ -237,4 +237,28 @@ void test_evaluators()
|
|||||||
VERIFY_IS_APPROX(mat2, mat2ref);
|
VERIFY_IS_APPROX(mat2, mat2ref);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// test compound assignment
|
||||||
|
const Matrix4d mat_const = Matrix4d::Random();
|
||||||
|
Matrix4d mat, mat_ref;
|
||||||
|
mat = mat_ref = Matrix4d::Identity();
|
||||||
|
add_assign_using_evaluator(mat, mat_const);
|
||||||
|
mat_ref += mat_const;
|
||||||
|
VERIFY_IS_APPROX(mat, mat_ref);
|
||||||
|
|
||||||
|
subtract_assign_using_evaluator(mat.row(1), 2*mat.row(2));
|
||||||
|
mat_ref.row(1) -= 2*mat_ref.row(2);
|
||||||
|
VERIFY_IS_APPROX(mat, mat_ref);
|
||||||
|
|
||||||
|
const ArrayXXf arr_const = ArrayXXf::Random(5,3);
|
||||||
|
ArrayXXf arr, arr_ref;
|
||||||
|
arr = arr_ref = ArrayXXf::Constant(5, 3, 0.5);
|
||||||
|
multiply_assign_using_evaluator(arr, arr_const);
|
||||||
|
arr_ref *= arr_const;
|
||||||
|
VERIFY_IS_APPROX(arr, arr_ref);
|
||||||
|
|
||||||
|
divide_assign_using_evaluator(arr.row(1), arr.row(2) + 1);
|
||||||
|
arr_ref.row(1) /= (arr_ref.row(2) + 1);
|
||||||
|
VERIFY_IS_APPROX(arr, arr_ref);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user