mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 11:49:02 +08:00
Improved the contraction mapper to properly support tensor products
This commit is contained in:
parent
0bc020be9d
commit
3a2dd352ae
@ -130,19 +130,19 @@ class SimpleTensorContractionMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Index contract_val = left ? col : row;
|
Index contract_val = left ? col : row;
|
||||||
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
|
|
||||||
const Index idx = contract_val / m_k_strides[i];
|
|
||||||
linidx += idx * m_contract_strides[i];
|
|
||||||
contract_val -= idx * m_k_strides[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
if(array_size<contract_t>::value > 0) {
|
if(array_size<contract_t>::value > 0) {
|
||||||
if (side == Rhs && inner_dim_contiguous) {
|
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
|
||||||
eigen_assert(m_contract_strides[0] == 1);
|
const Index idx = contract_val / m_k_strides[i];
|
||||||
linidx += contract_val;
|
linidx += idx * m_contract_strides[i];
|
||||||
} else {
|
contract_val -= idx * m_k_strides[i];
|
||||||
linidx += contract_val * m_contract_strides[0];
|
}
|
||||||
}
|
|
||||||
|
if (side == Rhs && inner_dim_contiguous) {
|
||||||
|
eigen_assert(m_contract_strides[0] == 1);
|
||||||
|
linidx += contract_val;
|
||||||
|
} else {
|
||||||
|
linidx += contract_val * m_contract_strides[0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return linidx;
|
return linidx;
|
||||||
@ -153,15 +153,15 @@ class SimpleTensorContractionMapper {
|
|||||||
const bool left = (side == Lhs);
|
const bool left = (side == Lhs);
|
||||||
Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
|
Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
|
||||||
Index linidx[2] = {0, 0};
|
Index linidx[2] = {0, 0};
|
||||||
for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
|
|
||||||
const Index idx0 = nocontract_val[0] / m_ij_strides[i];
|
|
||||||
const Index idx1 = nocontract_val[1] / m_ij_strides[i];
|
|
||||||
linidx[0] += idx0 * m_nocontract_strides[i];
|
|
||||||
linidx[1] += idx1 * m_nocontract_strides[i];
|
|
||||||
nocontract_val[0] -= idx0 * m_ij_strides[i];
|
|
||||||
nocontract_val[1] -= idx1 * m_ij_strides[i];
|
|
||||||
}
|
|
||||||
if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
|
if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
|
||||||
|
for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
|
||||||
|
const Index idx0 = nocontract_val[0] / m_ij_strides[i];
|
||||||
|
const Index idx1 = nocontract_val[1] / m_ij_strides[i];
|
||||||
|
linidx[0] += idx0 * m_nocontract_strides[i];
|
||||||
|
linidx[1] += idx1 * m_nocontract_strides[i];
|
||||||
|
nocontract_val[0] -= idx0 * m_ij_strides[i];
|
||||||
|
nocontract_val[1] -= idx1 * m_ij_strides[i];
|
||||||
|
}
|
||||||
if (side == Lhs && inner_dim_contiguous) {
|
if (side == Lhs && inner_dim_contiguous) {
|
||||||
eigen_assert(m_nocontract_strides[0] == 1);
|
eigen_assert(m_nocontract_strides[0] == 1);
|
||||||
linidx[0] += nocontract_val[0];
|
linidx[0] += nocontract_val[0];
|
||||||
@ -173,22 +173,24 @@ class SimpleTensorContractionMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Index contract_val[2] = {left ? col : row, left ? col : row + distance};
|
Index contract_val[2] = {left ? col : row, left ? col : row + distance};
|
||||||
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
|
if (array_size<contract_t>::value> 0) {
|
||||||
const Index idx0 = contract_val[0] / m_k_strides[i];
|
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
|
||||||
const Index idx1 = contract_val[1] / m_k_strides[i];
|
const Index idx0 = contract_val[0] / m_k_strides[i];
|
||||||
linidx[0] += idx0 * m_contract_strides[i];
|
const Index idx1 = contract_val[1] / m_k_strides[i];
|
||||||
linidx[1] += idx1 * m_contract_strides[i];
|
linidx[0] += idx0 * m_contract_strides[i];
|
||||||
contract_val[0] -= idx0 * m_k_strides[i];
|
linidx[1] += idx1 * m_contract_strides[i];
|
||||||
contract_val[1] -= idx1 * m_k_strides[i];
|
contract_val[0] -= idx0 * m_k_strides[i];
|
||||||
}
|
contract_val[1] -= idx1 * m_k_strides[i];
|
||||||
|
}
|
||||||
|
|
||||||
if (side == Rhs && inner_dim_contiguous) {
|
if (side == Rhs && inner_dim_contiguous) {
|
||||||
eigen_assert(m_contract_strides[0] == 1);
|
eigen_assert(m_contract_strides[0] == 1);
|
||||||
linidx[0] += contract_val[0];
|
linidx[0] += contract_val[0];
|
||||||
linidx[1] += contract_val[1];
|
linidx[1] += contract_val[1];
|
||||||
} else {
|
} else {
|
||||||
linidx[0] += contract_val[0] * m_contract_strides[0];
|
linidx[0] += contract_val[0] * m_contract_strides[0];
|
||||||
linidx[1] += contract_val[1] * m_contract_strides[0];
|
linidx[1] += contract_val[1] * m_contract_strides[0];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return IndexPair<Index>(linidx[0], linidx[1]);
|
return IndexPair<Index>(linidx[0], linidx[1]);
|
||||||
}
|
}
|
||||||
@ -200,7 +202,7 @@ class SimpleTensorContractionMapper {
|
|||||||
return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
|
return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
|
||||||
}
|
}
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
|
||||||
return ((side == Lhs) && inner_dim_contiguous) ? m_contract_strides[0] : 1;
|
return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user