mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-03 03:35:11 +08:00
Fix cuda 9+ builds
Fix removed `shfl_` intrinsics, disable warnings, update CUDA header inclusion.
This commit is contained in:
parent
43b7aa2412
commit
d2ce4faa5a
@ -254,7 +254,7 @@
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined EIGEN_HAS_CUDA_FP16
|
#if defined EIGEN_HAS_CUDA_FP16
|
||||||
#include <host_defines.h>
|
#include <cuda_runtime_api.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -626,25 +626,71 @@ struct hash<Eigen::half> {
|
|||||||
} // end namespace std
|
} // end namespace std
|
||||||
|
|
||||||
|
|
||||||
// Add the missing shfl_xor intrinsic
|
// Add the missing shfl* intrinsics.
|
||||||
#if defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300
|
// The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300.
|
||||||
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) {
|
// CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__))
|
||||||
#if EIGEN_CUDACC_VER < 90000
|
//
|
||||||
return static_cast<Eigen::half>(__shfl_xor(static_cast<float>(var), laneMask, width));
|
// HIP and CUDA prior to SDK 9.0 define
|
||||||
#else
|
// __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float
|
||||||
return static_cast<Eigen::half>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(var), laneMask, width));
|
// CUDA since 9.0 deprecates those and instead defines
|
||||||
#endif
|
// __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync,
|
||||||
|
// with native support for __half and __nv_bfloat16
|
||||||
|
//
|
||||||
|
// Note that the following are __device__ - only functions.
|
||||||
|
#if defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 300)
|
||||||
|
#if defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDACC_VER >= 90000
|
||||||
|
|
||||||
|
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane,
|
||||||
|
int width = warpSize) {
|
||||||
|
const __half h = var;
|
||||||
|
return static_cast<Eigen::half>(__shfl_sync(mask, h, srcLane, width));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta,
|
||||||
|
int width = warpSize) {
|
||||||
|
const __half h = var;
|
||||||
|
return static_cast<Eigen::half>(__shfl_up_sync(mask, h, delta, width));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta,
|
||||||
|
int width = warpSize) {
|
||||||
|
const __half h = var;
|
||||||
|
return static_cast<Eigen::half>(__shfl_down_sync(mask, h, delta, width));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask,
|
||||||
|
int width = warpSize) {
|
||||||
|
const __half h = var;
|
||||||
|
return static_cast<Eigen::half>(__shfl_xor_sync(mask, h, laneMask, width));
|
||||||
|
}
|
||||||
|
|
||||||
|
#else // CUDA SDK < 9.0
|
||||||
|
|
||||||
|
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl(Eigen::half var, int srcLane, int width = warpSize) {
|
||||||
|
return static_cast<Eigen::half>(__shfl(static_cast<float>(var), laneMask, width));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up(Eigen::half var, unsigned int delta, int width = warpSize) {
|
||||||
|
return static_cast<Eigen::half>(__shfl_up(static_cast<float>(var), laneMask, width));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down(Eigen::half var, unsigned int delta, int width = warpSize) {
|
||||||
|
return static_cast<Eigen::half>(__shfl_down(static_cast<float>(var), laneMask, width));
|
||||||
|
}
|
||||||
|
|
||||||
|
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width = warpSize) {
|
||||||
|
return static_cast<Eigen::half>(__shfl_xor(static_cast<float>(var), laneMask, width));
|
||||||
|
}
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
#endif // __shfl*
|
||||||
|
|
||||||
// ldg() has an overload for __half_raw, but we also need one for Eigen::half.
|
// ldg() has an overload for __half_raw, but we also need one for Eigen::half.
|
||||||
#if defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 350
|
#if defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 350)
|
||||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half __ldg(const Eigen::half* ptr) {
|
EIGEN_STRONG_INLINE __device__ Eigen::half __ldg(const Eigen::half* ptr) {
|
||||||
return Eigen::half_impl::raw_uint16_to_half(
|
return Eigen::half_impl::raw_uint16_to_half(__ldg(reinterpret_cast<const Eigen::numext::uint16_t*>(ptr)));
|
||||||
__ldg(reinterpret_cast<const unsigned short*>(ptr)));
|
|
||||||
}
|
}
|
||||||
#endif
|
#endif // __ldg
|
||||||
|
|
||||||
|
|
||||||
#if defined(EIGEN_CUDA_ARCH)
|
#if defined(EIGEN_CUDA_ARCH)
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
|
@ -1,94 +1,146 @@
|
|||||||
#ifndef EIGEN_WARNINGS_DISABLED
|
#ifndef EIGEN_WARNINGS_DISABLED
|
||||||
#define EIGEN_WARNINGS_DISABLED
|
#define EIGEN_WARNINGS_DISABLED
|
||||||
|
|
||||||
#ifdef _MSC_VER
|
#if defined(_MSC_VER)
|
||||||
// 4100 - unreferenced formal parameter (occurred e.g. in aligned_allocator::destroy(pointer p))
|
// 4100 - unreferenced formal parameter (occurred e.g. in aligned_allocator::destroy(pointer p))
|
||||||
// 4101 - unreferenced local variable
|
// 4101 - unreferenced local variable
|
||||||
// 4127 - conditional expression is constant
|
// 4127 - conditional expression is constant
|
||||||
// 4181 - qualifier applied to reference type ignored
|
// 4181 - qualifier applied to reference type ignored
|
||||||
// 4211 - nonstandard extension used : redefined extern to static
|
// 4211 - nonstandard extension used : redefined extern to static
|
||||||
// 4244 - 'argument' : conversion from 'type1' to 'type2', possible loss of data
|
// 4244 - 'argument' : conversion from 'type1' to 'type2', possible loss of data
|
||||||
// 4273 - QtAlignedMalloc, inconsistent DLL linkage
|
// 4273 - QtAlignedMalloc, inconsistent DLL linkage
|
||||||
// 4324 - structure was padded due to declspec(align())
|
// 4324 - structure was padded due to declspec(align())
|
||||||
// 4503 - decorated name length exceeded, name was truncated
|
// 4503 - decorated name length exceeded, name was truncated
|
||||||
// 4512 - assignment operator could not be generated
|
// 4512 - assignment operator could not be generated
|
||||||
// 4522 - 'class' : multiple assignment operators specified
|
// 4522 - 'class' : multiple assignment operators specified
|
||||||
// 4700 - uninitialized local variable 'xyz' used
|
// 4700 - uninitialized local variable 'xyz' used
|
||||||
// 4714 - function marked as __forceinline not inlined
|
// 4714 - function marked as __forceinline not inlined
|
||||||
// 4717 - 'function' : recursive on all control paths, function will cause runtime stack overflow
|
// 4717 - 'function' : recursive on all control paths, function will cause runtime stack overflow
|
||||||
// 4800 - 'type' : forcing value to bool 'true' or 'false' (performance warning)
|
// 4800 - 'type' : forcing value to bool 'true' or 'false' (performance warning)
|
||||||
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
||||||
#pragma warning( push )
|
#pragma warning(push)
|
||||||
#endif
|
#endif
|
||||||
#pragma warning( disable : 4100 4101 4127 4181 4211 4244 4273 4324 4503 4512 4522 4700 4714 4717 4800)
|
#pragma warning(disable : 4100 4101 4127 4181 4211 4244 4273 4324 4503 4512 4522 4700 4714 4717 4800)
|
||||||
|
// We currently rely on has_denorm in tests, and need it defined correctly for half/bfloat16.
|
||||||
#elif defined __INTEL_COMPILER
|
#ifndef _SILENCE_CXX23_DENORM_DEPRECATION_WARNING
|
||||||
// 2196 - routine is both "inline" and "noinline" ("noinline" assumed)
|
#define EIGEN_REENABLE_CXX23_DENORM_DEPRECATION_WARNING 1
|
||||||
// ICC 12 generates this warning even without any inline keyword, when defining class methods 'inline' i.e. inside of class body
|
#define _SILENCE_CXX23_DENORM_DEPRECATION_WARNING
|
||||||
// typedef that may be a reference type.
|
|
||||||
// 279 - controlling expression is constant
|
|
||||||
// ICC 12 generates this warning on assert(constant_expression_depending_on_template_params) and frankly this is a legitimate use case.
|
|
||||||
// 1684 - conversion from pointer to same-sized integral type (potential portability problem)
|
|
||||||
// 2259 - non-pointer conversion from "Eigen::Index={ptrdiff_t={long}}" to "int" may lose significant bits
|
|
||||||
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
|
||||||
#pragma warning push
|
|
||||||
#endif
|
|
||||||
#pragma warning disable 2196 279 1684 2259
|
|
||||||
|
|
||||||
#elif defined __clang__
|
|
||||||
// -Wconstant-logical-operand - warning: use of logical && with constant operand; switch to bitwise & or remove constant
|
|
||||||
// this is really a stupid warning as it warns on compile-time expressions involving enums
|
|
||||||
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
|
||||||
#pragma clang diagnostic push
|
|
||||||
#endif
|
|
||||||
#pragma clang diagnostic ignored "-Wconstant-logical-operand"
|
|
||||||
|
|
||||||
#elif defined __GNUC__
|
|
||||||
|
|
||||||
#if (!defined(EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS)) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))
|
|
||||||
#pragma GCC diagnostic push
|
|
||||||
#endif
|
|
||||||
// g++ warns about local variables shadowing member functions, which is too strict
|
|
||||||
#pragma GCC diagnostic ignored "-Wshadow"
|
|
||||||
#if __GNUC__ == 4 && __GNUC_MINOR__ < 8
|
|
||||||
// Until g++-4.7 there are warnings when comparing unsigned int vs 0, even in templated functions:
|
|
||||||
#pragma GCC diagnostic ignored "-Wtype-limits"
|
|
||||||
#endif
|
|
||||||
#if __GNUC__>=6
|
|
||||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
|
||||||
#endif
|
|
||||||
#if __GNUC__==7
|
|
||||||
// See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89325
|
|
||||||
#pragma GCC diagnostic ignored "-Wattributes"
|
|
||||||
#endif
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined __NVCC__
|
#elif defined __INTEL_COMPILER
|
||||||
// Disable the "statement is unreachable" message
|
// 2196 - routine is both "inline" and "noinline" ("noinline" assumed)
|
||||||
#pragma diag_suppress code_is_unreachable
|
// ICC 12 generates this warning even without any inline keyword, when defining class methods 'inline' i.e.
|
||||||
// Disable the "dynamic initialization in unreachable code" message
|
// inside of class body typedef that may be a reference type.
|
||||||
#pragma diag_suppress initialization_not_reachable
|
// 279 - controlling expression is constant
|
||||||
// Disable the "invalid error number" message that we get with older versions of nvcc
|
// ICC 12 generates this warning on assert(constant_expression_depending_on_template_params) and frankly this is
|
||||||
#pragma diag_suppress 1222
|
// a legitimate use case.
|
||||||
// Disable the "calling a __host__ function from a __host__ __device__ function is not allowed" messages (yes, there are many of them and they seem to change with every version of the compiler)
|
// 1684 - conversion from pointer to same-sized integral type (potential portability problem)
|
||||||
#pragma diag_suppress 2527
|
// 2259 - non-pointer conversion from "Eigen::Index={ptrdiff_t={long}}" to "int" may lose significant bits
|
||||||
#pragma diag_suppress 2529
|
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
||||||
#pragma diag_suppress 2651
|
#pragma warning push
|
||||||
#pragma diag_suppress 2653
|
#endif
|
||||||
#pragma diag_suppress 2668
|
#pragma warning disable 2196 279 1684 2259
|
||||||
#pragma diag_suppress 2669
|
|
||||||
#pragma diag_suppress 2670
|
#elif defined __clang__
|
||||||
#pragma diag_suppress 2671
|
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
|
||||||
#pragma diag_suppress 2735
|
#pragma clang diagnostic push
|
||||||
#pragma diag_suppress 2737
|
#endif
|
||||||
|
#if defined(__has_warning)
|
||||||
|
// -Wconstant-logical-operand - warning: use of logical && with constant operand; switch to bitwise & or remove constant
|
||||||
|
// this is really a stupid warning as it warns on compile-time expressions involving enums
|
||||||
|
#if __has_warning("-Wconstant-logical-operand")
|
||||||
|
#pragma clang diagnostic ignored "-Wconstant-logical-operand"
|
||||||
|
#endif
|
||||||
|
#if __has_warning("-Wimplicit-int-float-conversion")
|
||||||
|
#pragma clang diagnostic ignored "-Wimplicit-int-float-conversion"
|
||||||
|
#endif
|
||||||
|
#if (defined(__ALTIVEC__) || defined(__VSX__)) && (!defined(__STDC_VERSION__) || (__STDC_VERSION__ < 201112L))
|
||||||
|
// warning: generic selections are a C11-specific feature
|
||||||
|
// ignoring warnings thrown at vec_ctf in Altivec/PacketMath.h
|
||||||
|
#if __has_warning("-Wc11-extensions")
|
||||||
|
#pragma clang diagnostic ignored "-Wc11-extensions"
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#elif defined __GNUC__ && !defined(__FUJITSU)
|
||||||
|
|
||||||
|
#if (!defined(EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS)) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#endif
|
||||||
|
// g++ warns about local variables shadowing member functions, which is too strict
|
||||||
|
#pragma GCC diagnostic ignored "-Wshadow"
|
||||||
|
#if __GNUC__ == 4 && __GNUC_MINOR__ < 8
|
||||||
|
// Until g++-4.7 there are warnings when comparing unsigned int vs 0, even in templated functions:
|
||||||
|
#pragma GCC diagnostic ignored "-Wtype-limits"
|
||||||
|
#endif
|
||||||
|
#if __GNUC__ >= 6
|
||||||
|
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||||
|
#endif
|
||||||
|
#if __GNUC__ == 7
|
||||||
|
// See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89325
|
||||||
|
#pragma GCC diagnostic ignored "-Wattributes"
|
||||||
|
#endif
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#if defined __NVCC__ && defined __CUDACC__
|
||||||
|
// MSVC 14.16 (required by CUDA 9.*) does not support the _Pragma keyword, so
|
||||||
|
// we instead use Microsoft's __pragma extension.
|
||||||
|
#if defined _MSC_VER
|
||||||
|
#define EIGEN_MAKE_PRAGMA(X) __pragma(#X)
|
||||||
|
#else
|
||||||
|
#define EIGEN_MAKE_PRAGMA(X) _Pragma(#X)
|
||||||
|
#endif
|
||||||
|
#if defined __NVCC_DIAG_PRAGMA_SUPPORT__
|
||||||
|
#define EIGEN_NV_DIAG_SUPPRESS(X) EIGEN_MAKE_PRAGMA(nv_diag_suppress X)
|
||||||
|
#else
|
||||||
|
#define EIGEN_NV_DIAG_SUPPRESS(X) EIGEN_MAKE_PRAGMA(diag_suppress X)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(boolean_controlling_expr_is_constant)
|
||||||
|
// Disable the "statement is unreachable" message
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(code_is_unreachable)
|
||||||
|
// Disable the "dynamic initialization in unreachable code" message
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(initialization_not_reachable)
|
||||||
|
// Disable the "invalid error number" message that we get with older versions of nvcc
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(1222)
|
||||||
|
// Disable the "calling a __host__ function from a __host__ __device__ function is not allowed" messages (yes, there are
|
||||||
|
// many of them and they seem to change with every version of the compiler)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2527)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2529)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2651)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2653)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2668)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2669)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2670)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2671)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2735)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2737)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2739)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2885)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2888)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2976)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2979)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(20011)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(20014)
|
||||||
|
// Disable the "// __device__ annotation is ignored on a function(...) that is
|
||||||
|
// explicitly defaulted on its first declaration" message.
|
||||||
|
// The __device__ annotation seems to actually be needed in some cases,
|
||||||
|
// otherwise resulting in kernel runtime errors.
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2886)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2929)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(2977)
|
||||||
|
EIGEN_NV_DIAG_SUPPRESS(20012)
|
||||||
|
#undef EIGEN_NV_DIAG_SUPPRESS
|
||||||
|
#undef EIGEN_MAKE_PRAGMA
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#else
|
#else
|
||||||
// warnings already disabled:
|
// warnings already disabled:
|
||||||
# ifndef EIGEN_WARNINGS_DISABLED_2
|
#ifndef EIGEN_WARNINGS_DISABLED_2
|
||||||
# define EIGEN_WARNINGS_DISABLED_2
|
#define EIGEN_WARNINGS_DISABLED_2
|
||||||
# elif defined(EIGEN_INTERNAL_DEBUGGING)
|
#elif defined(EIGEN_INTERNAL_DEBUGGING)
|
||||||
# error "Do not include \"DisableStupidWarnings.h\" recursively more than twice!"
|
#error "Do not include \"DisableStupidWarnings.h\" recursively more than twice!"
|
||||||
# endif
|
#endif
|
||||||
|
|
||||||
#endif // not EIGEN_WARNINGS_DISABLED
|
#endif // not EIGEN_WARNINGS_DISABLED
|
||||||
|
@ -388,7 +388,11 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
|||||||
// the sum across all big k blocks of the product of little k block of index (x, y)
|
// the sum across all big k blocks of the product of little k block of index (x, y)
|
||||||
// with block of index (y, z). To compute the final output, we need to reduce
|
// with block of index (y, z). To compute the final output, we need to reduce
|
||||||
// the 8 threads over y by summation.
|
// the 8 threads over y by summation.
|
||||||
|
#if EIGEN_CUDACC_VER < 90000
|
||||||
#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
|
#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
|
||||||
|
#else
|
||||||
|
#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
|
||||||
|
#endif
|
||||||
|
|
||||||
#define reduceRow(i, mask) \
|
#define reduceRow(i, mask) \
|
||||||
shuffleInc(i, 0, mask); \
|
shuffleInc(i, 0, mask); \
|
||||||
@ -543,12 +547,12 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
|||||||
#define prefetch_lhs(reg, row, col) \
|
#define prefetch_lhs(reg, row, col) \
|
||||||
if (!CHECK_LHS_BOUNDARY) { \
|
if (!CHECK_LHS_BOUNDARY) { \
|
||||||
if (col < k_size) { \
|
if (col < k_size) { \
|
||||||
reg =lhs.loadPacket<Unaligned>(row, col); \
|
reg =lhs.template loadPacket<Unaligned>(row, col); \
|
||||||
} \
|
} \
|
||||||
} else { \
|
} else { \
|
||||||
if (col < k_size) { \
|
if (col < k_size) { \
|
||||||
if (row + 3 < m_size) { \
|
if (row + 3 < m_size) { \
|
||||||
reg =lhs.loadPacket<Unaligned>(row, col); \
|
reg =lhs.template loadPacket<Unaligned>(row, col); \
|
||||||
} else if (row + 2 < m_size) { \
|
} else if (row + 2 < m_size) { \
|
||||||
reg.x =lhs(row + 0, col); \
|
reg.x =lhs(row + 0, col); \
|
||||||
reg.y =lhs(row + 1, col); \
|
reg.y =lhs(row + 1, col); \
|
||||||
@ -578,7 +582,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
|||||||
if (!CHECK_RHS_BOUNDARY) {
|
if (!CHECK_RHS_BOUNDARY) {
|
||||||
if ((rhs_vert + 3) < k_size) {
|
if ((rhs_vert + 3) < k_size) {
|
||||||
// just CHECK_RHS_BOUNDARY
|
// just CHECK_RHS_BOUNDARY
|
||||||
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||||
} else if (rhs_vert + 2 < k_size) {
|
} else if (rhs_vert + 2 < k_size) {
|
||||||
// just CHECK_RHS_BOUNDARY
|
// just CHECK_RHS_BOUNDARY
|
||||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||||
@ -593,7 +597,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
|||||||
} else {
|
} else {
|
||||||
if (rhs_horiz0 < n_size) {
|
if (rhs_horiz0 < n_size) {
|
||||||
if ((rhs_vert + 3) < k_size) {
|
if ((rhs_vert + 3) < k_size) {
|
||||||
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||||
} else if ((rhs_vert + 2) < k_size) {
|
} else if ((rhs_vert + 2) < k_size) {
|
||||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||||
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
|
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
|
||||||
@ -615,8 +619,13 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
|
|||||||
x1 = rhs_pf0.x;
|
x1 = rhs_pf0.x;
|
||||||
x2 = rhs_pf0.z;
|
x2 = rhs_pf0.z;
|
||||||
}
|
}
|
||||||
|
#if EIGEN_CUDACC_VER < 90000
|
||||||
x1 = __shfl_xor(x1, 4);
|
x1 = __shfl_xor(x1, 4);
|
||||||
x2 = __shfl_xor(x2, 4);
|
x2 = __shfl_xor(x2, 4);
|
||||||
|
#else
|
||||||
|
x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
|
||||||
|
x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
|
||||||
|
#endif
|
||||||
if((threadIdx.x%8) < 4) {
|
if((threadIdx.x%8) < 4) {
|
||||||
rhs_pf0.y = x1;
|
rhs_pf0.y = x1;
|
||||||
rhs_pf0.w = x2;
|
rhs_pf0.w = x2;
|
||||||
@ -790,37 +799,37 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
|||||||
|
|
||||||
if (!CHECK_LHS_BOUNDARY) {
|
if (!CHECK_LHS_BOUNDARY) {
|
||||||
if ((threadIdx.y/4+k+24) < k_size) {
|
if ((threadIdx.y/4+k+24) < k_size) {
|
||||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||||
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||||
lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
||||||
} else if ((threadIdx.y/4+k+16) < k_size) {
|
} else if ((threadIdx.y/4+k+16) < k_size) {
|
||||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||||
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||||
} else if ((threadIdx.y/4+k+8) < k_size) {
|
} else if ((threadIdx.y/4+k+8) < k_size) {
|
||||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||||
} else if ((threadIdx.y/4+k) < k_size) {
|
} else if ((threadIdx.y/4+k) < k_size) {
|
||||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// just CHECK_LHS_BOUNDARY
|
// just CHECK_LHS_BOUNDARY
|
||||||
if (lhs_vert + 3 < m_size) {
|
if (lhs_vert + 3 < m_size) {
|
||||||
if ((threadIdx.y/4+k+24) < k_size) {
|
if ((threadIdx.y/4+k+24) < k_size) {
|
||||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||||
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||||
lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
|
||||||
} else if ((threadIdx.y/4+k+16) < k_size) {
|
} else if ((threadIdx.y/4+k+16) < k_size) {
|
||||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||||
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
|
||||||
} else if ((threadIdx.y/4+k+8) < k_size) {
|
} else if ((threadIdx.y/4+k+8) < k_size) {
|
||||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||||
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
|
||||||
} else if ((threadIdx.y/4+k) < k_size) {
|
} else if ((threadIdx.y/4+k) < k_size) {
|
||||||
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
|
||||||
}
|
}
|
||||||
} else if (lhs_vert + 2 < m_size) {
|
} else if (lhs_vert + 2 < m_size) {
|
||||||
if ((threadIdx.y/4+k+24) < k_size) {
|
if ((threadIdx.y/4+k+24) < k_size) {
|
||||||
@ -909,8 +918,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
|||||||
if (!CHECK_RHS_BOUNDARY) {
|
if (!CHECK_RHS_BOUNDARY) {
|
||||||
if ((rhs_vert + 3) < k_size) {
|
if ((rhs_vert + 3) < k_size) {
|
||||||
// just CHECK_RHS_BOUNDARY
|
// just CHECK_RHS_BOUNDARY
|
||||||
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||||
rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
|
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
|
||||||
} else if (rhs_vert + 2 < k_size) {
|
} else if (rhs_vert + 2 < k_size) {
|
||||||
// just CHECK_RHS_BOUNDARY
|
// just CHECK_RHS_BOUNDARY
|
||||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||||
@ -932,8 +941,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
|||||||
if (rhs_horiz1 < n_size) {
|
if (rhs_horiz1 < n_size) {
|
||||||
if ((rhs_vert + 3) < k_size) {
|
if ((rhs_vert + 3) < k_size) {
|
||||||
// just CHECK_RHS_BOUNDARY
|
// just CHECK_RHS_BOUNDARY
|
||||||
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||||
rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
|
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
|
||||||
} else if (rhs_vert + 2 < k_size) {
|
} else if (rhs_vert + 2 < k_size) {
|
||||||
// just CHECK_RHS_BOUNDARY
|
// just CHECK_RHS_BOUNDARY
|
||||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||||
@ -954,7 +963,7 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
|||||||
} else if (rhs_horiz0 < n_size) {
|
} else if (rhs_horiz0 < n_size) {
|
||||||
if ((rhs_vert + 3) < k_size) {
|
if ((rhs_vert + 3) < k_size) {
|
||||||
// just CHECK_RHS_BOUNDARY
|
// just CHECK_RHS_BOUNDARY
|
||||||
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
|
||||||
} else if ((rhs_vert + 2) < k_size) {
|
} else if ((rhs_vert + 2) < k_size) {
|
||||||
// just CHECK_RHS_BOUNDARY
|
// just CHECK_RHS_BOUNDARY
|
||||||
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user