mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-05-08 13:59:05 +08:00
Made it possible to compare tensor dimensions inside a CUDA kernel.
This commit is contained in:
parent
aed4cb1269
commit
0461f0153e
@ -405,20 +405,20 @@ template <std::size_t n, std::size_t V1, std::size_t V2, std::size_t V3, std::si
|
|||||||
|
|
||||||
template <typename Dims1, typename Dims2, size_t n, size_t m>
|
template <typename Dims1, typename Dims2, size_t n, size_t m>
|
||||||
struct sizes_match_below_dim {
|
struct sizes_match_below_dim {
|
||||||
static inline bool run(Dims1&, Dims2&) {
|
static EIGEN_DEVICE_FUNC inline bool run(Dims1&, Dims2&) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template <typename Dims1, typename Dims2, size_t n>
|
template <typename Dims1, typename Dims2, size_t n>
|
||||||
struct sizes_match_below_dim<Dims1, Dims2, n, n> {
|
struct sizes_match_below_dim<Dims1, Dims2, n, n> {
|
||||||
static inline bool run(Dims1& dims1, Dims2& dims2) {
|
static EIGEN_DEVICE_FUNC inline bool run(Dims1& dims1, Dims2& dims2) {
|
||||||
return (array_get<n-1>(dims1) == array_get<n-1>(dims2)) &
|
return (array_get<n-1>(dims1) == array_get<n-1>(dims2)) &
|
||||||
sizes_match_below_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2);
|
sizes_match_below_dim<Dims1, Dims2, n-1, n-1>::run(dims1, dims2);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template <typename Dims1, typename Dims2>
|
template <typename Dims1, typename Dims2>
|
||||||
struct sizes_match_below_dim<Dims1, Dims2, 0, 0> {
|
struct sizes_match_below_dim<Dims1, Dims2, 0, 0> {
|
||||||
static inline bool run(Dims1&, Dims2&) {
|
static EIGEN_DEVICE_FUNC inline bool run(Dims1&, Dims2&) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -427,7 +427,7 @@ struct sizes_match_below_dim<Dims1, Dims2, 0, 0> {
|
|||||||
|
|
||||||
|
|
||||||
template <typename Dims1, typename Dims2>
|
template <typename Dims1, typename Dims2>
|
||||||
bool dimensions_match(Dims1& dims1, Dims2& dims2) {
|
EIGEN_DEVICE_FUNC bool dimensions_match(Dims1& dims1, Dims2& dims2) {
|
||||||
return internal::sizes_match_below_dim<Dims1, Dims2, internal::array_size<Dims1>::value, internal::array_size<Dims2>::value>::run(dims1, dims2);
|
return internal::sizes_match_below_dim<Dims1, Dims2, internal::array_size<Dims1>::value, internal::array_size<Dims2>::value>::run(dims1, dims2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user