Fix cuda 9+ builds

Fix removed `shfl_` intrinsics, disable warnings, update CUDA header inclusion.
This commit is contained in:
Antonio Sanchez 2025-03-02 07:38:47 -08:00
parent 43b7aa2412
commit d2ce4faa5a
4 changed files with 235 additions and 128 deletions

View File

@ -254,7 +254,7 @@
#endif
#if defined EIGEN_HAS_CUDA_FP16
#include <host_defines.h>
#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#endif

View File

@ -626,25 +626,71 @@ struct hash<Eigen::half> {
} // end namespace std
// Add the missing shfl_xor intrinsic
#if defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 300
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) {
#if EIGEN_CUDACC_VER < 90000
return static_cast<Eigen::half>(__shfl_xor(static_cast<float>(var), laneMask, width));
#else
return static_cast<Eigen::half>(__shfl_xor_sync(0xFFFFFFFF, static_cast<float>(var), laneMask, width));
#endif
// Add the missing shfl* intrinsics.
// The __shfl* functions are only valid on HIP or _CUDA_ARCH_ >= 300.
// CUDA defines them for (__CUDA_ARCH__ >= 300 || !defined(__CUDA_ARCH__))
//
// HIP and CUDA prior to SDK 9.0 define
// __shfl, __shfl_up, __shfl_down, __shfl_xor for int and float
// CUDA since 9.0 deprecates those and instead defines
// __shfl_sync, __shfl_up_sync, __shfl_down_sync, __shfl_xor_sync,
// with native support for __half and __nv_bfloat16
//
// Note that the following are __device__ - only functions.
#if defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 300)
#if defined(EIGEN_HAS_CUDA_FP16) && EIGEN_CUDACC_VER >= 90000
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_sync(unsigned mask, Eigen::half var, int srcLane,
int width = warpSize) {
const __half h = var;
return static_cast<Eigen::half>(__shfl_sync(mask, h, srcLane, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up_sync(unsigned mask, Eigen::half var, unsigned int delta,
int width = warpSize) {
const __half h = var;
return static_cast<Eigen::half>(__shfl_up_sync(mask, h, delta, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down_sync(unsigned mask, Eigen::half var, unsigned int delta,
int width = warpSize) {
const __half h = var;
return static_cast<Eigen::half>(__shfl_down_sync(mask, h, delta, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor_sync(unsigned mask, Eigen::half var, int laneMask,
int width = warpSize) {
const __half h = var;
return static_cast<Eigen::half>(__shfl_xor_sync(mask, h, laneMask, width));
}
#else // CUDA SDK < 9.0
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl(Eigen::half var, int srcLane, int width = warpSize) {
return static_cast<Eigen::half>(__shfl(static_cast<float>(var), laneMask, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_up(Eigen::half var, unsigned int delta, int width = warpSize) {
return static_cast<Eigen::half>(__shfl_up(static_cast<float>(var), laneMask, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_down(Eigen::half var, unsigned int delta, int width = warpSize) {
return static_cast<Eigen::half>(__shfl_down(static_cast<float>(var), laneMask, width));
}
__device__ EIGEN_STRONG_INLINE Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width = warpSize) {
return static_cast<Eigen::half>(__shfl_xor(static_cast<float>(var), laneMask, width));
}
#endif
#endif // __shfl*
// ldg() has an overload for __half_raw, but we also need one for Eigen::half.
#if defined(EIGEN_CUDA_ARCH) && EIGEN_CUDA_ARCH >= 350
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half __ldg(const Eigen::half* ptr) {
return Eigen::half_impl::raw_uint16_to_half(
__ldg(reinterpret_cast<const unsigned short*>(ptr)));
#if defined(EIGEN_CUDACC) && (!defined(EIGEN_CUDA_ARCH) || EIGEN_CUDA_ARCH >= 350)
EIGEN_STRONG_INLINE __device__ Eigen::half __ldg(const Eigen::half* ptr) {
return Eigen::half_impl::raw_uint16_to_half(__ldg(reinterpret_cast<const Eigen::numext::uint16_t*>(ptr)));
}
#endif
#endif // __ldg
#if defined(EIGEN_CUDA_ARCH)
namespace Eigen {

View File

@ -1,94 +1,146 @@
#ifndef EIGEN_WARNINGS_DISABLED
#define EIGEN_WARNINGS_DISABLED
#ifdef _MSC_VER
// 4100 - unreferenced formal parameter (occurred e.g. in aligned_allocator::destroy(pointer p))
// 4101 - unreferenced local variable
// 4127 - conditional expression is constant
// 4181 - qualifier applied to reference type ignored
// 4211 - nonstandard extension used : redefined extern to static
// 4244 - 'argument' : conversion from 'type1' to 'type2', possible loss of data
// 4273 - QtAlignedMalloc, inconsistent DLL linkage
// 4324 - structure was padded due to declspec(align())
// 4503 - decorated name length exceeded, name was truncated
// 4512 - assignment operator could not be generated
// 4522 - 'class' : multiple assignment operators specified
// 4700 - uninitialized local variable 'xyz' used
// 4714 - function marked as __forceinline not inlined
// 4717 - 'function' : recursive on all control paths, function will cause runtime stack overflow
// 4800 - 'type' : forcing value to bool 'true' or 'false' (performance warning)
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
#pragma warning( push )
#endif
#pragma warning( disable : 4100 4101 4127 4181 4211 4244 4273 4324 4503 4512 4522 4700 4714 4717 4800)
#elif defined __INTEL_COMPILER
// 2196 - routine is both "inline" and "noinline" ("noinline" assumed)
// ICC 12 generates this warning even without any inline keyword, when defining class methods 'inline' i.e. inside of class body
// typedef that may be a reference type.
// 279 - controlling expression is constant
// ICC 12 generates this warning on assert(constant_expression_depending_on_template_params) and frankly this is a legitimate use case.
// 1684 - conversion from pointer to same-sized integral type (potential portability problem)
// 2259 - non-pointer conversion from "Eigen::Index={ptrdiff_t={long}}" to "int" may lose significant bits
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
#pragma warning push
#endif
#pragma warning disable 2196 279 1684 2259
#elif defined __clang__
// -Wconstant-logical-operand - warning: use of logical && with constant operand; switch to bitwise & or remove constant
// this is really a stupid warning as it warns on compile-time expressions involving enums
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
#pragma clang diagnostic push
#endif
#pragma clang diagnostic ignored "-Wconstant-logical-operand"
#elif defined __GNUC__
#if (!defined(EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS)) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))
#pragma GCC diagnostic push
#endif
// g++ warns about local variables shadowing member functions, which is too strict
#pragma GCC diagnostic ignored "-Wshadow"
#if __GNUC__ == 4 && __GNUC_MINOR__ < 8
// Until g++-4.7 there are warnings when comparing unsigned int vs 0, even in templated functions:
#pragma GCC diagnostic ignored "-Wtype-limits"
#endif
#if __GNUC__>=6
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
#if __GNUC__==7
// See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89325
#pragma GCC diagnostic ignored "-Wattributes"
#endif
#if defined(_MSC_VER)
// 4100 - unreferenced formal parameter (occurred e.g. in aligned_allocator::destroy(pointer p))
// 4101 - unreferenced local variable
// 4127 - conditional expression is constant
// 4181 - qualifier applied to reference type ignored
// 4211 - nonstandard extension used : redefined extern to static
// 4244 - 'argument' : conversion from 'type1' to 'type2', possible loss of data
// 4273 - QtAlignedMalloc, inconsistent DLL linkage
// 4324 - structure was padded due to declspec(align())
// 4503 - decorated name length exceeded, name was truncated
// 4512 - assignment operator could not be generated
// 4522 - 'class' : multiple assignment operators specified
// 4700 - uninitialized local variable 'xyz' used
// 4714 - function marked as __forceinline not inlined
// 4717 - 'function' : recursive on all control paths, function will cause runtime stack overflow
// 4800 - 'type' : forcing value to bool 'true' or 'false' (performance warning)
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
#pragma warning(push)
#endif
#pragma warning(disable : 4100 4101 4127 4181 4211 4244 4273 4324 4503 4512 4522 4700 4714 4717 4800)
// We currently rely on has_denorm in tests, and need it defined correctly for half/bfloat16.
#ifndef _SILENCE_CXX23_DENORM_DEPRECATION_WARNING
#define EIGEN_REENABLE_CXX23_DENORM_DEPRECATION_WARNING 1
#define _SILENCE_CXX23_DENORM_DEPRECATION_WARNING
#endif
#if defined __NVCC__
// Disable the "statement is unreachable" message
#pragma diag_suppress code_is_unreachable
// Disable the "dynamic initialization in unreachable code" message
#pragma diag_suppress initialization_not_reachable
// Disable the "invalid error number" message that we get with older versions of nvcc
#pragma diag_suppress 1222
// Disable the "calling a __host__ function from a __host__ __device__ function is not allowed" messages (yes, there are many of them and they seem to change with every version of the compiler)
#pragma diag_suppress 2527
#pragma diag_suppress 2529
#pragma diag_suppress 2651
#pragma diag_suppress 2653
#pragma diag_suppress 2668
#pragma diag_suppress 2669
#pragma diag_suppress 2670
#pragma diag_suppress 2671
#pragma diag_suppress 2735
#pragma diag_suppress 2737
#elif defined __INTEL_COMPILER
// 2196 - routine is both "inline" and "noinline" ("noinline" assumed)
// ICC 12 generates this warning even without any inline keyword, when defining class methods 'inline' i.e.
// inside of class body typedef that may be a reference type.
// 279 - controlling expression is constant
// ICC 12 generates this warning on assert(constant_expression_depending_on_template_params) and frankly this is
// a legitimate use case.
// 1684 - conversion from pointer to same-sized integral type (potential portability problem)
// 2259 - non-pointer conversion from "Eigen::Index={ptrdiff_t={long}}" to "int" may lose significant bits
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
#pragma warning push
#endif
#pragma warning disable 2196 279 1684 2259
#elif defined __clang__
#ifndef EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS
#pragma clang diagnostic push
#endif
#if defined(__has_warning)
// -Wconstant-logical-operand - warning: use of logical && with constant operand; switch to bitwise & or remove constant
// this is really a stupid warning as it warns on compile-time expressions involving enums
#if __has_warning("-Wconstant-logical-operand")
#pragma clang diagnostic ignored "-Wconstant-logical-operand"
#endif
#if __has_warning("-Wimplicit-int-float-conversion")
#pragma clang diagnostic ignored "-Wimplicit-int-float-conversion"
#endif
#if (defined(__ALTIVEC__) || defined(__VSX__)) && (!defined(__STDC_VERSION__) || (__STDC_VERSION__ < 201112L))
// warning: generic selections are a C11-specific feature
// ignoring warnings thrown at vec_ctf in Altivec/PacketMath.h
#if __has_warning("-Wc11-extensions")
#pragma clang diagnostic ignored "-Wc11-extensions"
#endif
#endif
#endif
#elif defined __GNUC__ && !defined(__FUJITSU)
#if (!defined(EIGEN_PERMANENTLY_DISABLE_STUPID_WARNINGS)) && (__GNUC__ > 4 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 6))
#pragma GCC diagnostic push
#endif
// g++ warns about local variables shadowing member functions, which is too strict
#pragma GCC diagnostic ignored "-Wshadow"
#if __GNUC__ == 4 && __GNUC_MINOR__ < 8
// Until g++-4.7 there are warnings when comparing unsigned int vs 0, even in templated functions:
#pragma GCC diagnostic ignored "-Wtype-limits"
#endif
#if __GNUC__ >= 6
#pragma GCC diagnostic ignored "-Wignored-attributes"
#endif
#if __GNUC__ == 7
// See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=89325
#pragma GCC diagnostic ignored "-Wattributes"
#endif
#endif
#if defined __NVCC__ && defined __CUDACC__
// MSVC 14.16 (required by CUDA 9.*) does not support the _Pragma keyword, so
// we instead use Microsoft's __pragma extension.
#if defined _MSC_VER
#define EIGEN_MAKE_PRAGMA(X) __pragma(#X)
#else
#define EIGEN_MAKE_PRAGMA(X) _Pragma(#X)
#endif
#if defined __NVCC_DIAG_PRAGMA_SUPPORT__
#define EIGEN_NV_DIAG_SUPPRESS(X) EIGEN_MAKE_PRAGMA(nv_diag_suppress X)
#else
#define EIGEN_NV_DIAG_SUPPRESS(X) EIGEN_MAKE_PRAGMA(diag_suppress X)
#endif
EIGEN_NV_DIAG_SUPPRESS(boolean_controlling_expr_is_constant)
// Disable the "statement is unreachable" message
EIGEN_NV_DIAG_SUPPRESS(code_is_unreachable)
// Disable the "dynamic initialization in unreachable code" message
EIGEN_NV_DIAG_SUPPRESS(initialization_not_reachable)
// Disable the "invalid error number" message that we get with older versions of nvcc
EIGEN_NV_DIAG_SUPPRESS(1222)
// Disable the "calling a __host__ function from a __host__ __device__ function is not allowed" messages (yes, there are
// many of them and they seem to change with every version of the compiler)
EIGEN_NV_DIAG_SUPPRESS(2527)
EIGEN_NV_DIAG_SUPPRESS(2529)
EIGEN_NV_DIAG_SUPPRESS(2651)
EIGEN_NV_DIAG_SUPPRESS(2653)
EIGEN_NV_DIAG_SUPPRESS(2668)
EIGEN_NV_DIAG_SUPPRESS(2669)
EIGEN_NV_DIAG_SUPPRESS(2670)
EIGEN_NV_DIAG_SUPPRESS(2671)
EIGEN_NV_DIAG_SUPPRESS(2735)
EIGEN_NV_DIAG_SUPPRESS(2737)
EIGEN_NV_DIAG_SUPPRESS(2739)
EIGEN_NV_DIAG_SUPPRESS(2885)
EIGEN_NV_DIAG_SUPPRESS(2888)
EIGEN_NV_DIAG_SUPPRESS(2976)
EIGEN_NV_DIAG_SUPPRESS(2979)
EIGEN_NV_DIAG_SUPPRESS(20011)
EIGEN_NV_DIAG_SUPPRESS(20014)
// Disable the "// __device__ annotation is ignored on a function(...) that is
// explicitly defaulted on its first declaration" message.
// The __device__ annotation seems to actually be needed in some cases,
// otherwise resulting in kernel runtime errors.
EIGEN_NV_DIAG_SUPPRESS(2886)
EIGEN_NV_DIAG_SUPPRESS(2929)
EIGEN_NV_DIAG_SUPPRESS(2977)
EIGEN_NV_DIAG_SUPPRESS(20012)
#undef EIGEN_NV_DIAG_SUPPRESS
#undef EIGEN_MAKE_PRAGMA
#endif
#else
// warnings already disabled:
# ifndef EIGEN_WARNINGS_DISABLED_2
# define EIGEN_WARNINGS_DISABLED_2
# elif defined(EIGEN_INTERNAL_DEBUGGING)
# error "Do not include \"DisableStupidWarnings.h\" recursively more than twice!"
# endif
#ifndef EIGEN_WARNINGS_DISABLED_2
#define EIGEN_WARNINGS_DISABLED_2
#elif defined(EIGEN_INTERNAL_DEBUGGING)
#error "Do not include \"DisableStupidWarnings.h\" recursively more than twice!"
#endif
#endif // not EIGEN_WARNINGS_DISABLED
#endif // not EIGEN_WARNINGS_DISABLED

View File

@ -388,7 +388,11 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
// the sum across all big k blocks of the product of little k block of index (x, y)
// with block of index (y, z). To compute the final output, we need to reduce
// the 8 threads over y by summation.
#if EIGEN_CUDACC_VER < 90000
#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor(res(i, j), mask)
#else
#define shuffleInc(i, j, mask) res(i, j) += __shfl_xor_sync(0xFFFFFFFF, res(i, j), mask)
#endif
#define reduceRow(i, mask) \
shuffleInc(i, 0, mask); \
@ -543,12 +547,12 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
#define prefetch_lhs(reg, row, col) \
if (!CHECK_LHS_BOUNDARY) { \
if (col < k_size) { \
reg =lhs.loadPacket<Unaligned>(row, col); \
reg =lhs.template loadPacket<Unaligned>(row, col); \
} \
} else { \
if (col < k_size) { \
if (row + 3 < m_size) { \
reg =lhs.loadPacket<Unaligned>(row, col); \
reg =lhs.template loadPacket<Unaligned>(row, col); \
} else if (row + 2 < m_size) { \
reg.x =lhs(row + 0, col); \
reg.y =lhs(row + 1, col); \
@ -578,7 +582,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
if (!CHECK_RHS_BOUNDARY) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
} else if (rhs_vert + 2 < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
@ -593,7 +597,7 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
} else {
if (rhs_horiz0 < n_size) {
if ((rhs_vert + 3) < k_size) {
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
} else if ((rhs_vert + 2) < k_size) {
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
rhs_pf0.y = rhs(rhs_vert + 1, rhs_horiz0);
@ -615,8 +619,13 @@ EigenFloatContractionKernelInternal16x16(const LhsMapper lhs, const RhsMapper rh
x1 = rhs_pf0.x;
x2 = rhs_pf0.z;
}
#if EIGEN_CUDACC_VER < 90000
x1 = __shfl_xor(x1, 4);
x2 = __shfl_xor(x2, 4);
#else
x1 = __shfl_xor_sync(0xFFFFFFFF, x1, 4);
x2 = __shfl_xor_sync(0xFFFFFFFF, x2, 4);
#endif
if((threadIdx.x%8) < 4) {
rhs_pf0.y = x1;
rhs_pf0.w = x2;
@ -790,37 +799,37 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
if (!CHECK_LHS_BOUNDARY) {
if ((threadIdx.y/4+k+24) < k_size) {
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
} else if ((threadIdx.y/4+k+16) < k_size) {
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
} else if ((threadIdx.y/4+k+8) < k_size) {
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
} else if ((threadIdx.y/4+k) < k_size) {
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
}
} else {
// just CHECK_LHS_BOUNDARY
if (lhs_vert + 3 < m_size) {
if ((threadIdx.y/4+k+24) < k_size) {
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf3 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf3 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+24));
} else if ((threadIdx.y/4+k+16) < k_size) {
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf2 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+16));
} else if ((threadIdx.y/4+k+8) < k_size) {
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf1 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k+8));
} else if ((threadIdx.y/4+k) < k_size) {
lhs_pf0 =lhs.loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
lhs_pf0 =lhs.template loadPacket<Unaligned>(lhs_vert, (threadIdx.y/4+k));
}
} else if (lhs_vert + 2 < m_size) {
if ((threadIdx.y/4+k+24) < k_size) {
@ -909,8 +918,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
if (!CHECK_RHS_BOUNDARY) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
} else if (rhs_vert + 2 < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
@ -932,8 +941,8 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
if (rhs_horiz1 < n_size) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf1 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz1);
} else if (rhs_vert + 2 < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);
@ -954,7 +963,7 @@ EigenFloatContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
} else if (rhs_horiz0 < n_size) {
if ((rhs_vert + 3) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0 = rhs.loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
rhs_pf0 = rhs.template loadPacket<Unaligned>(rhs_vert, rhs_horiz0);
} else if ((rhs_vert + 2) < k_size) {
// just CHECK_RHS_BOUNDARY
rhs_pf0.x = rhs(rhs_vert, rhs_horiz0);