Make new Select implementation backwards compatible.

This commit is contained in:
Rasmus Munk Larsen 2023-03-10 23:07:47 +00:00
parent 394aabb0a3
commit 6bb9609bcb
3 changed files with 56 additions and 43 deletions

View File

@ -569,24 +569,24 @@ template<typename Derived> class DenseBase
template <typename ThenDerived, typename ElseDerived>
inline EIGEN_DEVICE_FUNC
CwiseTernaryOp<internal::scalar_boolean_select_op<Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>,
Derived, ThenDerived, ElseDerived>
CwiseTernaryOp<internal::scalar_boolean_select_op<typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar, Scalar>,
ThenDerived, ElseDerived, Derived>
select(const DenseBase<ThenDerived>& thenMatrix, const DenseBase<ElseDerived>& elseMatrix) const;
template <typename ThenDerived>
inline EIGEN_DEVICE_FUNC
CwiseTernaryOp<internal::scalar_boolean_select_op<Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ThenDerived>::Scalar>,
Derived, ThenDerived, typename DenseBase<ThenDerived>::ConstantReturnType>
CwiseTernaryOp<internal::scalar_boolean_select_op<typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ThenDerived>::Scalar, Scalar>,
ThenDerived, typename DenseBase<ThenDerived>::ConstantReturnType, Derived>
select(const DenseBase<ThenDerived>& thenMatrix,
const typename DenseBase<ThenDerived>::Scalar& elseScalar) const;
template <typename ElseDerived>
inline EIGEN_DEVICE_FUNC
CwiseTernaryOp<internal::scalar_boolean_select_op<Scalar, typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>,
Derived, typename DenseBase<ElseDerived>::ConstantReturnType, ElseDerived>
CwiseTernaryOp<internal::scalar_boolean_select_op<typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar, Scalar>,
typename DenseBase<ElseDerived>::ConstantReturnType, ElseDerived, Derived>
select(const typename DenseBase<ElseDerived>::Scalar& thenScalar,
const DenseBase<ElseDerived>& elseMatrix) const;

View File

@ -124,14 +124,17 @@ class Select : public internal::dense_xpr_base< Select<ConditionMatrixType, Then
template <typename Derived>
template <typename ThenDerived, typename ElseDerived>
inline EIGEN_DEVICE_FUNC CwiseTernaryOp<
internal::scalar_boolean_select_op<typename DenseBase<Derived>::Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>,
Derived, ThenDerived, ElseDerived>
DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix, const DenseBase<ElseDerived>& elseMatrix) const {
using Op = internal::scalar_boolean_select_op<Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>;
return CwiseTernaryOp<Op, Derived, ThenDerived, ElseDerived>(derived(), thenMatrix.derived(), elseMatrix.derived(),
Op());
internal::scalar_boolean_select_op<typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<Derived>::Scalar>,
ThenDerived, ElseDerived, Derived>
DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix,
const DenseBase<ElseDerived>& elseMatrix) const {
using Op = internal::scalar_boolean_select_op<
typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar, Scalar>;
return CwiseTernaryOp<Op, ThenDerived, ElseDerived, Derived>(
thenMatrix.derived(), elseMatrix.derived(), derived(), Op());
}
/** Version of DenseBase::select(const DenseBase&, const DenseBase&) with
* the \em else expression being a scalar value.
@ -141,16 +144,21 @@ DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix, const Dense
template <typename Derived>
template <typename ThenDerived>
inline EIGEN_DEVICE_FUNC CwiseTernaryOp<
internal::scalar_boolean_select_op<typename DenseBase<Derived>::Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ThenDerived>::Scalar>,
Derived, ThenDerived, typename DenseBase<ThenDerived>::ConstantReturnType>
DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix,
const typename DenseBase<ThenDerived>::Scalar& elseScalar) const {
using ElseConstantType = typename DenseBase<ThenDerived>::ConstantReturnType;
using Op = internal::scalar_boolean_select_op<Scalar, typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ThenDerived>::Scalar>;
return CwiseTernaryOp<Op, Derived, ThenDerived, ElseConstantType>(derived(), thenMatrix.derived(),
ElseConstantType(rows(), cols(), elseScalar), Op());
internal::scalar_boolean_select_op<typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<Derived>::Scalar>,
ThenDerived, typename DenseBase<ThenDerived>::ConstantReturnType, Derived>
DenseBase<Derived>::select(
const DenseBase<ThenDerived>& thenMatrix,
const typename DenseBase<ThenDerived>::Scalar& elseScalar) const {
using ElseConstantType =
typename DenseBase<ThenDerived>::ConstantReturnType;
using Op = internal::scalar_boolean_select_op<
typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ThenDerived>::Scalar, Scalar>;
return CwiseTernaryOp<Op, ThenDerived, ElseConstantType, Derived>(
thenMatrix.derived(), ElseConstantType(rows(), cols(), elseScalar),
derived(), Op());
}
/** Version of DenseBase::select(const DenseBase&, const DenseBase&) with
* the \em then expression being a scalar value.
@ -160,16 +168,22 @@ DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix,
template <typename Derived>
template <typename ElseDerived>
inline EIGEN_DEVICE_FUNC CwiseTernaryOp<
internal::scalar_boolean_select_op<typename DenseBase<Derived>::Scalar, typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>,
Derived, typename DenseBase<ElseDerived>::ConstantReturnType, ElseDerived>
DenseBase<Derived>::select(const typename DenseBase<ElseDerived>::Scalar& thenScalar,
const DenseBase<ElseDerived>& elseMatrix) const {
using ThenConstantType = typename DenseBase<ElseDerived>::ConstantReturnType;
using Op = internal::scalar_boolean_select_op<Scalar, typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar>;
return CwiseTernaryOp<Op, Derived, ThenConstantType, ElseDerived>(
derived(), ThenConstantType(rows(), cols(), thenScalar), elseMatrix.derived(), Op());
internal::scalar_boolean_select_op<typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<Derived>::Scalar>,
typename DenseBase<ElseDerived>::ConstantReturnType, ElseDerived,
Derived>
DenseBase<Derived>::select(
const typename DenseBase<ElseDerived>::Scalar& thenScalar,
const DenseBase<ElseDerived>& elseMatrix) const {
using ThenConstantType =
typename DenseBase<ElseDerived>::ConstantReturnType;
using Op = internal::scalar_boolean_select_op<
typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar, Scalar>;
return CwiseTernaryOp<Op, ThenConstantType, ElseDerived, Derived>(
ThenConstantType(rows(), cols(), thenScalar), elseMatrix.derived(),
derived(), Op());
}
} // end namespace Eigen

View File

@ -18,24 +18,23 @@ namespace internal {
//---------- associative ternary functors ----------
template <typename ConditionScalar, typename ThenScalar, typename ElseScalar>
template <typename ThenScalar, typename ElseScalar, typename ConditionScalar>
struct scalar_boolean_select_op {
static constexpr bool ThenElseAreSame = is_same<ThenScalar, ElseScalar>::value;
EIGEN_STATIC_ASSERT(ThenElseAreSame, THEN AND ELSE MUST BE SAME TYPE)
using Scalar = ThenScalar;
using result_type = Scalar;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const ConditionScalar& cond, const ThenScalar& a,
const ElseScalar& b) const {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const ThenScalar& a, const ElseScalar& b, const ConditionScalar& cond) const {
return cond == ConditionScalar(0) ? b : a;
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& cond, const Packet& a, const Packet& b) const {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b, const Packet& cond) const {
return pselect(pcmp_eq(cond, pzero(cond)), b, a);
}
};
template <typename ConditionScalar, typename ThenScalar, typename ElseScalar>
struct functor_traits<scalar_boolean_select_op<ConditionScalar, ThenScalar, ElseScalar>> {
template <typename ThenScalar, typename ElseScalar, typename ConditionScalar>
struct functor_traits<scalar_boolean_select_op<ThenScalar, ElseScalar, ConditionScalar>> {
using Scalar = ThenScalar;
enum {
Cost = 1,