Add a minimalistic symbolic scalar type with expression template and make use of it to define the last placeholder and to unify the return type of seq and seqN.

This commit is contained in:
Gael Guennebaud 2017-01-09 23:42:16 +01:00
parent 68064e14fa
commit b50c3e967e
2 changed files with 197 additions and 28 deletions

View File

@ -34,7 +34,7 @@ struct last_t {
int operator- (last_t) const { return 0; }
int operator- (shifted_last x) const { return -x.offset; }
};
static const last_t last;
static const last_t last_legacy;
struct shifted_end {
@ -52,7 +52,145 @@ struct end_t {
int operator- (end_t) const { return 0; }
int operator- (shifted_end x) const { return -x.offset; }
};
static const end_t end;
static const end_t end_legacy;
// A simple wrapper around an Index to provide the eval method.
// We could also use a free-function symbolic_eval...
class symbolic_value_wrapper {
public:
symbolic_value_wrapper(Index val) : m_value(val) {}
template<typename T>
Index eval(const T&) const { return m_value; }
protected:
Index m_value;
};
//--------------------------------------------------------------------------------
// minimalistic symbolic scalar type
//--------------------------------------------------------------------------------
template<typename Tag> class symbolic_symbol;
template<typename Arg0> class symbolic_negate;
template<typename Arg1,typename Arg2> class symbolic_add;
template<typename Arg1,typename Arg2> class symbolic_product;
template<typename Arg1,typename Arg2> class symbolic_quotient;
template<typename Derived>
class symbolic_index_base
{
public:
const Derived& derived() const { return *static_cast<const Derived*>(this); }
symbolic_negate<Derived> operator-() const { return symbolic_negate<Derived>(derived()); }
symbolic_add<Derived,symbolic_value_wrapper> operator+(Index b) const
{ return symbolic_add<Derived,symbolic_value_wrapper >(derived(), b); }
symbolic_add<Derived,symbolic_value_wrapper> operator-(Index a) const
{ return symbolic_add<Derived,symbolic_value_wrapper >(derived(), -a); }
symbolic_quotient<Derived,symbolic_value_wrapper> operator/(Index a) const
{ return symbolic_quotient<Derived,symbolic_value_wrapper>(derived(),a); }
friend symbolic_add<Derived,symbolic_value_wrapper> operator+(Index a, const symbolic_index_base& b)
{ return symbolic_add<Derived,symbolic_value_wrapper>(b.derived(), a); }
friend symbolic_add<symbolic_negate<Derived>,symbolic_value_wrapper> operator-(Index a, const symbolic_index_base& b)
{ return symbolic_add<symbolic_negate<Derived>,symbolic_value_wrapper>(-b.derived(), a); }
friend symbolic_add<symbolic_value_wrapper,Derived> operator/(Index a, const symbolic_index_base& b)
{ return symbolic_add<symbolic_value_wrapper,Derived>(a,b.derived()); }
template<typename OtherDerived>
symbolic_add<Derived,OtherDerived> operator+(const symbolic_index_base<OtherDerived> &b) const
{ return symbolic_add<Derived,OtherDerived>(derived(), b.derived()); }
template<typename OtherDerived>
symbolic_add<Derived,symbolic_negate<OtherDerived> > operator-(const symbolic_index_base<OtherDerived> &b) const
{ return symbolic_add<Derived,symbolic_negate<OtherDerived> >(derived(), -b.derived()); }
template<typename OtherDerived>
symbolic_add<Derived,OtherDerived> operator/(const symbolic_index_base<OtherDerived> &b) const
{ return symbolic_quotient<Derived,OtherDerived>(derived(), b.derived()); }
};
template<typename T>
struct is_symbolic {
enum { value = internal::is_convertible<T,symbolic_index_base<T> >::value };
};
template<typename Tag>
class symbolic_value_pair
{
public:
symbolic_value_pair(Index val) : m_value(val) {}
Index value() const { return m_value; }
protected:
Index m_value;
};
template<typename Tag>
class symbolic_value : public symbolic_index_base<symbolic_value<Tag> >
{
public:
symbolic_value() {}
Index eval(const symbolic_value_pair<Tag> &values) const { return values.value(); }
// TODO add a c++14 eval taking a tuple of symbolic_value_pair and getting the value with std::get<symbolic_value_pair<Tag> >...
};
template<typename Arg0>
class symbolic_negate : public symbolic_index_base<symbolic_negate<Arg0> >
{
public:
symbolic_negate(const Arg0& arg0) : m_arg0(arg0) {}
template<typename T>
Index eval(const T& values) const { return -m_arg0.eval(values); }
protected:
Arg0 m_arg0;
};
template<typename Arg0, typename Arg1>
class symbolic_add : public symbolic_index_base<symbolic_add<Arg0,Arg1> >
{
public:
symbolic_add(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template<typename T>
Index eval(const T& values) const { return m_arg0.eval(values) + m_arg1.eval(values); }
protected:
Arg0 m_arg0;
Arg1 m_arg1;
};
template<typename Arg0, typename Arg1>
class symbolic_product : public symbolic_index_base<symbolic_product<Arg0,Arg1> >
{
public:
symbolic_product(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template<typename T>
Index eval(const T& values) const { return m_arg0.eval(values) * m_arg1.eval(values); }
protected:
Arg0 m_arg0;
Arg1 m_arg1;
};
template<typename Arg0, typename Arg1>
class symbolic_quotient : public symbolic_index_base<symbolic_quotient<Arg0,Arg1> >
{
public:
symbolic_quotient(const Arg0& arg0, const Arg1& arg1) : m_arg0(arg0), m_arg1(arg1) {}
template<typename T>
Index eval(const T& values) const { return m_arg0.eval(values) / m_arg1.eval(values); }
protected:
Arg0 m_arg0;
Arg1 m_arg1;
};
struct symb_last_tag {};
static const symbolic_value<symb_last_tag> last;
static const symbolic_add<symbolic_value<symb_last_tag>,symbolic_value_wrapper> end(last+1);
//--------------------------------------------------------------------------------
// integral constant
@ -116,34 +254,30 @@ protected:
IncrType m_incr;
};
template<typename T> struct cleanup_slice_type { typedef Index type; };
template<> struct cleanup_slice_type<last_t> { typedef last_t type; };
template<> struct cleanup_slice_type<shifted_last> { typedef shifted_last type; };
template<> struct cleanup_slice_type<end_t> { typedef end_t type; };
template<> struct cleanup_slice_type<shifted_end> { typedef shifted_end type; };
template<int N> struct cleanup_slice_type<fix_t<N> > { typedef fix_t<N> type; };
template<int N> struct cleanup_slice_type<fix_t<N> (*)() > { typedef fix_t<N> type; };
template<typename T> struct cleanup_seq_type { typedef T type; };
template<int N> struct cleanup_seq_type<fix_t<N> > { typedef fix_t<N> type; };
template<int N> struct cleanup_seq_type<fix_t<N> (*)() > { typedef fix_t<N> type; };
template<typename FirstType,typename LastType>
ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type >
seq(FirstType f, LastType l) {
return ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type>(f,l);
ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type >
seq_legacy(FirstType f, LastType l) {
return ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type>(f,l);
}
template<typename FirstType,typename LastType,typename IncrType>
ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type,typename cleanup_slice_type<IncrType>::type >
seq(FirstType f, LastType l, IncrType s) {
return ArithemeticSequenceProxyWithBounds<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<LastType>::type,typename cleanup_slice_type<IncrType>::type>(f,l,typename cleanup_slice_type<IncrType>::type(s));
ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type,typename cleanup_seq_type<IncrType>::type >
seq_legacy(FirstType f, LastType l, IncrType s) {
return ArithemeticSequenceProxyWithBounds<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<LastType>::type,typename cleanup_seq_type<IncrType>::type>(f,l,typename cleanup_seq_type<IncrType>::type(s));
}
template<typename FirstType=Index,typename SizeType=Index,typename IncrType=fix_t<1> >
class ArithemeticSequenceProxyWithSize
class ArithemeticSequence
{
public:
ArithemeticSequenceProxyWithSize(FirstType first, SizeType size) : m_first(first), m_size(size) {}
ArithemeticSequenceProxyWithSize(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {}
ArithemeticSequence(FirstType first, SizeType size) : m_first(first), m_size(size) {}
ArithemeticSequence(FirstType first, SizeType size, IncrType incr) : m_first(first), m_size(size), m_incr(incr) {}
enum {
SizeAtCompileTime = get_compile_time<SizeType>::value,
@ -165,18 +299,30 @@ protected:
template<typename FirstType,typename SizeType,typename IncrType>
ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type,typename cleanup_slice_type<IncrType>::type >
ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type,typename cleanup_seq_type<IncrType>::type >
seqN(FirstType first, SizeType size, IncrType incr) {
return ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type,typename cleanup_slice_type<IncrType>::type>(first,size,incr);
return ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type,typename cleanup_seq_type<IncrType>::type>(first,size,incr);
}
template<typename FirstType,typename SizeType>
ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type >
ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type >
seqN(FirstType first, SizeType size) {
return ArithemeticSequenceProxyWithSize<typename cleanup_slice_type<FirstType>::type,typename cleanup_slice_type<SizeType>::type>(first,size);
return ArithemeticSequence<typename cleanup_seq_type<FirstType>::type,typename cleanup_seq_type<SizeType>::type>(first,size);
}
template<typename FirstType,typename LastType>
auto seq(FirstType f, LastType l) -> decltype(seqN(f,(l-f+1)))
{
return seqN(f,(l-f+1));
}
template<typename FirstType,typename LastType, typename IncrType>
auto seq(FirstType f, LastType l, IncrType incr)
-> decltype(seqN(f,(l-f+typename cleanup_seq_type<IncrType>::type(incr))/typename cleanup_seq_type<IncrType>::type(incr),typename cleanup_seq_type<IncrType>::type(incr)))
{
typedef typename cleanup_seq_type<IncrType>::type CleanedIncrType;
return seqN(f,(l-f+CleanedIncrType(incr))/CleanedIncrType(incr),CleanedIncrType(incr));
}
namespace internal {
@ -214,7 +360,7 @@ struct get_compile_time_incr<ArithemeticSequenceProxyWithBounds<FirstType,LastTy
};
template<typename FirstType,typename SizeType,typename IncrType>
struct get_compile_time_incr<ArithemeticSequenceProxyWithSize<FirstType,SizeType,IncrType> > {
struct get_compile_time_incr<ArithemeticSequence<FirstType,SizeType,IncrType> > {
enum { value = get_compile_time<IncrType,DynamicIndex>::value };
};
@ -258,6 +404,17 @@ Index symbolic2value(shifted_last x, Index size) { return size+x.offset-1; }
Index symbolic2value(end_t, Index size) { return size; }
Index symbolic2value(shifted_end x, Index size) { return size+x.offset; }
template<int N>
fix_t<N> symbolic2value(fix_t<N> x, Index /*size*/) { return x; }
template<typename Derived>
Index symbolic2value(const symbolic_index_base<Derived> &x, Index size)
{
Index h=x.derived().eval(symbolic_value_pair<symb_last_tag>(size-1));
return x.derived().eval(symbolic_value_pair<symb_last_tag>(size-1));
}
// Convert a symbolic range into a usable one (i.e., remove last/end "keywords")
template<typename FirstType,typename LastType,typename IncrType>
struct MakeIndexing<ArithemeticSequenceProxyWithBounds<FirstType,LastType,IncrType> > {
@ -270,14 +427,21 @@ ArithemeticSequenceProxyWithBounds<Index,Index,IncrType> make_indexing(const Ari
}
// Convert a symbolic span into a usable one (i.e., remove last/end "keywords")
template<typename FirstType,typename SizeType,typename IncrType>
struct MakeIndexing<ArithemeticSequenceProxyWithSize<FirstType,SizeType,IncrType> > {
typedef ArithemeticSequenceProxyWithSize<Index,SizeType,IncrType> type;
template<typename T>
struct make_size_type {
typedef typename internal::conditional<is_symbolic<T>::value, Index, T>::type type;
};
template<typename FirstType,typename SizeType,typename IncrType>
ArithemeticSequenceProxyWithSize<Index,SizeType,IncrType> make_indexing(const ArithemeticSequenceProxyWithSize<FirstType,SizeType,IncrType>& ids, Index size) {
return ArithemeticSequenceProxyWithSize<Index,SizeType,IncrType>(symbolic2value(ids.firstObject(),size),ids.sizeObject(),ids.incrObject());
struct MakeIndexing<ArithemeticSequence<FirstType,SizeType,IncrType> > {
typedef ArithemeticSequence<Index,typename make_size_type<SizeType>::type,IncrType> type;
};
template<typename FirstType,typename SizeType,typename IncrType>
ArithemeticSequence<Index,typename make_size_type<SizeType>::type,IncrType>
make_indexing(const ArithemeticSequence<FirstType,SizeType,IncrType>& ids, Index size) {
return ArithemeticSequence<Index,typename make_size_type<SizeType>::type,IncrType>(
symbolic2value(ids.firstObject(),size),symbolic2value(ids.sizeObject(),size),ids.incrObject());
}
// Convert a symbolic 'all' into a usable range

View File

@ -139,6 +139,11 @@ void check_indexed_view()
VERIFY_IS_EQUAL( (A(eii, eii)).InnerStrideAtCompileTime, 0);
VERIFY_IS_EQUAL( (A(eii, eii)).OuterStrideAtCompileTime, 0);
VERIFY_IS_APPROX( A(seq(n-1,2,-2), seqN(n-1-6,4)), A(seq(last,2,-2), seqN(last-6,4)) );
VERIFY_IS_APPROX( A(seq(n-1-6,n-1-2), seqN(n-1-6,4)), A(seq(last-6,last-2), seqN(6+last-6-6,4)) );
VERIFY_IS_APPROX( A(seq((n-1)/2,(n)/2+3), seqN(2,4)), A(seq(last/2,(last+1)/2+3), seqN(last+2-last,4)) );
VERIFY_IS_APPROX( A(seq(n-2,2,-2), seqN(n-8,4)), A(seq(end-2,2,-2), seqN(end-8,4)) );
#if EIGEN_HAS_CXX11
VERIFY( (A(all, std::array<int,4>{{1,3,2,4}})).ColsAtCompileTime == 4);