mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-20 08:39:37 +08:00
Fix cuda 9+ builds
Fix removed `shfl_` intrinsics, disable warnings, update CUDA header inclusion.
This commit is contained in:
parent
43b7aa2412
commit
d2ce4faa5a
@ -254,7 +254,7 @@
|
||||
#endif
|
||||
|
||||
#if defined EIGEN_HAS_CUDA_FP16
|
||||
#include <host_defines.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
|
||||
|
@ -626,25 +626,71 @@ struct hash<Eigen::half> {
|
||||
} // end namespace std
|
||||
|
||||
|
||||
// Add the missing shfl_xor intrinsic
|
||||
#if defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) {
|
||||
#if EIGEN_CUDACC_VER < 90000
|
||||
return static_cast<Eigen::half>(__shfl_xor(static_cast<float>(var), laneMask, width));
|
||||
#else
|
||||
return static_cast<Eigen::half>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(var), laneMask, width));
|
||||
#endif
|
||||
// Add the missing shfl* intrinsics.
|
||||
// The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300.
|
||||
// CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__))
|
||||
//
|
||||
// HIP and CUDA prior to SDK 9.0 define
|
||||
// __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float
|
||||
// CUDA since 9.0 deprecates those and instead defines
|
||||
// __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync,
|
||||
// with native support for __half and __nv_bfloat16
|
||||
//
|
||||
// Note that the following are __device__ - only functions.
|
||||
#if defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 300)
|
||||
#if defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDACC_VER >= 90000
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane,
|
||||
int width = warpSize) {
|
||||
const __half h = var;
|
||||
return static_cast<Eigen::half>(__shfl_sync(mask, h, srcLane, width));
|
||||
}
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta,
|
||||
int width = warpSize) {
|
||||
const __half h = var;
|
||||
return static_cast<Eigen::half>(__shfl_up_sync(mask, h, delta, width));
|
||||
}
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta,
|
||||
int width = warpSize) {
|
||||
const __half h = var;
|
||||
return static_cast<Eigen::half>(__shfl_down_sync(mask, h, delta, width));
|
||||
}
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask,
|
||||
int width = warpSize) {
|
||||
const __half h = var;
|
||||
return static_cast<Eigen::half>(__shfl_xor_sync(mask, h, laneMask, width));
|
||||
}
|
||||
|
||||
#else // CUDA SDK < 9.0
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl(Eigen::half var, int srcLane, int width = warpSize) {
|
||||
return static_cast<Eigen::half>(__shfl(static_cast<float>(var), laneMask, width));
|
||||
}
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up(Eigen::half var, unsigned int delta, int width = warpSize) {
|
||||
return static_cast<Eigen::half>(__shfl_up(static_cast<float>(var), laneMask, width));
|
||||
}
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down(Eigen::half var, unsigned int delta, int width = warpSize) {
|
||||
return static_cast<Eigen::half>(__shfl_down(static_cast<float>(var), laneMask, width));
|
||||
}
|
||||
|
||||
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width = warpSize) {
|
||||
return static_cast<Eigen::half>(__shfl_xor(static_cast<float>(var), laneMask, width));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // __shfl*
|
||||
|
||||
// ldg() has an overload for __half_raw, but we also need one for Eigen::half.
|
||||
#if defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 350
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half __ldg(const Eigen::half* ptr) {
|
||||
return Eigen::half_impl::raw_uint16_to_half(
|
||||
__ldg(reinterpret_cast<const unsigned short*>(ptr)));
|
||||
#if defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 350)
|
||||
EIGEN_STRONG_INLINE __device__ Eigen::half __ldg(const Eigen::half* ptr) {
|
||||
return Eigen::half_impl::raw_uint16_to_half(__ldg(reinterpret_cast<const Eigen::numext::uint16_t*>(ptr)));
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // __ldg
|
||||
|
||||
#if defined(EIGEN_CUDA_ARCH)
|
||||
namespace Eigen {
|
||||
|
@ -1,94 +1,146 @@
|
||||
#ifndef EIGEN_WARNINGS_DISABLED
|
||||
#define EIGEN_WARNINGS_DISABLED
|
||||
|
||||
#ifdef _MSC_VER
|
||||
// 4100 - unreferenced formal parameter (occurred e.g. in aligned_allocator::destroy(pointer p))
|
||||
// 4101 - unreferenced local variable
|
||||
// 4127 - conditional expression is constant
|
||||
// 4181 - qualifier applied to reference type ignored
|
||||
// 4211 - nonstandard extension used : redefined extern to static
|
||||
// 4244 - 'argument' : conversion from 'type1' to 'type2', possible loss of data
|
||||
// 4273 - QtAlignedMalloc, inconsistent DLL linkage
|
||||
// 4324 - structure was padded due to declspec(align())
|
||||
// 4503 - decorated name length exceeded, name was truncated
|
||||
// 4512 - assignment operator could not be generated
|
||||
// 4522 - 'class' : multiple assignment operators specified
|
||||
// 4700 - uninitialized local variable 'xyz' used
|
||||
// 4714 - function marked as __forceinline not inlined
|
||||
// 4717 - 'function' : recursive on all control paths, function will cause runtime stack overflow
|
||||
// 4800 - 'type' : forcing value to bool 'true' or 'false' (performance warning)
|
||||
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
||||
#pragma warning( push )
|
||||
#endif
|
||||
#pragma warning( disable : 4100 4101 4127 4181 4211 4244 4273 4324 4503 4512 4522 4700 4714 4717 4800)
|
||||
|
||||
#elif defined __INTEL_COMPILER
|
||||
// 2196 - routine is both "inline" and "noinline" ("noinline" assumed)
|
||||
// ICC 12 generates this warning even without any inline keyword, when defining class methods 'inline' i.e. inside of class body
|
||||
// typedef that may be a reference type.
|
||||
// 279 - controlling expression is constant
|
||||
// ICC 12 generates this warning on assert(constant_expression_depending_on_template_params) and frankly this is a legitimate use case.
|
||||
// 1684 - conversion from pointer to same-sized integral type (potential portability problem)
|
||||
// 2259 - non-pointer conversion from "Eigen::Index={ptrdiff_t={long}}" to "int" may lose significant bits
|
||||
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
||||
#pragma warning push
|
||||
#endif
|
||||
#pragma warning disable 2196 279 1684 2259
|
||||
|
||||
#elif defined __clang__
|
||||
// -Wconstant-logical-operand - warning: use of logical && with constant operand; switch to bitwise & or remove constant
|
||||
// this is really a stupid warning as it warns on compile-time expressions involving enums
|
||||
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
||||
#pragma clang diagnostic push
|
||||
#endif
|
||||
#pragma clang diagnostic ignored "-Wconstant-logical-operand"
|
||||
|
||||
#elif defined __GNUC__
|
||||
|
||||
#if (!defined(EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS)) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))
|
||||
#pragma GCC diagnostic push
|
||||
#endif
|
||||
// g++ warns about local variables shadowing member functions, which is too strict
|
||||
#pragma GCC diagnostic ignored "-Wshadow"
|
||||
#if __GNUC__ == 4 && __GNUC_MINOR__ < 8
|
||||
// Until g++-4.7 there are warnings when comparing unsigned int vs 0, even in templated functions:
|
||||
#pragma GCC diagnostic ignored "-Wtype-limits"
|
||||
#endif
|
||||
#if __GNUC__>=6
|
||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||
#endif
|
||||
#if __GNUC__==7
|
||||
// See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89325
|
||||
#pragma GCC diagnostic ignored "-Wattributes"
|
||||
#endif
|
||||
#if defined(_MSC_VER)
|
||||
// 4100 - unreferenced formal parameter (occurred e.g. in aligned_allocator::destroy(pointer p))
|
||||
// 4101 - unreferenced local variable
|
||||
// 4127 - conditional expression is constant
|
||||
// 4181 - qualifier applied to reference type ignored
|
||||
// 4211 - nonstandard extension used : redefined extern to static
|
||||
// 4244 - 'argument' : conversion from 'type1' to 'type2', possible loss of data
|
||||
// 4273 - QtAlignedMalloc, inconsistent DLL linkage
|
||||
// 4324 - structure was padded due to declspec(align())
|
||||
// 4503 - decorated name length exceeded, name was truncated
|
||||
// 4512 - assignment operator could not be generated
|
||||
// 4522 - 'class' : multiple assignment operators specified
|
||||
// 4700 - uninitialized local variable 'xyz' used
|
||||
// 4714 - function marked as __forceinline not inlined
|
||||
// 4717 - 'function' : recursive on all control paths, function will cause runtime stack overflow
|
||||
// 4800 - 'type' : forcing value to bool 'true' or 'false' (performance warning)
|
||||
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
||||
#pragma warning(push)
|
||||
#endif
|
||||
#pragma warning(disable : 4100 4101 4127 4181 4211 4244 4273 4324 4503 4512 4522 4700 4714 4717 4800)
|
||||
// We currently rely on has_denorm in tests, and need it defined correctly for half/bfloat16.
|
||||
#ifndef _SILENCE_CXX23_DENORM_DEPRECATION_WARNING
|
||||
#define EIGEN_REENABLE_CXX23_DENORM_DEPRECATION_WARNING 1
|
||||
#define _SILENCE_CXX23_DENORM_DEPRECATION_WARNING
|
||||
#endif
|
||||
|
||||
#if defined __NVCC__
|
||||
// Disable the "statement is unreachable" message
|
||||
#pragma diag_suppress code_is_unreachable
|
||||
// Disable the "dynamic initialization in unreachable code" message
|
||||
#pragma diag_suppress initialization_not_reachable
|
||||
// Disable the "invalid error number" message that we get with older versions of nvcc
|
||||
#pragma diag_suppress 1222
|
||||
// Disable the "calling a __host__ function from a __host__ __device__ function is not allowed" messages (yes, there are many of them and they seem to change with every version of the compiler)
|
||||
#pragma diag_suppress 2527
|
||||
#pragma diag_suppress 2529
|
||||
#pragma diag_suppress 2651
|
||||
#pragma diag_suppress 2653
|
||||
#pragma diag_suppress 2668
|
||||
#pragma diag_suppress 2669
|
||||
#pragma diag_suppress 2670
|
||||
#pragma diag_suppress 2671
|
||||
#pragma diag_suppress 2735
|
||||
#pragma diag_suppress 2737
|
||||
#elif defined __INTEL_COMPILER
|
||||
// 2196 - routine is both "inline" and "noinline" ("noinline" assumed)
|
||||
// ICC 12 generates this warning even without any inline keyword, when defining class methods 'inline' i.e.
|
||||
// inside of class body typedef that may be a reference type.
|
||||
// 279 - controlling expression is constant
|
||||
// ICC 12 generates this warning on assert(constant_expression_depending_on_template_params) and frankly this is
|
||||
// a legitimate use case.
|
||||
// 1684 - conversion from pointer to same-sized integral type (potential portability problem)
|
||||
// 2259 - non-pointer conversion from "Eigen::Index={ptrdiff_t={long}}" to "int" may lose significant bits
|
||||
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
||||
#pragma warning push
|
||||
#endif
|
||||
#pragma warning disable 2196 279 1684 2259
|
||||
|
||||
#elif defined __clang__
|
||||
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
||||
#pragma clang diagnostic push
|
||||
#endif
|
||||
#if defined(__has_warning)
|
||||
// -Wconstant-logical-operand - warning: use of logical && with constant operand; switch to bitwise & or remove constant
|
||||
// this is really a stupid warning as it warns on compile-time expressions involving enums
|
||||
#if __has_warning("-Wconstant-logical-operand")
|
||||
#pragma clang diagnostic ignored "-Wconstant-logical-operand"
|
||||
#endif
|
||||
#if __has_warning("-Wimplicit-int-float-conversion")
|
||||
#pragma clang diagnostic ignored "-Wimplicit-int-float-conversion"
|
||||
#endif
|
||||
#if (defined(__ALTIVEC__) || defined(__VSX__)) && (!defined(__STDC_VERSION__) || (__STDC_VERSION__ < 201112L))
|
||||
// warning: generic selections are a C11-specific feature
|
||||
// ignoring warnings thrown at vec_ctf in Altivec/PacketMath.h
|
||||
#if __has_warning("-Wc11-extensions")
|
||||
#pragma clang diagnostic ignored "-Wc11-extensions"
|
||||
#endif
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#elif defined __GNUC__ && !defined(__FUJITSU)
|
||||
|
||||
#if (!defined(EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS)) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))
|
||||
#pragma GCC diagnostic push
|
||||
#endif
|
||||
// g++ warns about local variables shadowing member functions, which is too strict
|
||||
#pragma GCC diagnostic ignored "-Wshadow"
|
||||
#if __GNUC__ == 4 && __GNUC_MINOR__ < 8
|
||||
// Until g++-4.7 there are warnings when comparing unsigned int vs 0, even in templated functions:
|
||||
#pragma GCC diagnostic ignored "-Wtype-limits"
|
||||
#endif
|
||||
#if __GNUC__ >= 6
|
||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||
#endif
|
||||
#if __GNUC__ == 7
|
||||
// See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89325
|
||||
#pragma GCC diagnostic ignored "-Wattributes"
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined __NVCC__ && defined __CUDACC__
|
||||
// MSVC 14.16 (required by CUDA 9.*) does not support the _Pragma keyword, so
|
||||
// we instead use Microsoft's __pragma extension.
|
||||
#if defined _MSC_VER
|
||||
#define EIGEN_MAKE_PRAGMA(X) __pragma(#X)
|
||||
#else
|
||||
#define EIGEN_MAKE_PRAGMA(X) _Pragma(#X)
|
||||
#endif
|
||||
#if defined __NVCC_DIAG_PRAGMA_SUPPORT__
|
||||
#define EIGEN_NV_DIAG_SUPPRESS(X) EIGEN_MAKE_PRAGMA(nv_diag_suppress X)
|
||||
#else
|
||||
#define EIGEN_NV_DIAG_SUPPRESS(X) EIGEN_MAKE_PRAGMA(diag_suppress X)
|
||||
#endif
|
||||
|
||||
EIGEN_NV_DIAG_SUPPRESS(boolean_controlling_expr_is_constant)
|
||||
// Disable the "statement is unreachable" message
|
||||
EIGEN_NV_DIAG_SUPPRESS(code_is_unreachable)
|
||||
// Disable the "dynamic initialization in unreachable code" message
|
||||
EIGEN_NV_DIAG_SUPPRESS(initialization_not_reachable)
|
||||
// Disable the "invalid error number" message that we get with older versions of nvcc
|
||||
EIGEN_NV_DIAG_SUPPRESS(1222)
|
||||
// Disable the "calling a __host__ function from a __host__ __device__ function is not allowed" messages (yes, there are
|
||||
// many of them and they seem to change with every version of the compiler)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2527)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2529)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2651)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2653)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2668)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2669)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2670)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2671)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2735)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2737)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2739)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2885)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2888)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2976)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2979)
|
||||
EIGEN_NV_DIAG_SUPPRESS(20011)
|
||||
EIGEN_NV_DIAG_SUPPRESS(20014)
|
||||
// Disable the "// __device__ annotation is ignored on a function(...) that is
|
||||
// explicitly defaulted on its first declaration" message.
|
||||
// The __device__ annotation seems to actually be needed in some cases,
|
||||
// otherwise resulting in kernel runtime errors.
|
||||
EIGEN_NV_DIAG_SUPPRESS(2886)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2929)
|
||||
EIGEN_NV_DIAG_SUPPRESS(2977)
|
||||
EIGEN_NV_DIAG_SUPPRESS(20012)
|
||||
#undef EIGEN_NV_DIAG_SUPPRESS
|
||||
#undef EIGEN_MAKE_PRAGMA
|
||||
#endif
|
||||
|
||||
#else
|
||||
// warnings already disabled:
|
||||
# ifndef EIGEN_WARNINGS_DISABLED_2
|
||||
# define EIGEN_WARNINGS_DISABLED_2
|
||||
# elif defined(EIGEN_INTERNAL_DEBUGGING)
|
||||
# error "Do not include \"DisableStupidWarnings.h\" recursively more than twice!"
|
||||
# endif
|
||||
#ifndef EIGEN_WARNINGS_DISABLED_2
|
||||
#define EIGEN_WARNINGS_DISABLED_2
|
||||
#elif defined(EIGEN_INTERNAL_DEBUGGING)
|
||||
#error "Do not include \"DisableStupidWarnings.h\" recursively more than twice!"
|
||||
#endif
|
||||
|
||||
#endif // not EIGEN_WARNINGS_DISABLED
|
||||
#endif // not EIGEN_WARNINGS_DISABLED
|
||||
|
@ -388,7 +388,11 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
// the sum across all big k blocks of the product of little k block of index (x, y)
|
||||
// with block of index (y, z). To compute the final output, we need to reduce
|
||||
// the 8 threads over y by summation.
|
||||
#if EIGEN_CUDACC_VER < 90000
|
||||
#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
|
||||
#else
|
||||
#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
|
||||
#endif
|
||||
|
||||
#define reduceRow(i, mask) \
|
||||
shuffleInc(i, 0, mask); \
|
||||
@ -543,12 +547,12 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
||||
#define prefetch_lhs(reg, row, col) \
|
||||
if (!CHECK_LHS_BOUNDARY) { \
|
||||
if (col < k_size) { \
|
||||
reg =lhs.loadPacket<Unaligned>(row, col); \
|
||||
reg =lhs.template loadPacket<Unaligned>(row, col); \
|
||||
} \
|
||||
} else { \
|
||||
if (col < k_size) { \
|
||||
if (row + 3 < m_size) { \
|
||||
reg =lhs.loadPacket<Unaligned>(row, col); \
|
||||
reg =lhs.template loadPacket<Unaligned>(row, col); \
|
||||
} else if (row + 2 < m_size) { \
|
||||
reg.x =lhs(row + 0, col); \
|
||||
reg.y =lhs(row + 1, col); \
|
||||
@ -578,7 +582,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
||||
if (!CHECK_RHS_BOUNDARY) {
|
||||
if ((rhs_vert + 3) < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
} else if (rhs_vert + 2 < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||
@ -593,7 +597,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
||||
} else {
|
||||
if (rhs_horiz0 < n_size) {
|
||||
if ((rhs_vert + 3) < k_size) {
|
||||
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
} else if ((rhs_vert + 2) < k_size) {
|
||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
|
||||
@ -615,8 +619,13 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
||||
x1 = rhs_pf0.x;
|
||||
x2 = rhs_pf0.z;
|
||||
}
|
||||
#if EIGEN_CUDACC_VER < 90000
|
||||
x1 = __shfl_xor(x1, 4);
|
||||
x2 = __shfl_xor(x2, 4);
|
||||
#else
|
||||
x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
|
||||
x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
|
||||
#endif
|
||||
if((threadIdx.x%8) < 4) {
|
||||
rhs_pf0.y = x1;
|
||||
rhs_pf0.w = x2;
|
||||
@ -790,37 +799,37 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
|
||||
if (!CHECK_LHS_BOUNDARY) {
|
||||
if ((threadIdx.y/4+k+24) < k_size) {
|
||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
||||
} else if ((threadIdx.y/4+k+16) < k_size) {
|
||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
} else if ((threadIdx.y/4+k+8) < k_size) {
|
||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
} else if ((threadIdx.y/4+k) < k_size) {
|
||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
}
|
||||
} else {
|
||||
// just CHECK_LHS_BOUNDARY
|
||||
if (lhs_vert + 3 < m_size) {
|
||||
if ((threadIdx.y/4+k+24) < k_size) {
|
||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
||||
} else if ((threadIdx.y/4+k+16) < k_size) {
|
||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||
} else if ((threadIdx.y/4+k+8) < k_size) {
|
||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||
} else if ((threadIdx.y/4+k) < k_size) {
|
||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||
}
|
||||
} else if (lhs_vert + 2 < m_size) {
|
||||
if ((threadIdx.y/4+k+24) < k_size) {
|
||||
@ -909,8 +918,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
if (!CHECK_RHS_BOUNDARY) {
|
||||
if ((rhs_vert + 3) < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
|
||||
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
|
||||
} else if (rhs_vert + 2 < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||
@ -932,8 +941,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
if (rhs_horiz1 < n_size) {
|
||||
if ((rhs_vert + 3) < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
|
||||
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
|
||||
} else if (rhs_vert + 2 < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||
@ -954,7 +963,7 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||
} else if (rhs_horiz0 < n_size) {
|
||||
if ((rhs_vert + 3) < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||
} else if ((rhs_vert + 2) < k_size) {
|
||||
// just CHECK_RHS_BOUNDARY
|
||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||
|
Loading…
x
Reference in New Issue
Block a user