Refactoring: move all symbolic stuff into its own namespace

This commit is contained in:
Gael Guennebaud 2017-01-10 10:57:08 +01:00
parent acd08900c9
commit 9eaab4f9e0

View File

@ -19,93 +19,103 @@ namespace Eigen {
struct all_t { all_t() {} }; struct all_t { all_t() {} };
static const all_t all; static const all_t all;
//--------------------------------------------------------------------------------
// minimalistic symbolic scalar type
//--------------------------------------------------------------------------------
namespace Symbolic {
template<typename Tag> class Symbol;
template<typename Arg0> class NegateExpr;
template<typename Arg1,typename Arg2> class AddExpr;
template<typename Arg1,typename Arg2> class ProductExpr;
template<typename Arg1,typename Arg2> class QuotientExpr;
// A simple wrapper around an Index to provide the eval method. // A simple wrapper around an Index to provide the eval method.
// We could also use a free-function symbolic_eval... // We could also use a free-function symbolic_eval...
class symbolic_value_wrapper { class ValueExpr {
public: public:
symbolic_value_wrapper(Index val) : m_value(val) {} ValueExpr(Index val) : m_value(val) {}
template<typename T> template<typename T>
Index eval(const T&) const { return m_value; } Index eval(const T&) const { return m_value; }
protected: protected:
Index m_value; Index m_value;
}; };
//--------------------------------------------------------------------------------
// minimalistic symbolic scalar type
//--------------------------------------------------------------------------------
template<typename Tag> class symbolic_symbol;
template<typename Arg0> class symbolic_negate;
template<typename Arg1,typename Arg2> class symbolic_add;
template<typename Arg1,typename Arg2> class symbolic_product;
template<typename Arg1,typename Arg2> class symbolic_quotient;
template<typename Derived> template<typename Derived>
class symbolic_index_base class BaseExpr
{ {
public: public:
const Derived& derived() const { return *static_cast<const Derived*>(this); } const Derived& derived() const { return *static_cast<const Derived*>(this); }
symbolic_negate<Derived> operator-() const { return symbolic_negate<Derived>(derived()); } NegateExpr<Derived> operator-() const { return NegateExpr<Derived>(derived()); }
symbolic_add<Derived,symbolic_value_wrapper> operator+(Index b) const AddExpr<Derived,ValueExpr> operator+(Index b) const
{ return symbolic_add<Derived,symbolic_value_wrapper >(derived(), b); } { return AddExpr<Derived,ValueExpr >(derived(), b); }
symbolic_add<Derived,symbolic_value_wrapper> operator-(Index a) const AddExpr<Derived,ValueExpr> operator-(Index a) const
{ return symbolic_add<Derived,symbolic_value_wrapper >(derived(), -a); } { return AddExpr<Derived,ValueExpr >(derived(), -a); }
symbolic_quotient<Derived,symbolic_value_wrapper> operator/(Index a) const QuotientExpr<Derived,ValueExpr> operator/(Index a) const
{ return symbolic_quotient<Derived,symbolic_value_wrapper>(derived(),a); } { return QuotientExpr<Derived,ValueExpr>(derived(),a); }
friend symbolic_add<Derived,symbolic_value_wrapper> operator+(Index a, const symbolic_index_base& b) friend AddExpr<Derived,ValueExpr> operator+(Index a, const BaseExpr& b)
{ return symbolic_add<Derived,symbolic_value_wrapper>(b.derived(), a); } { return AddExpr<Derived,ValueExpr>(b.derived(), a); }
friend symbolic_add<symbolic_negate<Derived>,symbolic_value_wrapper> operator-(Index a, const symbolic_index_base& b) friend AddExpr<NegateExpr<Derived>,ValueExpr> operator-(Index a, const BaseExpr& b)
{ return symbolic_add<symbolic_negate<Derived>,symbolic_value_wrapper>(-b.derived(), a); } { return AddExpr<NegateExpr<Derived>,ValueExpr>(-b.derived(), a); }
friend symbolic_add<symbolic_value_wrapper,Derived> operator/(Index a, const symbolic_index_base& b) friend AddExpr<ValueExpr,Derived> operator/(Index a, const BaseExpr& b)
{ return symbolic_add<symbolic_value_wrapper,Derived>(a,b.derived()); } { return AddExpr<ValueExpr,Derived>(a,b.derived()); }
template<typename OtherDerived> template<typename OtherDerived>
symbolic_add<Derived,OtherDerived> operator+(const symbolic_index_base<OtherDerived> &b) const AddExpr<Derived,OtherDerived> operator+(const BaseExpr<OtherDerived> &b) const
{ return symbolic_add<Derived,OtherDerived>(derived(), b.derived()); } { return AddExpr<Derived,OtherDerived>(derived(), b.derived()); }
template<typename OtherDerived> template<typename OtherDerived>
symbolic_add<Derived,symbolic_negate<OtherDerived> > operator-(const symbolic_index_base<OtherDerived> &b) const AddExpr<Derived,NegateExpr<OtherDerived> > operator-(const BaseExpr<OtherDerived> &b) const
{ return symbolic_add<Derived,symbolic_negate<OtherDerived> >(derived(), -b.derived()); } { return AddExpr<Derived,NegateExpr<OtherDerived> >(derived(), -b.derived()); }
template<typename OtherDerived> template<typename OtherDerived>
symbolic_add<Derived,OtherDerived> operator/(const symbolic_index_base<OtherDerived> &b) const AddExpr<Derived,OtherDerived> operator/(const BaseExpr<OtherDerived> &b) const
{ return symbolic_quotient<Derived,OtherDerived>(derived(), b.derived()); } { return QuotientExpr<Derived,OtherDerived>(derived(), b.derived()); }
}; };
template<typename T> template<typename T>
struct is_symbolic { struct is_symbolic {
enum { value = internal::is_convertible<T,symbolic_index_base<T> >::value }; // BaseExpr has no conversion ctor, so we only to check whether T can be staticaly cast to its base class BaseExpr<T>.
enum { value = internal::is_convertible<T,BaseExpr<T> >::value };
}; };
template<typename Tag> template<typename Tag>
class symbolic_value_pair class SymbolValue
{ {
public: public:
symbolic_value_pair(Index val) : m_value(val) {} SymbolValue(Index val) : m_value(val) {}
Index value() const { return m_value; } Index value() const { return m_value; }
protected: protected:
Index m_value; Index m_value;
}; };
template<typename Tag> template<typename TagT>
class symbolic_value : public symbolic_index_base<symbolic_value<Tag> > class SymbolExpr : public BaseExpr<SymbolExpr<TagT> >
{ {
public: public:
symbolic_value() {} typedef TagT Tag;
SymbolExpr() {}
Index eval(const symbolic_value_pair<Tag> &values) const { return values.value(); } Index eval(const SymbolValue<Tag> &values) const { return values.value(); }
// TODO add a c++14 eval taking a tuple of symbolic_value_pair and getting the value with std::get<symbolic_value_pair<Tag> >...
// TODO add a c++14 eval taking a tuple of SymbolValue and getting the value with std::get<SymbolValue<Tag> >...
}; };
template<typename Tag>
SymbolValue<Tag> defineValue(SymbolExpr<Tag>,Index val) {
return SymbolValue<Tag>(val);
}
template<typename Arg0> template<typename Arg0>
class symbolic_negate : public symbolic_index_base<symbolic_negate<Arg0> > class NegateExpr : public BaseExpr<NegateExpr<Arg0> >
{ {
public: public:
symbolic_negate(const Arg0& arg0) : m_arg0(arg0) {} NegateExpr(const Arg0& arg0) : m_arg0(arg0) {}
template<typename T> template<typename T>
Index eval(const T& values) const { return -m_arg0.eval(values); } Index eval(const T& values) const { return -m_arg0.eval(values); }
@ -114,10 +124,10 @@ protected:
}; };
template<typename Arg0, typename Arg1> template<typename Arg0, typename Arg1>
class symbolic_add : public symbolic_index_base<symbolic_add<Arg0,Arg1> > class AddExpr : public BaseExpr<AddExpr<Arg0,Arg1> >
{ {
public: public:
symbolic_add(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} AddExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template<typename T> template<typename T>
Index eval(const T& values) const { return m_arg0.eval(values) + m_arg1.eval(values); } Index eval(const T& values) const { return m_arg0.eval(values) + m_arg1.eval(values); }
@ -127,10 +137,10 @@ protected:
}; };
template<typename Arg0, typename Arg1> template<typename Arg0, typename Arg1>
class symbolic_product : public symbolic_index_base<symbolic_product<Arg0,Arg1> > class ProductExpr : public BaseExpr<ProductExpr<Arg0,Arg1> >
{ {
public: public:
symbolic_product(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} ProductExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template<typename T> template<typename T>
Index eval(const T& values) const { return m_arg0.eval(values) * m_arg1.eval(values); } Index eval(const T& values) const { return m_arg0.eval(values) * m_arg1.eval(values); }
@ -140,10 +150,10 @@ protected:
}; };
template<typename Arg0, typename Arg1> template<typename Arg0, typename Arg1>
class symbolic_quotient : public symbolic_index_base<symbolic_quotient<Arg0,Arg1> > class QuotientExpr : public BaseExpr<QuotientExpr<Arg0,Arg1> >
{ {
public: public:
symbolic_quotient(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {} QuotientExpr(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template<typename T> template<typename T>
Index eval(const T& values) const { return m_arg0.eval(values) / m_arg1.eval(values); } Index eval(const T& values) const { return m_arg0.eval(values) / m_arg1.eval(values); }
@ -152,12 +162,16 @@ protected:
Arg1 m_arg1; Arg1 m_arg1;
}; };
struct symb_last_tag {}; } // end namespace Symbolic
namespace placeholders { namespace placeholders {
static const symbolic_value<symb_last_tag> last; namespace internal {
static const symbolic_add<symbolic_value<symb_last_tag>,symbolic_value_wrapper> end(last+1); struct symbolic_last_tag {};
}
static const Symbolic::SymbolExpr<internal::symbolic_last_tag> last;
static const Symbolic::AddExpr<Symbolic::SymbolExpr<internal::symbolic_last_tag>,Symbolic::ValueExpr> end(last+1);
} // end namespace placeholders } // end namespace placeholders
@ -256,7 +270,7 @@ auto seq(FirstType f, LastType l, IncrType incr)
} }
#else #else
template<typename FirstType,typename LastType> template<typename FirstType,typename LastType>
typename internal::enable_if<!(is_symbolic<FirstType>::value || is_symbolic<LastType>::value), typename internal::enable_if<!(Symbolic::is_symbolic<FirstType>::value || Symbolic::is_symbolic<LastType>::value),
ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,Index> >::type ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,Index> >::type
seq(FirstType f, LastType l) seq(FirstType f, LastType l)
{ {
@ -264,31 +278,34 @@ seq(FirstType f, LastType l)
} }
template<typename FirstTypeDerived,typename LastType> template<typename FirstTypeDerived,typename LastType>
typename internal::enable_if<!is_symbolic<LastType>::value, typename internal::enable_if<!Symbolic::is_symbolic<LastType>::value,
ArithemeticSequence<FirstTypeDerived,symbolic_add<symbolic_add<symbolic_negate<FirstTypeDerived>,symbolic_value_wrapper>,symbolic_value_wrapper> > >::type ArithemeticSequence<FirstTypeDerived, Symbolic::AddExpr<Symbolic::AddExpr<Symbolic::NegateExpr<FirstTypeDerived>,Symbolic::ValueExpr>,
seq(const symbolic_index_base<FirstTypeDerived> &f, LastType l) Symbolic::ValueExpr> > >::type
seq(const Symbolic::BaseExpr<FirstTypeDerived> &f, LastType l)
{ {
return seqN(f.derived(),(l-f.derived()+1)); return seqN(f.derived(),(l-f.derived()+1));
} }
template<typename FirstType,typename LastTypeDerived> template<typename FirstType,typename LastTypeDerived>
typename internal::enable_if<!is_symbolic<FirstType>::value, typename internal::enable_if<!Symbolic::is_symbolic<FirstType>::value,
ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,symbolic_add<symbolic_add<LastTypeDerived,symbolic_value_wrapper>,symbolic_value_wrapper> > >::type ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,
seq(FirstType f, const symbolic_index_base<LastTypeDerived> &l) Symbolic::AddExpr<Symbolic::AddExpr<LastTypeDerived,Symbolic::ValueExpr>,Symbolic::ValueExpr> > >::type
seq(FirstType f, const Symbolic::BaseExpr<LastTypeDerived> &l)
{ {
return seqN(f,(l.derived()-f+1)); return seqN(f,(l.derived()-f+1));
} }
template<typename FirstTypeDerived,typename LastTypeDerived> template<typename FirstTypeDerived,typename LastTypeDerived>
ArithemeticSequence<FirstTypeDerived,symbolic_add<symbolic_add<LastTypeDerived,symbolic_negate<FirstTypeDerived> >,symbolic_value_wrapper> > ArithemeticSequence<FirstTypeDerived,
seq(const symbolic_index_base<FirstTypeDerived> &f, const symbolic_index_base<LastTypeDerived> &l) Symbolic::AddExpr<Symbolic::AddExpr<LastTypeDerived,Symbolic::NegateExpr<FirstTypeDerived> >,Symbolic::ValueExpr> >
seq(const Symbolic::BaseExpr<FirstTypeDerived> &f, const Symbolic::BaseExpr<LastTypeDerived> &l)
{ {
return seqN(f.derived(),(l.derived()-f.derived()+1)); return seqN(f.derived(),(l.derived()-f.derived()+1));
} }
template<typename FirstType,typename LastType, typename IncrType> template<typename FirstType,typename LastType, typename IncrType>
typename internal::enable_if<!(is_symbolic<FirstType>::value || is_symbolic<LastType>::value), typename internal::enable_if<!(Symbolic::is_symbolic<FirstType>::value || Symbolic::is_symbolic<LastType>::value),
ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,Index,typename cleanup_seq_type<IncrType>::type> >::type ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,Index,typename cleanup_seq_type<IncrType>::type> >::type
seq(FirstType f, LastType l, IncrType incr) seq(FirstType f, LastType l, IncrType incr)
{ {
@ -297,22 +314,27 @@ seq(FirstType f, LastType l, IncrType incr)
} }
template<typename FirstTypeDerived,typename LastType, typename IncrType> template<typename FirstTypeDerived,typename LastType, typename IncrType>
typename internal::enable_if<!is_symbolic<LastType>::value, typename internal::enable_if<!Symbolic::is_symbolic<LastType>::value,
ArithemeticSequence<FirstTypeDerived, ArithemeticSequence<FirstTypeDerived,
symbolic_quotient<symbolic_add<symbolic_add<symbolic_negate<FirstTypeDerived>,symbolic_value_wrapper>,symbolic_value_wrapper>,symbolic_value_wrapper>, Symbolic::QuotientExpr<Symbolic::AddExpr<Symbolic::AddExpr<Symbolic::NegateExpr<FirstTypeDerived>,
typename cleanup_seq_type<IncrType>::type> >::type Symbolic::ValueExpr>,
seq(const symbolic_index_base<FirstTypeDerived> &f, LastType l, IncrType incr) Symbolic::ValueExpr>,
Symbolic::ValueExpr>,
typename cleanup_seq_type<IncrType>::type> >::type
seq(const Symbolic::BaseExpr<FirstTypeDerived> &f, LastType l, IncrType incr)
{ {
typedef typename cleanup_seq_type<IncrType>::type CleanedIncrType; typedef typename cleanup_seq_type<IncrType>::type CleanedIncrType;
return seqN(f.derived(),(l-f.derived()+CleanedIncrType(incr))/CleanedIncrType(incr), incr); return seqN(f.derived(),(l-f.derived()+CleanedIncrType(incr))/CleanedIncrType(incr), incr);
} }
template<typename FirstType,typename LastTypeDerived, typename IncrType> template<typename FirstType,typename LastTypeDerived, typename IncrType>
typename internal::enable_if<!is_symbolic<FirstType>::value, typename internal::enable_if<!Symbolic::is_symbolic<FirstType>::value,
ArithemeticSequence<typename cleanup_seq_type<FirstType>::type, ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,
symbolic_quotient<symbolic_add<symbolic_add<LastTypeDerived,symbolic_value_wrapper>,symbolic_value_wrapper>,symbolic_value_wrapper>, Symbolic::QuotientExpr<Symbolic::AddExpr<Symbolic::AddExpr<LastTypeDerived,Symbolic::ValueExpr>,
Symbolic::ValueExpr>,
Symbolic::ValueExpr>,
typename cleanup_seq_type<IncrType>::type> >::type typename cleanup_seq_type<IncrType>::type> >::type
seq(FirstType f, const symbolic_index_base<LastTypeDerived> &l, IncrType incr) seq(FirstType f, const Symbolic::BaseExpr<LastTypeDerived> &l, IncrType incr)
{ {
typedef typename cleanup_seq_type<IncrType>::type CleanedIncrType; typedef typename cleanup_seq_type<IncrType>::type CleanedIncrType;
return seqN(f,(l.derived()-f+CleanedIncrType(incr))/CleanedIncrType(incr), incr); return seqN(f,(l.derived()-f+CleanedIncrType(incr))/CleanedIncrType(incr), incr);
@ -320,9 +342,12 @@ seq(FirstType f, const symbolic_index_base<LastTypeDerived> &l, IncrType incr)
template<typename FirstTypeDerived,typename LastTypeDerived, typename IncrType> template<typename FirstTypeDerived,typename LastTypeDerived, typename IncrType>
ArithemeticSequence<FirstTypeDerived, ArithemeticSequence<FirstTypeDerived,
symbolic_quotient<symbolic_add<symbolic_add<LastTypeDerived,symbolic_negate<FirstTypeDerived> >,symbolic_value_wrapper>,symbolic_value_wrapper>, Symbolic::QuotientExpr<Symbolic::AddExpr<Symbolic::AddExpr<LastTypeDerived,
Symbolic::NegateExpr<FirstTypeDerived> >,
Symbolic::ValueExpr>,
Symbolic::ValueExpr>,
typename cleanup_seq_type<IncrType>::type> typename cleanup_seq_type<IncrType>::type>
seq(const symbolic_index_base<FirstTypeDerived> &f, const symbolic_index_base<LastTypeDerived> &l, IncrType incr) seq(const Symbolic::BaseExpr<FirstTypeDerived> &f, const Symbolic::BaseExpr<LastTypeDerived> &l, IncrType incr)
{ {
typedef typename cleanup_seq_type<IncrType>::type CleanedIncrType; typedef typename cleanup_seq_type<IncrType>::type CleanedIncrType;
return seqN(f.derived(),(l.derived()-f.derived()+CleanedIncrType(incr))/CleanedIncrType(incr), incr); return seqN(f.derived(),(l.derived()-f.derived()+CleanedIncrType(incr))/CleanedIncrType(incr), incr);
@ -404,15 +429,15 @@ template<int N>
fix_t<N> symbolic2value(fix_t<N> x, Index /*size*/) { return x; } fix_t<N> symbolic2value(fix_t<N> x, Index /*size*/) { return x; }
template<typename Derived> template<typename Derived>
Index symbolic2value(const symbolic_index_base<Derived> &x, Index size) Index symbolic2value(const Symbolic::BaseExpr<Derived> &x, Index size)
{ {
return x.derived().eval(symbolic_value_pair<symb_last_tag>(size-1)); return x.derived().eval(Symbolic::defineValue(placeholders::last,size-1));
} }
// Convert a symbolic span into a usable one (i.e., remove last/end "keywords") // Convert a symbolic span into a usable one (i.e., remove last/end "keywords")
template<typename T> template<typename T>
struct make_size_type { struct make_size_type {
typedef typename internal::conditional<is_symbolic<T>::value, Index, T>::type type; typedef typename internal::conditional<Symbolic::is_symbolic<T>::value, Index, T>::type type;
}; };
template<typename FirstType,typename SizeType,typename IncrType> template<typename FirstType,typename SizeType,typename IncrType>