Udated the Sizes class to work on AMD gpus without requiring a separate implementation

This commit is contained in:
Benoit Steiner 2016-11-30 19:57:28 -08:00
parent e37c2c52d3
commit f5107010ee

View File

@ -69,12 +69,7 @@ struct fixed_size_tensor_index_extraction_helper
{ {
const Index mult = (index == n-1) ? 1 : 0; const Index mult = (index == n-1) ? 1 : 0;
return return
#ifdef EIGEN_USE_SYCL array_get<n-1>(dimensions) * mult +
utility::tuple::get<n-1>(dimensions)
#else
array_get<n-1>(dimensions)
#endif
* mult +
fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions); fixed_size_tensor_index_extraction_helper<Index, n - 1>::run(index, dimensions);
} }
}; };
@ -96,12 +91,12 @@ struct fixed_size_tensor_index_extraction_helper<Index, 0>
// Fixed size // Fixed size
#ifndef EIGEN_EMULATE_CXX11_META_H #ifndef EIGEN_EMULATE_CXX11_META_H
template <typename std::ptrdiff_t... Indices> template <typename std::ptrdiff_t... Indices>
struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> { struct Sizes {
typedef internal::numeric_list<std::ptrdiff_t, Indices...> Base; typedef internal::numeric_list<std::ptrdiff_t, Indices...> Base;
#ifdef EIGEN_USE_SYCL const Base t = Base();
const decltype(utility::tuple::make_tuple(Indices...)) t= utility::tuple::make_tuple(Indices...);
#endif
static const std::ptrdiff_t total_size = internal::arg_prod(Indices...); static const std::ptrdiff_t total_size = internal::arg_prod(Indices...);
static const size_t count = Base::count;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t rank() const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t rank() const {
return Base::count; return Base::count;
@ -129,20 +124,16 @@ struct Sizes : internal::numeric_list<std::ptrdiff_t, Indices...> {
} }
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const std::size_t index) const { EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::ptrdiff_t operator[] (const std::size_t index) const {
#ifdef EIGEN_USE_SYCL
return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count>::run(index, t); return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count>::run(index, t);
#else
return internal::fixed_size_tensor_index_extraction_helper<std::ptrdiff_t, Base::count>::run(index, *this);
#endif
} }
template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const { size_t IndexOfColMajor(const array<DenseIndex, Base::count>& indices) const {
return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(indices, *static_cast<const Base*>(this)); return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, false>::run(indices, t);
} }
template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE template <typename DenseIndex> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const { size_t IndexOfRowMajor(const array<DenseIndex, Base::count>& indices) const {
return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(indices, *static_cast<const Base*>(this)); return internal::fixed_size_tensor_index_linearization_helper<DenseIndex, Base::count, Base::count, true>::run(indices, t);
} }
}; };