Fix build issues with MSVC for AVX512

This commit is contained in:
aaraujom 2022-06-03 14:55:40 +00:00 committed by Antonio Sánchez
parent 4f6354128f
commit 8fbb76a043
2 changed files with 18 additions and 6 deletions

View File

@ -10,7 +10,11 @@
#ifndef GEMM_KERNEL_H #ifndef GEMM_KERNEL_H
#define GEMM_KERNEL_H #define GEMM_KERNEL_H
#if EIGEN_COMP_MSVC
#include <intrin.h>
#else
#include <x86intrin.h> #include <x86intrin.h>
#endif
#include <immintrin.h> #include <immintrin.h>
#include <type_traits> #include <type_traits>
@ -452,7 +456,7 @@ class gemm_class
co2 = co1 + ldc; co2 = co1 + ldc;
if (!is_alpha1) alpha_reg = pload1<vec>(alpha); if (!is_alpha1) alpha_reg = pload1<vec>(alpha);
if (!is_unit_inc && a_unroll < nelems_in_cache_line) if (!is_unit_inc && a_unroll < nelems_in_cache_line)
mask = (umask_t)(1 << a_unroll) - 1; mask = static_cast<umask_t>((1ull << a_unroll) - 1);
static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll"); static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll");

View File

@ -28,11 +28,11 @@
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
#define EIGEN_AVX_MAX_NUM_ACC (24L) #define EIGEN_AVX_MAX_NUM_ACC (int64_t(24))
#define EIGEN_AVX_MAX_NUM_ROW (8L) // Denoted L in code. #define EIGEN_AVX_MAX_NUM_ROW (int64_t(8)) // Denoted L in code.
#define EIGEN_AVX_MAX_K_UNROL (4L) #define EIGEN_AVX_MAX_K_UNROL (int64_t(4))
#define EIGEN_AVX_B_LOAD_SETS (2L) #define EIGEN_AVX_B_LOAD_SETS (int64_t(2))
#define EIGEN_AVX_MAX_A_BCAST (2L) #define EIGEN_AVX_MAX_A_BCAST (int64_t(2))
typedef Packet16f vecFullFloat; typedef Packet16f vecFullFloat;
typedef Packet8d vecFullDouble; typedef Packet8d vecFullDouble;
typedef Packet8f vecHalfFloat; typedef Packet8f vecHalfFloat;
@ -882,7 +882,11 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
* The updated row-major copy of B is reused in the GEMM updates. * The updated row-major copy of B is reused in the GEMM updates.
*/ */
sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM; sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM;
#if EIGEN_COMP_MSVC
B_temp = (Scalar*) _aligned_malloc(sizeof(Scalar)*sizeBTemp,4096);
#else
B_temp = (Scalar*) aligned_alloc(4096,sizeof(Scalar)*sizeBTemp); B_temp = (Scalar*) aligned_alloc(4096,sizeof(Scalar)*sizeBTemp);
#endif
} }
for(int64_t k = 0; k < numRHS; k += kB) { for(int64_t k = 0; k < numRHS; k += kB) {
int64_t bK = numRHS - k > kB ? kB : numRHS - k; int64_t bK = numRHS - k > kB ? kB : numRHS - k;
@ -1026,7 +1030,11 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L
} }
} }
} }
#if EIGEN_COMP_MSVC
EIGEN_IF_CONSTEXPR(!isBRowMajor) _aligned_free(B_temp);
#else
EIGEN_IF_CONSTEXPR(!isBRowMajor) free(B_temp); EIGEN_IF_CONSTEXPR(!isBRowMajor) free(B_temp);
#endif
} }
template <typename Scalar, bool isARowMajor = true, bool isCRowMajor = true> template <typename Scalar, bool isARowMajor = true, bool isCRowMajor = true>