mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-14 04:35:57 +08:00
Adjust Tensor module wrt recent change in nullary functor
This commit is contained in:
parent
72a4d49315
commit
46475eff9a
@ -226,7 +226,7 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
|
|||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
TensorEvaluator(const XprType& op, const Device& device)
|
TensorEvaluator(const XprType& op, const Device& device)
|
||||||
: m_functor(op.functor()), m_argImpl(op.nestedExpression(), device)
|
: m_functor(op.functor()), m_argImpl(op.nestedExpression(), device), m_wrapper()
|
||||||
{ }
|
{ }
|
||||||
|
|
||||||
typedef typename XprType::Index Index;
|
typedef typename XprType::Index Index;
|
||||||
@ -243,13 +243,13 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
|
|||||||
|
|
||||||
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
|
EIGEN_DEVICE_FUNC CoeffReturnType coeff(Index index) const
|
||||||
{
|
{
|
||||||
return m_functor(index);
|
return m_wrapper(m_functor,index);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int LoadMode>
|
template<int LoadMode>
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const
|
||||||
{
|
{
|
||||||
return m_functor.template packetOp<Index, PacketReturnType>(index);
|
return m_wrapper.template packetOp<PacketReturnType>(m_functor,index);
|
||||||
}
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorOpCost
|
||||||
@ -263,6 +263,7 @@ struct TensorEvaluator<const TensorCwiseNullaryOp<NullaryOp, ArgType>, Device>
|
|||||||
private:
|
private:
|
||||||
const NullaryOp m_functor;
|
const NullaryOp m_functor;
|
||||||
TensorEvaluator<ArgType, Device> m_argImpl;
|
TensorEvaluator<ArgType, Device> m_argImpl;
|
||||||
|
const internal::nullary_wrapper<CoeffReturnType,NullaryOp> m_wrapper;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
@ -460,12 +460,11 @@ template <typename T> class UniformRandomGenerator {
|
|||||||
m_deterministic = other.m_deterministic;
|
m_deterministic = other.m_deterministic;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index>
|
T operator()() const {
|
||||||
T operator()(Index) const {
|
|
||||||
return random<T>();
|
return random<T>();
|
||||||
}
|
}
|
||||||
template<typename Index, typename PacketType>
|
template<typename PacketType>
|
||||||
PacketType packetOp(Index) const {
|
PacketType packetOp() const {
|
||||||
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
||||||
EIGEN_ALIGN_MAX T values[packetSize];
|
EIGEN_ALIGN_MAX T values[packetSize];
|
||||||
for (int i = 0; i < packetSize; ++i) {
|
for (int i = 0; i < packetSize; ++i) {
|
||||||
@ -490,23 +489,22 @@ template <> class UniformRandomGenerator<float> {
|
|||||||
}
|
}
|
||||||
UniformRandomGenerator(const UniformRandomGenerator<float>& other) {
|
UniformRandomGenerator(const UniformRandomGenerator<float>& other) {
|
||||||
m_generator = new std::mt19937();
|
m_generator = new std::mt19937();
|
||||||
m_generator->seed(other(0) * UINT_MAX);
|
m_generator->seed(other() * UINT_MAX);
|
||||||
m_deterministic = other.m_deterministic;
|
m_deterministic = other.m_deterministic;
|
||||||
}
|
}
|
||||||
~UniformRandomGenerator() {
|
~UniformRandomGenerator() {
|
||||||
delete m_generator;
|
delete m_generator;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index>
|
float operator()() const {
|
||||||
float operator()(Index) const {
|
|
||||||
return m_distribution(*m_generator);
|
return m_distribution(*m_generator);
|
||||||
}
|
}
|
||||||
template<typename Index, typename PacketType>
|
template<typename PacketType>
|
||||||
PacketType packetOp(Index i) const {
|
PacketType packetOp() const {
|
||||||
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
||||||
EIGEN_ALIGN_MAX float values[packetSize];
|
EIGEN_ALIGN_MAX float values[packetSize];
|
||||||
for (int k = 0; k < packetSize; ++k) {
|
for (int k = 0; k < packetSize; ++k) {
|
||||||
values[k] = this->operator()(i);
|
values[k] = this->operator()();
|
||||||
}
|
}
|
||||||
return internal::pload<PacketType>(values);
|
return internal::pload<PacketType>(values);
|
||||||
}
|
}
|
||||||
@ -531,23 +529,22 @@ template <> class UniformRandomGenerator<double> {
|
|||||||
}
|
}
|
||||||
UniformRandomGenerator(const UniformRandomGenerator<double>& other) {
|
UniformRandomGenerator(const UniformRandomGenerator<double>& other) {
|
||||||
m_generator = new std::mt19937();
|
m_generator = new std::mt19937();
|
||||||
m_generator->seed(other(0) * UINT_MAX);
|
m_generator->seed(other() * UINT_MAX);
|
||||||
m_deterministic = other.m_deterministic;
|
m_deterministic = other.m_deterministic;
|
||||||
}
|
}
|
||||||
~UniformRandomGenerator() {
|
~UniformRandomGenerator() {
|
||||||
delete m_generator;
|
delete m_generator;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index>
|
double operator()() const {
|
||||||
double operator()(Index) const {
|
|
||||||
return m_distribution(*m_generator);
|
return m_distribution(*m_generator);
|
||||||
}
|
}
|
||||||
template<typename Index, typename PacketType>
|
template<typename PacketType>
|
||||||
PacketType packetOp(Index i) const {
|
PacketType packetOp() const {
|
||||||
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
||||||
EIGEN_ALIGN_MAX double values[packetSize];
|
EIGEN_ALIGN_MAX double values[packetSize];
|
||||||
for (int k = 0; k < packetSize; ++k) {
|
for (int k = 0; k < packetSize; ++k) {
|
||||||
values[k] = this->operator()(i);
|
values[k] = this->operator()();
|
||||||
}
|
}
|
||||||
return internal::pload<PacketType>(values);
|
return internal::pload<PacketType>(values);
|
||||||
}
|
}
|
||||||
@ -584,12 +581,11 @@ template <> class UniformRandomGenerator<float> {
|
|||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Index>
|
__device__ float operator()() const {
|
||||||
__device__ float operator()(Index) const {
|
|
||||||
return curand_uniform(&m_state);
|
return curand_uniform(&m_state);
|
||||||
}
|
}
|
||||||
template<typename Index, typename PacketType>
|
template<typename PacketType>
|
||||||
__device__ float4 packetOp(Index) const {
|
__device__ float4 packetOp() const {
|
||||||
EIGEN_STATIC_ASSERT((is_same<PacketType, float4>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT((is_same<PacketType, float4>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
return curand_uniform4(&m_state);
|
return curand_uniform4(&m_state);
|
||||||
}
|
}
|
||||||
@ -614,12 +610,11 @@ template <> class UniformRandomGenerator<double> {
|
|||||||
const int seed = m_deterministic ? 0 : get_random_seed();
|
const int seed = m_deterministic ? 0 : get_random_seed();
|
||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
__device__ double operator()() const {
|
||||||
__device__ double operator()(Index) const {
|
|
||||||
return curand_uniform_double(&m_state);
|
return curand_uniform_double(&m_state);
|
||||||
}
|
}
|
||||||
template<typename Index, typename PacketType>
|
template<typename PacketType>
|
||||||
__device__ double2 packetOp(Index) const {
|
__device__ double2 packetOp() const {
|
||||||
EIGEN_STATIC_ASSERT((is_same<PacketType, double2>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT((is_same<PacketType, double2>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
return curand_uniform2_double(&m_state);
|
return curand_uniform2_double(&m_state);
|
||||||
}
|
}
|
||||||
@ -644,8 +639,7 @@ template <> class UniformRandomGenerator<std::complex<float> > {
|
|||||||
const int seed = m_deterministic ? 0 : get_random_seed();
|
const int seed = m_deterministic ? 0 : get_random_seed();
|
||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
__device__ std::complex<float> operator()() const {
|
||||||
__device__ std::complex<float> operator()(Index) const {
|
|
||||||
float4 vals = curand_uniform4(&m_state);
|
float4 vals = curand_uniform4(&m_state);
|
||||||
return std::complex<float>(vals.x, vals.y);
|
return std::complex<float>(vals.x, vals.y);
|
||||||
}
|
}
|
||||||
@ -670,8 +664,7 @@ template <> class UniformRandomGenerator<std::complex<double> > {
|
|||||||
const int seed = m_deterministic ? 0 : get_random_seed();
|
const int seed = m_deterministic ? 0 : get_random_seed();
|
||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
__device__ std::complex<double> operator()() const {
|
||||||
__device__ std::complex<double> operator()(Index) const {
|
|
||||||
double2 vals = curand_uniform2_double(&m_state);
|
double2 vals = curand_uniform2_double(&m_state);
|
||||||
return std::complex<double>(vals.x, vals.y);
|
return std::complex<double>(vals.x, vals.y);
|
||||||
}
|
}
|
||||||
@ -707,17 +700,16 @@ template <typename T> class NormalRandomGenerator {
|
|||||||
}
|
}
|
||||||
NormalRandomGenerator(const NormalRandomGenerator& other)
|
NormalRandomGenerator(const NormalRandomGenerator& other)
|
||||||
: m_deterministic(other.m_deterministic), m_distribution(other.m_distribution), m_generator(new std::mt19937()) {
|
: m_deterministic(other.m_deterministic), m_distribution(other.m_distribution), m_generator(new std::mt19937()) {
|
||||||
m_generator->seed(other(0) * UINT_MAX);
|
m_generator->seed(other() * UINT_MAX);
|
||||||
}
|
}
|
||||||
~NormalRandomGenerator() {
|
~NormalRandomGenerator() {
|
||||||
delete m_generator;
|
delete m_generator;
|
||||||
}
|
}
|
||||||
template<typename Index>
|
T operator()() const {
|
||||||
T operator()(Index) const {
|
|
||||||
return m_distribution(*m_generator);
|
return m_distribution(*m_generator);
|
||||||
}
|
}
|
||||||
template<typename Index, typename PacketType>
|
template<typename PacketType>
|
||||||
PacketType packetOp(Index) const {
|
PacketType packetOp() const {
|
||||||
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
const int packetSize = internal::unpacket_traits<PacketType>::size;
|
||||||
EIGEN_ALIGN_MAX T values[packetSize];
|
EIGEN_ALIGN_MAX T values[packetSize];
|
||||||
for (int i = 0; i < packetSize; ++i) {
|
for (int i = 0; i < packetSize; ++i) {
|
||||||
@ -755,12 +747,11 @@ template <> class NormalRandomGenerator<float> {
|
|||||||
const int seed = m_deterministic ? 0 : get_random_seed();
|
const int seed = m_deterministic ? 0 : get_random_seed();
|
||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
__device__ float operator()() const {
|
||||||
__device__ float operator()(Index) const {
|
|
||||||
return curand_normal(&m_state);
|
return curand_normal(&m_state);
|
||||||
}
|
}
|
||||||
template<typename Index, typename PacketType>
|
template<typename PacketType>
|
||||||
__device__ float4 packetOp(Index) const {
|
__device__ float4 packetOp() const {
|
||||||
EIGEN_STATIC_ASSERT((is_same<PacketType, float4>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT((is_same<PacketType, float4>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
return curand_normal4(&m_state);
|
return curand_normal4(&m_state);
|
||||||
}
|
}
|
||||||
@ -785,12 +776,11 @@ template <> class NormalRandomGenerator<double> {
|
|||||||
const int seed = m_deterministic ? 0 : get_random_seed();
|
const int seed = m_deterministic ? 0 : get_random_seed();
|
||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
__device__ double operator()() const {
|
||||||
__device__ double operator()(Index) const {
|
|
||||||
return curand_normal_double(&m_state);
|
return curand_normal_double(&m_state);
|
||||||
}
|
}
|
||||||
template<typename Index, typename PacketType>
|
template<typename PacketType>
|
||||||
__device__ double2 packetOp(Index) const {
|
__device__ double2 packetOp() const {
|
||||||
EIGEN_STATIC_ASSERT((is_same<PacketType, double2>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
EIGEN_STATIC_ASSERT((is_same<PacketType, double2>::value), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
return curand_normal2_double(&m_state);
|
return curand_normal2_double(&m_state);
|
||||||
}
|
}
|
||||||
@ -815,8 +805,7 @@ template <> class NormalRandomGenerator<std::complex<float> > {
|
|||||||
const int seed = m_deterministic ? 0 : get_random_seed();
|
const int seed = m_deterministic ? 0 : get_random_seed();
|
||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
__device__ std::complex<float> operator()() const {
|
||||||
__device__ std::complex<float> operator()(Index) const {
|
|
||||||
float4 vals = curand_normal4(&m_state);
|
float4 vals = curand_normal4(&m_state);
|
||||||
return std::complex<float>(vals.x, vals.y);
|
return std::complex<float>(vals.x, vals.y);
|
||||||
}
|
}
|
||||||
@ -841,8 +830,7 @@ template <> class NormalRandomGenerator<std::complex<double> > {
|
|||||||
const int seed = m_deterministic ? 0 : get_random_seed();
|
const int seed = m_deterministic ? 0 : get_random_seed();
|
||||||
curand_init(seed, tid, 0, &m_state);
|
curand_init(seed, tid, 0, &m_state);
|
||||||
}
|
}
|
||||||
template<typename Index>
|
__device__ std::complex<double> operator()() const {
|
||||||
__device__ std::complex<double> operator()(Index) const {
|
|
||||||
double2 vals = curand_normal2_double(&m_state);
|
double2 vals = curand_normal2_double(&m_state);
|
||||||
return std::complex<double>(vals.x, vals.y);
|
return std::complex<double>(vals.x, vals.y);
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user