Extend support for nvcc to Array objects and wrappers

This commit is contained in:
Gael Guennebaud 2013-07-31 15:30:50 +02:00
parent 2f593ee67c
commit 6126ad801f
8 changed files with 126 additions and 3 deletions

View File

@ -118,40 +118,52 @@ template<typename Derived> class ArrayBase
/** Special case of the template operator=, in order to prevent the compiler /** Special case of the template operator=, in order to prevent the compiler
* from generating a default operator= (issue hit with g++ 4.1) * from generating a default operator= (issue hit with g++ 4.1)
*/ */
EIGEN_DEVICE_FUNC
Derived& operator=(const ArrayBase& other) Derived& operator=(const ArrayBase& other)
{ {
return internal::assign_selector<Derived,Derived>::run(derived(), other.derived()); return internal::assign_selector<Derived,Derived>::run(derived(), other.derived());
} }
EIGEN_DEVICE_FUNC
Derived& operator+=(const Scalar& scalar) Derived& operator+=(const Scalar& scalar)
{ return *this = derived() + scalar; } { return *this = derived() + scalar; }
EIGEN_DEVICE_FUNC
Derived& operator-=(const Scalar& scalar) Derived& operator-=(const Scalar& scalar)
{ return *this = derived() - scalar; } { return *this = derived() - scalar; }
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
Derived& operator+=(const ArrayBase<OtherDerived>& other); Derived& operator+=(const ArrayBase<OtherDerived>& other);
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
Derived& operator-=(const ArrayBase<OtherDerived>& other); Derived& operator-=(const ArrayBase<OtherDerived>& other);
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
Derived& operator*=(const ArrayBase<OtherDerived>& other); Derived& operator*=(const ArrayBase<OtherDerived>& other);
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
Derived& operator/=(const ArrayBase<OtherDerived>& other); Derived& operator/=(const ArrayBase<OtherDerived>& other);
public: public:
EIGEN_DEVICE_FUNC
ArrayBase<Derived>& array() { return *this; } ArrayBase<Derived>& array() { return *this; }
EIGEN_DEVICE_FUNC
const ArrayBase<Derived>& array() const { return *this; } const ArrayBase<Derived>& array() const { return *this; }
/** \returns an \link Eigen::MatrixBase Matrix \endlink expression of this array /** \returns an \link Eigen::MatrixBase Matrix \endlink expression of this array
* \sa MatrixBase::array() */ * \sa MatrixBase::array() */
EIGEN_DEVICE_FUNC
MatrixWrapper<Derived> matrix() { return derived(); } MatrixWrapper<Derived> matrix() { return derived(); }
EIGEN_DEVICE_FUNC
const MatrixWrapper<const Derived> matrix() const { return derived(); } const MatrixWrapper<const Derived> matrix() const { return derived(); }
// template<typename Dest> // template<typename Dest>
// inline void evalTo(Dest& dst) const { dst = matrix(); } // inline void evalTo(Dest& dst) const { dst = matrix(); }
protected: protected:
EIGEN_DEVICE_FUNC
ArrayBase() : Base() {} ArrayBase() : Base() {}
private: private:

View File

@ -48,41 +48,54 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
typedef typename internal::nested<ExpressionType>::type NestedExpressionType; typedef typename internal::nested<ExpressionType>::type NestedExpressionType;
EIGEN_DEVICE_FUNC
inline ArrayWrapper(ExpressionType& matrix) : m_expression(matrix) {} inline ArrayWrapper(ExpressionType& matrix) : m_expression(matrix) {}
EIGEN_DEVICE_FUNC
inline Index rows() const { return m_expression.rows(); } inline Index rows() const { return m_expression.rows(); }
EIGEN_DEVICE_FUNC
inline Index cols() const { return m_expression.cols(); } inline Index cols() const { return m_expression.cols(); }
EIGEN_DEVICE_FUNC
inline Index outerStride() const { return m_expression.outerStride(); } inline Index outerStride() const { return m_expression.outerStride(); }
EIGEN_DEVICE_FUNC
inline Index innerStride() const { return m_expression.innerStride(); } inline Index innerStride() const { return m_expression.innerStride(); }
EIGEN_DEVICE_FUNC
inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); } inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); }
EIGEN_DEVICE_FUNC
inline const Scalar* data() const { return m_expression.data(); } inline const Scalar* data() const { return m_expression.data(); }
EIGEN_DEVICE_FUNC
inline CoeffReturnType coeff(Index rowId, Index colId) const inline CoeffReturnType coeff(Index rowId, Index colId) const
{ {
return m_expression.coeff(rowId, colId); return m_expression.coeff(rowId, colId);
} }
EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index rowId, Index colId) inline Scalar& coeffRef(Index rowId, Index colId)
{ {
return m_expression.const_cast_derived().coeffRef(rowId, colId); return m_expression.const_cast_derived().coeffRef(rowId, colId);
} }
EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index rowId, Index colId) const inline const Scalar& coeffRef(Index rowId, Index colId) const
{ {
return m_expression.const_cast_derived().coeffRef(rowId, colId); return m_expression.const_cast_derived().coeffRef(rowId, colId);
} }
EIGEN_DEVICE_FUNC
inline CoeffReturnType coeff(Index index) const inline CoeffReturnType coeff(Index index) const
{ {
return m_expression.coeff(index); return m_expression.coeff(index);
} }
EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index index) inline Scalar& coeffRef(Index index)
{ {
return m_expression.const_cast_derived().coeffRef(index); return m_expression.const_cast_derived().coeffRef(index);
} }
EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index index) const inline const Scalar& coeffRef(Index index) const
{ {
return m_expression.const_cast_derived().coeffRef(index); return m_expression.const_cast_derived().coeffRef(index);
@ -113,9 +126,11 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
} }
template<typename Dest> template<typename Dest>
EIGEN_DEVICE_FUNC
inline void evalTo(Dest& dst) const { dst = m_expression; } inline void evalTo(Dest& dst) const { dst = m_expression; }
const typename internal::remove_all<NestedExpressionType>::type& const typename internal::remove_all<NestedExpressionType>::type&
EIGEN_DEVICE_FUNC
nestedExpression() const nestedExpression() const
{ {
return m_expression; return m_expression;
@ -123,9 +138,11 @@ class ArrayWrapper : public ArrayBase<ArrayWrapper<ExpressionType> >
/** Forwards the resizing request to the nested expression /** Forwards the resizing request to the nested expression
* \sa DenseBase::resize(Index) */ * \sa DenseBase::resize(Index) */
EIGEN_DEVICE_FUNC
void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); } void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); }
/** Forwards the resizing request to the nested expression /** Forwards the resizing request to the nested expression
* \sa DenseBase::resize(Index,Index)*/ * \sa DenseBase::resize(Index,Index)*/
EIGEN_DEVICE_FUNC
void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); } void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); }
protected: protected:
@ -168,41 +185,54 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
typedef typename internal::nested<ExpressionType>::type NestedExpressionType; typedef typename internal::nested<ExpressionType>::type NestedExpressionType;
EIGEN_DEVICE_FUNC
inline MatrixWrapper(ExpressionType& a_matrix) : m_expression(a_matrix) {} inline MatrixWrapper(ExpressionType& a_matrix) : m_expression(a_matrix) {}
EIGEN_DEVICE_FUNC
inline Index rows() const { return m_expression.rows(); } inline Index rows() const { return m_expression.rows(); }
EIGEN_DEVICE_FUNC
inline Index cols() const { return m_expression.cols(); } inline Index cols() const { return m_expression.cols(); }
EIGEN_DEVICE_FUNC
inline Index outerStride() const { return m_expression.outerStride(); } inline Index outerStride() const { return m_expression.outerStride(); }
EIGEN_DEVICE_FUNC
inline Index innerStride() const { return m_expression.innerStride(); } inline Index innerStride() const { return m_expression.innerStride(); }
EIGEN_DEVICE_FUNC
inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); } inline ScalarWithConstIfNotLvalue* data() { return m_expression.const_cast_derived().data(); }
EIGEN_DEVICE_FUNC
inline const Scalar* data() const { return m_expression.data(); } inline const Scalar* data() const { return m_expression.data(); }
EIGEN_DEVICE_FUNC
inline CoeffReturnType coeff(Index rowId, Index colId) const inline CoeffReturnType coeff(Index rowId, Index colId) const
{ {
return m_expression.coeff(rowId, colId); return m_expression.coeff(rowId, colId);
} }
EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index rowId, Index colId) inline Scalar& coeffRef(Index rowId, Index colId)
{ {
return m_expression.const_cast_derived().coeffRef(rowId, colId); return m_expression.const_cast_derived().coeffRef(rowId, colId);
} }
EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index rowId, Index colId) const inline const Scalar& coeffRef(Index rowId, Index colId) const
{ {
return m_expression.derived().coeffRef(rowId, colId); return m_expression.derived().coeffRef(rowId, colId);
} }
EIGEN_DEVICE_FUNC
inline CoeffReturnType coeff(Index index) const inline CoeffReturnType coeff(Index index) const
{ {
return m_expression.coeff(index); return m_expression.coeff(index);
} }
EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index index) inline Scalar& coeffRef(Index index)
{ {
return m_expression.const_cast_derived().coeffRef(index); return m_expression.const_cast_derived().coeffRef(index);
} }
EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index index) const inline const Scalar& coeffRef(Index index) const
{ {
return m_expression.const_cast_derived().coeffRef(index); return m_expression.const_cast_derived().coeffRef(index);
@ -232,6 +262,7 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
m_expression.const_cast_derived().template writePacket<LoadMode>(index, val); m_expression.const_cast_derived().template writePacket<LoadMode>(index, val);
} }
EIGEN_DEVICE_FUNC
const typename internal::remove_all<NestedExpressionType>::type& const typename internal::remove_all<NestedExpressionType>::type&
nestedExpression() const nestedExpression() const
{ {
@ -240,9 +271,11 @@ class MatrixWrapper : public MatrixBase<MatrixWrapper<ExpressionType> >
/** Forwards the resizing request to the nested expression /** Forwards the resizing request to the nested expression
* \sa DenseBase::resize(Index) */ * \sa DenseBase::resize(Index) */
EIGEN_DEVICE_FUNC
void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); } void resize(Index newSize) { m_expression.const_cast_derived().resize(newSize); }
/** Forwards the resizing request to the nested expression /** Forwards the resizing request to the nested expression
* \sa DenseBase::resize(Index,Index)*/ * \sa DenseBase::resize(Index,Index)*/
EIGEN_DEVICE_FUNC
void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); } void resize(Index nbRows, Index nbCols) { m_expression.const_cast_derived().resize(nbRows,nbCols); }
protected: protected:

View File

@ -235,3 +235,4 @@ MatrixBase<Derived>::operator+=(const MatrixBase<OtherDerived>& other)
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_CWISE_BINARY_OP_H #endif // EIGEN_CWISE_BINARY_OP_H

View File

@ -70,20 +70,25 @@ template<typename MatrixType, int _DiagIndex> class Diagonal
typedef typename internal::dense_xpr_base<Diagonal>::type Base; typedef typename internal::dense_xpr_base<Diagonal>::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Diagonal) EIGEN_DENSE_PUBLIC_INTERFACE(Diagonal)
EIGEN_DEVICE_FUNC
inline Diagonal(MatrixType& matrix, Index a_index = DiagIndex) : m_matrix(matrix), m_index(a_index) {} inline Diagonal(MatrixType& matrix, Index a_index = DiagIndex) : m_matrix(matrix), m_index(a_index) {}
EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Diagonal) EIGEN_INHERIT_ASSIGNMENT_OPERATORS(Diagonal)
EIGEN_DEVICE_FUNC
inline Index rows() const inline Index rows() const
{ return m_index.value()<0 ? (std::min<Index>)(m_matrix.cols(),m_matrix.rows()+m_index.value()) : (std::min<Index>)(m_matrix.rows(),m_matrix.cols()-m_index.value()); } { return m_index.value()<0 ? (std::min<Index>)(m_matrix.cols(),m_matrix.rows()+m_index.value()) : (std::min<Index>)(m_matrix.rows(),m_matrix.cols()-m_index.value()); }
EIGEN_DEVICE_FUNC
inline Index cols() const { return 1; } inline Index cols() const { return 1; }
EIGEN_DEVICE_FUNC
inline Index innerStride() const inline Index innerStride() const
{ {
return m_matrix.outerStride() + 1; return m_matrix.outerStride() + 1;
} }
EIGEN_DEVICE_FUNC
inline Index outerStride() const inline Index outerStride() const
{ {
return 0; return 0;
@ -95,47 +100,57 @@ template<typename MatrixType, int _DiagIndex> class Diagonal
const Scalar const Scalar
>::type ScalarWithConstIfNotLvalue; >::type ScalarWithConstIfNotLvalue;
EIGEN_DEVICE_FUNC
inline ScalarWithConstIfNotLvalue* data() { return &(m_matrix.const_cast_derived().coeffRef(rowOffset(), colOffset())); } inline ScalarWithConstIfNotLvalue* data() { return &(m_matrix.const_cast_derived().coeffRef(rowOffset(), colOffset())); }
EIGEN_DEVICE_FUNC
inline const Scalar* data() const { return &(m_matrix.const_cast_derived().coeffRef(rowOffset(), colOffset())); } inline const Scalar* data() const { return &(m_matrix.const_cast_derived().coeffRef(rowOffset(), colOffset())); }
EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index row, Index) inline Scalar& coeffRef(Index row, Index)
{ {
EIGEN_STATIC_ASSERT_LVALUE(MatrixType) EIGEN_STATIC_ASSERT_LVALUE(MatrixType)
return m_matrix.const_cast_derived().coeffRef(row+rowOffset(), row+colOffset()); return m_matrix.const_cast_derived().coeffRef(row+rowOffset(), row+colOffset());
} }
EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index row, Index) const inline const Scalar& coeffRef(Index row, Index) const
{ {
return m_matrix.const_cast_derived().coeffRef(row+rowOffset(), row+colOffset()); return m_matrix.const_cast_derived().coeffRef(row+rowOffset(), row+colOffset());
} }
EIGEN_DEVICE_FUNC
inline CoeffReturnType coeff(Index row, Index) const inline CoeffReturnType coeff(Index row, Index) const
{ {
return m_matrix.coeff(row+rowOffset(), row+colOffset()); return m_matrix.coeff(row+rowOffset(), row+colOffset());
} }
EIGEN_DEVICE_FUNC
inline Scalar& coeffRef(Index idx) inline Scalar& coeffRef(Index idx)
{ {
EIGEN_STATIC_ASSERT_LVALUE(MatrixType) EIGEN_STATIC_ASSERT_LVALUE(MatrixType)
return m_matrix.const_cast_derived().coeffRef(idx+rowOffset(), idx+colOffset()); return m_matrix.const_cast_derived().coeffRef(idx+rowOffset(), idx+colOffset());
} }
EIGEN_DEVICE_FUNC
inline const Scalar& coeffRef(Index idx) const inline const Scalar& coeffRef(Index idx) const
{ {
return m_matrix.const_cast_derived().coeffRef(idx+rowOffset(), idx+colOffset()); return m_matrix.const_cast_derived().coeffRef(idx+rowOffset(), idx+colOffset());
} }
EIGEN_DEVICE_FUNC
inline CoeffReturnType coeff(Index idx) const inline CoeffReturnType coeff(Index idx) const
{ {
return m_matrix.coeff(idx+rowOffset(), idx+colOffset()); return m_matrix.coeff(idx+rowOffset(), idx+colOffset());
} }
EIGEN_DEVICE_FUNC
const typename internal::remove_all<typename MatrixType::Nested>::type& const typename internal::remove_all<typename MatrixType::Nested>::type&
nestedExpression() const nestedExpression() const
{ {
return m_matrix; return m_matrix;
} }
EIGEN_DEVICE_FUNC
int index() const int index() const
{ {
return m_index.value(); return m_index.value();
@ -147,8 +162,11 @@ template<typename MatrixType, int _DiagIndex> class Diagonal
private: private:
// some compilers may fail to optimize std::max etc in case of compile-time constants... // some compilers may fail to optimize std::max etc in case of compile-time constants...
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index absDiagIndex() const { return m_index.value()>0 ? m_index.value() : -m_index.value(); } EIGEN_STRONG_INLINE Index absDiagIndex() const { return m_index.value()>0 ? m_index.value() : -m_index.value(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index rowOffset() const { return m_index.value()>0 ? 0 : -m_index.value(); } EIGEN_STRONG_INLINE Index rowOffset() const { return m_index.value()>0 ? 0 : -m_index.value(); }
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Index colOffset() const { return m_index.value()>0 ? m_index.value() : 0; } EIGEN_STRONG_INLINE Index colOffset() const { return m_index.value()>0 ? m_index.value() : 0; }
// triger a compile time error is someone try to call packet // triger a compile time error is someone try to call packet
template<int LoadMode> typename MatrixType::PacketReturnType packet(Index) const; template<int LoadMode> typename MatrixType::PacketReturnType packet(Index) const;

View File

@ -37,45 +37,59 @@ class DiagonalBase : public EigenBase<Derived>
typedef DenseMatrixType DenseType; typedef DenseMatrixType DenseType;
typedef DiagonalMatrix<Scalar,DiagonalVectorType::SizeAtCompileTime,DiagonalVectorType::MaxSizeAtCompileTime> PlainObject; typedef DiagonalMatrix<Scalar,DiagonalVectorType::SizeAtCompileTime,DiagonalVectorType::MaxSizeAtCompileTime> PlainObject;
EIGEN_DEVICE_FUNC
inline const Derived& derived() const { return *static_cast<const Derived*>(this); } inline const Derived& derived() const { return *static_cast<const Derived*>(this); }
EIGEN_DEVICE_FUNC
inline Derived& derived() { return *static_cast<Derived*>(this); } inline Derived& derived() { return *static_cast<Derived*>(this); }
EIGEN_DEVICE_FUNC
DenseMatrixType toDenseMatrix() const { return derived(); } DenseMatrixType toDenseMatrix() const { return derived(); }
template<typename DenseDerived> template<typename DenseDerived>
EIGEN_DEVICE_FUNC
void evalTo(MatrixBase<DenseDerived> &other) const; void evalTo(MatrixBase<DenseDerived> &other) const;
template<typename DenseDerived> template<typename DenseDerived>
EIGEN_DEVICE_FUNC
void addTo(MatrixBase<DenseDerived> &other) const void addTo(MatrixBase<DenseDerived> &other) const
{ other.diagonal() += diagonal(); } { other.diagonal() += diagonal(); }
template<typename DenseDerived> template<typename DenseDerived>
EIGEN_DEVICE_FUNC
void subTo(MatrixBase<DenseDerived> &other) const void subTo(MatrixBase<DenseDerived> &other) const
{ other.diagonal() -= diagonal(); } { other.diagonal() -= diagonal(); }
EIGEN_DEVICE_FUNC
inline const DiagonalVectorType& diagonal() const { return derived().diagonal(); } inline const DiagonalVectorType& diagonal() const { return derived().diagonal(); }
EIGEN_DEVICE_FUNC
inline DiagonalVectorType& diagonal() { return derived().diagonal(); } inline DiagonalVectorType& diagonal() { return derived().diagonal(); }
EIGEN_DEVICE_FUNC
inline Index rows() const { return diagonal().size(); } inline Index rows() const { return diagonal().size(); }
EIGEN_DEVICE_FUNC
inline Index cols() const { return diagonal().size(); } inline Index cols() const { return diagonal().size(); }
/** \returns the diagonal matrix product of \c *this by the matrix \a matrix. /** \returns the diagonal matrix product of \c *this by the matrix \a matrix.
*/ */
template<typename MatrixDerived> template<typename MatrixDerived>
EIGEN_DEVICE_FUNC
const DiagonalProduct<MatrixDerived, Derived, OnTheLeft> const DiagonalProduct<MatrixDerived, Derived, OnTheLeft>
operator*(const MatrixBase<MatrixDerived> &matrix) const operator*(const MatrixBase<MatrixDerived> &matrix) const
{ {
return DiagonalProduct<MatrixDerived, Derived, OnTheLeft>(matrix.derived(), derived()); return DiagonalProduct<MatrixDerived, Derived, OnTheLeft>(matrix.derived(), derived());
} }
EIGEN_DEVICE_FUNC
inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const DiagonalVectorType> > inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const DiagonalVectorType> >
inverse() const inverse() const
{ {
return diagonal().cwiseInverse(); return diagonal().cwiseInverse();
} }
EIGEN_DEVICE_FUNC
inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const DiagonalVectorType> > inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const DiagonalVectorType> >
operator*(const Scalar& scalar) const operator*(const Scalar& scalar) const
{ {
return diagonal() * scalar; return diagonal() * scalar;
} }
EIGEN_DEVICE_FUNC
friend inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const DiagonalVectorType> > friend inline const DiagonalWrapper<const CwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const DiagonalVectorType> >
operator*(const Scalar& scalar, const DiagonalBase& other) operator*(const Scalar& scalar, const DiagonalBase& other)
{ {
@ -84,11 +98,13 @@ class DiagonalBase : public EigenBase<Derived>
#ifdef EIGEN2_SUPPORT #ifdef EIGEN2_SUPPORT
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
bool isApprox(const DiagonalBase<OtherDerived>& other, typename NumTraits<Scalar>::Real precision = NumTraits<Scalar>::dummy_precision()) const bool isApprox(const DiagonalBase<OtherDerived>& other, typename NumTraits<Scalar>::Real precision = NumTraits<Scalar>::dummy_precision()) const
{ {
return diagonal().isApprox(other.diagonal(), precision); return diagonal().isApprox(other.diagonal(), precision);
} }
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
bool isApprox(const MatrixBase<OtherDerived>& other, typename NumTraits<Scalar>::Real precision = NumTraits<Scalar>::dummy_precision()) const bool isApprox(const MatrixBase<OtherDerived>& other, typename NumTraits<Scalar>::Real precision = NumTraits<Scalar>::dummy_precision()) const
{ {
return toDenseMatrix().isApprox(other, precision); return toDenseMatrix().isApprox(other, precision);
@ -151,24 +167,31 @@ class DiagonalMatrix
public: public:
/** const version of diagonal(). */ /** const version of diagonal(). */
EIGEN_DEVICE_FUNC
inline const DiagonalVectorType& diagonal() const { return m_diagonal; } inline const DiagonalVectorType& diagonal() const { return m_diagonal; }
/** \returns a reference to the stored vector of diagonal coefficients. */ /** \returns a reference to the stored vector of diagonal coefficients. */
EIGEN_DEVICE_FUNC
inline DiagonalVectorType& diagonal() { return m_diagonal; } inline DiagonalVectorType& diagonal() { return m_diagonal; }
/** Default constructor without initialization */ /** Default constructor without initialization */
EIGEN_DEVICE_FUNC
inline DiagonalMatrix() {} inline DiagonalMatrix() {}
/** Constructs a diagonal matrix with given dimension */ /** Constructs a diagonal matrix with given dimension */
EIGEN_DEVICE_FUNC
inline DiagonalMatrix(Index dim) : m_diagonal(dim) {} inline DiagonalMatrix(Index dim) : m_diagonal(dim) {}
/** 2D constructor. */ /** 2D constructor. */
EIGEN_DEVICE_FUNC
inline DiagonalMatrix(const Scalar& x, const Scalar& y) : m_diagonal(x,y) {} inline DiagonalMatrix(const Scalar& x, const Scalar& y) : m_diagonal(x,y) {}
/** 3D constructor. */ /** 3D constructor. */
EIGEN_DEVICE_FUNC
inline DiagonalMatrix(const Scalar& x, const Scalar& y, const Scalar& z) : m_diagonal(x,y,z) {} inline DiagonalMatrix(const Scalar& x, const Scalar& y, const Scalar& z) : m_diagonal(x,y,z) {}
/** Copy constructor. */ /** Copy constructor. */
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline DiagonalMatrix(const DiagonalBase<OtherDerived>& other) : m_diagonal(other.diagonal()) {} inline DiagonalMatrix(const DiagonalBase<OtherDerived>& other) : m_diagonal(other.diagonal()) {}
#ifndef EIGEN_PARSED_BY_DOXYGEN #ifndef EIGEN_PARSED_BY_DOXYGEN
@ -178,11 +201,13 @@ class DiagonalMatrix
/** generic constructor from expression of the diagonal coefficients */ /** generic constructor from expression of the diagonal coefficients */
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
explicit inline DiagonalMatrix(const MatrixBase<OtherDerived>& other) : m_diagonal(other) explicit inline DiagonalMatrix(const MatrixBase<OtherDerived>& other) : m_diagonal(other)
{} {}
/** Copy operator. */ /** Copy operator. */
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
DiagonalMatrix& operator=(const DiagonalBase<OtherDerived>& other) DiagonalMatrix& operator=(const DiagonalBase<OtherDerived>& other)
{ {
m_diagonal = other.diagonal(); m_diagonal = other.diagonal();
@ -193,6 +218,7 @@ class DiagonalMatrix
/** This is a special case of the templated operator=. Its purpose is to /** This is a special case of the templated operator=. Its purpose is to
* prevent a default operator= from hiding the templated operator=. * prevent a default operator= from hiding the templated operator=.
*/ */
EIGEN_DEVICE_FUNC
DiagonalMatrix& operator=(const DiagonalMatrix& other) DiagonalMatrix& operator=(const DiagonalMatrix& other)
{ {
m_diagonal = other.diagonal(); m_diagonal = other.diagonal();
@ -201,14 +227,19 @@ class DiagonalMatrix
#endif #endif
/** Resizes to given size. */ /** Resizes to given size. */
EIGEN_DEVICE_FUNC
inline void resize(Index size) { m_diagonal.resize(size); } inline void resize(Index size) { m_diagonal.resize(size); }
/** Sets all coefficients to zero. */ /** Sets all coefficients to zero. */
EIGEN_DEVICE_FUNC
inline void setZero() { m_diagonal.setZero(); } inline void setZero() { m_diagonal.setZero(); }
/** Resizes and sets all coefficients to zero. */ /** Resizes and sets all coefficients to zero. */
EIGEN_DEVICE_FUNC
inline void setZero(Index size) { m_diagonal.setZero(size); } inline void setZero(Index size) { m_diagonal.setZero(size); }
/** Sets this matrix to be the identity matrix of the current size. */ /** Sets this matrix to be the identity matrix of the current size. */
EIGEN_DEVICE_FUNC
inline void setIdentity() { m_diagonal.setOnes(); } inline void setIdentity() { m_diagonal.setOnes(); }
/** Sets this matrix to be the identity matrix of the given size. */ /** Sets this matrix to be the identity matrix of the given size. */
EIGEN_DEVICE_FUNC
inline void setIdentity(Index size) { m_diagonal.setOnes(size); } inline void setIdentity(Index size) { m_diagonal.setOnes(size); }
}; };
@ -255,9 +286,11 @@ class DiagonalWrapper
#endif #endif
/** Constructor from expression of diagonal coefficients to wrap. */ /** Constructor from expression of diagonal coefficients to wrap. */
EIGEN_DEVICE_FUNC
inline DiagonalWrapper(DiagonalVectorType& a_diagonal) : m_diagonal(a_diagonal) {} inline DiagonalWrapper(DiagonalVectorType& a_diagonal) : m_diagonal(a_diagonal) {}
/** \returns a const reference to the wrapped expression of diagonal coefficients. */ /** \returns a const reference to the wrapped expression of diagonal coefficients. */
EIGEN_DEVICE_FUNC
const DiagonalVectorType& diagonal() const { return m_diagonal; } const DiagonalVectorType& diagonal() const { return m_diagonal; }
protected: protected:

View File

@ -98,6 +98,7 @@ template<typename Derived> class MatrixBase
/** \returns the size of the main diagonal, which is min(rows(),cols()). /** \returns the size of the main diagonal, which is min(rows(),cols()).
* \sa rows(), cols(), SizeAtCompileTime. */ * \sa rows(), cols(), SizeAtCompileTime. */
EIGEN_DEVICE_FUNC
inline Index diagonalSize() const { return (std::min)(rows(),cols()); } inline Index diagonalSize() const { return (std::min)(rows(),cols()); }
/** \brief The plain matrix type corresponding to this expression. /** \brief The plain matrix type corresponding to this expression.
@ -206,6 +207,7 @@ template<typename Derived> class MatrixBase
void applyOnTheRight(const EigenBase<OtherDerived>& other); void applyOnTheRight(const EigenBase<OtherDerived>& other);
template<typename DiagonalDerived> template<typename DiagonalDerived>
EIGEN_DEVICE_FUNC
const DiagonalProduct<Derived, DiagonalDerived, OnTheRight> const DiagonalProduct<Derived, DiagonalDerived, OnTheRight>
operator*(const DiagonalBase<DiagonalDerived> &diagonal) const; operator*(const DiagonalBase<DiagonalDerived> &diagonal) const;
@ -231,15 +233,23 @@ template<typename Derived> class MatrixBase
EIGEN_DEVICE_FUNC void adjointInPlace(); EIGEN_DEVICE_FUNC void adjointInPlace();
typedef Diagonal<Derived> DiagonalReturnType; typedef Diagonal<Derived> DiagonalReturnType;
EIGEN_DEVICE_FUNC
DiagonalReturnType diagonal(); DiagonalReturnType diagonal();
typedef typename internal::add_const<Diagonal<const Derived> >::type ConstDiagonalReturnType;
typedef typename internal::add_const<Diagonal<const Derived> >::type ConstDiagonalReturnType;
EIGEN_DEVICE_FUNC
ConstDiagonalReturnType diagonal() const; ConstDiagonalReturnType diagonal() const;
template<int Index> struct DiagonalIndexReturnType { typedef Diagonal<Derived,Index> Type; }; template<int Index> struct DiagonalIndexReturnType { typedef Diagonal<Derived,Index> Type; };
template<int Index> struct ConstDiagonalIndexReturnType { typedef const Diagonal<const Derived,Index> Type; }; template<int Index> struct ConstDiagonalIndexReturnType { typedef const Diagonal<const Derived,Index> Type; };
template<int Index> typename DiagonalIndexReturnType<Index>::Type diagonal(); template<int Index>
template<int Index> typename ConstDiagonalIndexReturnType<Index>::Type diagonal() const; EIGEN_DEVICE_FUNC
typename DiagonalIndexReturnType<Index>::Type diagonal();
template<int Index>
EIGEN_DEVICE_FUNC
typename ConstDiagonalIndexReturnType<Index>::Type diagonal() const;
// Note: The "MatrixBase::" prefixes are added to help MSVC9 to match these declarations with the later implementations. // Note: The "MatrixBase::" prefixes are added to help MSVC9 to match these declarations with the later implementations.
// On the other hand they confuse MSVC8... // On the other hand they confuse MSVC8...
@ -247,7 +257,10 @@ template<typename Derived> class MatrixBase
typename MatrixBase::template DiagonalIndexReturnType<DynamicIndex>::Type diagonal(Index index); typename MatrixBase::template DiagonalIndexReturnType<DynamicIndex>::Type diagonal(Index index);
typename MatrixBase::template ConstDiagonalIndexReturnType<DynamicIndex>::Type diagonal(Index index) const; typename MatrixBase::template ConstDiagonalIndexReturnType<DynamicIndex>::Type diagonal(Index index) const;
#else #else
EIGEN_DEVICE_FUNC
typename DiagonalIndexReturnType<DynamicIndex>::Type diagonal(Index index); typename DiagonalIndexReturnType<DynamicIndex>::Type diagonal(Index index);
EIGEN_DEVICE_FUNC
typename ConstDiagonalIndexReturnType<DynamicIndex>::Type diagonal(Index index) const; typename ConstDiagonalIndexReturnType<DynamicIndex>::Type diagonal(Index index) const;
#endif #endif
@ -285,6 +298,7 @@ template<typename Derived> class MatrixBase
static const BasisReturnType UnitZ(); static const BasisReturnType UnitZ();
static const BasisReturnType UnitW(); static const BasisReturnType UnitW();
EIGEN_DEVICE_FUNC
const DiagonalWrapper<const Derived> asDiagonal() const; const DiagonalWrapper<const Derived> asDiagonal() const;
const PermutationWrapper<const Derived> asPermutation() const; const PermutationWrapper<const Derived> asPermutation() const;

View File

@ -3,6 +3,7 @@
* \sa MatrixBase::cwiseProduct * \sa MatrixBase::cwiseProduct
*/ */
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const EIGEN_CWISE_PRODUCT_RETURN_TYPE(Derived,OtherDerived) EIGEN_STRONG_INLINE const EIGEN_CWISE_PRODUCT_RETURN_TYPE(Derived,OtherDerived)
operator*(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const operator*(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{ {
@ -14,6 +15,7 @@ operator*(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
* \sa MatrixBase::cwiseQuotient * \sa MatrixBase::cwiseQuotient
*/ */
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived> EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>
operator/(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const operator/(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{ {
@ -33,6 +35,7 @@ EIGEN_MAKE_CWISE_BINARY_OP(min,internal::scalar_min_op)
* *
* \sa max() * \sa max()
*/ */
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived, EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_min_op<Scalar>, const Derived,
const CwiseNullaryOp<internal::scalar_constant_op<Scalar>, PlainObject> > const CwiseNullaryOp<internal::scalar_constant_op<Scalar>, PlainObject> >
#ifdef EIGEN_PARSED_BY_DOXYGEN #ifdef EIGEN_PARSED_BY_DOXYGEN
@ -58,6 +61,7 @@ EIGEN_MAKE_CWISE_BINARY_OP(max,internal::scalar_max_op)
* *
* \sa min() * \sa min()
*/ */
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived, EIGEN_STRONG_INLINE const CwiseBinaryOp<internal::scalar_max_op<Scalar>, const Derived,
const CwiseNullaryOp<internal::scalar_constant_op<Scalar>, PlainObject> > const CwiseNullaryOp<internal::scalar_constant_op<Scalar>, PlainObject> >
#ifdef EIGEN_PARSED_BY_DOXYGEN #ifdef EIGEN_PARSED_BY_DOXYGEN
@ -143,12 +147,14 @@ EIGEN_MAKE_CWISE_BINARY_OP(operator!=,std::not_equal_to)
* *
* \sa operator+=(), operator-() * \sa operator+=(), operator-()
*/ */
EIGEN_DEVICE_FUNC
inline const CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived> inline const CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived>
operator+(const Scalar& scalar) const operator+(const Scalar& scalar) const
{ {
return CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived>(derived(), internal::scalar_add_op<Scalar>(scalar)); return CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived>(derived(), internal::scalar_add_op<Scalar>(scalar));
} }
EIGEN_DEVICE_FUNC
friend inline const CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived> friend inline const CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived>
operator+(const Scalar& scalar,const EIGEN_CURRENT_STORAGE_BASE_CLASS<Derived>& other) operator+(const Scalar& scalar,const EIGEN_CURRENT_STORAGE_BASE_CLASS<Derived>& other)
{ {
@ -162,12 +168,14 @@ operator+(const Scalar& scalar,const EIGEN_CURRENT_STORAGE_BASE_CLASS<Derived>&
* *
* \sa operator+(), operator-=() * \sa operator+(), operator-=()
*/ */
EIGEN_DEVICE_FUNC
inline const CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived> inline const CwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived>
operator-(const Scalar& scalar) const operator-(const Scalar& scalar) const
{ {
return *this + (-scalar); return *this + (-scalar);
} }
EIGEN_DEVICE_FUNC
friend inline const CwiseUnaryOp<internal::scalar_add_op<Scalar>, const CwiseUnaryOp<internal::scalar_opposite_op<Scalar>, const Derived> > friend inline const CwiseUnaryOp<internal::scalar_add_op<Scalar>, const CwiseUnaryOp<internal::scalar_opposite_op<Scalar>, const Derived> >
operator-(const Scalar& scalar,const EIGEN_CURRENT_STORAGE_BASE_CLASS<Derived>& other) operator-(const Scalar& scalar,const EIGEN_CURRENT_STORAGE_BASE_CLASS<Derived>& other)
{ {
@ -184,6 +192,7 @@ operator-(const Scalar& scalar,const EIGEN_CURRENT_STORAGE_BASE_CLASS<Derived>&
* \sa operator||(), select() * \sa operator||(), select()
*/ */
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_boolean_and_op, const Derived, const OtherDerived> inline const CwiseBinaryOp<internal::scalar_boolean_and_op, const Derived, const OtherDerived>
operator&&(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const operator&&(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{ {
@ -202,6 +211,7 @@ operator&&(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
* \sa operator&&(), select() * \sa operator&&(), select()
*/ */
template<typename OtherDerived> template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_boolean_or_op, const Derived, const OtherDerived> inline const CwiseBinaryOp<internal::scalar_boolean_or_op, const Derived, const OtherDerived>
operator||(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const operator||(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{ {
@ -209,3 +219,4 @@ operator||(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_OF_BOOL); THIS_METHOD_IS_ONLY_FOR_EXPRESSIONS_OF_BOOL);
return CwiseBinaryOp<internal::scalar_boolean_or_op, const Derived, const OtherDerived>(derived(),other.derived()); return CwiseBinaryOp<internal::scalar_boolean_or_op, const Derived, const OtherDerived>(derived(),other.derived());
} }

View File

@ -201,6 +201,7 @@ cube() const
} }
#define EIGEN_MAKE_SCALAR_CWISE_UNARY_OP(METHOD_NAME,FUNCTOR) \ #define EIGEN_MAKE_SCALAR_CWISE_UNARY_OP(METHOD_NAME,FUNCTOR) \
EIGEN_DEVICE_FUNC \
inline const CwiseUnaryOp<std::binder2nd<FUNCTOR<Scalar> >, const Derived> \ inline const CwiseUnaryOp<std::binder2nd<FUNCTOR<Scalar> >, const Derived> \
METHOD_NAME(const Scalar& s) const { \ METHOD_NAME(const Scalar& s) const { \
return CwiseUnaryOp<std::binder2nd<FUNCTOR<Scalar> >, const Derived> \ return CwiseUnaryOp<std::binder2nd<FUNCTOR<Scalar> >, const Derived> \