Implement swap using evaluators.

This commit is contained in:
Jitse Niesen 2011-04-28 15:52:15 +01:00
parent 2d11041e24
commit 3b60d2dbc4
4 changed files with 107 additions and 5 deletions

View File

@ -404,7 +404,7 @@ struct copy_using_evaluator_impl<DstXprType, SrcXprType, LinearVectorizedTravers
dstAlignment = PacketTraits::AlignedOnScalar ? Aligned : dstIsAligned,
srcAlignment = copy_using_evaluator_traits<DstXprType,SrcXprType>::JointAlignment
};
const Index alignedStart = dstIsAligned ? 0 : first_aligned(&dst.coeffRef(0), size);
const Index alignedStart = dstIsAligned ? 0 : first_aligned(&dstEvaluator.coeffRef(0), size);
const Index alignedEnd = alignedStart + ((size-alignedStart)/packetSize)*packetSize;
unaligned_copy_using_evaluator_impl<dstIsAligned!=0>::run(dstEvaluator, srcEvaluator, 0, alignedStart);
@ -614,6 +614,16 @@ const DstXprType& copy_using_evaluator(const DstXprType& dst, const SrcXprType&
return dst;
}
// Based on DenseBase::swap()
// TODO: Chech whether we need to do something special for swapping two
// Arrays or Matrices.
template<typename DstXprType, typename SrcXprType>
void swap_using_evaluator(const DstXprType& dst, const SrcXprType& src)
{
copy_using_evaluator(SwapWrapper<DstXprType>(const_cast<DstXprType&>(dst)), src);
}
} // namespace internal
#endif // EIGEN_ASSIGN_EVALUATOR_H

View File

@ -65,7 +65,7 @@ struct evaluator_impl_base
{
Index row = rowIndexByOuterInner(outer, inner);
Index col = colIndexByOuterInner(outer, inner);
derived().coeffRef(row, col) = other.coeff(row, col);
derived().copyCoeff(row, col, other);
}
template<typename OtherEvaluatorType>
@ -86,8 +86,7 @@ struct evaluator_impl_base
{
Index row = rowIndexByOuterInner(outer, inner);
Index col = colIndexByOuterInner(outer, inner);
derived().template writePacket<StoreMode>(row, col,
other.template packet<LoadMode>(row, col));
derived().template copyPacket<StoreMode, LoadMode>(row, col, other);
}
template<int StoreMode, int LoadMode, typename OtherEvaluatorType>
@ -1017,6 +1016,75 @@ private:
};
// ---------- SwapWrapper ----------
template<typename ArgType>
struct evaluator_impl<SwapWrapper<ArgType> >
: evaluator_impl_base<SwapWrapper<ArgType> >
{
typedef SwapWrapper<ArgType> XprType;
evaluator_impl(const XprType& swapWrapper)
: m_argImpl(swapWrapper.expression())
{ }
typedef typename XprType::Index Index;
typedef typename XprType::Scalar Scalar;
typedef typename XprType::Packet Packet;
Scalar& coeffRef(Index row, Index col)
{
return m_argImpl.coeffRef(row, col);
}
inline Scalar& coeffRef(Index index)
{
return m_argImpl.coeffRef(index);
}
template<typename OtherEvaluatorType>
void copyCoeff(Index row, Index col, const OtherEvaluatorType& other)
{
OtherEvaluatorType& nonconst_other = const_cast<OtherEvaluatorType&>(other);
Scalar tmp = m_argImpl.coeff(row, col);
m_argImpl.coeffRef(row, col) = nonconst_other.coeff(row, col);
nonconst_other.coeffRef(row, col) = tmp;
}
template<typename OtherEvaluatorType>
void copyCoeff(Index index, const OtherEvaluatorType& other)
{
OtherEvaluatorType& nonconst_other = const_cast<OtherEvaluatorType&>(other);
Scalar tmp = m_argImpl.coeff(index);
m_argImpl.coeffRef(index) = nonconst_other.coeff(index);
nonconst_other.coeffRef(index) = tmp;
}
template<int StoreMode, int LoadMode, typename OtherEvaluatorType>
void copyPacket(Index row, Index col, const OtherEvaluatorType& other)
{
OtherEvaluatorType& nonconst_other = const_cast<OtherEvaluatorType&>(other);
Packet tmp = m_argImpl.template packet<StoreMode>(row, col);
m_argImpl.template writePacket<StoreMode>
(row, col, nonconst_other.template packet<LoadMode>(row, col));
nonconst_other.template writePacket<LoadMode>(row, col, tmp);
}
template<int StoreMode, int LoadMode, typename OtherEvaluatorType>
void copyPacket(Index index, const OtherEvaluatorType& other)
{
OtherEvaluatorType& nonconst_other = const_cast<OtherEvaluatorType&>(other);
Packet tmp = m_argImpl.template packet<StoreMode>(index);
m_argImpl.template writePacket<StoreMode>
(index, nonconst_other.template packet<LoadMode>(index));
nonconst_other.template writePacket<LoadMode>(index, tmp);
}
protected:
typename evaluator<ArgType>::type m_argImpl;
};
} // namespace internal
#endif // EIGEN_COREEVALUATORS_H

View File

@ -119,6 +119,8 @@ template<typename ExpressionType> class SwapWrapper
_other.template writePacket<LoadMode>(index, tmp);
}
ExpressionType& expression() const { return m_expression; }
protected:
ExpressionType& m_expression;
};

View File

@ -214,5 +214,27 @@ void test_evaluators()
copy_using_evaluator(mat1.diagonal<-1>(), mat1.diagonal(1));
mat2.diagonal<-1>() = mat2.diagonal(1);
VERIFY_IS_APPROX(mat1, mat2);
VERIFY_IS_APPROX(mat1, mat2);
{
// test swapping
MatrixXd mat1, mat2, mat1ref, mat2ref;
mat1ref = mat1 = MatrixXd::Random(6, 6);
mat2ref = mat2 = 2 * mat1 + MatrixXd::Identity(6, 6);
swap_using_evaluator(mat1, mat2);
mat1ref.swap(mat2ref);
VERIFY_IS_APPROX(mat1, mat1ref);
VERIFY_IS_APPROX(mat2, mat2ref);
swap_using_evaluator(mat1.block(0, 0, 3, 3), mat2.block(3, 3, 3, 3));
mat1ref.block(0, 0, 3, 3).swap(mat2ref.block(3, 3, 3, 3));
VERIFY_IS_APPROX(mat1, mat1ref);
VERIFY_IS_APPROX(mat2, mat2ref);
swap_using_evaluator(mat1.row(2), mat2.col(3).transpose());
mat1.row(2).swap(mat2.col(3).transpose());
VERIFY_IS_APPROX(mat1, mat1ref);
VERIFY_IS_APPROX(mat2, mat2ref);
}
}