mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Vectorize fp16 tanh and logistic functions on Neon
Activates vectorization of the Eigen::half versions of the tanh and logistic functions when they run on Neon. Both functions convert their inputs to float before computing the output, and as a result of this commit, the conversions and the computation in float are vectorized.
This commit is contained in:
parent
185ad0e610
commit
6bb6a6bf53
@ -263,6 +263,11 @@ using std::ptrdiff_t;
|
||||
#include "src/Core/arch/GPU/Complex.h"
|
||||
#endif
|
||||
|
||||
// Specializations of vectorized activation functions for NEON.
|
||||
#ifdef EIGEN_VECTORIZE_NEON
|
||||
#include "src/Core/arch/NEON/UnaryFunctors.h"
|
||||
#endif
|
||||
|
||||
#include "src/Core/util/IndexedViewHelper.h"
|
||||
#include "src/Core/util/ReshapedHelper.h"
|
||||
#include "src/Core/ArithmeticSequence.h"
|
||||
|
@ -40,6 +40,25 @@ template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Pack
|
||||
template<> EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS EIGEN_UNUSED Packet4f ptanh<Packet4f>(const Packet4f& x)
|
||||
{ return internal::generic_fast_tanh_float(x); }
|
||||
|
||||
#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED
|
||||
Packet4hf ptanh<Packet4hf>(const Packet4hf& x) {
|
||||
// Convert to float, call the float ptanh, and then convert back.
|
||||
return vcvt_f16_f32(ptanh<Packet4f>(vcvt_f32_f16(x)));
|
||||
}
|
||||
|
||||
template <>
|
||||
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC EIGEN_UNUSED
|
||||
Packet8hf ptanh<Packet8hf>(const Packet8hf& x) {
|
||||
// Convert each 4 halfs to float, call the float ptanh, and then convert back.
|
||||
return vcombine_f16(
|
||||
vcvt_f16_f32(ptanh<Packet4f>(vcvt_f32_f16(vget_low_f16(x)))),
|
||||
vcvt_f16_f32(ptanh<Packet4f>(vcvt_high_f32_f16(x))));
|
||||
}
|
||||
#endif // EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, psin)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, pcos)
|
||||
BF16_PACKET_FUNCTION(Packet4f, Packet4bf, plog)
|
||||
|
@ -4028,6 +4028,7 @@ struct packet_traits<Eigen::half> : default_packet_traits {
|
||||
HasCos = 0,
|
||||
HasLog = 0,
|
||||
HasExp = 0,
|
||||
HasTanh = packet_traits<float>::HasTanh, // tanh<half> calls tanh<float>
|
||||
HasSqrt = 1,
|
||||
HasRsqrt = 1,
|
||||
HasErf = EIGEN_FAST_MATH,
|
||||
|
64
Eigen/src/Core/arch/NEON/UnaryFunctors.h
Normal file
64
Eigen/src/Core/arch/NEON/UnaryFunctors.h
Normal file
@ -0,0 +1,64 @@
|
||||
// This file is part of Eigen, a lightweight C++ template library
|
||||
// for linear algebra.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla
|
||||
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
#ifndef EIGEN_NEON_UNARY_FUNCTORS_H
|
||||
#define EIGEN_NEON_UNARY_FUNCTORS_H
|
||||
|
||||
#include "../../InternalHeaderCheck.h"
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
namespace internal {
|
||||
|
||||
#if EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
|
||||
/** \internal
|
||||
* \brief Template specialization of the logistic function for Eigen::half.
|
||||
*/
|
||||
template <>
|
||||
struct scalar_logistic_op<Eigen::half> {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_logistic_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Eigen::half operator()(const Eigen::half& x) const {
|
||||
// Convert to float and call scalar_logistic_op<float>.
|
||||
const scalar_logistic_op<float> float_op;
|
||||
return Eigen::half(float_op(float(x)));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Eigen::half packetOp(const Eigen::half& x) const {
|
||||
return this->operator()(x);
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Packet4hf packetOp(const Packet4hf& x) const {
|
||||
const scalar_logistic_op<float> float_op;
|
||||
return vcvt_f16_f32(float_op.packetOp(vcvt_f32_f16(x)));
|
||||
}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||
Packet8hf packetOp(const Packet8hf& x) const {
|
||||
const scalar_logistic_op<float> float_op;
|
||||
return vcombine_f16(
|
||||
vcvt_f16_f32(float_op.packetOp(vcvt_f32_f16(vget_low_f16(x)))),
|
||||
vcvt_f16_f32(float_op.packetOp(vcvt_high_f32_f16(x))));
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct functor_traits<scalar_logistic_op<Eigen::half>> {
|
||||
enum {
|
||||
Cost = functor_traits<scalar_logistic_op<float>>::Cost,
|
||||
PacketAccess = functor_traits<scalar_logistic_op<float>>::PacketAccess,
|
||||
};
|
||||
};
|
||||
#endif // EIGEN_HAS_ARM64_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
} // end namespace internal
|
||||
|
||||
} // end namespace Eigen
|
||||
|
||||
#endif // EIGEN_NEON_UNARY_FUNCTORS_H
|
Loading…
x
Reference in New Issue
Block a user