diff --git a/unsupported/Eigen/CXX11/Tensor b/unsupported/Eigen/CXX11/Tensor index 3f9aa1026..b13760d6c 100644 --- a/unsupported/Eigen/CXX11/Tensor +++ b/unsupported/Eigen/CXX11/Tensor @@ -85,6 +85,7 @@ #include "unsupported/Eigen/CXX11/src/Tensor/TensorVolumePatch.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorBroadcasting.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorChipping.h" +#include "unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorLayoutSwap.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorMorphing.h" #include "unsupported/Eigen/CXX11/src/Tensor/TensorPadding.h" diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h index ca712f9c5..71f69b568 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorBase.h @@ -471,6 +471,11 @@ class TensorBase stride(const Strides& strides) const { return TensorStridingOp(derived(), strides); } + template EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + const TensorInflationOp + inflate(const Strides& strides) const { + return TensorInflationOp(derived(), strides); + } // Support for custom unary and binary operations template diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h index f73ce83f4..17b0e6153 100644 --- a/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorForwardDeclarations.h @@ -39,6 +39,7 @@ template class TensorReverseOp; template class TensorPaddingOp; template class TensorShufflingOp; template class TensorStridingOp; +template class TensorInflationOp; template class TensorGeneratorOp; template class TensorAssignOp; diff --git a/unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h b/unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h new file mode 100644 index 000000000..40a50e466 --- /dev/null +++ b/unsupported/Eigen/CXX11/src/Tensor/TensorInflation.h @@ -0,0 +1,219 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Ke Yang +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#ifndef EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H +#define EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H + +namespace Eigen { + +/** \class TensorInflation + * \ingroup CXX11_Tensor_Module + * + * \brief Tensor inflation class. + * + * + */ +namespace internal { +template +struct traits > : public traits +{ + typedef typename XprType::Scalar Scalar; + typedef traits XprTraits; + typedef typename packet_traits::type Packet; + typedef typename XprTraits::StorageKind StorageKind; + typedef typename XprTraits::Index Index; + typedef typename XprType::Nested Nested; + typedef typename remove_reference::type _Nested; + static const int NumDimensions = XprTraits::NumDimensions; + static const int Layout = XprTraits::Layout; +}; + +template +struct eval, Eigen::Dense> +{ + typedef const TensorInflationOp& type; +}; + +template +struct nested, 1, typename eval >::type> +{ + typedef TensorInflationOp type; +}; + +} // end namespace internal + +template +class TensorInflationOp : public TensorBase, ReadOnlyAccessors> +{ + public: + typedef typename Eigen::internal::traits::Scalar Scalar; + typedef typename Eigen::internal::traits::Packet Packet; + typedef typename Eigen::NumTraits::Real RealScalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + typedef typename Eigen::internal::nested::type Nested; + typedef typename Eigen::internal::traits::StorageKind StorageKind; + typedef typename Eigen::internal::traits::Index Index; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorInflationOp(const XprType& expr, const Strides& strides) + : m_xpr(expr), m_strides(strides) {} + + EIGEN_DEVICE_FUNC + const Strides& strides() const { return m_strides; } + + EIGEN_DEVICE_FUNC + const typename internal::remove_all::type& + expression() const { return m_xpr; } + + protected: + typename XprType::Nested m_xpr; + const Strides m_strides; +}; + +// Eval as rvalue +template +struct TensorEvaluator, Device> +{ + typedef TensorInflationOp XprType; + typedef typename XprType::Index Index; + static const int NumDims = internal::array_size::Dimensions>::value; + typedef DSizes Dimensions; + + enum { + IsAligned = /*TensorEvaluator::IsAligned*/ false, + PacketAccess = TensorEvaluator::PacketAccess, + BlockAccess = false, + Layout = TensorEvaluator::Layout, + CoordAccess = false, // to be implemented + }; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device) + : m_impl(op.expression(), device), m_strides(op.strides()) + { + m_dimensions = m_impl.dimensions(); + // Expand each dimension to the inflated dimension. + for (int i = 0; i < NumDims; ++i) { + m_dimensions[i] = (m_dimensions[i] - 1) * op.strides()[i] + 1; + } + + // Remember the strides for fast division. + for (int i = 0; i < NumDims; ++i) { + m_fastStrides[i] = internal::TensorIntDivisor(m_strides[i]); + } + + const typename TensorEvaluator::Dimensions& input_dims = m_impl.dimensions(); + if (static_cast(Layout) == static_cast(ColMajor)) { + m_outputStrides[0] = 1; + m_inputStrides[0] = 1; + for (int i = 1; i < NumDims; ++i) { + m_outputStrides[i] = m_outputStrides[i-1] * m_dimensions[i-1]; + m_inputStrides[i] = m_inputStrides[i-1] * input_dims[i-1]; + } + } else { // RowMajor + m_outputStrides[NumDims-1] = 1; + m_inputStrides[NumDims-1] = 1; + for (int i = NumDims - 2; i >= 0; --i) { + m_outputStrides[i] = m_outputStrides[i+1] * m_dimensions[i+1]; + m_inputStrides[i] = m_inputStrides[i+1] * input_dims[i+1]; + } + } + } + + typedef typename XprType::Scalar Scalar; + typedef typename XprType::CoeffReturnType CoeffReturnType; + typedef typename XprType::PacketReturnType PacketReturnType; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const { return m_dimensions; } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) { + m_impl.evalSubExprsIfNeeded(NULL); + return true; + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() { + m_impl.cleanup(); + } + + // Computes the input index given the output index. Returns true if the output + // index doesn't fall into a hole. + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool getInputIndex(Index index, Index* inputIndex) const + { + eigen_assert(index < dimensions().TotalSize()); + *inputIndex = 0; + if (static_cast(Layout) == static_cast(ColMajor)) { + for (int i = NumDims - 1; i > 0; --i) { + const Index idx = index / m_outputStrides[i]; + if (idx != idx / m_fastStrides[i] * m_strides[i]) { + return false; + } + *inputIndex += idx / m_strides[i] * m_inputStrides[i]; + index -= idx * m_outputStrides[i]; + } + if (index != index / m_fastStrides[0] * m_strides[0]) { + return false; + } + *inputIndex += index / m_strides[0]; + return true; + } else { + for (int i = 0; i < NumDims - 1; ++i) { + const Index idx = index / m_outputStrides[i]; + if (idx != idx / m_fastStrides[i] * m_strides[i]) { + return false; + } + *inputIndex += idx / m_strides[i] * m_inputStrides[i]; + index -= idx * m_outputStrides[i]; + } + if (index != index / m_fastStrides[NumDims-1] * m_strides[NumDims-1]) { + return false; + } + *inputIndex += index / m_strides[NumDims - 1]; + } + return true; + } + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const + { + Index inputIndex = 0; + if (getInputIndex(index, &inputIndex)) { + return m_impl.coeff(inputIndex); + } else { + return Scalar(0); + } + } + + // TODO(yangke): optimize this function so that we can detect and produce + // all-zero packets + template + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const + { + const int packetSize = internal::unpacket_traits::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(index+packetSize-1 < dimensions().TotalSize()); + + EIGEN_ALIGN_DEFAULT typename internal::remove_const::type values[packetSize]; + for (int i = 0; i < packetSize; ++i) { + values[i] = coeff(index+i); + } + PacketReturnType rslt = internal::pload(values); + return rslt; + } + + EIGEN_DEVICE_FUNC Scalar* data() const { return NULL; } + + protected: + Dimensions m_dimensions; + array m_outputStrides; + array m_inputStrides; + TensorEvaluator m_impl; + const Strides m_strides; + array, NumDims> m_fastStrides; +}; + +} // end namespace Eigen + +#endif // EIGEN_CXX11_TENSOR_TENSOR_INFLATION_H diff --git a/unsupported/test/CMakeLists.txt b/unsupported/test/CMakeLists.txt index 845cda8ce..14f524769 100644 --- a/unsupported/test/CMakeLists.txt +++ b/unsupported/test/CMakeLists.txt @@ -121,6 +121,7 @@ if(EIGEN_TEST_CXX11) ei_add_test(cxx11_tensor_broadcasting "-std=c++0x") ei_add_test(cxx11_tensor_chipping "-std=c++0x") ei_add_test(cxx11_tensor_concatenation "-std=c++0x") + ei_add_test(cxx11_tensor_inflation "-std=c++0x") ei_add_test(cxx11_tensor_morphing "-std=c++0x") ei_add_test(cxx11_tensor_padding "-std=c++0x") ei_add_test(cxx11_tensor_patch "-std=c++0x") diff --git a/unsupported/test/cxx11_tensor_inflation.cpp b/unsupported/test/cxx11_tensor_inflation.cpp new file mode 100644 index 000000000..4997935e9 --- /dev/null +++ b/unsupported/test/cxx11_tensor_inflation.cpp @@ -0,0 +1,81 @@ +// This file is part of Eigen, a lightweight C++ template library +// for linear algebra. +// +// Copyright (C) 2015 Ke Yang +// +// This Source Code Form is subject to the terms of the Mozilla +// Public License v. 2.0. If a copy of the MPL was not distributed +// with this file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#include "main.h" + +#include + +using Eigen::Tensor; + +template +static void test_simple_inflation() +{ + Tensor tensor(2,3,5,7); + tensor.setRandom(); + array strides; + + strides[0] = 1; + strides[1] = 1; + strides[2] = 1; + strides[3] = 1; + + Tensor no_stride; + no_stride = tensor.inflate(strides); + + VERIFY_IS_EQUAL(no_stride.dimension(0), 2); + VERIFY_IS_EQUAL(no_stride.dimension(1), 3); + VERIFY_IS_EQUAL(no_stride.dimension(2), 5); + VERIFY_IS_EQUAL(no_stride.dimension(3), 7); + + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 5; ++k) { + for (int l = 0; l < 7; ++l) { + VERIFY_IS_EQUAL(tensor(i,j,k,l), no_stride(i,j,k,l)); + } + } + } + } + + strides[0] = 2; + strides[1] = 4; + strides[2] = 2; + strides[3] = 3; + Tensor inflated; + inflated = tensor.inflate(strides); + + VERIFY_IS_EQUAL(inflated.dimension(0), 3); + VERIFY_IS_EQUAL(inflated.dimension(1), 9); + VERIFY_IS_EQUAL(inflated.dimension(2), 9); + VERIFY_IS_EQUAL(inflated.dimension(3), 19); + + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 9; ++j) { + for (int k = 0; k < 9; ++k) { + for (int l = 0; l < 19; ++l) { + if (i % 2 == 0 && + j % 4 == 0 && + k % 2 == 0 && + l % 3 == 0) { + VERIFY_IS_EQUAL(inflated(i,j,k,l), + tensor(i/2, j/4, k/2, l/3)); + } else { + VERIFY_IS_EQUAL(0, inflated(i,j,k,l)); + } + } + } + } + } +} + +void test_cxx11_tensor_inflation() +{ + CALL_SUBTEST(test_simple_inflation()); + CALL_SUBTEST(test_simple_inflation()); +}