This commit is contained in:
Gael Guennebaud 2015-12-11 10:06:38 +01:00
commit c684a07eba
8 changed files with 43 additions and 26 deletions

View File

@ -676,8 +676,13 @@ struct scalar_sign_op<Scalar,true> {
EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_sign_op)
EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const EIGEN_DEVICE_FUNC inline const Scalar operator() (const Scalar& a) const
{ {
typename NumTraits<Scalar>::Real aa = std::abs(a); using std::abs;
return (aa==0) ? Scalar(0) : (a/aa); typedef typename NumTraits<Scalar>::Real real_type;
real_type aa = abs(a);
if (aa==0)
return Scalar(0);
aa = 1./aa;
return Scalar(real(a)*aa, imag(a)*aa );
} }
//TODO //TODO
//template <typename Packet> //template <typename Packet>

View File

@ -78,7 +78,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
IsAligned = bool(EIGEN_MAX_ALIGN_BYTES>0) & !(Options_&DontAlign), IsAligned = bool(EIGEN_MAX_ALIGN_BYTES>0) & !(Options_&DontAlign),
PacketAccess = (internal::packet_traits<Scalar>::size > 1), PacketAccess = (internal::packet_traits<Scalar>::size > 1),
Layout = Options_ & RowMajor ? RowMajor : ColMajor, Layout = Options_ & RowMajor ? RowMajor : ColMajor,
CoordAccess = true, CoordAccess = true
}; };
static const int Options = Options_; static const int Options = Options_;
@ -368,7 +368,7 @@ class Tensor : public TensorBase<Tensor<Scalar_, NumIndices_, Options_, IndexTyp
EIGEN_STATIC_ASSERT(4 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) EIGEN_STATIC_ASSERT(4 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
} }
inline explicit Tensor(Index dim1, Index dim2, Index dim3, Index dim4, Index dim5) inline explicit Tensor(Index dim1, Index dim2, Index dim3, Index dim4, Index dim5)
: m_storage(dim1*dim2*dim3*dim4*dim5, array<Index, 4>(dim1, dim2, dim3, dim4, dim5)) : m_storage(dim1*dim2*dim3*dim4*dim5, array<Index, 5>(dim1, dim2, dim3, dim4, dim5))
{ {
EIGEN_STATIC_ASSERT(5 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE) EIGEN_STATIC_ASSERT(5 == NumIndices, YOU_MADE_A_PROGRAMMING_MISTAKE)
} }

View File

@ -49,7 +49,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
IsAligned = ((int(Options_)&Aligned)==Aligned), IsAligned = ((int(Options_)&Aligned)==Aligned),
PacketAccess = (internal::packet_traits<Scalar>::size > 1), PacketAccess = (internal::packet_traits<Scalar>::size > 1),
Layout = PlainObjectType::Layout, Layout = PlainObjectType::Layout,
CoordAccess = true, CoordAccess = true
}; };
EIGEN_DEVICE_FUNC EIGEN_DEVICE_FUNC
@ -158,7 +158,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1) const EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1) const
{ {
if (PlainObjectType::Options&RowMajor) { if (PlainObjectType::Options&RowMajor) {
const Index index = i1 + i0 * m_dimensions[0]; const Index index = i1 + i0 * m_dimensions[1];
return m_data[index]; return m_data[index];
} else { } else {
const Index index = i0 + i1 * m_dimensions[0]; const Index index = i0 + i1 * m_dimensions[0];
@ -169,7 +169,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2) const EIGEN_STRONG_INLINE const Scalar& operator()(Index i0, Index i1, Index i2) const
{ {
if (PlainObjectType::Options&RowMajor) { if (PlainObjectType::Options&RowMajor) {
const Index index = i2 + m_dimensions[1] * (i1 + m_dimensions[0] * i0); const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
return m_data[index]; return m_data[index];
} else { } else {
const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2); const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);
@ -245,7 +245,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1) EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1)
{ {
if (PlainObjectType::Options&RowMajor) { if (PlainObjectType::Options&RowMajor) {
const Index index = i1 + i0 * m_dimensions[0]; const Index index = i1 + i0 * m_dimensions[1];
return m_data[index]; return m_data[index];
} else { } else {
const Index index = i0 + i1 * m_dimensions[0]; const Index index = i0 + i1 * m_dimensions[0];
@ -256,7 +256,7 @@ template<typename PlainObjectType, int Options_> class TensorMap : public Tensor
EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2) EIGEN_STRONG_INLINE Scalar& operator()(Index i0, Index i1, Index i2)
{ {
if (PlainObjectType::Options&RowMajor) { if (PlainObjectType::Options&RowMajor) {
const Index index = i2 + m_dimensions[1] * (i1 + m_dimensions[0] * i0); const Index index = i2 + m_dimensions[2] * (i1 + m_dimensions[1] * i0);
return m_data[index]; return m_data[index];
} else { } else {
const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2); const Index index = i0 + m_dimensions[0] * (i1 + m_dimensions[1] * i2);

View File

@ -29,8 +29,8 @@ static void test_1d()
int row_major[6]; int row_major[6];
memset(col_major, 0, 6*sizeof(int)); memset(col_major, 0, 6*sizeof(int));
memset(row_major, 0, 6*sizeof(int)); memset(row_major, 0, 6*sizeof(int));
TensorMap<Tensor<int, 1>> vec3(col_major, 6); TensorMap<Tensor<int, 1> > vec3(col_major, 6);
TensorMap<Tensor<int, 1, RowMajor>> vec4(row_major, 6); TensorMap<Tensor<int, 1, RowMajor> > vec4(row_major, 6);
vec3 = vec1; vec3 = vec1;
vec4 = vec2; vec4 = vec2;
@ -92,8 +92,8 @@ static void test_2d()
int row_major[6]; int row_major[6];
memset(col_major, 0, 6*sizeof(int)); memset(col_major, 0, 6*sizeof(int));
memset(row_major, 0, 6*sizeof(int)); memset(row_major, 0, 6*sizeof(int));
TensorMap<Tensor<int, 2>> mat3(row_major, 2, 3); TensorMap<Tensor<int, 2> > mat3(row_major, 2, 3);
TensorMap<Tensor<int, 2, RowMajor>> mat4(col_major, 2, 3); TensorMap<Tensor<int, 2, RowMajor> > mat4(col_major, 2, 3);
mat3 = mat1; mat3 = mat1;
mat4 = mat2; mat4 = mat2;
@ -152,8 +152,8 @@ static void test_3d()
int row_major[2*3*7]; int row_major[2*3*7];
memset(col_major, 0, 2*3*7*sizeof(int)); memset(col_major, 0, 2*3*7*sizeof(int));
memset(row_major, 0, 2*3*7*sizeof(int)); memset(row_major, 0, 2*3*7*sizeof(int));
TensorMap<Tensor<int, 3>> mat3(col_major, 2, 3, 7); TensorMap<Tensor<int, 3> > mat3(col_major, 2, 3, 7);
TensorMap<Tensor<int, 3, RowMajor>> mat4(row_major, 2, 3, 7); TensorMap<Tensor<int, 3, RowMajor> > mat4(row_major, 2, 3, 7);
mat3 = mat1; mat3 = mat1;
mat4 = mat2; mat4 = mat2;

View File

@ -24,12 +24,12 @@ static void test_simple_cast()
cplextensor.setRandom(); cplextensor.setRandom();
chartensor = ftensor.cast<char>(); chartensor = ftensor.cast<char>();
cplextensor = ftensor.cast<std::complex<float>>(); cplextensor = ftensor.cast<std::complex<float> >();
for (int i = 0; i < 20; ++i) { for (int i = 0; i < 20; ++i) {
for (int j = 0; j < 30; ++j) { for (int j = 0; j < 30; ++j) {
VERIFY_IS_EQUAL(chartensor(i,j), static_cast<char>(ftensor(i,j))); VERIFY_IS_EQUAL(chartensor(i,j), static_cast<char>(ftensor(i,j)));
VERIFY_IS_EQUAL(cplextensor(i,j), static_cast<std::complex<float>>(ftensor(i,j))); VERIFY_IS_EQUAL(cplextensor(i,j), static_cast<std::complex<float> >(ftensor(i,j)));
} }
} }
} }

View File

@ -25,7 +25,9 @@ struct InsertZeros {
template <typename Output, typename Device> template <typename Output, typename Device>
void eval(const Tensor<float, 2>& input, Output& output, const Device& device) const void eval(const Tensor<float, 2>& input, Output& output, const Device& device) const
{ {
array<DenseIndex, 2> strides{{2, 2}}; array<DenseIndex, 2> strides;
strides[0] = 2;
strides[1] = 2;
output.stride(strides).device(device) = input; output.stride(strides).device(device) = input;
Eigen::DSizes<DenseIndex, 2> offsets(1,1); Eigen::DSizes<DenseIndex, 2> offsets(1,1);
@ -70,7 +72,8 @@ struct BatchMatMul {
Output& output, const Device& device) const Output& output, const Device& device) const
{ {
typedef Tensor<float, 3>::DimensionPair DimPair; typedef Tensor<float, 3>::DimensionPair DimPair;
array<DimPair, 1> dims({{DimPair(1, 0)}}); array<DimPair, 1> dims;
dims[0] = DimPair(1, 0);
for (int i = 0; i < output.dimension(2); ++i) { for (int i = 0; i < output.dimension(2); ++i) {
output.template chip<2>(i).device(device) = input1.chip<2>(i).contract(input2.chip<2>(i), dims); output.template chip<2>(i).device(device) = input1.chip<2>(i).contract(input2.chip<2>(i), dims);
} }
@ -88,9 +91,10 @@ static void test_custom_binary_op()
Tensor<float, 3> result = tensor1.customOp(tensor2, BatchMatMul()); Tensor<float, 3> result = tensor1.customOp(tensor2, BatchMatMul());
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
typedef Tensor<float, 3>::DimensionPair DimPair; typedef Tensor<float, 3>::DimensionPair DimPair;
array<DimPair, 1> dims({{DimPair(1, 0)}}); array<DimPair, 1> dims;
dims[0] = DimPair(1, 0);
Tensor<float, 2> reference = tensor1.chip<2>(i).contract(tensor2.chip<2>(i), dims); Tensor<float, 2> reference = tensor1.chip<2>(i).contract(tensor2.chip<2>(i), dims);
TensorRef<Tensor<float, 2>> val = result.chip<2>(i); TensorRef<Tensor<float, 2> > val = result.chip<2>(i);
for (int j = 0; j < 2; ++j) { for (int j = 0; j < 2; ++j) {
for (int k = 0; k < 7; ++k) { for (int k = 0; k < 7; ++k) {
VERIFY_IS_APPROX(val(j, k), reference(j, k)); VERIFY_IS_APPROX(val(j, k), reference(j, k));

View File

@ -114,10 +114,18 @@ static void test_expr_reverse(bool LValue)
Tensor<float, 4, DataLayout> result(2,3,5,7); Tensor<float, 4, DataLayout> result(2,3,5,7);
array<ptrdiff_t, 4> src_slice_dim{{2,3,1,7}}; array<ptrdiff_t, 4> src_slice_dim;
array<ptrdiff_t, 4> src_slice_start{{0,0,0,0}}; src_slice_dim[0] = 2;
array<ptrdiff_t, 4> dst_slice_dim{{2,3,1,7}}; src_slice_dim[1] = 3;
array<ptrdiff_t, 4> dst_slice_start{{0,0,0,0}}; src_slice_dim[2] = 1;
src_slice_dim[3] = 7;
array<ptrdiff_t, 4> src_slice_start;
src_slice_start[0] = 0;
src_slice_start[1] = 0;
src_slice_start[2] = 0;
src_slice_start[3] = 0;
array<ptrdiff_t, 4> dst_slice_dim = src_slice_dim;
array<ptrdiff_t, 4> dst_slice_start = src_slice_start;
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 5; ++i) {
if (LValue) { if (LValue) {

View File

@ -18,7 +18,7 @@ static void test_comparison_sugar() {
#define TEST_TENSOR_EQUAL(e1, e2) \ #define TEST_TENSOR_EQUAL(e1, e2) \
b = ((e1) == (e2)).all(); \ b = ((e1) == (e2)).all(); \
VERIFY(b(0)) VERIFY(b())
#define TEST_OP(op) TEST_TENSOR_EQUAL(t op 0, t op t.constant(0)) #define TEST_OP(op) TEST_TENSOR_EQUAL(t op 0, t op t.constant(0))