Made it possible to compare tensor dimensions inside a CUDA kernel.

This commit is contained in:
Benoit Steiner 2016-01-15 11:22:16 -08:00
parent aed4cb1269
commit 0461f0153e

View File

@ -405,20 +405,20 @@ template <std::size_t n, std::size_t V1, std::size_t V2, std::size_t V3, std::si
template <typename Dims1, typename Dims2, size_t n, size_t m> template <typename Dims1, typename Dims2, size_t n, size_t m>
struct sizes_match_below_dim { struct sizes_match_below_dim {
static inline bool run(Dims1&, Dims2&) { static EIGEN_DEVICE_FUNC inline bool run(Dims1&, Dims2&) {
return false; return false;
} }
}; };
template <typename Dims1, typename Dims2, size_t n> template <typename Dims1, typename Dims2, size_t n>
struct sizes_match_below_dim<Dims1, Dims2, n, n> { struct sizes_match_below_dim<Dims1, Dims2, n, n> {
static inline bool run(Dims1& dims1, Dims2& dims2) { static EIGEN_DEVICE_FUNC inline bool run(Dims1& dims1, Dims2& dims2) {
return (array_get<n-1>(dims1) == array_get<n-1>(dims2)) & return (array_get<n-1>(dims1) == array_get<n-1>(dims2)) &
sizes_match_below_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2); sizes_match_below_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2);
} }
}; };
template <typename Dims1, typename Dims2> template <typename Dims1, typename Dims2>
struct sizes_match_below_dim<Dims1, Dims2, 0, 0> { struct sizes_match_below_dim<Dims1, Dims2, 0, 0> {
static inline bool run(Dims1&, Dims2&) { static EIGEN_DEVICE_FUNC inline bool run(Dims1&, Dims2&) {
return true; return true;
} }
}; };
@ -427,7 +427,7 @@ struct sizes_match_below_dim<Dims1, Dims2, 0, 0> {
template <typename Dims1, typename Dims2> template <typename Dims1, typename Dims2>
bool dimensions_match(Dims1& dims1, Dims2& dims2) { EIGEN_DEVICE_FUNC bool dimensions_match(Dims1& dims1, Dims2& dims2) {
return internal::sizes_match_below_dim<Dims1, Dims2, internal::array_size<Dims1>::value, internal::array_size<Dims2>::value>::run(dims1, dims2); return internal::sizes_match_below_dim<Dims1, Dims2, internal::array_size<Dims1>::value, internal::array_size<Dims2>::value>::run(dims1, dims2);
} }