mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-24 05:44:26 +08:00
Merged in tntnatbry/eigen (pull request PR-319)
Tensor Trace op
This commit is contained in:
parent
6795512e59
commit
9daed67952
@ -141,6 +141,7 @@ typedef unsigned __int64 uint64_t;
|
|||||||
#include "src/Tensor/TensorGenerator.h"
|
#include "src/Tensor/TensorGenerator.h"
|
||||||
#include "src/Tensor/TensorAssign.h"
|
#include "src/Tensor/TensorAssign.h"
|
||||||
#include "src/Tensor/TensorScan.h"
|
#include "src/Tensor/TensorScan.h"
|
||||||
|
#include "src/Tensor/TensorTrace.h"
|
||||||
|
|
||||||
#include "src/Tensor/TensorSycl.h"
|
#include "src/Tensor/TensorSycl.h"
|
||||||
#include "src/Tensor/TensorExecutor.h"
|
#include "src/Tensor/TensorExecutor.h"
|
||||||
|
@ -1168,6 +1168,58 @@ Reduce a tensor using a user-defined reduction operator. See ```SumReducer```
|
|||||||
in TensorFunctors.h for information on how to implement a reduction operator.
|
in TensorFunctors.h for information on how to implement a reduction operator.
|
||||||
|
|
||||||
|
|
||||||
|
## Trace
|
||||||
|
|
||||||
|
A *Trace* operation returns a tensor with fewer dimensions than the original
|
||||||
|
tensor. It returns a tensor whose elements are the sum of the elements of the
|
||||||
|
original tensor along the main diagonal for a list of specified dimensions, the
|
||||||
|
"trace dimensions". Similar to the ```Reduction Dimensions```, the trace dimensions
|
||||||
|
are passed as an input parameter to the operation, are of type ```<TensorType>::Dimensions```
|
||||||
|
, and have the same requirements when passed as an input parameter. In addition,
|
||||||
|
the trace dimensions must have the same size.
|
||||||
|
|
||||||
|
Example: Trace along 2 dimensions.
|
||||||
|
|
||||||
|
// Create a tensor of 3 dimensions
|
||||||
|
Eigen::Tensor<int, 3> a(2, 2, 3);
|
||||||
|
a.setValues({{{1, 2, 3}, {4, 5, 6}}, {{7, 8, 9}, {10, 11, 12}}});
|
||||||
|
// Specify the dimensions along which the trace will be computed.
|
||||||
|
// In this example, the trace can only be computed along the dimensions
|
||||||
|
// with indices 0 and 1
|
||||||
|
Eigen::array<int, 2> dims({0, 1});
|
||||||
|
// The output tensor contains all but the trace dimensions.
|
||||||
|
Tensor<int, 1> a_trace = a.trace(dims);
|
||||||
|
cout << "a_trace:" << endl;
|
||||||
|
cout << a_trace << endl;
|
||||||
|
=>
|
||||||
|
a_trace:
|
||||||
|
11
|
||||||
|
13
|
||||||
|
15
|
||||||
|
|
||||||
|
|
||||||
|
### <Operation> trace(const Dimensions& new_dims)
|
||||||
|
### <Operation> trace()
|
||||||
|
|
||||||
|
As a special case, if no parameter is passed to the operation, trace is computed
|
||||||
|
along *all* dimensions of the input tensor.
|
||||||
|
|
||||||
|
Example: Trace along all dimensions.
|
||||||
|
|
||||||
|
// Create a tensor of 3 dimensions, with all dimensions having the same size.
|
||||||
|
Eigen::Tensor<int, 3> a(3, 3, 3);
|
||||||
|
a.setValues({{{1, 2, 3}, {4, 5, 6}, {7, 8, 9}},
|
||||||
|
{{10, 11, 12}, {13, 14, 15}, {16, 17, 18}},
|
||||||
|
{{19, 20, 21}, {22, 23, 24}, {25, 26, 27}}});
|
||||||
|
// Result is a zero dimension tensor
|
||||||
|
Tensor<int, 0> a_trace = a.trace();
|
||||||
|
cout<<"a_trace:"<<endl;
|
||||||
|
cout<<a_trace<<endl;
|
||||||
|
=>
|
||||||
|
a_trace:
|
||||||
|
42
|
||||||
|
|
||||||
|
|
||||||
## Scan Operations
|
## Scan Operations
|
||||||
|
|
||||||
A *Scan* operation returns a tensor with the same dimensions as the original
|
A *Scan* operation returns a tensor with the same dimensions as the original
|
||||||
|
@ -671,6 +671,18 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
return TensorReductionOp<Reducer, const Dims, const Derived>(derived(), dims, reducer);
|
return TensorReductionOp<Reducer, const Dims, const Derived>(derived(), dims, reducer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename Dims> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const TensorTraceOp<const Dims, const Derived>
|
||||||
|
trace(const Dims& dims) const {
|
||||||
|
return TensorTraceOp<const Dims, const Derived>(derived(), dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
const TensorTraceOp<const DimensionList<Index, NumDimensions>, const Derived>
|
||||||
|
trace() const {
|
||||||
|
DimensionList<Index, NumDimensions> in_dims;
|
||||||
|
return TensorTraceOp<const DimensionList<Index, NumDimensions>, const Derived>(derived(), in_dims);
|
||||||
|
}
|
||||||
|
|
||||||
template <typename Broadcast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
template <typename Broadcast> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
const TensorBroadcastingOp<const Broadcast, const Derived>
|
const TensorBroadcastingOp<const Broadcast, const Derived>
|
||||||
broadcast(const Broadcast& broadcast) const {
|
broadcast(const Broadcast& broadcast) const {
|
||||||
|
@ -70,6 +70,7 @@ template<typename Strides, typename XprType> class TensorInflationOp;
|
|||||||
template<typename Generator, typename XprType> class TensorGeneratorOp;
|
template<typename Generator, typename XprType> class TensorGeneratorOp;
|
||||||
template<typename LeftXprType, typename RightXprType> class TensorAssignOp;
|
template<typename LeftXprType, typename RightXprType> class TensorAssignOp;
|
||||||
template<typename Op, typename XprType> class TensorScanOp;
|
template<typename Op, typename XprType> class TensorScanOp;
|
||||||
|
template<typename Dims, typename XprType> class TensorTraceOp;
|
||||||
|
|
||||||
template<typename CustomUnaryFunc, typename XprType> class TensorCustomUnaryOp;
|
template<typename CustomUnaryFunc, typename XprType> class TensorCustomUnaryOp;
|
||||||
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> class TensorCustomBinaryOp;
|
template<typename CustomBinaryFunc, typename LhsXprType, typename RhsXprType> class TensorCustomBinaryOp;
|
||||||
|
288
unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h
Normal file
288
unsupported/Eigen/CXX11/src/Tensor/TensorTrace.h
Normal file
@ -0,0 +1,288 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com>
|
||||||
|
// Copyright (C) 2017 Benoit Steiner <benoit.steiner.goog@gmail.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#ifndef EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
|
||||||
|
#define EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
|
||||||
|
|
||||||
|
namespace Eigen {
|
||||||
|
|
||||||
|
/** \class TensorTrace
|
||||||
|
* \ingroup CXX11_Tensor_Module
|
||||||
|
*
|
||||||
|
* \brief Tensor Trace class.
|
||||||
|
*
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
|
||||||
|
namespace internal {
|
||||||
|
template<typename Dims, typename XprType>
|
||||||
|
struct traits<TensorTraceOp<Dims, XprType> > : public traits<XprType>
|
||||||
|
{
|
||||||
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
typedef traits<XprType> XprTraits;
|
||||||
|
typedef typename XprTraits::StorageKind StorageKind;
|
||||||
|
typedef typename XprTraits::Index Index;
|
||||||
|
typedef typename XprType::Nested Nested;
|
||||||
|
typedef typename remove_reference<Nested>::type _Nested;
|
||||||
|
static const int NumDimensions = XprTraits::NumDimensions - array_size<Dims>::value;
|
||||||
|
static const int Layout = XprTraits::Layout;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Dims, typename XprType>
|
||||||
|
struct eval<TensorTraceOp<Dims, XprType>, Eigen::Dense>
|
||||||
|
{
|
||||||
|
typedef const TensorTraceOp<Dims, XprType>& type;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<typename Dims, typename XprType>
|
||||||
|
struct nested<TensorTraceOp<Dims, XprType>, 1, typename eval<TensorTraceOp<Dims, XprType> >::type>
|
||||||
|
{
|
||||||
|
typedef TensorTraceOp<Dims, XprType> type;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
|
||||||
|
|
||||||
|
template<typename Dims, typename XprType>
|
||||||
|
class TensorTraceOp : public TensorBase<TensorTraceOp<Dims, XprType> >
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
typedef typename Eigen::internal::traits<TensorTraceOp>::Scalar Scalar;
|
||||||
|
typedef typename Eigen::NumTraits<Scalar>::Real RealScalar;
|
||||||
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
typedef typename Eigen::internal::nested<TensorTraceOp>::type Nested;
|
||||||
|
typedef typename Eigen::internal::traits<TensorTraceOp>::StorageKind StorageKind;
|
||||||
|
typedef typename Eigen::internal::traits<TensorTraceOp>::Index Index;
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorTraceOp(const XprType& expr, const Dims& dims)
|
||||||
|
: m_xpr(expr), m_dims(dims) {
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const Dims& dims() const { return m_dims; }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE
|
||||||
|
const typename internal::remove_all<typename XprType::Nested>::type& expression() const { return m_xpr; }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
typename XprType::Nested m_xpr;
|
||||||
|
const Dims m_dims;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
// Eval as rvalue
|
||||||
|
template<typename Dims, typename ArgType, typename Device>
|
||||||
|
struct TensorEvaluator<const TensorTraceOp<Dims, ArgType>, Device>
|
||||||
|
{
|
||||||
|
typedef TensorTraceOp<Dims, ArgType> XprType;
|
||||||
|
static const int NumInputDims = internal::array_size<typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
|
||||||
|
static const int NumReducedDims = internal::array_size<Dims>::value;
|
||||||
|
static const int NumOutputDims = NumInputDims - NumReducedDims;
|
||||||
|
typedef typename XprType::Index Index;
|
||||||
|
typedef DSizes<Index, NumOutputDims> Dimensions;
|
||||||
|
typedef typename XprType::Scalar Scalar;
|
||||||
|
typedef typename XprType::CoeffReturnType CoeffReturnType;
|
||||||
|
typedef typename PacketType<CoeffReturnType, Device>::type PacketReturnType;
|
||||||
|
static const int PacketSize = internal::unpacket_traits<PacketReturnType>::size;
|
||||||
|
|
||||||
|
enum {
|
||||||
|
IsAligned = false,
|
||||||
|
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
|
||||||
|
Layout = TensorEvaluator<ArgType, Device>::Layout,
|
||||||
|
CoordAccess = false,
|
||||||
|
RawAccess = false
|
||||||
|
};
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||||
|
: m_impl(op.expression(), device), m_device(device)
|
||||||
|
{
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((NumOutputDims >= 0), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
EIGEN_STATIC_ASSERT((NumReducedDims >= 2) || ((NumReducedDims == 0) && (NumInputDims == 0)), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
|
||||||
|
for (int i = 0; i < NumInputDims; ++i) {
|
||||||
|
m_reduced[i] = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Dims& op_dims = op.dims();
|
||||||
|
for (int i = 0; i < NumReducedDims; ++i) {
|
||||||
|
eigen_assert(op_dims[i] >= 0);
|
||||||
|
eigen_assert(op_dims[i] < NumInputDims);
|
||||||
|
m_reduced[op_dims[i]] = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// All the dimensions should be distinct to compute the trace
|
||||||
|
int num_distinct_reduce_dims = 0;
|
||||||
|
for (int i = 0; i < NumInputDims; ++i) {
|
||||||
|
if (m_reduced[i]) {
|
||||||
|
++num_distinct_reduce_dims;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
eigen_assert(num_distinct_reduce_dims == NumReducedDims);
|
||||||
|
|
||||||
|
// Compute the dimensions of the result.
|
||||||
|
const typename TensorEvaluator<ArgType, Device>::Dimensions& input_dims = m_impl.dimensions();
|
||||||
|
|
||||||
|
int output_index = 0;
|
||||||
|
int reduced_index = 0;
|
||||||
|
for (int i = 0; i < NumInputDims; ++i) {
|
||||||
|
if (m_reduced[i]) {
|
||||||
|
m_reducedDims[reduced_index] = input_dims[i];
|
||||||
|
if (reduced_index > 0) {
|
||||||
|
// All the trace dimensions must have the same size
|
||||||
|
eigen_assert(m_reducedDims[0] == m_reducedDims[reduced_index]);
|
||||||
|
}
|
||||||
|
++reduced_index;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
m_dimensions[output_index] = input_dims[i];
|
||||||
|
++output_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (NumReducedDims != 0) {
|
||||||
|
m_traceDim = m_reducedDims[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the output strides
|
||||||
|
if (NumOutputDims > 0) {
|
||||||
|
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||||
|
m_outputStrides[0] = 1;
|
||||||
|
for (int i = 1; i < NumOutputDims; ++i) {
|
||||||
|
m_outputStrides[i] = m_outputStrides[i - 1] * m_dimensions[i - 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
m_outputStrides.back() = 1;
|
||||||
|
for (int i = NumOutputDims - 2; i >= 0; --i) {
|
||||||
|
m_outputStrides[i] = m_outputStrides[i + 1] * m_dimensions[i + 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the input strides
|
||||||
|
if (NumInputDims > 0) {
|
||||||
|
array<Index, NumInputDims> input_strides;
|
||||||
|
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||||
|
input_strides[0] = 1;
|
||||||
|
for (int i = 1; i < NumInputDims; ++i) {
|
||||||
|
input_strides[i] = input_strides[i - 1] * input_dims[i - 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
input_strides.back() = 1;
|
||||||
|
for (int i = NumInputDims - 2; i >= 0; --i) {
|
||||||
|
input_strides[i] = input_strides[i + 1] * input_dims[i + 1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
output_index = 0;
|
||||||
|
reduced_index = 0;
|
||||||
|
for (int i = 0; i < NumInputDims; ++i) {
|
||||||
|
if(m_reduced[i]) {
|
||||||
|
m_reducedStrides[reduced_index] = input_strides[i];
|
||||||
|
++reduced_index;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
m_preservedStrides[output_index] = input_strides[i];
|
||||||
|
++output_index;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Dimensions& dimensions() const {
|
||||||
|
return m_dimensions;
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool evalSubExprsIfNeeded(Scalar* /*data*/) {
|
||||||
|
m_impl.evalSubExprsIfNeeded(NULL);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void cleanup() {
|
||||||
|
m_impl.cleanup();
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const
|
||||||
|
{
|
||||||
|
// Initialize the result
|
||||||
|
CoeffReturnType result = internal::cast<int, CoeffReturnType>(0);
|
||||||
|
Index index_stride = 0;
|
||||||
|
for (int i = 0; i < NumReducedDims; ++i) {
|
||||||
|
index_stride += m_reducedStrides[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
// If trace is requested along all dimensions, starting index would be 0
|
||||||
|
Index cur_index = 0;
|
||||||
|
if (NumOutputDims != 0)
|
||||||
|
cur_index = firstInput(index);
|
||||||
|
for (Index i = 0; i < m_traceDim; ++i) {
|
||||||
|
result += m_impl.coeff(cur_index);
|
||||||
|
cur_index += index_stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<int LoadMode>
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketReturnType packet(Index index) const {
|
||||||
|
|
||||||
|
EIGEN_STATIC_ASSERT((PacketSize > 1), YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||||
|
eigen_assert(index + PacketSize - 1 < dimensions().TotalSize());
|
||||||
|
|
||||||
|
EIGEN_ALIGN_MAX typename internal::remove_const<CoeffReturnType>::type values[PacketSize];
|
||||||
|
for (int i = 0; i < PacketSize; ++i) {
|
||||||
|
values[i] = coeff(index + i);
|
||||||
|
}
|
||||||
|
PacketReturnType result = internal::ploadt<PacketReturnType, LoadMode>(values);
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// Given the output index, finds the first index in the input tensor used to compute the trace
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index firstInput(Index index) const {
|
||||||
|
Index startInput = 0;
|
||||||
|
if (static_cast<int>(Layout) == static_cast<int>(ColMajor)) {
|
||||||
|
for (int i = NumOutputDims - 1; i > 0; --i) {
|
||||||
|
const Index idx = index / m_outputStrides[i];
|
||||||
|
startInput += idx * m_preservedStrides[i];
|
||||||
|
index -= idx * m_outputStrides[i];
|
||||||
|
}
|
||||||
|
startInput += index * m_preservedStrides[0];
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int i = 0; i < NumOutputDims - 1; ++i) {
|
||||||
|
const Index idx = index / m_outputStrides[i];
|
||||||
|
startInput += idx * m_preservedStrides[i];
|
||||||
|
index -= idx * m_outputStrides[i];
|
||||||
|
}
|
||||||
|
startInput += index * m_preservedStrides[NumOutputDims - 1];
|
||||||
|
}
|
||||||
|
return startInput;
|
||||||
|
}
|
||||||
|
|
||||||
|
Dimensions m_dimensions;
|
||||||
|
TensorEvaluator<ArgType, Device> m_impl;
|
||||||
|
const Device& m_device;
|
||||||
|
array<bool, NumInputDims> m_reduced;
|
||||||
|
array<Index, NumReducedDims> m_reducedDims;
|
||||||
|
// Initialize the size of the trace dimension
|
||||||
|
Index m_traceDim = 1;
|
||||||
|
array<Index, NumOutputDims> m_outputStrides;
|
||||||
|
array<Index, NumReducedDims> m_reducedStrides;
|
||||||
|
array<Index, NumOutputDims> m_preservedStrides;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
} // End namespace Eigen
|
||||||
|
|
||||||
|
#endif // EIGEN_CXX11_TENSOR_TENSOR_TRACE_H
|
@ -227,6 +227,7 @@ if(EIGEN_TEST_CXX11)
|
|||||||
ei_add_test(cxx11_tensor_fft)
|
ei_add_test(cxx11_tensor_fft)
|
||||||
ei_add_test(cxx11_tensor_ifft)
|
ei_add_test(cxx11_tensor_ifft)
|
||||||
ei_add_test(cxx11_tensor_scan)
|
ei_add_test(cxx11_tensor_scan)
|
||||||
|
ei_add_test(cxx11_tensor_trace)
|
||||||
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
171
unsupported/test/cxx11_tensor_trace.cpp
Normal file
171
unsupported/test/cxx11_tensor_trace.cpp
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
// This file is part of Eigen, a lightweight C++ template library
|
||||||
|
// for linear algebra.
|
||||||
|
//
|
||||||
|
// Copyright (C) 2017 Gagan Goel <gagan.nith@gmail.com>
|
||||||
|
//
|
||||||
|
// This Source Code Form is subject to the terms of the Mozilla
|
||||||
|
// Public License v. 2.0. If a copy of the MPL was not distributed
|
||||||
|
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||||
|
|
||||||
|
#include "main.h"
|
||||||
|
|
||||||
|
#include <Eigen/CXX11/Tensor>
|
||||||
|
|
||||||
|
using Eigen::Tensor;
|
||||||
|
using Eigen::array;
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_0D_trace() {
|
||||||
|
Tensor<float, 0, DataLayout> tensor;
|
||||||
|
tensor.setRandom();
|
||||||
|
array<ptrdiff_t, 0> dims;
|
||||||
|
Tensor<float, 0, DataLayout> result = tensor.trace(dims);
|
||||||
|
VERIFY_IS_EQUAL(result(), tensor());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_all_dimensions_trace() {
|
||||||
|
Tensor<float, 3, DataLayout> tensor1(5, 5, 5);
|
||||||
|
tensor1.setRandom();
|
||||||
|
Tensor<float, 0, DataLayout> result1 = tensor1.trace();
|
||||||
|
VERIFY_IS_EQUAL(result1.rank(), 0);
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
sum += tensor1(i, i, i);
|
||||||
|
}
|
||||||
|
VERIFY_IS_EQUAL(result1(), sum);
|
||||||
|
|
||||||
|
Tensor<float, 5, DataLayout> tensor2(7, 7, 7, 7, 7);
|
||||||
|
array<ptrdiff_t, 5> dims({{2, 1, 0, 3, 4}});
|
||||||
|
Tensor<float, 0, DataLayout> result2 = tensor2.trace(dims);
|
||||||
|
VERIFY_IS_EQUAL(result2.rank(), 0);
|
||||||
|
sum = 0.0f;
|
||||||
|
for (int i = 0; i < 7; ++i) {
|
||||||
|
sum += tensor2(i, i, i, i, i);
|
||||||
|
}
|
||||||
|
VERIFY_IS_EQUAL(result2(), sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <int DataLayout>
|
||||||
|
static void test_simple_trace() {
|
||||||
|
Tensor<float, 3, DataLayout> tensor1(3, 5, 3);
|
||||||
|
tensor1.setRandom();
|
||||||
|
array<ptrdiff_t, 2> dims1({{0, 2}});
|
||||||
|
Tensor<float, 1, DataLayout> result1 = tensor1.trace(dims1);
|
||||||
|
VERIFY_IS_EQUAL(result1.rank(), 1);
|
||||||
|
VERIFY_IS_EQUAL(result1.dimension(0), 5);
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
sum = 0.0f;
|
||||||
|
for (int j = 0; j < 3; ++j) {
|
||||||
|
sum += tensor1(j, i, j);
|
||||||
|
}
|
||||||
|
VERIFY_IS_EQUAL(result1(i), sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor<float, 4, DataLayout> tensor2(5, 5, 7, 7);
|
||||||
|
tensor2.setRandom();
|
||||||
|
array<ptrdiff_t, 2> dims2({{2, 3}});
|
||||||
|
Tensor<float, 2, DataLayout> result2 = tensor2.trace(dims2);
|
||||||
|
VERIFY_IS_EQUAL(result2.rank(), 2);
|
||||||
|
VERIFY_IS_EQUAL(result2.dimension(0), 5);
|
||||||
|
VERIFY_IS_EQUAL(result2.dimension(1), 5);
|
||||||
|
for (int i = 0; i < 5; ++i) {
|
||||||
|
for (int j = 0; j < 5; ++j) {
|
||||||
|
sum = 0.0f;
|
||||||
|
for (int k = 0; k < 7; ++k) {
|
||||||
|
sum += tensor2(i, j, k, k);
|
||||||
|
}
|
||||||
|
VERIFY_IS_EQUAL(result2(i, j), sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
array<ptrdiff_t, 2> dims3({{1, 0}});
|
||||||
|
Tensor<float, 2, DataLayout> result3 = tensor2.trace(dims3);
|
||||||
|
VERIFY_IS_EQUAL(result3.rank(), 2);
|
||||||
|
VERIFY_IS_EQUAL(result3.dimension(0), 7);
|
||||||
|
VERIFY_IS_EQUAL(result3.dimension(1), 7);
|
||||||
|
for (int i = 0; i < 7; ++i) {
|
||||||
|
for (int j = 0; j < 7; ++j) {
|
||||||
|
sum = 0.0f;
|
||||||
|
for (int k = 0; k < 5; ++k) {
|
||||||
|
sum += tensor2(k, k, i, j);
|
||||||
|
}
|
||||||
|
VERIFY_IS_EQUAL(result3(i, j), sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor<float, 5, DataLayout> tensor3(3, 7, 3, 7, 3);
|
||||||
|
tensor3.setRandom();
|
||||||
|
array<ptrdiff_t, 3> dims4({{0, 2, 4}});
|
||||||
|
Tensor<float, 2, DataLayout> result4 = tensor3.trace(dims4);
|
||||||
|
VERIFY_IS_EQUAL(result4.rank(), 2);
|
||||||
|
VERIFY_IS_EQUAL(result4.dimension(0), 7);
|
||||||
|
VERIFY_IS_EQUAL(result4.dimension(1), 7);
|
||||||
|
for (int i = 0; i < 7; ++i) {
|
||||||
|
for (int j = 0; j < 7; ++j) {
|
||||||
|
sum = 0.0f;
|
||||||
|
for (int k = 0; k < 3; ++k) {
|
||||||
|
sum += tensor3(k, i, k, j, k);
|
||||||
|
}
|
||||||
|
VERIFY_IS_EQUAL(result4(i, j), sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor<float, 5, DataLayout> tensor4(3, 7, 4, 7, 5);
|
||||||
|
tensor4.setRandom();
|
||||||
|
array<ptrdiff_t, 2> dims5({{1, 3}});
|
||||||
|
Tensor<float, 3, DataLayout> result5 = tensor4.trace(dims5);
|
||||||
|
VERIFY_IS_EQUAL(result5.rank(), 3);
|
||||||
|
VERIFY_IS_EQUAL(result5.dimension(0), 3);
|
||||||
|
VERIFY_IS_EQUAL(result5.dimension(1), 4);
|
||||||
|
VERIFY_IS_EQUAL(result5.dimension(2), 5);
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
for (int j = 0; j < 4; ++j) {
|
||||||
|
for (int k = 0; k < 5; ++k) {
|
||||||
|
sum = 0.0f;
|
||||||
|
for (int l = 0; l < 7; ++l) {
|
||||||
|
sum += tensor4(i, l, j, l, k);
|
||||||
|
}
|
||||||
|
VERIFY_IS_EQUAL(result5(i, j, k), sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<int DataLayout>
|
||||||
|
static void test_trace_in_expr() {
|
||||||
|
Tensor<float, 4, DataLayout> tensor(2, 3, 5, 3);
|
||||||
|
tensor.setRandom();
|
||||||
|
array<ptrdiff_t, 2> dims({{1, 3}});
|
||||||
|
Tensor<float, 2, DataLayout> result(2, 5);
|
||||||
|
result = result.constant(1.0f) - tensor.trace(dims);
|
||||||
|
VERIFY_IS_EQUAL(result.rank(), 2);
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(0), 2);
|
||||||
|
VERIFY_IS_EQUAL(result.dimension(1), 5);
|
||||||
|
float sum = 0.0f;
|
||||||
|
for (int i = 0; i < 2; ++i) {
|
||||||
|
for (int j = 0; j < 5; ++j) {
|
||||||
|
sum = 0.0f;
|
||||||
|
for (int k = 0; k < 3; ++k) {
|
||||||
|
sum += tensor(i, k, j, k);
|
||||||
|
}
|
||||||
|
VERIFY_IS_EQUAL(result(i, j), 1.0f - sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
void test_cxx11_tensor_trace() {
|
||||||
|
CALL_SUBTEST(test_0D_trace<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_0D_trace<RowMajor>());
|
||||||
|
CALL_SUBTEST(test_all_dimensions_trace<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_all_dimensions_trace<RowMajor>());
|
||||||
|
CALL_SUBTEST(test_simple_trace<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_simple_trace<RowMajor>());
|
||||||
|
CALL_SUBTEST(test_trace_in_expr<ColMajor>());
|
||||||
|
CALL_SUBTEST(test_trace_in_expr<RowMajor>());
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user