Make evaluators for Matrix and Array inherit from common base class.

This gets rid of some code duplication.
This commit is contained in:
Jitse Niesen 2011-04-04 15:35:14 +01:00
parent afdd26f229
commit ae06b8af5c

View File

@ -106,128 +106,92 @@ protected:
typename evaluator<ExpressionType>::type m_argImpl; typename evaluator<ExpressionType>::type m_argImpl;
}; };
// -------------------- Matrix -------------------- // -------------------- Matrix and Array--------------------
//
// evaluator_impl<PlainObjectBase> is a common base class for the
// Matrix and Array evaluators.
template<typename Derived>
struct evaluator_impl<PlainObjectBase<Derived> >
{
typedef PlainObjectBase<Derived> PlainObjectType;
evaluator_impl(const PlainObjectType& m) : m_plainObject(m) {}
typedef typename PlainObjectType::Index Index;
typedef typename PlainObjectType::Scalar Scalar;
typedef typename PlainObjectType::CoeffReturnType CoeffReturnType;
typedef typename PlainObjectType::PacketScalar PacketScalar;
typedef typename PlainObjectType::PacketReturnType PacketReturnType;
CoeffReturnType coeff(Index i, Index j) const
{
return m_plainObject.coeff(i, j);
}
CoeffReturnType coeff(Index index) const
{
return m_plainObject.coeff(index);
}
Scalar& coeffRef(Index i, Index j)
{
return m_plainObject.const_cast_derived().coeffRef(i, j);
}
Scalar& coeffRef(Index index)
{
return m_plainObject.const_cast_derived().coeffRef(index);
}
template<int LoadMode>
PacketReturnType packet(Index row, Index col) const
{
return m_plainObject.template packet<LoadMode>(row, col);
}
template<int LoadMode>
PacketReturnType packet(Index index) const
{
return m_plainObject.template packet<LoadMode>(index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const PacketScalar& x)
{
m_plainObject.const_cast_derived().template writePacket<StoreMode>(row, col, x);
}
template<int StoreMode>
void writePacket(Index index, const PacketScalar& x)
{
m_plainObject.const_cast_derived().template writePacket<StoreMode>(index, x);
}
protected:
const PlainObjectType &m_plainObject;
};
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols> template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > struct evaluator_impl<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
: evaluator_impl<PlainObjectBase<Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > >
{ {
typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> MatrixType; typedef Matrix<Scalar, Rows, Cols, Options, MaxRows, MaxCols> MatrixType;
evaluator_impl(const MatrixType& m) : m_matrix(m) {} evaluator_impl(const MatrixType& m)
: evaluator_impl<PlainObjectBase<MatrixType> >(m)
typedef typename MatrixType::Index Index; { }
typename MatrixType::CoeffReturnType coeff(Index i, Index j) const
{
return m_matrix.coeff(i, j);
}
typename MatrixType::CoeffReturnType coeff(Index index) const
{
return m_matrix.coeff(index);
}
typename MatrixType::Scalar& coeffRef(Index i, Index j)
{
return m_matrix.const_cast_derived().coeffRef(i, j);
}
typename MatrixType::Scalar& coeffRef(Index index)
{
return m_matrix.const_cast_derived().coeffRef(index);
}
template<int LoadMode>
typename MatrixType::PacketReturnType packet(Index row, Index col) const
{
return m_matrix.template packet<LoadMode>(row, col);
}
template<int LoadMode>
typename MatrixType::PacketReturnType packet(Index index) const
{
// eigen_internal_assert(index >= 0 && index < size());
return m_matrix.template packet<LoadMode>(index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const typename MatrixType::PacketScalar& x)
{
m_matrix.const_cast_derived().template writePacket<StoreMode>(row, col, x);
}
template<int StoreMode>
void writePacket(Index index, const typename MatrixType::PacketScalar& x)
{
// eigen_internal_assert(index >= 0 && index < size());
m_matrix.const_cast_derived().template writePacket<StoreMode>(index, x);
}
protected:
const MatrixType &m_matrix;
}; };
// -------------------- Array --------------------
// TODO: should be sharing code with Matrix case
template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols> template<typename Scalar, int Rows, int Cols, int Options, int MaxRows, int MaxCols>
struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > struct evaluator_impl<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> >
: evaluator_impl<PlainObjectBase<Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> > >
{ {
typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> ArrayType; typedef Array<Scalar, Rows, Cols, Options, MaxRows, MaxCols> ArrayType;
evaluator_impl(const ArrayType& a) : m_array(a) {} evaluator_impl(const ArrayType& m)
: evaluator_impl<PlainObjectBase<ArrayType> >(m)
typedef typename ArrayType::Index Index; { }
typename ArrayType::CoeffReturnType coeff(Index i, Index j) const
{
return m_array.coeff(i, j);
}
typename ArrayType::CoeffReturnType coeff(Index index) const
{
return m_array.coeff(index);
}
typename ArrayType::Scalar& coeffRef(Index i, Index j)
{
return m_array.const_cast_derived().coeffRef(i, j);
}
typename ArrayType::Scalar& coeffRef(Index index)
{
return m_array.const_cast_derived().coeffRef(index);
}
template<int LoadMode>
typename ArrayType::PacketReturnType packet(Index row, Index col) const
{
return m_array.template packet<LoadMode>(row, col);
}
template<int LoadMode>
typename ArrayType::PacketReturnType packet(Index index) const
{
// eigen_internal_assert(index >= 0 && index < size());
return m_array.template packet<LoadMode>(index);
}
template<int StoreMode>
void writePacket(Index row, Index col, const typename ArrayType::PacketScalar& x)
{
m_array.const_cast_derived().template writePacket<StoreMode>(row, col, x);
}
template<int StoreMode>
void writePacket(Index index, const typename ArrayType::PacketScalar& x)
{
// eigen_internal_assert(index >= 0 && index < size());
m_array.const_cast_derived().template writePacket<StoreMode>(index, x);
}
protected:
const ArrayType &m_array;
}; };
// -------------------- CwiseNullaryOp -------------------- // -------------------- CwiseNullaryOp --------------------
@ -400,8 +364,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
CoeffReturnType coeff(Index index) const CoeffReturnType coeff(Index index) const
{ {
return m_argImpl.coeff(m_startRow + (RowsAtCompileTime == 1 ? 0 : index), return coeff(RowsAtCompileTime == 1 ? 0 : index,
m_startCol + (RowsAtCompileTime == 1 ? index : 0)); RowsAtCompileTime == 1 ? index : 0);
} }
Scalar& coeffRef(Index row, Index col) Scalar& coeffRef(Index row, Index col)
@ -411,8 +375,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
Scalar& coeffRef(Index index) Scalar& coeffRef(Index index)
{ {
return m_argImpl.coeffRef(m_startRow + (RowsAtCompileTime == 1 ? 0 : index), return coeffRef(RowsAtCompileTime == 1 ? 0 : index,
m_startCol + (RowsAtCompileTime == 1 ? index : 0)); RowsAtCompileTime == 1 ? index : 0);
} }
template<int LoadMode> template<int LoadMode>
@ -424,8 +388,8 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
template<int LoadMode> template<int LoadMode>
PacketReturnType packet(Index index) const PacketReturnType packet(Index index) const
{ {
return m_argImpl.template packet<LoadMode>(m_startRow + (RowsAtCompileTime == 1 ? 0 : index), return packet<LoadMode>(RowsAtCompileTime == 1 ? 0 : index,
m_startCol + (RowsAtCompileTime == 1 ? index : 0)); RowsAtCompileTime == 1 ? index : 0);
} }
template<int StoreMode> template<int StoreMode>
@ -437,9 +401,9 @@ struct evaluator_impl<Block<XprType, BlockRows, BlockCols, InnerPanel, /* HasDir
template<int StoreMode> template<int StoreMode>
void writePacket(Index index, const PacketScalar& x) void writePacket(Index index, const PacketScalar& x)
{ {
return m_argImpl.template writePacket<StoreMode>(m_startRow + (RowsAtCompileTime == 1 ? 0 : index), return writePacket<StoreMode>(RowsAtCompileTime == 1 ? 0 : index,
m_startCol + (RowsAtCompileTime == 1 ? index : 0), RowsAtCompileTime == 1 ? index : 0,
x); x);
} }
protected: protected: