mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-12 03:39:01 +08:00
Improved the contraction mapper to properly support tensor products
This commit is contained in:
parent
0bc020be9d
commit
3a2dd352ae
@ -130,19 +130,19 @@ class SimpleTensorContractionMapper {
|
||||
}
|
||||
|
||||
Index contract_val = left ? col : row;
|
||||
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
|
||||
const Index idx = contract_val / m_k_strides[i];
|
||||
linidx += idx * m_contract_strides[i];
|
||||
contract_val -= idx * m_k_strides[i];
|
||||
}
|
||||
|
||||
if(array_size<contract_t>::value > 0) {
|
||||
if (side == Rhs && inner_dim_contiguous) {
|
||||
eigen_assert(m_contract_strides[0] == 1);
|
||||
linidx += contract_val;
|
||||
} else {
|
||||
linidx += contract_val * m_contract_strides[0];
|
||||
}
|
||||
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
|
||||
const Index idx = contract_val / m_k_strides[i];
|
||||
linidx += idx * m_contract_strides[i];
|
||||
contract_val -= idx * m_k_strides[i];
|
||||
}
|
||||
|
||||
if (side == Rhs && inner_dim_contiguous) {
|
||||
eigen_assert(m_contract_strides[0] == 1);
|
||||
linidx += contract_val;
|
||||
} else {
|
||||
linidx += contract_val * m_contract_strides[0];
|
||||
}
|
||||
}
|
||||
|
||||
return linidx;
|
||||
@ -153,15 +153,15 @@ class SimpleTensorContractionMapper {
|
||||
const bool left = (side == Lhs);
|
||||
Index nocontract_val[2] = {left ? row : col, left ? row + distance : col};
|
||||
Index linidx[2] = {0, 0};
|
||||
for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
|
||||
const Index idx0 = nocontract_val[0] / m_ij_strides[i];
|
||||
const Index idx1 = nocontract_val[1] / m_ij_strides[i];
|
||||
linidx[0] += idx0 * m_nocontract_strides[i];
|
||||
linidx[1] += idx1 * m_nocontract_strides[i];
|
||||
nocontract_val[0] -= idx0 * m_ij_strides[i];
|
||||
nocontract_val[1] -= idx1 * m_ij_strides[i];
|
||||
}
|
||||
if (array_size<typename Tensor::Dimensions>::value > array_size<contract_t>::value) {
|
||||
for (int i = static_cast<int>(array_size<nocontract_t>::value) - 1; i > 0; i--) {
|
||||
const Index idx0 = nocontract_val[0] / m_ij_strides[i];
|
||||
const Index idx1 = nocontract_val[1] / m_ij_strides[i];
|
||||
linidx[0] += idx0 * m_nocontract_strides[i];
|
||||
linidx[1] += idx1 * m_nocontract_strides[i];
|
||||
nocontract_val[0] -= idx0 * m_ij_strides[i];
|
||||
nocontract_val[1] -= idx1 * m_ij_strides[i];
|
||||
}
|
||||
if (side == Lhs && inner_dim_contiguous) {
|
||||
eigen_assert(m_nocontract_strides[0] == 1);
|
||||
linidx[0] += nocontract_val[0];
|
||||
@ -173,22 +173,24 @@ class SimpleTensorContractionMapper {
|
||||
}
|
||||
|
||||
Index contract_val[2] = {left ? col : row, left ? col : row + distance};
|
||||
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
|
||||
const Index idx0 = contract_val[0] / m_k_strides[i];
|
||||
const Index idx1 = contract_val[1] / m_k_strides[i];
|
||||
linidx[0] += idx0 * m_contract_strides[i];
|
||||
linidx[1] += idx1 * m_contract_strides[i];
|
||||
contract_val[0] -= idx0 * m_k_strides[i];
|
||||
contract_val[1] -= idx1 * m_k_strides[i];
|
||||
}
|
||||
if (array_size<contract_t>::value> 0) {
|
||||
for (int i = static_cast<int>(array_size<contract_t>::value) - 1; i > 0; i--) {
|
||||
const Index idx0 = contract_val[0] / m_k_strides[i];
|
||||
const Index idx1 = contract_val[1] / m_k_strides[i];
|
||||
linidx[0] += idx0 * m_contract_strides[i];
|
||||
linidx[1] += idx1 * m_contract_strides[i];
|
||||
contract_val[0] -= idx0 * m_k_strides[i];
|
||||
contract_val[1] -= idx1 * m_k_strides[i];
|
||||
}
|
||||
|
||||
if (side == Rhs && inner_dim_contiguous) {
|
||||
eigen_assert(m_contract_strides[0] == 1);
|
||||
linidx[0] += contract_val[0];
|
||||
linidx[1] += contract_val[1];
|
||||
} else {
|
||||
linidx[0] += contract_val[0] * m_contract_strides[0];
|
||||
linidx[1] += contract_val[1] * m_contract_strides[0];
|
||||
if (side == Rhs && inner_dim_contiguous) {
|
||||
eigen_assert(m_contract_strides[0] == 1);
|
||||
linidx[0] += contract_val[0];
|
||||
linidx[1] += contract_val[1];
|
||||
} else {
|
||||
linidx[0] += contract_val[0] * m_contract_strides[0];
|
||||
linidx[1] += contract_val[1] * m_contract_strides[0];
|
||||
}
|
||||
}
|
||||
return IndexPair<Index>(linidx[0], linidx[1]);
|
||||
}
|
||||
@ -200,7 +202,7 @@ class SimpleTensorContractionMapper {
|
||||
return (Alignment == Aligned) && (side == Lhs) && inner_dim_contiguous ? 0 : size;
|
||||
}
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index stride() const {
|
||||
return ((side == Lhs) && inner_dim_contiguous) ? m_contract_strides[0] : 1;
|
||||
return ((side == Lhs) && inner_dim_contiguous && array_size<contract_t>::value > 0) ? m_contract_strides[0] : 1;
|
||||
}
|
||||
|
||||
protected:
|
||||
|
Loading…
x
Reference in New Issue
Block a user