From 40bb98e76acbe6e077903e15896c100ee6cced39 Mon Sep 17 00:00:00 2001 From: Benoit Steiner Date: Thu, 10 Jul 2014 11:29:51 -0700 Subject: [PATCH] Added primitives to compare tensor dimensions --- .../Eigen/CXX11/src/Tensor/TensorDimensions.h | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h index 3e5687915..3b169a06f 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorDimensions.h @@ -210,6 +210,60 @@ struct DSizes : array { }; +namespace internal { + +template struct array_size > { + static const size_t value = NumDims; +}; +template struct array_size > { + static const size_t value = NumDims; +}; +#ifndef EIGEN_EMULATE_CXX11_META_H +template struct array_size > { +static const size_t value = Sizes::count; +}; +template struct array_size > { +static const size_t value = Sizes::count; +}; +#else +template struct array_size > { + static const size_t value = Sizes::count; +}; +template struct array_size > { + static const size_t value = Sizes::count; +}; +template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::size_t array_get(const Sizes& a) { + return get::Base>::value; +}; + +#endif + + +template +struct sizes_match_up_to_dim { + static inline bool run(Dims1& dims1, Dims2& dims2) { + return (array_get(dims1) == array_get(dims2)) & + sizes_match_up_to_dim::run(dims1, dims2); + } +}; +template +struct sizes_match_up_to_dim { + static inline bool run(Dims1& dims1, Dims2& dims2) { + return (array_get<0>(dims1) == array_get<0>(dims2)); + } +}; + +template +bool dimensions_match(Dims1& dims1, Dims2& dims2) { + if (array_size::value != array_size::value) { + return false; + } + return sizes_match_up_to_dim::value-1>::run(dims1, dims2); +} + +} // end namespace internal + + } // end namespace Eigen #endif // EIGEN_CXX11_TENSOR_TENSOR_DIMENSIONS_H