From 59dc1da5bf6b5c57205ee8d695e0f6e852ca74ae Mon Sep 17 00:00:00 2001 From: Gael Guennebaud Date: Wed, 3 Sep 2008 17:16:28 +0000 Subject: [PATCH] Add a Select expression in the Array module which mimics a coeff-wise ?: operator. Example: mat = (mat.cwise().abs().cwise() < Ones()).select(0,mat); replaces all small values by 0. (the scalar version is "s = abs(s)<1 ? 0 : s") --- Eigen/Array | 1 + Eigen/src/Array/Select.h | 153 ++++++++++++++++++++++ Eigen/src/Core/MatrixBase.h | 12 ++ Eigen/src/Core/util/ForwardDeclarations.h | 7 +- test/array.cpp | 30 ++++- 5 files changed, 194 insertions(+), 9 deletions(-) create mode 100644 Eigen/src/Array/Select.h diff --git a/Eigen/Array b/Eigen/Array index 74d8fa888..fbfffe36c 100644 --- a/Eigen/Array +++ b/Eigen/Array @@ -25,6 +25,7 @@ namespace Eigen { #include "src/Array/CwiseOperators.h" #include "src/Array/Functors.h" #include "src/Array/AllAndAny.h" +#include "src/Array/Select.h" #include "src/Array/PartialRedux.h" #include "src/Array/Random.h" diff --git a/Eigen/src/Array/Select.h b/Eigen/src/Array/Select.h new file mode 100644 index 000000000..65feb42b2 --- /dev/null +++ b/Eigen/src/Array/Select.h @@ -0,0 +1,153 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. Eigen itself is part of the KDE project. +// +// Copyright (C) 2008 Gael Guennebaud +// +// Eigen is free software; you can redistribute it and/or +// modify it under the terms of the GNU Lesser General Public +// License as published by the Free Software Foundation; either +// version 3 of the License, or (at your option) any later version. +// +// Alternatively, you can redistribute it and/or +// modify it under the terms of the GNU General Public License as +// published by the Free Software Foundation; either version 2 of +// the License, or (at your option) any later version. +// +// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY +// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public +// License and a copy of the GNU General Public License along with +// Eigen. If not, see . + +#ifndef EIGEN_SELECT_H +#define EIGEN_SELECT_H + +/** \array_module \ingroup Array + * + * \class Select + * + * \brief Expression of a coefficient wise version of the C++ ternary operator ?: + * + * \param ConditionMatrixType the type of the \em condition expression which must be a boolean matrix + * \param ThenMatrixType the type of the \em then expression + * \param ElseMatrixType the type of the \em else expression + * + * This class represents an expression of a coefficient wise version of the C++ ternary operator ?:. + * It is the return type of MatrixBase::select() and most of the time this is the only way it is used. + * + * \sa MatrixBase::select(const MatrixBase&, const MatrixBase&) const + */ + +template +struct ei_traits > +{ + typedef typename ei_traits::Scalar Scalar; + enum { + RowsAtCompileTime = ConditionMatrixType::RowsAtCompileTime, + ColsAtCompileTime = ConditionMatrixType::ColsAtCompileTime, + MaxRowsAtCompileTime = ConditionMatrixType::MaxRowsAtCompileTime, + MaxColsAtCompileTime = ConditionMatrixType::MaxColsAtCompileTime, + Flags = (unsigned int)ThenMatrixType::Flags & ElseMatrixType::Flags & HereditaryBits, + CoeffReadCost = ei_traits::CoeffReadCost + + EIGEN_ENUM_MAX(ei_traits::CoeffReadCost, + ei_traits::CoeffReadCost) + }; +}; + +template +class Select : ei_no_assignment_operator, + public MatrixBase > +{ + public: + + EIGEN_GENERIC_PUBLIC_INTERFACE(Select) + + Select(const ConditionMatrixType& conditionMatrix, + const ThenMatrixType& thenMatrix, + const ElseMatrixType& elseMatrix) + : m_condition(conditionMatrix), m_then(thenMatrix), m_else(elseMatrix) + { + ei_assert(m_condition.rows() == m_then.rows() && m_condition.rows() == m_else.rows()); + ei_assert(m_condition.cols() == m_then.cols() && m_condition.cols() == m_else.cols()); + } + + int rows() const { return m_condition.rows(); } + int cols() const { return m_condition.cols(); } + + const Scalar coeff(int i, int j) const + { + if (m_condition.coeff(i,j)) + return m_then.coeff(i,j); + else + return m_else.coeff(i,j); + } + + const Scalar coeff(int i) const + { + if (m_condition.coeff(i)) + return m_then.coeff(i); + else + return m_else.coeff(i); + } + + protected: + const typename ConditionMatrixType::Nested m_condition; + const typename ThenMatrixType::Nested m_then; + const typename ElseMatrixType::Nested m_else; +}; + + +/** \array_module + * + * \returns a matrix where each coefficient (i,j) is equal to \a thenMatrix(i,j) + * if \c *this(i,j), and \a elseMatrix(i,j) otherwise. + * + * \sa class Select + */ +template +template +inline const Select +MatrixBase::select(const MatrixBase& thenMatrix, + const MatrixBase& elseMatrix) const +{ + return Select(derived(), thenMatrix.derived(), elseMatrix.derived()); +} + +/** \array_module + * + * Version of MatrixBase::select(const MatrixBase&, const MatrixBase&) with + * the \em else expression being a scalar value. + * + * \sa MatrixBase::select(const MatrixBase&, const MatrixBase&) const, class Select + */ +template +template +inline const Select > +MatrixBase::select(const MatrixBase& thenMatrix, + typename ThenDerived::Scalar elseScalar) const +{ + return Select >( + derived(), thenMatrix.derived(), ThenDerived::Constant(rows(),cols(),elseScalar)); +} + +/** \array_module + * + * Version of MatrixBase::select(const MatrixBase&, const MatrixBase&) with + * the \em then expression being a scalar value. + * + * \sa MatrixBase::select(const MatrixBase&, const MatrixBase&) const, class Select + */ +template +template +inline const Select, ElseDerived > +MatrixBase::select(typename ElseDerived::Scalar thenScalar, + const MatrixBase& elseMatrix) const +{ + return Select,ElseDerived>( + derived(), ElseDerived::Constant(rows(),cols(),thenScalar), elseMatrix.derived()); +} + +#endif // EIGEN_SELECT_H diff --git a/Eigen/src/Core/MatrixBase.h b/Eigen/src/Core/MatrixBase.h index 4248ae523..e6af97bb3 100644 --- a/Eigen/src/Core/MatrixBase.h +++ b/Eigen/src/Core/MatrixBase.h @@ -539,6 +539,18 @@ template class MatrixBase static const CwiseNullaryOp,Derived> Random(int size); static const CwiseNullaryOp,Derived> Random(); + template + const Select + select(const MatrixBase& thenMatrix, + const MatrixBase& elseMatrix) const; + + template + inline const Select > + select(const MatrixBase& thenMatrix, typename ThenDerived::Scalar elseScalar) const; + + template + inline const Select, ElseDerived > + select(typename ElseDerived::Scalar thenScalar, const MatrixBase& elseMatrix) const; /////////// LU module /////////// diff --git a/Eigen/src/Core/util/ForwardDeclarations.h b/Eigen/src/Core/util/ForwardDeclarations.h index 92852cbfa..8a8beb9b7 100644 --- a/Eigen/src/Core/util/ForwardDeclarations.h +++ b/Eigen/src/Core/util/ForwardDeclarations.h @@ -50,8 +50,6 @@ template class Map; template class Part; template class Extract; template class Cwise; -template class PartialRedux; -template class PartialReduxExpr; template class WithFormat; template struct ei_product_mode; @@ -94,6 +92,11 @@ void ei_cache_friendly_product( bool _rhsRowMajor, const Scalar* _rhs, int _rhsStride, bool resRowMajor, Scalar* res, int resStride); +// Array module +template class Select; +template class PartialReduxExpr; +template class PartialRedux; + template class LU; template class QR; template class SVD; diff --git a/test/array.cpp b/test/array.cpp index f0b09051f..0fa13b65b 100644 --- a/test/array.cpp +++ b/test/array.cpp @@ -25,7 +25,7 @@ #include "main.h" #include -template void scalarAdd(const MatrixType& m) +template void array(const MatrixType& m) { /* this test covers the following files: Array.cpp @@ -45,6 +45,7 @@ template void scalarAdd(const MatrixType& m) Scalar s1 = ei_random(), s2 = ei_random(); + // scalar addition VERIFY_IS_APPROX(m1.cwise() + s1, s1 + m1.cwise()); VERIFY_IS_APPROX(m1.cwise() + s1, MatrixType::Constant(rows,cols,s1) + m1); VERIFY_IS_APPROX((m1*Scalar(2)).cwise() - s2, (m1+m1) - MatrixType::Constant(rows,cols,s2) ); @@ -55,6 +56,7 @@ template void scalarAdd(const MatrixType& m) m3.cwise() -= s1; VERIFY_IS_APPROX(m3, m1.cwise() - s1); + // reductions VERIFY_IS_APPROX(m1.colwise().sum().sum(), m1.sum()); VERIFY_IS_APPROX(m1.rowwise().sum().sum(), m1.sum()); if (!ei_isApprox(m1.sum(), (m1+m2).sum())) @@ -86,17 +88,31 @@ template void comparisons(const MatrixType& m) VERIFY(! (m1.cwise() < m3).all() ); VERIFY(! (m1.cwise() > m3).all() ); } + + // test Select + VERIFY_IS_APPROX( (m1.cwise()m2).select(m1,m2), m1.cwise().max(m2) ); + Scalar mid = (m1.cwise().abs().minCoeff() + m1.cwise().abs().maxCoeff())/Scalar(2); + for (int j=0; j=MatrixType::Constant(rows,cols,mid)) + .select(m1,0), m3); } void test_array() { for(int i = 0; i < g_repeat; i++) { - CALL_SUBTEST( scalarAdd(Matrix()) ); - CALL_SUBTEST( scalarAdd(Matrix2f()) ); - CALL_SUBTEST( scalarAdd(Matrix4d()) ); - CALL_SUBTEST( scalarAdd(MatrixXcf(3, 3)) ); - CALL_SUBTEST( scalarAdd(MatrixXf(8, 12)) ); - CALL_SUBTEST( scalarAdd(MatrixXi(8, 12)) ); + CALL_SUBTEST( array(Matrix()) ); + CALL_SUBTEST( array(Matrix2f()) ); + CALL_SUBTEST( array(Matrix4d()) ); + CALL_SUBTEST( array(MatrixXcf(3, 3)) ); + CALL_SUBTEST( array(MatrixXf(8, 12)) ); + CALL_SUBTEST( array(MatrixXi(8, 12)) ); } for(int i = 0; i < g_repeat; i++) { CALL_SUBTEST( comparisons(Matrix()) );