Misc fixes

This commit is contained in:
Benoit Steiner 2015-01-14 15:30:47 -08:00
parent 0feff6e987
commit 4cdf3fe427

View File

@ -40,6 +40,10 @@ template <typename Index> struct IndexPair {
// Boilerplate code // Boilerplate code
namespace internal { namespace internal {
template<std::size_t n, typename Dimension> struct dget {
static const std::size_t value = get<n, Dimension>::value;
};
template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor> template<typename Index, std::size_t NumIndices, std::size_t n, bool RowMajor>
struct fixed_size_tensor_index_linearization_helper struct fixed_size_tensor_index_linearization_helper
@ -49,7 +53,7 @@ struct fixed_size_tensor_index_linearization_helper
const Dimensions& dimensions) const Dimensions& dimensions)
{ {
return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) + return array_get<RowMajor ? n : (NumIndices - n - 1)>(indices) +
get<RowMajor ? n : (NumIndices - n - 1), Dimensions>::value * dget<RowMajor ? n : (NumIndices - n - 1), Dimensions>::value *
fixed_size_tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions); fixed_size_tensor_index_linearization_helper<Index, NumIndices, n - 1, RowMajor>::run(indices, dimensions);
} }
}; };
@ -75,6 +79,10 @@ struct Sizes : internal::numeric_list<std::size_t, Indices...> {
typedef internal::numeric_list<std::size_t, Indices...> Base; typedef internal::numeric_list<std::size_t, Indices...> Base;
static const std::size_t total_size = internal::arg_prod(Indices...); static const std::size_t total_size = internal::arg_prod(Indices...);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t rank() const {
return Base::count;
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t TotalSize() { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t TotalSize() {
return internal::arg_prod(Indices...); return internal::arg_prod(Indices...);
} }
@ -85,6 +93,7 @@ struct Sizes : internal::numeric_list<std::size_t, Indices...> {
// todo: add assertion // todo: add assertion
} }
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES #ifdef EIGEN_HAS_VARIADIC_TEMPLATES
template <typename... DenseIndex> Sizes(DenseIndex... indices) { }
explicit Sizes(std::initializer_list<std::size_t> /*l*/) { explicit Sizes(std::initializer_list<std::size_t> /*l*/) {
// todo: add assertion // todo: add assertion
} }
@ -121,11 +130,15 @@ struct non_zero_size<0> {
typedef internal::null_type type; typedef internal::null_type type;
}; };
template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0, std::size_t V5=0> struct Sizes : typename internal::make_type_list<typename non_zero_size<V1>::type, typename non_zero_size<V2>::type, typename non_zero_size<V3>::type, typename non_zero_size<V4>::type, typename non_zero_size<V5>::type >::type { template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0, std::size_t V5=0> struct Sizes {
typedef typename internal::make_type_list<typename non_zero_size<V1>::type, typename non_zero_size<V2>::type, typename non_zero_size<V3>::type, typename non_zero_size<V4>::type, typename non_zero_size<V5>::type >::type Base; typedef typename internal::make_type_list<typename non_zero_size<V1>::type, typename non_zero_size<V2>::type, typename non_zero_size<V3>::type, typename non_zero_size<V4>::type, typename non_zero_size<V5>::type >::type Base;
static const size_t count = Base::count; static const size_t count = Base::count;
static const std::size_t total_size = internal::arg_prod<Base>::value; static const std::size_t total_size = internal::arg_prod<Base>::value;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t rank() const {
return count;
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() { static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() {
return internal::arg_prod<Base>::value; return internal::arg_prod<Base>::value;
} }
@ -160,11 +173,11 @@ template <std::size_t V1=0, std::size_t V2=0, std::size_t V3=0, std::size_t V4=0
template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const { size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, false>::run(indices, *static_cast<Base*>(this)); return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, false>::run(indices, *static_cast<const Base*>(this);
} }
template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const { size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, true>::run(indices, *static_cast<Base*>(this)); return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count - 1, true>::run(indices, *static_cast<const Base*>(this);
} }
}; };
@ -208,6 +221,10 @@ struct DSizes : array<DenseIndex, NumDims> {
typedef array<DenseIndex, NumDims> Base; typedef array<DenseIndex, NumDims> Base;
static const std::size_t count = NumDims; static const std::size_t count = NumDims;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t rank() const {
return NumDims;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() const {
return internal::array_prod(*static_cast<const Base*>(this)); return internal::array_prod(*static_cast<const Base*>(this));
} }
@ -219,31 +236,44 @@ struct DSizes : array<DenseIndex, NumDims> {
} }
EIGEN_DEVICE_FUNC explicit DSizes(const array<DenseIndex, NumDims>& a) : Base(a) { } EIGEN_DEVICE_FUNC explicit DSizes(const array<DenseIndex, NumDims>& a) : Base(a) { }
#ifdef EIGEN_HAS_VARIADIC_TEMPLATES
template<typename... IndexTypes> EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE explicit DSizes(DenseIndex firstDimension, IndexTypes... otherDimensions) {
EIGEN_STATIC_ASSERT(sizeof...(otherDimensions) + 1 == NumDims, YOU_MADE_A_PROGRAMMING_MISTAKE)
(*this) = array<DenseIndex, NumDims>{{firstDimension, otherDimensions...}};
}
#else
EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0) { EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0) {
eigen_assert(NumDims == 1);
(*this)[0] = i0; (*this)[0] = i0;
} }
EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1) { EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1) {
eigen_assert(NumDims == 2);
(*this)[0] = i0; (*this)[0] = i0;
(*this)[1] = i1; (*this)[1] = i1;
} }
EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) { EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) {
eigen_assert(NumDims == 3);
(*this)[0] = i0; (*this)[0] = i0;
(*this)[1] = i1; (*this)[1] = i1;
(*this)[2] = i2; (*this)[2] = i2;
} }
EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) { EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) {
eigen_assert(NumDims == 4);
(*this)[0] = i0; (*this)[0] = i0;
(*this)[1] = i1; (*this)[1] = i1;
(*this)[2] = i2; (*this)[2] = i2;
(*this)[3] = i3; (*this)[3] = i3;
} }
EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) { EIGEN_DEVICE_FUNC explicit DSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) {
eigen_assert(NumDims == 5);
(*this)[0] = i0; (*this)[0] = i0;
(*this)[1] = i1; (*this)[1] = i1;
(*this)[2] = i2; (*this)[2] = i2;
(*this)[3] = i3; (*this)[3] = i3;
(*this)[4] = i4; (*this)[4] = i4;
} }
#endif
DSizes& operator = (const array<DenseIndex, NumDims>& other) { DSizes& operator = (const array<DenseIndex, NumDims>& other) {
*static_cast<Base*>(this) = other; *static_cast<Base*>(this) = other;
@ -287,84 +317,6 @@ struct tensor_vsize_index_linearization_helper<Index, NumIndices, 0, RowMajor>
}; };
} // end namespace internal } // end namespace internal
template <typename DenseIndex>
struct VSizes : std::vector<DenseIndex> {
typedef std::vector<DenseIndex> Base;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t TotalSize() const {
return internal::array_prod(*static_cast<const Base*>(this));
}
EIGEN_DEVICE_FUNC VSizes() { }
EIGEN_DEVICE_FUNC explicit VSizes(const std::vector<DenseIndex>& a) : Base(a) { }
template <std::size_t NumDims>
EIGEN_DEVICE_FUNC explicit VSizes(const array<DenseIndex, NumDims>& a) {
this->resize(NumDims);
for (int i = 0; i < NumDims; ++i) {
(*this)[i] = a[i];
}
}
EIGEN_DEVICE_FUNC explicit VSizes(const DenseIndex i0) {
this->resize(1);
(*this)[0] = i0;
}
EIGEN_DEVICE_FUNC explicit VSizes(const DenseIndex i0, const DenseIndex i1) {
this->resize(2);
(*this)[0] = i0;
(*this)[1] = i1;
}
EIGEN_DEVICE_FUNC explicit VSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2) {
this->resize(3);
(*this)[0] = i0;
(*this)[1] = i1;
(*this)[2] = i2;
}
EIGEN_DEVICE_FUNC explicit VSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3) {
this->resize(4);
(*this)[0] = i0;
(*this)[1] = i1;
(*this)[2] = i2;
(*this)[3] = i3;
}
EIGEN_DEVICE_FUNC explicit VSizes(const DenseIndex i0, const DenseIndex i1, const DenseIndex i2, const DenseIndex i3, const DenseIndex i4) {
this->resize(5);
(*this)[0] = i0;
(*this)[1] = i1;
(*this)[2] = i2;
(*this)[3] = i3;
(*this)[4] = i4;
}
VSizes& operator = (const std::vector<DenseIndex>& other) {
*static_cast<Base*>(this) = other;
return *this;
}
// A constexpr would be so much better here
template <std::size_t NumDims>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t IndexOfColMajor(const array<DenseIndex, NumDims>& indices) const {
return internal::tensor_vsize_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, false>::run(indices, *static_cast<const Base*>(this));
}
template <std::size_t NumDims>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE size_t IndexOfRowMajor(const array<DenseIndex, NumDims>& indices) const {
return internal::tensor_vsize_index_linearization_helper<DenseIndex, NumDims, NumDims - 1, true>::run(indices, *static_cast<const Base*>(this));
}
};
// Boilerplate
namespace internal {
template <typename DenseIndex>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE DenseIndex array_prod(const VSizes<DenseIndex>& sizes) {
DenseIndex total_size = 1;
for (int i = 0; i < sizes.size(); ++i) {
total_size *= sizes[i];
}
return total_size;
}
}
namespace internal { namespace internal {
@ -381,8 +333,8 @@ static const size_t value = Sizes<Indices...>::count;
template <typename std::size_t... Indices> struct array_size<Sizes<Indices...> > { template <typename std::size_t... Indices> struct array_size<Sizes<Indices...> > {
static const size_t value = Sizes<Indices...>::count; static const size_t value = Sizes<Indices...>::count;
}; };
template <std::size_t n, typename std::size_t... Indices> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_get(const Sizes<Indices...>) { template <std::size_t n, typename std::size_t... Indices> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_get(const Sizes<Indices...>& a) {
return get<n, typename Sizes<Indices...>::Base>::value; return get<n, internal::numeric_list<std::size_t, Indices...> >::value;
} }
#else #else
template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<const Sizes<V1,V2,V3,V4,V5> > { template <std::size_t V1, std::size_t V2, std::size_t V3, std::size_t V4, std::size_t V5> struct array_size<const Sizes<V1,V2,V3,V4,V5> > {
@ -412,17 +364,17 @@ struct sizes_match_up_to_dim<Dims1, Dims2, 0> {
} }
}; };
template <typename Dims1, typename Dims2>
bool dimensions_match(Dims1& dims1, Dims2& dims2) {
if (array_size<Dims1>::value != array_size<Dims2>::value) {
return false;
}
return sizes_match_up_to_dim<Dims1, Dims2, array_size<Dims1>::value-1>::run(dims1, dims2);
}
} // end namespace internal } // end namespace internal
template <typename Dims1, typename Dims2>
bool dimensions_match(Dims1& dims1, Dims2& dims2) {
if (internal::array_size<Dims1>::value != internal::array_size<Dims2>::value) {
return false;
}
return internal::sizes_match_up_to_dim<Dims1, Dims2, internal::array_size<Dims1>::value-1>::run(dims1, dims2);
}
} // end namespace Eigen } // end namespace Eigen
#endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H #endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H