diff --git a/Eigen/src/Core/arch/AVX512/GemmKernel.h b/Eigen/src/Core/arch/AVX512/GemmKernel.h index ee4beb91a..d7198986e 100644 --- a/Eigen/src/Core/arch/AVX512/GemmKernel.h +++ b/Eigen/src/Core/arch/AVX512/GemmKernel.h @@ -10,7 +10,11 @@ #ifndef GEMM_KERNEL_H #define GEMM_KERNEL_H +#if EIGEN_COMP_MSVC +#include +#else #include +#endif #include #include @@ -452,7 +456,7 @@ class gemm_class co2 = co1 + ldc; if (!is_alpha1) alpha_reg = pload1(alpha); if (!is_unit_inc && a_unroll < nelems_in_cache_line) - mask = (umask_t)(1 << a_unroll) - 1; + mask = static_cast((1ull << a_unroll) - 1); static_assert(max_b_unroll <= 8, "Unsupported max_b_unroll"); diff --git a/Eigen/src/Core/arch/AVX512/TrsmKernel.h b/Eigen/src/Core/arch/AVX512/TrsmKernel.h index 4b81bf915..a8b999c31 100644 --- a/Eigen/src/Core/arch/AVX512/TrsmKernel.h +++ b/Eigen/src/Core/arch/AVX512/TrsmKernel.h @@ -28,11 +28,11 @@ namespace Eigen { namespace internal { -#define EIGEN_AVX_MAX_NUM_ACC (24L) -#define EIGEN_AVX_MAX_NUM_ROW (8L) // Denoted L in code. -#define EIGEN_AVX_MAX_K_UNROL (4L) -#define EIGEN_AVX_B_LOAD_SETS (2L) -#define EIGEN_AVX_MAX_A_BCAST (2L) +#define EIGEN_AVX_MAX_NUM_ACC (int64_t(24)) +#define EIGEN_AVX_MAX_NUM_ROW (int64_t(8)) // Denoted L in code. +#define EIGEN_AVX_MAX_K_UNROL (int64_t(4)) +#define EIGEN_AVX_B_LOAD_SETS (int64_t(2)) +#define EIGEN_AVX_MAX_A_BCAST (int64_t(2)) typedef Packet16f vecFullFloat; typedef Packet8d vecFullDouble; typedef Packet8f vecHalfFloat; @@ -882,7 +882,11 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L * The updated row-major copy of B is reused in the GEMM updates. */ sizeBTemp = (((std::min(kB, numRHS) + 15)/16+ 4)*16)*numM; +#if EIGEN_COMP_MSVC + B_temp = (Scalar*) _aligned_malloc(sizeof(Scalar)*sizeBTemp,4096); +#else B_temp = (Scalar*) aligned_alloc(4096,sizeof(Scalar)*sizeBTemp); +#endif } for(int64_t k = 0; k < numRHS; k += kB) { int64_t bK = numRHS - k > kB ? kB : numRHS - k; @@ -1026,7 +1030,11 @@ void triSolve(Scalar *A_arr, Scalar *B_arr, int64_t M, int64_t numRHS, int64_t L } } } +#if EIGEN_COMP_MSVC + EIGEN_IF_CONSTEXPR(!isBRowMajor) _aligned_free(B_temp); +#else EIGEN_IF_CONSTEXPR(!isBRowMajor) free(B_temp); +#endif } template