mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 03:39:01 +08:00
Updated the contraction code to ensure that full contraction return a tensor of rank 0
This commit is contained in:
parent
b300a84989
commit
06d774bf58
@ -37,7 +37,7 @@ struct traits<TensorContractionOp<Dimensions, LhsXprType, RhsXprType> >
|
|||||||
typedef typename remove_reference<RhsNested>::type _RhsNested;
|
typedef typename remove_reference<RhsNested>::type _RhsNested;
|
||||||
|
|
||||||
// From NumDims below.
|
// From NumDims below.
|
||||||
static const int NumDimensions = max_n_1<traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value>::size;
|
static const int NumDimensions = traits<RhsXprType>::NumDimensions + traits<RhsXprType>::NumDimensions - 2 * array_size<Dimensions>::value;
|
||||||
static const int Layout = traits<LhsXprType>::Layout;
|
static const int Layout = traits<LhsXprType>::Layout;
|
||||||
|
|
||||||
enum {
|
enum {
|
||||||
@ -65,7 +65,7 @@ struct traits<TensorEvaluator<const TensorContractionOp<Indices_, LeftArgType_,
|
|||||||
typedef Device_ Device;
|
typedef Device_ Device;
|
||||||
|
|
||||||
// From NumDims below.
|
// From NumDims below.
|
||||||
static const int NumDimensions = max_n_1<traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value>::size;
|
static const int NumDimensions = traits<LeftArgType_>::NumDimensions + traits<RightArgType_>::NumDimensions - 2 * array_size<Indices_>::value;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace internal
|
} // end namespace internal
|
||||||
@ -140,7 +140,7 @@ struct TensorContractionEvaluatorBase
|
|||||||
static const int RDims =
|
static const int RDims =
|
||||||
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
|
internal::array_size<typename TensorEvaluator<EvalRightArgType, Device>::Dimensions>::value;
|
||||||
static const int ContractDims = internal::array_size<Indices>::value;
|
static const int ContractDims = internal::array_size<Indices>::value;
|
||||||
static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
|
static const int NumDims = LDims + RDims - 2 * ContractDims;
|
||||||
|
|
||||||
typedef array<Index, ContractDims> contract_t;
|
typedef array<Index, ContractDims> contract_t;
|
||||||
typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
|
typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
|
||||||
@ -218,11 +218,9 @@ struct TensorContractionEvaluatorBase
|
|||||||
rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
|
rhs_strides[i+1] = rhs_strides[i] * eval_right_dims[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
m_i_strides[0] = 1;
|
if (m_i_strides.size() > 0) m_i_strides[0] = 1;
|
||||||
m_j_strides[0] = 1;
|
if (m_j_strides.size() > 0) m_j_strides[0] = 1;
|
||||||
if(ContractDims) {
|
if (m_k_strides.size() > 0) m_k_strides[0] = 1;
|
||||||
m_k_strides[0] = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
m_i_size = 1;
|
m_i_size = 1;
|
||||||
m_j_size = 1;
|
m_j_size = 1;
|
||||||
@ -318,11 +316,6 @@ struct TensorContractionEvaluatorBase
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scalar case. We represent the result as a 1d tensor of size 1.
|
|
||||||
if (LDims + RDims == 2 * ContractDims) {
|
|
||||||
m_dimensions[0] = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the layout is RowMajor, we need to reverse the m_dimensions
|
// If the layout is RowMajor, we need to reverse the m_dimensions
|
||||||
if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
|
if (static_cast<int>(Layout) == static_cast<int>(RowMajor)) {
|
||||||
for (int i = 0, j = NumDims - 1; i < j; i++, j--) {
|
for (int i = 0, j = NumDims - 1; i < j; i++, j--) {
|
||||||
@ -607,15 +600,14 @@ struct TensorEvaluator<const TensorContractionOp<Indices, LeftArgType, RightArgT
|
|||||||
static const int ContractDims = internal::array_size<Indices>::value;
|
static const int ContractDims = internal::array_size<Indices>::value;
|
||||||
|
|
||||||
typedef array<Index, ContractDims> contract_t;
|
typedef array<Index, ContractDims> contract_t;
|
||||||
typedef array<Index, max_n_1<LDims - ContractDims>::size> left_nocontract_t;
|
typedef array<Index, LDims - ContractDims> left_nocontract_t;
|
||||||
typedef array<Index, max_n_1<RDims - ContractDims>::size> right_nocontract_t;
|
typedef array<Index, RDims - ContractDims> right_nocontract_t;
|
||||||
|
|
||||||
static const int NumDims = max_n_1<LDims + RDims - 2 * ContractDims>::size;
|
static const int NumDims = LDims + RDims - 2 * ContractDims;
|
||||||
|
|
||||||
// Could we use NumDimensions here?
|
// Could we use NumDimensions here?
|
||||||
typedef DSizes<Index, NumDims> Dimensions;
|
typedef DSizes<Index, NumDims> Dimensions;
|
||||||
|
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
|
EIGEN_DEVICE_FUNC TensorEvaluator(const XprType& op, const Device& device) :
|
||||||
Base(op, device) { }
|
Base(op, device) { }
|
||||||
|
|
||||||
|
@ -87,19 +87,14 @@ static void test_scalar()
|
|||||||
vec1.setRandom();
|
vec1.setRandom();
|
||||||
vec2.setRandom();
|
vec2.setRandom();
|
||||||
|
|
||||||
Tensor<float, 1, DataLayout> scalar(1);
|
|
||||||
scalar.setZero();
|
|
||||||
Eigen::array<DimPair, 1> dims = {{DimPair(0, 0)}};
|
Eigen::array<DimPair, 1> dims = {{DimPair(0, 0)}};
|
||||||
typedef TensorEvaluator<decltype(vec1.contract(vec2, dims)), DefaultDevice> Evaluator;
|
Tensor<float, 0, DataLayout> scalar = vec1.contract(vec2, dims);
|
||||||
Evaluator eval(vec1.contract(vec2, dims), DefaultDevice());
|
|
||||||
eval.evalTo(scalar.data());
|
|
||||||
EIGEN_STATIC_ASSERT(Evaluator::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
|
||||||
|
|
||||||
float expected = 0.0f;
|
float expected = 0.0f;
|
||||||
for (int i = 0; i < 6; ++i) {
|
for (int i = 0; i < 6; ++i) {
|
||||||
expected += vec1(i) * vec2(i);
|
expected += vec1(i) * vec2(i);
|
||||||
}
|
}
|
||||||
VERIFY_IS_APPROX(scalar(0), expected);
|
VERIFY_IS_APPROX(scalar(), expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<int DataLayout>
|
template<int DataLayout>
|
||||||
|
Loading…
x
Reference in New Issue
Block a user