Implement compound assignments using evaluator of SelfCwiseBinaryOp.

This commit is contained in:
Jitse Niesen 2011-04-28 16:57:35 +01:00
parent 3b60d2dbc4
commit 06fb7cf470
4 changed files with 155 additions and 1 deletions

View File

@ -624,6 +624,59 @@ void swap_using_evaluator(const DstXprType& dst, const SrcXprType& src)
copy_using_evaluator(SwapWrapper<DstXprType>(const_cast<DstXprType&>(dst)), src);
}
// Based on MatrixBase::operator+= (in CwiseBinaryOp.h)
template<typename DstXprType, typename SrcXprType>
void add_assign_using_evaluator(const MatrixBase<DstXprType>& dst, const MatrixBase<SrcXprType>& src)
{
typedef typename DstXprType::Scalar Scalar;
SelfCwiseBinaryOp<internal::scalar_sum_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
copy_using_evaluator(tmp, src.derived());
}
// Based on ArrayBase::operator+=
template<typename DstXprType, typename SrcXprType>
void add_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src)
{
typedef typename DstXprType::Scalar Scalar;
SelfCwiseBinaryOp<internal::scalar_sum_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
copy_using_evaluator(tmp, src.derived());
}
// TODO: Add add_assign_using_evaluator for EigenBase ?
template<typename DstXprType, typename SrcXprType>
void subtract_assign_using_evaluator(const MatrixBase<DstXprType>& dst, const MatrixBase<SrcXprType>& src)
{
typedef typename DstXprType::Scalar Scalar;
SelfCwiseBinaryOp<internal::scalar_difference_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
copy_using_evaluator(tmp, src.derived());
}
template<typename DstXprType, typename SrcXprType>
void subtract_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src)
{
typedef typename DstXprType::Scalar Scalar;
SelfCwiseBinaryOp<internal::scalar_difference_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
copy_using_evaluator(tmp, src.derived());
}
template<typename DstXprType, typename SrcXprType>
void multiply_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src)
{
typedef typename DstXprType::Scalar Scalar;
SelfCwiseBinaryOp<internal::scalar_product_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
copy_using_evaluator(tmp, src.derived());
}
template<typename DstXprType, typename SrcXprType>
void divide_assign_using_evaluator(const ArrayBase<DstXprType>& dst, const ArrayBase<SrcXprType>& src)
{
typedef typename DstXprType::Scalar Scalar;
SelfCwiseBinaryOp<internal::scalar_quotient_op<Scalar>, DstXprType, SrcXprType> tmp(dst.const_cast_derived());
copy_using_evaluator(tmp, src.derived());
}
} // namespace internal
#endif // EIGEN_ASSIGN_EVALUATOR_H

View File

@ -1032,6 +1032,8 @@ struct evaluator_impl<SwapWrapper<ArgType> >
typedef typename XprType::Scalar Scalar;
typedef typename XprType::Packet Packet;
// This function and the next one are needed by assign to correctly align loads/stores
// TODO make Assign use .data()
Scalar& coeffRef(Index row, Index col)
{
return m_argImpl.coeffRef(row, col);
@ -1085,6 +1087,71 @@ protected:
};
// ---------- SelfCwiseBinaryOp ----------
template<typename BinaryOp, typename LhsXpr, typename RhsXpr>
struct evaluator_impl<SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> >
: evaluator_impl_base<SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> >
{
typedef SelfCwiseBinaryOp<BinaryOp, LhsXpr, RhsXpr> XprType;
evaluator_impl(const XprType& selfCwiseBinaryOp)
: m_argImpl(selfCwiseBinaryOp.expression()),
m_functor(selfCwiseBinaryOp.functor())
{ }
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::Packet Packet;
// This function and the next one are needed by assign to correctly align loads/stores
// TODO make Assign use .data()
Scalar& coeffRef(Index row, Index col)
{
return m_argImpl.coeffRef(row, col);
}
inline Scalar& coeffRef(Index index)
{
return m_argImpl.coeffRef(index);
}
template<typename OtherEvaluatorType>
void copyCoeff(Index row, Index col, const OtherEvaluatorType& other)
{
Scalar& tmp = m_argImpl.coeffRef(row, col);
tmp = m_functor(tmp, other.coeff(row, col));
}
template<typename OtherEvaluatorType>
void copyCoeff(Index index, const OtherEvaluatorType& other)
{
Scalar& tmp = m_argImpl.coeffRef(index);
tmp = m_functor(tmp, other.coeff(index));
}
template<int StoreMode, int LoadMode, typename OtherEvaluatorType>
void copyPacket(Index row, Index col, const OtherEvaluatorType& other)
{
const Packet res = m_functor.packetOp(m_argImpl.template packet<StoreMode>(row, col),
other.template packet<LoadMode>(row, col));
m_argImpl.template writePacket<StoreMode>(row, col, res);
}
template<int StoreMode, int LoadMode, typename OtherEvaluatorType>
void copyPacket(Index index, const OtherEvaluatorType& other)
{
const Packet res = m_functor.packetOp(m_argImpl.template packet<StoreMode>(index),
other.template packet<LoadMode>(index));
m_argImpl.template writePacket<StoreMode>(index, res);
}
protected:
typename evaluator<LhsXpr>::type m_argImpl;
const BinaryOp& m_functor;
};
} // namespace internal
#endif // EIGEN_COREEVALUATORS_H

View File

@ -163,6 +163,16 @@ template<typename BinaryOp, typename Lhs, typename Rhs> class SelfCwiseBinaryOp
return Base::operator=(rhs);
}
Lhs& expression() const
{
return m_matrix;
}
const BinaryOp& functor() const
{
return m_functor;
}
protected:
Lhs& m_matrix;
const BinaryOp& m_functor;

View File

@ -236,5 +236,29 @@ void test_evaluators()
VERIFY_IS_APPROX(mat1, mat1ref);
VERIFY_IS_APPROX(mat2, mat2ref);
}
{
// test compound assignment
const Matrix4d mat_const = Matrix4d::Random();
Matrix4d mat, mat_ref;
mat = mat_ref = Matrix4d::Identity();
add_assign_using_evaluator(mat, mat_const);
mat_ref += mat_const;
VERIFY_IS_APPROX(mat, mat_ref);
subtract_assign_using_evaluator(mat.row(1), 2*mat.row(2));
mat_ref.row(1) -= 2*mat_ref.row(2);
VERIFY_IS_APPROX(mat, mat_ref);
const ArrayXXf arr_const = ArrayXXf::Random(5,3);
ArrayXXf arr, arr_ref;
arr = arr_ref = ArrayXXf::Constant(5, 3, 0.5);
multiply_assign_using_evaluator(arr, arr_const);
arr_ref *= arr_const;
VERIFY_IS_APPROX(arr, arr_ref);
divide_assign_using_evaluator(arr.row(1), arr.row(2) + 1);
arr_ref.row(1) /= (arr_ref.row(2) + 1);
VERIFY_IS_APPROX(arr, arr_ref);
}
}