mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-04-24 02:29:33 +08:00
Fix bug in tensor contraction. The code assumes that contraction axis indices for the LHS (after possibly swapping to ColMajor!) is increasing. Explicitly sort the contraction axis pairs to make it so.
This commit is contained in:
parent
46aa9772fc
commit
f7329619da
@ -193,6 +193,19 @@ struct TensorContractionEvaluatorBase
|
||||
}
|
||||
}
|
||||
|
||||
// Check for duplicate axes and make sure the first index in eval_op_indices
|
||||
// is increasing. Using O(n^2) sorting is OK since ContractDims is small
|
||||
for (int i = 0; i < ContractDims; i++) {
|
||||
for (int j = i + 1; j < ContractDims; j++) {
|
||||
eigen_assert(eval_op_indices[j].first != eval_op_indices[i].first &&
|
||||
eval_op_indices[j].second != eval_op_indices[i].second &&
|
||||
"contraction axes should be unique");
|
||||
if (eval_op_indices[j].first < eval_op_indices[i].first) {
|
||||
numext::swap(eval_op_indices[j], eval_op_indices[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
array<Index, LDims> lhs_strides;
|
||||
lhs_strides[0] = 1;
|
||||
for (int i = 0; i < LDims-1; ++i) {
|
||||
|
@ -138,6 +138,26 @@ static void test_multidims()
|
||||
mat1(1,0,1)*mat2(1,0,0,1) + mat1(1,1,1)*mat2(1,0,1,1));
|
||||
VERIFY_IS_APPROX(mat3(1,1,1), mat1(1,0,0)*mat2(1,1,0,0) + mat1(1,1,0)*mat2(1,1,1,0) +
|
||||
mat1(1,0,1)*mat2(1,1,0,1) + mat1(1,1,1)*mat2(1,1,1,1));
|
||||
|
||||
Tensor<float, 2, DataLayout> mat4(2, 2);
|
||||
Tensor<float, 3, DataLayout> mat5(2, 2, 2);
|
||||
|
||||
mat4.setRandom();
|
||||
mat5.setRandom();
|
||||
|
||||
Tensor<float, 1, DataLayout> mat6(2);
|
||||
mat6.setZero();
|
||||
Eigen::array<DimPair, 2> dims2({{DimPair(0, 1), DimPair(1, 0)}});
|
||||
typedef TensorEvaluator<decltype(mat4.contract(mat5, dims2)), DefaultDevice> Evaluator2;
|
||||
Evaluator2 eval2(mat4.contract(mat5, dims2), DefaultDevice());
|
||||
eval2.evalTo(mat6.data());
|
||||
EIGEN_STATIC_ASSERT(Evaluator2::NumDims==1ul, YOU_MADE_A_PROGRAMMING_MISTAKE);
|
||||
VERIFY_IS_EQUAL(eval2.dimensions()[0], 2);
|
||||
|
||||
VERIFY_IS_APPROX(mat6(0), mat4(0,0)*mat5(0,0,0) + mat4(1,0)*mat5(0,1,0) +
|
||||
mat4(0,1)*mat5(1,0,0) + mat4(1,1)*mat5(1,1,0));
|
||||
VERIFY_IS_APPROX(mat6(1), mat4(0,0)*mat5(0,0,1) + mat4(1,0)*mat5(0,1,1) +
|
||||
mat4(0,1)*mat5(1,0,1) + mat4(1,1)*mat5(1,1,1));
|
||||
}
|
||||
|
||||
template<int DataLayout>
|
||||
|
Loading…
x
Reference in New Issue
Block a user