vectorize comparisons and select by enabling typed comparisons

This commit is contained in:
Charles Schlosser 2023-02-25 20:52:11 +00:00 committed by Rasmus Munk Larsen
parent 2e9b945baf
commit 826627f653
11 changed files with 463 additions and 185 deletions

View File

@ -567,18 +567,28 @@ template<typename Derived> class DenseBase
static const RandomReturnType Random(Index size);
static const RandomReturnType Random();
template<typename ThenDerived,typename ElseDerived>
inline EIGEN_DEVICE_FUNC const Select<Derived,ThenDerived,ElseDerived>
select(const DenseBase<ThenDerived>& thenMatrix,
const DenseBase<ElseDerived>& elseMatrix) const;
template <typename ThenDerived, typename ElseDerived>
inline EIGEN_DEVICE_FUNC
CwiseTernaryOp<internal::scalar_boolean_select_op<Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>,
Derived, ThenDerived, ElseDerived>
select(const DenseBase<ThenDerived>& thenMatrix, const DenseBase<ElseDerived>& elseMatrix) const;
template<typename ThenDerived>
inline EIGEN_DEVICE_FUNC const Select<Derived,ThenDerived, typename ThenDerived::ConstantReturnType>
select(const DenseBase<ThenDerived>& thenMatrix, const typename ThenDerived::Scalar& elseScalar) const;
template <typename ThenDerived>
inline EIGEN_DEVICE_FUNC
CwiseTernaryOp<internal::scalar_boolean_select_op<Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ThenDerived>::Scalar>,
Derived, ThenDerived, typename DenseBase<ThenDerived>::ConstantReturnType>
select(const DenseBase<ThenDerived>& thenMatrix,
const typename DenseBase<ThenDerived>::Scalar& elseScalar) const;
template<typename ElseDerived>
inline EIGEN_DEVICE_FUNC const Select<Derived, typename ElseDerived::ConstantReturnType, ElseDerived >
select(const typename ElseDerived::Scalar& thenScalar, const DenseBase<ElseDerived>& elseMatrix) const;
template <typename ElseDerived>
inline EIGEN_DEVICE_FUNC
CwiseTernaryOp<internal::scalar_boolean_select_op<Scalar, typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>,
Derived, typename DenseBase<ElseDerived>::ConstantReturnType, ElseDerived>
select(const typename DenseBase<ElseDerived>::Scalar& thenScalar,
const DenseBase<ElseDerived>& elseMatrix) const;
template<int p> RealScalar lpNorm() const;

View File

@ -26,7 +26,7 @@ struct isApprox_selector
{
typename internal::nested_eval<Derived,2>::type nested(x);
typename internal::nested_eval<OtherDerived,2>::type otherNested(y);
return (nested - otherNested).cwiseAbs2().sum() <= prec * prec * numext::mini(nested.cwiseAbs2().sum(), otherNested.cwiseAbs2().sum());
return (nested.matrix() - otherNested.matrix()).cwiseAbs2().sum() <= prec * prec * numext::mini(nested.cwiseAbs2().sum(), otherNested.cwiseAbs2().sum());
}
};

View File

@ -113,52 +113,63 @@ class Select : public internal::dense_xpr_base< Select<ConditionMatrixType, Then
typename ElseMatrixType::Nested m_else;
};
/** \returns a matrix where each coefficient (i,j) is equal to \a thenMatrix(i,j)
* if \c *this(i,j), and \a elseMatrix(i,j) otherwise.
*
* Example: \include MatrixBase_select.cpp
* Output: \verbinclude MatrixBase_select.out
*
* \sa class Select
*/
template<typename Derived>
template<typename ThenDerived,typename ElseDerived>
inline EIGEN_DEVICE_FUNC const Select<Derived,ThenDerived,ElseDerived>
DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix,
const DenseBase<ElseDerived>& elseMatrix) const
{
return Select<Derived,ThenDerived,ElseDerived>(derived(), thenMatrix.derived(), elseMatrix.derived());
* if \c *this(i,j) != Scalar(0), and \a elseMatrix(i,j) otherwise.
*
* Example: \include MatrixBase_select.cpp
* Output: \verbinclude MatrixBase_select.out
*
* \sa DenseBase::bitwiseSelect(const DenseBase<ThenDerived>&, const DenseBase<ElseDerived>&)
*/
template <typename Derived>
template <typename ThenDerived, typename ElseDerived>
inline EIGEN_DEVICE_FUNC CwiseTernaryOp<
internal::scalar_boolean_select_op<typename DenseBase<Derived>::Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>,
Derived, ThenDerived, ElseDerived>
DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix, const DenseBase<ElseDerived>& elseMatrix) const {
using Op = internal::scalar_boolean_select_op<Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>;
return CwiseTernaryOp<Op, Derived, ThenDerived, ElseDerived>(derived(), thenMatrix.derived(), elseMatrix.derived(),
Op());
}
/** Version of DenseBase::select(const DenseBase&, const DenseBase&) with
* the \em else expression being a scalar value.
*
* \sa DenseBase::select(const DenseBase<ThenDerived>&, const DenseBase<ElseDerived>&) const, class Select
*/
template<typename Derived>
template<typename ThenDerived>
inline EIGEN_DEVICE_FUNC const Select<Derived,ThenDerived, typename ThenDerived::ConstantReturnType>
* the \em else expression being a scalar value.
*
* \sa DenseBase::booleanSelect(const DenseBase<ThenDerived>&, const DenseBase<ElseDerived>&) const, class Select
*/
template <typename Derived>
template <typename ThenDerived>
inline EIGEN_DEVICE_FUNC CwiseTernaryOp<
internal::scalar_boolean_select_op<typename DenseBase<Derived>::Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ThenDerived>::Scalar>,
Derived, ThenDerived, typename DenseBase<ThenDerived>::ConstantReturnType>
DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix,
const typename ThenDerived::Scalar& elseScalar) const
{
return Select<Derived,ThenDerived,typename ThenDerived::ConstantReturnType>(
derived(), thenMatrix.derived(), ThenDerived::Constant(rows(),cols(),elseScalar));
const typename DenseBase<ThenDerived>::Scalar& elseScalar) const {
using ElseConstantType = typename DenseBase<ThenDerived>::ConstantReturnType;
using Op = internal::scalar_boolean_select_op<Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ThenDerived>::Scalar>;
return CwiseTernaryOp<Op, Derived, ThenDerived, ElseConstantType>(derived(), thenMatrix.derived(),
ElseConstantType(rows(), cols(), elseScalar), Op());
}
/** Version of DenseBase::select(const DenseBase&, const DenseBase&) with
* the \em then expression being a scalar value.
*
* \sa DenseBase::select(const DenseBase<ThenDerived>&, const DenseBase<ElseDerived>&) const, class Select
*/
template<typename Derived>
template<typename ElseDerived>
inline EIGEN_DEVICE_FUNC const Select<Derived, typename ElseDerived::ConstantReturnType, ElseDerived >
DenseBase<Derived>::select(const typename ElseDerived::Scalar& thenScalar,
const DenseBase<ElseDerived>& elseMatrix) const
{
return Select<Derived,typename ElseDerived::ConstantReturnType,ElseDerived>(
derived(), ElseDerived::Constant(rows(),cols(),thenScalar), elseMatrix.derived());
* the \em then expression being a scalar value.
*
* \sa DenseBase::booleanSelect(const DenseBase<ThenDerived>&, const DenseBase<ElseDerived>&) const, class Select
*/
template <typename Derived>
template <typename ElseDerived>
inline EIGEN_DEVICE_FUNC CwiseTernaryOp<
internal::scalar_boolean_select_op<typename DenseBase<Derived>::Scalar, typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>,
Derived, typename DenseBase<ElseDerived>::ConstantReturnType, ElseDerived>
DenseBase<Derived>::select(const typename DenseBase<ElseDerived>::Scalar& thenScalar,
const DenseBase<ElseDerived>& elseMatrix) const {
using ThenConstantType = typename DenseBase<ElseDerived>::ConstantReturnType;
using Op = internal::scalar_boolean_select_op<Scalar, typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>;
return CwiseTernaryOp<Op, Derived, ThenConstantType, ElseDerived>(
derived(), ThenConstantType(rows(), cols(), thenScalar), elseMatrix.derived(), Op());
}
} // end namespace Eigen

View File

@ -191,91 +191,122 @@ struct functor_traits<scalar_max_op<LhsScalar,RhsScalar, NaNPropagation> > {
};
/** \internal
* \brief Template functors for comparison of two scalars
* \todo Implement packet-comparisons
*/
template<typename LhsScalar, typename RhsScalar, ComparisonName cmp> struct scalar_cmp_op;
* \brief Template functors for comparison of two scalars
* \todo Implement packet-comparisons
*/
template <typename LhsScalar, typename RhsScalar, ComparisonName cmp,
bool UseTypedComparators = true>
struct scalar_cmp_op;
template<typename LhsScalar, typename RhsScalar, ComparisonName cmp>
struct functor_traits<scalar_cmp_op<LhsScalar,RhsScalar, cmp> > {
template <typename LhsScalar, typename RhsScalar, ComparisonName cmp, bool UseTypedComparators>
struct functor_traits<scalar_cmp_op<LhsScalar, RhsScalar, cmp, UseTypedComparators>> {
enum {
Cost = (NumTraits<LhsScalar>::AddCost+NumTraits<RhsScalar>::AddCost)/2,
PacketAccess = is_same<LhsScalar, RhsScalar>::value &&
packet_traits<LhsScalar>::HasCmp &&
// Since return type is bool, we currently require the inputs
// to be bool to enable packet access.
is_same<LhsScalar, bool>::value
Cost = (NumTraits<LhsScalar>::AddCost + NumTraits<RhsScalar>::AddCost) / 2,
PacketAccess = (UseTypedComparators || is_same<LhsScalar, bool>::value) && is_same<LhsScalar, RhsScalar>::value &&
packet_traits<LhsScalar>::HasCmp
};
};
template<ComparisonName Cmp, typename LhsScalar, typename RhsScalar>
struct result_of<scalar_cmp_op<LhsScalar, RhsScalar, Cmp>(LhsScalar,RhsScalar)> {
typedef bool type;
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
struct typed_cmp_helper {
static constexpr bool SameType = is_same<LhsScalar, RhsScalar>::value;
static constexpr bool IsNumeric = is_arithmetic<typename NumTraits<LhsScalar>::Real>::value;
static constexpr bool UseTyped = UseTypedComparators && SameType && IsNumeric;
using type = typename conditional<UseTyped, LhsScalar, bool>::type;
};
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
using cmp_return_t = typename typed_cmp_helper<LhsScalar, RhsScalar, UseTypedComparators>::type;
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_EQ> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a==b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_eq(a,b); }
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_EQ, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
return a == b ? result_type(1) : result_type(0);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const {
const Packet cst_one = pset1<Packet>(result_type(1));
return pand(pcmp_eq(a, b), cst_one);
}
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LT> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a<b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_lt(a,b); }
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_LT, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
return a < b ? result_type(1) : result_type(0);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const {
const Packet cst_one = pset1<Packet>(result_type(1));
return pand(pcmp_lt(a, b), cst_one);
}
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_LE> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a<=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_le(a,b); }
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_LE, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
return a <= b ? result_type(1) : result_type(0);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const {
const Packet cst_one = pset1<Packet>(result_type(1));
return pand(cst_one, pcmp_le(a, b));
}
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GT> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a>b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_lt(b,a); }
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_GT, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
return a > b ? result_type(1) : result_type(0);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const {
const Packet cst_one = pset1<Packet>(result_type(1));
return pand(cst_one, pcmp_lt(b, a));
}
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_GE> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a>=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_le(b,a); }
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_GE, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
return a >= b ? result_type(1) : result_type(0);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const {
const Packet cst_one = pset1<Packet>(result_type(1));
return pand(cst_one, pcmp_le(b, a));
}
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_UNORD> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return !(a<=b || b<=a);}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_eq(internal::por(internal::pcmp_le(a, b), internal::pcmp_le(b, a)), internal::pzero(a)); }
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_UNORD, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
return !(a <= b || b <= a) ? result_type(1) : result_type(0);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const {
const Packet cst_one = pset1<Packet>(result_type(1));
return pandnot(cst_one, por(pcmp_le(a, b), pcmp_le(b, a)));
}
};
template<typename LhsScalar, typename RhsScalar>
struct scalar_cmp_op<LhsScalar,RhsScalar, cmp_NEQ> : binary_op_base<LhsScalar,RhsScalar>
{
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const LhsScalar& a, const RhsScalar& b) const {return a!=b;}
template<typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const
{ return internal::pcmp_eq(internal::pcmp_eq(a, b), internal::pzero(a)); }
template <typename LhsScalar, typename RhsScalar, bool UseTypedComparators>
struct scalar_cmp_op<LhsScalar, RhsScalar, cmp_NEQ, UseTypedComparators> : binary_op_base<LhsScalar, RhsScalar> {
using result_type = cmp_return_t<LhsScalar, RhsScalar, UseTypedComparators>;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(const LhsScalar& a, const RhsScalar& b) const {
return a != b ? result_type(1) : result_type(0);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const {
const Packet cst_one = pset1<Packet>(result_type(1));
return pandnot(cst_one, pcmp_eq(a, b));
}
};
/** \internal
@ -511,6 +542,50 @@ struct functor_traits<scalar_boolean_xor_op<Scalar>> {
enum { Cost = NumTraits<Scalar>::AddCost, PacketAccess = packet_traits<Scalar>::HasCmp };
};
template <typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
struct bitwise_binary_impl {
static constexpr size_t Size = sizeof(Scalar);
using uint_t = typename numext::get_integer_by_size<Size>::unsigned_type;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_and(const Scalar& a, const Scalar& b) {
uint_t a_as_uint = numext::bit_cast<uint_t, Scalar>(a);
uint_t b_as_uint = numext::bit_cast<uint_t, Scalar>(b);
uint_t result = a_as_uint & b_as_uint;
return numext::bit_cast<Scalar, uint_t>(result);
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_or(const Scalar& a, const Scalar& b) {
uint_t a_as_uint = numext::bit_cast<uint_t, Scalar>(a);
uint_t b_as_uint = numext::bit_cast<uint_t, Scalar>(b);
uint_t result = a_as_uint | b_as_uint;
return numext::bit_cast<Scalar, uint_t>(result);
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_xor(const Scalar& a, const Scalar& b) {
uint_t a_as_uint = numext::bit_cast<uint_t, Scalar>(a);
uint_t b_as_uint = numext::bit_cast<uint_t, Scalar>(b);
uint_t result = a_as_uint ^ b_as_uint;
return numext::bit_cast<Scalar, uint_t>(result);
}
};
template <typename Scalar>
struct bitwise_binary_impl<Scalar, true> {
using Real = typename NumTraits<Scalar>::Real;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_and(const Scalar& a, const Scalar& b) {
Real real_result = bitwise_binary_impl<Real>::run_and(numext::real(a), numext::real(b));
Real imag_result = bitwise_binary_impl<Real>::run_and(numext::imag(a), numext::imag(b));
return Scalar(real_result, imag_result);
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_or(const Scalar& a, const Scalar& b) {
Real real_result = bitwise_binary_impl<Real>::run_or(numext::real(a), numext::real(b));
Real imag_result = bitwise_binary_impl<Real>::run_or(numext::imag(a), numext::imag(b));
return Scalar(real_result, imag_result);
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_xor(const Scalar& a, const Scalar& b) {
Real real_result = bitwise_binary_impl<Real>::run_xor(numext::real(a), numext::real(b));
Real imag_result = bitwise_binary_impl<Real>::run_xor(numext::imag(a), numext::imag(b));
return Scalar(real_result, imag_result);
}
};
/** \internal
* \brief Template functor to compute the bitwise and of two scalars
*
@ -518,15 +593,12 @@ struct functor_traits<scalar_boolean_xor_op<Scalar>> {
*/
template <typename Scalar>
struct scalar_bitwise_and_op {
EIGEN_STATIC_ASSERT(!NumTraits<Scalar>::RequireInitialization, BITWISE OPERATIONS MAY ONLY BE PERFORMED ON PLAIN DATA TYPES )
EIGEN_STATIC_ASSERT(!NumTraits<Scalar>::RequireInitialization,
BITWISE OPERATIONS MAY ONLY BE PERFORMED ON PLAIN DATA TYPES)
EIGEN_STATIC_ASSERT((!internal::is_same<Scalar, bool>::value), DONT USE BITWISE OPS ON BOOLEAN TYPES)
using result_type = Scalar;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a, const Scalar& b) const {
Scalar result;
const uint8_t* a_bytes = reinterpret_cast<const uint8_t*>(&a);
const uint8_t* b_bytes = reinterpret_cast<const uint8_t*>(&b);
uint8_t* r_bytes = reinterpret_cast<uint8_t*>(&result);
for (Index i = 0; i < sizeof(Scalar); i++) r_bytes[i] = a_bytes[i] & b_bytes[i];
return result;
return bitwise_binary_impl<Scalar>::run_and(a, b);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const {
@ -545,15 +617,12 @@ struct functor_traits<scalar_bitwise_and_op<Scalar>> {
*/
template <typename Scalar>
struct scalar_bitwise_or_op {
EIGEN_STATIC_ASSERT(!NumTraits<Scalar>::RequireInitialization, BITWISE OPERATIONS MAY ONLY BE PERFORMED ON PLAIN DATA TYPES)
EIGEN_STATIC_ASSERT(!NumTraits<Scalar>::RequireInitialization,
BITWISE OPERATIONS MAY ONLY BE PERFORMED ON PLAIN DATA TYPES)
EIGEN_STATIC_ASSERT((!internal::is_same<Scalar, bool>::value), DONT USE BITWISE OPS ON BOOLEAN TYPES)
using result_type = Scalar;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a, const Scalar& b) const {
Scalar result;
const uint8_t* a_bytes = reinterpret_cast<const uint8_t*>(&a);
const uint8_t* b_bytes = reinterpret_cast<const uint8_t*>(&b);
uint8_t* r_bytes = reinterpret_cast<uint8_t*>(&result);
for (Index i = 0; i < sizeof(Scalar); i++) r_bytes[i] = a_bytes[i] | b_bytes[i];
return result;
return bitwise_binary_impl<Scalar>::run_or(a, b);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const {
@ -572,15 +641,12 @@ struct functor_traits<scalar_bitwise_or_op<Scalar>> {
*/
template <typename Scalar>
struct scalar_bitwise_xor_op {
EIGEN_STATIC_ASSERT(!NumTraits<Scalar>::RequireInitialization, BITWISE OPERATIONS MAY ONLY BE PERFORMED ON PLAIN DATA TYPES)
EIGEN_STATIC_ASSERT(!NumTraits<Scalar>::RequireInitialization,
BITWISE OPERATIONS MAY ONLY BE PERFORMED ON PLAIN DATA TYPES)
EIGEN_STATIC_ASSERT((!internal::is_same<Scalar, bool>::value), DONT USE BITWISE OPS ON BOOLEAN TYPES)
using result_type = Scalar;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a, const Scalar& b) const {
Scalar result;
const uint8_t* a_bytes = reinterpret_cast<const uint8_t*>(&a);
const uint8_t* b_bytes = reinterpret_cast<const uint8_t*>(&b);
uint8_t* r_bytes = reinterpret_cast<uint8_t*>(&result);
for (Index i = 0; i < sizeof(Scalar); i++) r_bytes[i] = a_bytes[i] ^ b_bytes[i];
return result;
return bitwise_binary_impl<Scalar>::run_xor(a, b);
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) const {

View File

@ -18,7 +18,30 @@ namespace internal {
//---------- associative ternary functors ----------
template <typename ConditionScalar, typename ThenScalar, typename ElseScalar>
struct scalar_boolean_select_op {
static constexpr bool ThenElseAreSame = is_same<ThenScalar, ElseScalar>::value;
EIGEN_STATIC_ASSERT(ThenElseAreSame, THEN AND ELSE MUST BE SAME TYPE)
using Scalar = ThenScalar;
using result_type = Scalar;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const ConditionScalar& cond, const ThenScalar& a,
const ElseScalar& b) const {
return cond == ConditionScalar(0) ? b : a;
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& cond, const Packet& a, const Packet& b) const {
return pselect(pcmp_eq(cond, pzero(cond)), b, a);
}
};
template <typename ConditionScalar, typename ThenScalar, typename ElseScalar>
struct functor_traits<scalar_boolean_select_op<ConditionScalar, ThenScalar, ElseScalar>> {
using Scalar = ThenScalar;
enum {
Cost = 1,
PacketAccess = is_same<ThenScalar, ElseScalar>::value && is_same<ConditionScalar, Scalar>::value && packet_traits<Scalar>::HasCmp
};
};
} // end namespace internal

View File

@ -178,6 +178,13 @@ struct scalar_cast_op {
typedef NewType result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const NewType operator() (const Scalar& a) const { return cast<Scalar, NewType>(a); }
};
template <typename Scalar>
struct scalar_cast_op<Scalar, bool> {
typedef bool result_type;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const Scalar& a) const { return a != Scalar(0); }
};
template<typename Scalar, typename NewType>
struct functor_traits<scalar_cast_op<Scalar,NewType> >
{ enum { Cost = is_same<Scalar, NewType>::value ? 0 : NumTraits<NewType>::AddCost, PacketAccess = false }; };
@ -942,6 +949,27 @@ struct functor_traits<scalar_boolean_not_op<Scalar>> {
enum { Cost = NumTraits<Scalar>::AddCost, PacketAccess = packet_traits<Scalar>::HasCmp };
};
template <typename Scalar, bool IsComplex = NumTraits<Scalar>::IsComplex>
struct bitwise_unary_impl {
static constexpr size_t Size = sizeof(Scalar);
using uint_t = typename numext::get_integer_by_size<Size>::unsigned_type;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_not(const Scalar& a) {
uint_t a_as_uint = numext::bit_cast<uint_t, Scalar>(a);
uint_t result = ~a_as_uint;
return numext::bit_cast<Scalar, uint_t>(result);
}
};
template <typename Scalar>
struct bitwise_unary_impl<Scalar, true> {
using Real = typename NumTraits<Scalar>::Real;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_not(const Scalar& a) {
Real real_result = bitwise_unary_impl<Real>::run_not(numext::real(a));
Real imag_result = bitwise_unary_impl<Real>::run_not(numext::imag(a));
return Scalar(real_result, imag_result);
}
};
/** \internal
* \brief Template functor to compute the bitwise not of a scalar
*
@ -950,13 +978,10 @@ struct functor_traits<scalar_boolean_not_op<Scalar>> {
template <typename Scalar>
struct scalar_bitwise_not_op {
EIGEN_STATIC_ASSERT(!NumTraits<Scalar>::RequireInitialization, BITWISE OPERATIONS MAY ONLY BE PERFORMED ON PLAIN DATA TYPES)
EIGEN_STATIC_ASSERT((!internal::is_same<Scalar, bool>::value), DONT USE BITWISE OPS ON BOOLEAN TYPES)
using result_type = Scalar;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& a) const {
Scalar result;
const uint8_t* a_bytes = reinterpret_cast<const uint8_t*>(&a);
uint8_t* r_bytes = reinterpret_cast<uint8_t*>(&result);
for (Index i = 0; i < sizeof(Scalar); i++) r_bytes[i] = ~a_bytes[i];
return result;
return bitwise_unary_impl<Scalar>::run_not(a);
}
template <typename Packet>
EIGEN_STRONG_INLINE Packet packetOp(const Packet& a) const {

View File

@ -39,10 +39,10 @@ cwiseProduct(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*/
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ>, const Derived, const OtherDerived>
cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
return CwiseBinaryOp<std::equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise != operator of *this and \a other
@ -59,10 +59,46 @@ cwiseEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
*/
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ>, const Derived, const OtherDerived>
cwiseNotEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
{
return CwiseBinaryOp<std::not_equal_to<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise < operator of *this and \a other */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT>, const Derived, const OtherDerived>
cwiseLesser(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const
{
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise > operator of *this and \a other */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT>, const Derived, const OtherDerived>
cwiseGreater(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const
{
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise <= operator of *this and \a other */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE>, const Derived, const OtherDerived>
cwiseLesserOrEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const
{
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise >= operator of *this and \a other */
template<typename OtherDerived>
EIGEN_DEVICE_FUNC
inline const CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE>, const Derived, const OtherDerived>
cwiseGreaterOrEqual(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived>& other) const
{
return CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE>, const Derived, const OtherDerived>(derived(), other.derived());
}
/** \returns an expression of the coefficient-wise min of *this and \a other
@ -135,7 +171,12 @@ cwiseQuotient(const EIGEN_CURRENT_STORAGE_BASE_CLASS<OtherDerived> &other) const
return CwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const Derived, const OtherDerived>(derived(), other.derived());
}
typedef CwiseBinaryOp<internal::scalar_cmp_op<Scalar,Scalar,internal::cmp_EQ>, const Derived, const ConstantReturnType> CwiseScalarEqualReturnType;
using CwiseScalarEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar,Scalar,internal::cmp_EQ>, const Derived, const ConstantReturnType>;
using CwiseScalarNotEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ>, const Derived, const ConstantReturnType>;
using CwiseScalarLesserReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT>, const Derived, const ConstantReturnType>;
using CwiseScalarGreaterReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT>, const Derived, const ConstantReturnType>;
using CwiseScalarLesserOrEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE>, const Derived, const ConstantReturnType>;
using CwiseScalarGreaterOrEqualReturnType = CwiseBinaryOp<internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE>, const Derived, const ConstantReturnType>;
/** \returns an expression of the coefficient-wise == operator of \c *this and a scalar \a s
*
@ -152,3 +193,54 @@ cwiseEqual(const Scalar& s) const
{
return CwiseScalarEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar,Scalar,internal::cmp_EQ>());
}
/** \returns an expression of the coefficient-wise == operator of \c *this and a scalar \a s
*
* \warning this performs an exact comparison, which is generally a bad idea with floating-point types.
* In order to check for equality between two vectors or matrices with floating-point coefficients, it is
* generally a far better idea to use a fuzzy comparison as provided by isApprox() and
* isMuchSmallerThan().
*
* \sa cwiseEqual(const MatrixBase<OtherDerived> &) const
*/
EIGEN_DEVICE_FUNC
inline const CwiseScalarNotEqualReturnType
cwiseNotEqual(const Scalar& s) const
{
return CwiseScalarNotEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ>());
}
/** \returns an expression of the coefficient-wise < operator of \c *this and a scalar \a s */
EIGEN_DEVICE_FUNC
inline const CwiseScalarLesserReturnType
cwiseLesser(const Scalar& s) const
{
return CwiseScalarLesserReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT>());
}
/** \returns an expression of the coefficient-wise > operator of \c *this and a scalar \a s */
EIGEN_DEVICE_FUNC
inline const CwiseScalarGreaterReturnType
cwiseGreater(const Scalar& s) const
{
return CwiseScalarGreaterReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT>());
}
/** \returns an expression of the coefficient-wise <= operator of \c *this and a scalar \a s */
EIGEN_DEVICE_FUNC
inline const CwiseScalarLesserOrEqualReturnType
cwiseLesserOrEqual(const Scalar& s) const
{
return CwiseScalarLesserOrEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE>());
}
/** \returns an expression of the coefficient-wise >= operator of \c *this and a scalar \a s */
EIGEN_DEVICE_FUNC
inline const CwiseScalarGreaterOrEqualReturnType
cwiseGreaterOrEqual(const Scalar& s) const
{
return CwiseScalarGreaterOrEqualReturnType(derived(), Derived::Constant(rows(), cols(), s), internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE>());
}

View File

@ -590,6 +590,21 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
typedef typename ArrayType::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
// explicitly test both typed and boolean comparison ops
using typed_eq = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, true>;
using typed_ne = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, true>;
using typed_lt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, true>;
using typed_le = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, true>;
using typed_gt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, true>;
using typed_ge = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, true>;
using bool_eq = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_EQ, false>;
using bool_ne = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_NEQ, false>;
using bool_lt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LT, false>;
using bool_le = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_LE, false>;
using bool_gt = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GT, false>;
using bool_ge = internal::scalar_cmp_op<Scalar, Scalar, internal::cmp_GE, false>;
Index rows = m.rows();
Index cols = m.cols();
@ -603,6 +618,8 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
m4 = (m4.abs()==Scalar(0)).select(1,m4);
// use operator overloads with default return type
VERIFY(((m1 + Scalar(1)) > m1).all());
VERIFY(((m1 - Scalar(1)) < m1).all());
if (rows*cols>1)
@ -627,6 +644,34 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
VERIFY( ( (m1(r,c)+1) > m1).any() );
VERIFY( ( m1(r,c) == m1).any() );
// currently, any() / all() are not vectorized, so use VERIFY_IS_CWISE_EQUAL to test vectorized path
// use typed comparisons, regardless of operator overload behavior
typename ArrayType::ConstantReturnType typed_true = ArrayType::Constant(rows, cols, Scalar(1));
// (m1 + Scalar(1)) > m1).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1, typed_gt()), typed_true);
// (m1 - Scalar(1)) < m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, typed_lt()), typed_true);
// (m1 + Scalar(1)) == (m1 + Scalar(1))).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1 + Scalar(1), typed_eq()), typed_true);
// (m1 - Scalar(1)) != m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, typed_ne()), typed_true);
// (m1 <= m2 || m1 >= m2).all()
VERIFY_IS_CWISE_EQUAL(m1.binaryExpr(m2, typed_le()) || m1.binaryExpr(m2, typed_ge()), typed_true);
// use boolean comparisons, regardless of operator overload behavior
ArrayXX<bool>::ConstantReturnType bool_true = ArrayXX<bool>::Constant(rows, cols, true);
// (m1 + Scalar(1)) > m1).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1, bool_gt()), bool_true);
// (m1 - Scalar(1)) < m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, bool_lt()), bool_true);
// (m1 + Scalar(1)) == (m1 + Scalar(1))).all()
VERIFY_IS_CWISE_EQUAL((m1 + Scalar(1)).binaryExpr(m1 + Scalar(1), bool_eq()), bool_true);
// (m1 - Scalar(1)) != m1).all()
VERIFY_IS_CWISE_EQUAL((m1 - Scalar(1)).binaryExpr(m1, bool_ne()), bool_true);
// (m1 <= m2 || m1 >= m2).all()
VERIFY_IS_CWISE_EQUAL(m1.binaryExpr(m2, bool_le()) || m1.binaryExpr(m2, bool_ge()), bool_true);
// test Select
VERIFY_IS_APPROX( (m1<m2).select(m1,m2), m1.cwiseMin(m2) );
VERIFY_IS_APPROX( (m1>m2).select(m1,m2), m1.cwiseMax(m2) );
@ -642,7 +687,7 @@ template<typename ArrayType> void comparisons(const ArrayType& m)
VERIFY_IS_APPROX( (m1.abs()>=ArrayType::Constant(rows,cols,mid))
.select(m1,0), m3);
// even shorter version:
VERIFY_IS_APPROX( (m1.abs()<mid).select(0,m1), m3);
VERIFY_IS_APPROX( (m1.abs()<mid).select(0,m1), m3);
// count
VERIFY(((m1.abs()+1)>RealScalar(0.1)).count() == rows*cols);
@ -1039,7 +1084,7 @@ struct typed_logicals_test_impl {
using Scalar = typename ArrayType::Scalar;
static bool scalar_to_bool(const Scalar& x) { return x != Scalar(0); }
static Scalar bool_to_scalar(const bool& x) { return x ? Scalar(1) : Scalar(0); }
static Scalar bool_to_scalar(bool x) { return x ? Scalar(1) : Scalar(0); }
static Scalar eval_bool_and(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) && scalar_to_bool(y)); }
static Scalar eval_bool_or(const Scalar& x, const Scalar& y) { return bool_to_scalar(scalar_to_bool(x) || scalar_to_bool(y)); }
@ -1091,40 +1136,45 @@ struct typed_logicals_test_impl {
m4 = (!m1).binaryExpr((!m2), internal::scalar_boolean_xor_op<Scalar>());
VERIFY_IS_CWISE_EQUAL(m3, m4);
const Index bytes = rows * cols * sizeof(Scalar);
const uint8_t* m1_data = reinterpret_cast<const uint8_t*>(m1.data());
const uint8_t* m2_data = reinterpret_cast<const uint8_t*>(m2.data());
uint8_t* m3_data = reinterpret_cast<uint8_t*>(m3.data());
uint8_t* m4_data = reinterpret_cast<uint8_t*>(m4.data());
const size_t bytes = size_t(rows) * size_t(cols) * sizeof(Scalar);
std::vector<uint8_t> m1_buffer(bytes), m2_buffer(bytes), m3_buffer(bytes), m4_buffer(bytes);
std::memcpy(m1_buffer.data(), m1.data(), bytes);
std::memcpy(m2_buffer.data(), m2.data(), bytes);
// test bitwise and
m3 = m1 & m2;
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] & m2_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(m1_buffer[i] & m2_buffer[i]));
// test bitwise or
m3 = m1 | m2;
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] | m2_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(m1_buffer[i] | m2_buffer[i]));
// test bitwise xor
m3 = m1 ^ m2;
for (Index i = 0; i < bytes; i++) m4_data[i] = m1_data[i] ^ m2_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(m1_buffer[i] ^ m2_buffer[i]));
// test bitwise not
m3 = ~m1;
for (Index i = 0; i < bytes; i++) m4_data[i] = ~m1_data[i];
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], uint8_t(~m1_buffer[i]));
// test something more complicated
m3 = m1 & m2;
m4 = ~(~m1 | ~m2);
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
std::memcpy(m4_buffer.data(), m4.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], m4_buffer[i]);
m3 = m1 ^ m2;
m4 = (~m1) ^ (~m2);
for (Index i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_data[i], m4_data[i]);
std::memcpy(m3_buffer.data(), m3.data(), bytes);
std::memcpy(m4_buffer.data(), m4.data(), bytes);
for (size_t i = 0; i < bytes; i++) VERIFY_IS_EQUAL(m3_buffer[i], m4_buffer[i]);
}
};
template <typename ArrayType>
@ -1181,7 +1231,6 @@ EIGEN_DECLARE_TEST(array_cwise)
CALL_SUBTEST_8( signbit_tests() );
}
for (int i = 0; i < g_repeat; i++) {
CALL_SUBTEST_1( typed_logicals_test(ArrayX<bool>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_2( typed_logicals_test(ArrayX<int>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_2( typed_logicals_test(ArrayX<float>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))) );
CALL_SUBTEST_3( typed_logicals_test(ArrayX<double>(internal::random<int>(1, EIGEN_TEST_MAX_SIZE))));

View File

@ -496,7 +496,7 @@ typename NumTraits<typename T1::RealScalar>::NonInteger test_relative_error(cons
typedef typename NumTraits<typename T1::RealScalar>::NonInteger RealScalar;
typename internal::nested_eval<T1,2>::type ea(a.derived());
typename internal::nested_eval<T2,2>::type eb(b.derived());
return sqrt(RealScalar((ea-eb).cwiseAbs2().sum()) / RealScalar((std::min)(eb.cwiseAbs2().sum(),ea.cwiseAbs2().sum())));
return sqrt(RealScalar((ea.matrix()-eb.matrix()).cwiseAbs2().sum()) / RealScalar((std::min)(eb.cwiseAbs2().sum(),ea.cwiseAbs2().sum())));
}
template<typename T1,typename T2>

View File

@ -14,14 +14,16 @@
using Eigen::Tensor;
using Eigen::RowMajor;
using Scalar = float;
static void test_orderings()
{
Tensor<float, 3> mat1(2,3,7);
Tensor<float, 3> mat2(2,3,7);
Tensor<bool, 3> lt(2,3,7);
Tensor<bool, 3> le(2,3,7);
Tensor<bool, 3> gt(2,3,7);
Tensor<bool, 3> ge(2,3,7);
Tensor<Scalar, 3> mat1(2,3,7);
Tensor<Scalar, 3> mat2(2,3,7);
Tensor<Scalar, 3> lt(2,3,7);
Tensor<Scalar, 3> le(2,3,7);
Tensor<Scalar, 3> gt(2,3,7);
Tensor<Scalar, 3> ge(2,3,7);
mat1.setRandom();
mat2.setRandom();
@ -46,8 +48,8 @@ static void test_orderings()
static void test_equality()
{
Tensor<float, 3> mat1(2,3,7);
Tensor<float, 3> mat2(2,3,7);
Tensor<Scalar, 3> mat1(2,3,7);
Tensor<Scalar, 3> mat2(2,3,7);
mat1.setRandom();
mat2.setRandom();
@ -61,8 +63,8 @@ static void test_equality()
}
}
Tensor<bool, 3> eq(2,3,7);
Tensor<bool, 3> ne(2,3,7);
Tensor<Scalar, 3> eq(2,3,7);
Tensor<Scalar, 3> ne(2,3,7);
eq = (mat1 == mat2);
ne = (mat1 != mat2);

View File

@ -200,14 +200,14 @@ static void test_boolean()
std::iota(vec.data(), vec.data() + kSize, 0);
// Test ||.
Tensor<bool, 1> bool1 = vec < vec.constant(1) || vec > vec.constant(4);
Tensor<bool, 1> bool1 = (vec < vec.constant(1) || vec > vec.constant(4)).cast<bool>();
for (int i = 0; i < kSize; ++i) {
bool expected = i < 1 || i > 4;
VERIFY_IS_EQUAL(bool1[i], expected);
}
// Test &&, including cast of operand vec.
Tensor<bool, 1> bool2 = vec.cast<bool>() && vec < vec.constant(4);
Tensor<bool, 1> bool2 = vec.cast<bool>() && (vec < vec.constant(4)).cast<bool>();
for (int i = 0; i < kSize; ++i) {
bool expected = bool(i) && i < 4;
VERIFY_IS_EQUAL(bool2[i], expected);
@ -218,7 +218,7 @@ static void test_boolean()
// CoeffReturnType is set to match Op return type of bool for Unary and Binary
// Ops.
Tensor<bool, 1> bool3 = vec.cast<bool>() && bool2;
bool3 = vec < vec.constant(4) && bool2;
bool3 = (vec < vec.constant(4)).cast<bool>() && bool2;
}
static void test_functors()