mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-09-22 22:33:15 +08:00
Fixed contractions of 16 bit floats
This commit is contained in:
parent
8ef3181f15
commit
f9ad25e4d8
@ -73,8 +73,6 @@ struct half : public __half {
|
|||||||
: __half(internal::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
|
: __half(internal::raw_uint16_to_half(b ? 0x3c00 : 0)) {}
|
||||||
EIGEN_DEVICE_FUNC half(const __half& h) : __half(h) {}
|
EIGEN_DEVICE_FUNC half(const __half& h) : __half(h) {}
|
||||||
EIGEN_DEVICE_FUNC half(const half& h) : __half(h) {}
|
EIGEN_DEVICE_FUNC half(const half& h) : __half(h) {}
|
||||||
EIGEN_DEVICE_FUNC half(const volatile half& h)
|
|
||||||
: __half(internal::raw_uint16_to_half(h.x)) {}
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
|
EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
|
||||||
return internal::half_to_float(*this);
|
return internal::half_to_float(*this);
|
||||||
@ -87,14 +85,6 @@ struct half : public __half {
|
|||||||
x = other.x;
|
x = other.x;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC half& operator=(const volatile half& other) {
|
|
||||||
x = other.x;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
EIGEN_DEVICE_FUNC volatile half& operator=(const half& other) volatile {
|
|
||||||
x = other.x;
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
|
#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
|
||||||
@ -341,4 +331,14 @@ static inline EIGEN_DEVICE_FUNC Eigen::half log(const Eigen::half& a) {
|
|||||||
|
|
||||||
} // end namespace std
|
} // end namespace std
|
||||||
|
|
||||||
|
|
||||||
|
// Add the missing shfl_xor intrinsic
|
||||||
|
#if defined(EIGEN_HAS_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 300
|
||||||
|
__device__ inline Eigen::half __shfl_xor(Eigen::half var, int laneMask, int width=warpSize) {
|
||||||
|
return static_cast<Eigen::half>(__shfl_xor(static_cast<float>(var), laneMask, width));
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
#endif // EIGEN_HALF_CUDA_H
|
#endif // EIGEN_HALF_CUDA_H
|
||||||
|
@ -20,7 +20,7 @@ template<typename Scalar, typename Index, typename LhsMapper,
|
|||||||
typename RhsMapper, typename OutputMapper, bool needs_edge_check>
|
typename RhsMapper, typename OutputMapper, bool needs_edge_check>
|
||||||
__device__ EIGEN_STRONG_INLINE void
|
__device__ EIGEN_STRONG_INLINE void
|
||||||
EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
||||||
const OutputMapper output, volatile Scalar* lhs_shmem, volatile Scalar* rhs_shmem,
|
const OutputMapper output, Scalar* lhs_shmem, Scalar* rhs_shmem,
|
||||||
const Index m_size, const Index n_size, const Index k_size) {
|
const Index m_size, const Index n_size, const Index k_size) {
|
||||||
|
|
||||||
const Index m_block_idx = blockIdx.x;
|
const Index m_block_idx = blockIdx.x;
|
||||||
@ -319,8 +319,8 @@ EigenContractionKernelInternal(const LhsMapper lhs, const RhsMapper rhs,
|
|||||||
Scalar rrow(7);
|
Scalar rrow(7);
|
||||||
|
|
||||||
// Now x corresponds to k, y to m, and z to n
|
// Now x corresponds to k, y to m, and z to n
|
||||||
const volatile Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
|
const Scalar* lhs_block = &lhs_shmem[threadIdx.x + 9 * threadIdx.y];
|
||||||
const volatile Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
|
const Scalar* rhs_block = &rhs_shmem[threadIdx.x + 8 * threadIdx.z];
|
||||||
|
|
||||||
#define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
|
#define lhs_element(i, j) lhs_block[72 * ((i) + 8 * (j))]
|
||||||
#define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
|
#define rhs_element(i, j) rhs_block[72 * ((i) + 8 * (j))]
|
||||||
@ -503,8 +503,8 @@ __launch_bounds__(512)
|
|||||||
EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
|
EigenContractionKernel(const LhsMapper lhs, const RhsMapper rhs,
|
||||||
const OutputMapper output,
|
const OutputMapper output,
|
||||||
const Index m_size, const Index n_size, const Index k_size) {
|
const Index m_size, const Index n_size, const Index k_size) {
|
||||||
__shared__ volatile Scalar lhs_shmem[72 * 64];
|
__shared__ Scalar lhs_shmem[72 * 64];
|
||||||
__shared__ volatile Scalar rhs_shmem[72 * 64];
|
__shared__ Scalar rhs_shmem[72 * 64];
|
||||||
|
|
||||||
const Index m_block_idx = blockIdx.x;
|
const Index m_block_idx = blockIdx.x;
|
||||||
const Index n_block_idx = blockIdx.y;
|
const Index n_block_idx = blockIdx.y;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user