diff --git a/Eigen/src/Core/Product.h b/Eigen/src/Core/Product.h index 44fde3dcf..8bd6af1b8 100644 --- a/Eigen/src/Core/Product.h +++ b/Eigen/src/Core/Product.h @@ -129,6 +129,19 @@ template struct ei_product_factor_traits struct ei_product_factor_traits > + : ei_product_factor_traits +{ + typedef typename NestedXpr::Scalar Scalar; + typedef ei_product_factor_traits Base; + typedef NestByValue XprType; + typedef typename Base::ActualXprType ActualXprType; + static inline const ActualXprType& extract(const XprType& x) { return Base::extract(static_cast(x)); } + static inline Scalar extractScalarFactor(const XprType& x) + { return Base::extractScalarFactor(static_cast(x)); } +}; + /* Helper class to determine the type of the product, can be either: * - NormalProduct * - CacheFriendlyProduct diff --git a/Eigen/src/Core/SolveTriangular.h b/Eigen/src/Core/SolveTriangular.h index f3567c96c..b28078fa1 100644 --- a/Eigen/src/Core/SolveTriangular.h +++ b/Eigen/src/Core/SolveTriangular.h @@ -43,123 +43,49 @@ struct ei_triangular_solver_selector0; IsLowerTriangular ? pi+=PanelWidth : pi-=PanelWidth) { int actualPanelWidth = std::min(IsLowerTriangular ? size - pi : pi, PanelWidth); - int startBlock = IsLowerTriangular ? pi : pi-actualPanelWidth; - int endBlock = IsLowerTriangular ? pi + actualPanelWidth : 0; - if (pi > 0) + int r = IsLowerTriangular ? pi : size - pi; // remaining size + if (r > 0) { - int r = IsLowerTriangular ? size - endBlock : startBlock; // remaining size - ei_cache_friendly_product_colmajor_times_vector( - r, - &(lhs.const_cast_derived().coeffRef(endBlock,startBlock)), lhs.stride(), - other.col(c).segment(startBlock, actualPanelWidth), - &(other.coeffRef(endBlock, c)), - Scalar(-1)); + int startRow = IsLowerTriangular ? pi : pi-actualPanelWidth; + int startCol = IsLowerTriangular ? 0 : pi; +// Block target(other,startRow,c,actualPanelWidth,1); + +// ei_cache_friendly_product_rowmajor_times_vector( +// &(lhs.const_cast_derived().coeffRef(startRow,startCol)), lhs.stride(), +// &(other.coeffRef(startCol, c)), r, +// target, Scalar(-1)); + other.col(c).segment(startRow,actualPanelWidth) -= + lhs.block(startRow,startCol,actualPanelWidth,r) + * other.col(c).segment(startCol,r); } for(int k=0; k0) + other.coeffRef(i,c) -= ((lhs.row(i).segment(s,k).transpose()) + .cwise()*(other.col(c).segment(s,k))).sum(); + if(!(Mode & UnitDiagBit)) other.coeffRef(i,c) /= lhs.coeff(i,i); - - int r = actualPanelWidth - k - 1; // remaining size - if (r>0) - { - other.col(c).segment((IsLowerTriangular ? i+1 : i-r), r) -= - other.coeffRef(i,c) - * Block(lhs, (IsLowerTriangular ? i+1 : i-r), i, r, 1); - } } + } } - #else - const bool IsLowerTriangular = (UpLo==LowerTriangular); - const int size = lhs.cols(); - /* We perform the inverse product per block of 4 rows such that we perfectly match - * our optimized matrix * vector product. blockyStart represents the number of rows - * we have process first using the non-block version. - */ - int blockyStart = (std::max(size-5,0)/4)*4; - if (IsLowerTriangular) - blockyStart = size - blockyStart; - else - blockyStart -= 1; - for(int c=0 ; cblockyStart; i += (IsLowerTriangular ? 1 : -1) ) - { - Scalar tmp = other.coeff(i,c) - - (IsLowerTriangular ? ((lhs.row(i).start(i)) * other.col(c).start(i)).coeff(0,0) - : ((lhs.row(i).end(size-i-1)) * other.col(c).end(size-i-1)).coeff(0,0)); - if (Mode & UnitDiagBit) - other.coeffRef(i,c) = tmp; - else - other.coeffRef(i,c) = tmp/lhs.coeff(i,i); - } - - // now let's process the remaining rows 4 at once - for(int i=blockyStart; IsLowerTriangular ? i0; ) - { - int startBlock = i; - int endBlock = startBlock + (IsLowerTriangular ? 4 : -4); - - /* Process the i cols times 4 rows block, and keep the result in a temporary vector */ - // FIXME use fixed size block but take care to small fixed size matrices... - Matrix btmp(4); - if (IsLowerTriangular) - btmp = lhs.block(startBlock,0,4,i) * other.col(c).start(i); - else - btmp = lhs.block(i-3,i+1,4,size-1-i) * other.col(c).end(size-1-i); - - /* Let's process the 4x4 sub-matrix as usual. - * btmp stores the diagonal coefficients used to update the remaining part of the result. - */ - { - Scalar tmp = other.coeff(startBlock,c)-btmp.coeff(IsLowerTriangular?0:3); - if (Mode & UnitDiagBit) - other.coeffRef(i,c) = tmp; - else - other.coeffRef(i,c) = tmp/lhs.coeff(i,i); - } - - i += IsLowerTriangular ? 1 : -1; - for (;IsLowerTriangular ? iendBlock; i += IsLowerTriangular ? 1 : -1) - { - int remainingSize = IsLowerTriangular ? i-startBlock : startBlock-i; - Scalar tmp = other.coeff(i,c) - - btmp.coeff(IsLowerTriangular ? remainingSize : 3-remainingSize) - - ( lhs.row(i).segment(IsLowerTriangular ? startBlock : i+1, remainingSize) - * other.col(c).segment(IsLowerTriangular ? startBlock : i+1, remainingSize)).coeff(0,0); - - if (Mode & UnitDiagBit) - other.coeffRef(i,c) = tmp; - else - other.coeffRef(i,c) = tmp/lhs.coeff(i,i); - } - } - } - #endif } }; @@ -168,15 +94,15 @@ struct ei_triangular_solver_selector -struct ei_triangular_solver_selector +template +struct ei_triangular_solver_selector { typedef typename Rhs::Scalar Scalar; typedef typename ei_packet_traits::type Packet; enum { PacketSize = ei_packet_traits::size }; static void run(const Lhs& lhs, Rhs& other) - { + {//std::cerr << "col maj " << ConjugateLhs << " , " << ConjugateRhs << "\n"; static const int PanelWidth = 4; // TODO make this a user definable constant static const bool IsLowerTriangular = (UpLo==LowerTriangular); const int size = lhs.cols(); @@ -207,12 +133,16 @@ struct ei_triangular_solver_selector int r = IsLowerTriangular ? size - endBlock : startBlock; // remaining size if (r > 0) { - ei_cache_friendly_product_colmajor_times_vector( - r, - &(lhs.const_cast_derived().coeffRef(endBlock,startBlock)), lhs.stride(), - other.col(c).segment(startBlock, actualPanelWidth), - &(other.coeffRef(endBlock, c)), - Scalar(-1)); +// ei_cache_friendly_product_colmajor_times_vector( +// r, +// &(lhs.const_cast_derived().coeffRef(endBlock,startBlock)), lhs.stride(), +// other.col(c).segment(startBlock, actualPanelWidth), +// &(other.coeffRef(endBlock, c)), +// Scalar(-1)); + + other.col(c).segment(endBlock,r) -= + lhs.block(endBlock,startBlock,r,actualPanelWidth) + * other.col(c).segment(startBlock,actualPanelWidth); } } } @@ -238,13 +168,21 @@ void TriangularView::solveInPlace(const MatrixBase& ei_assert(!(Mode & ZeroDiagBit)); ei_assert(Mode & (UpperTriangularBit|LowerTriangularBit)); - enum { copy = ei_traits::Flags & RowMajorBit }; +// typedef ei_product_factor_traits LhsProductTraits; +// typedef ei_product_factor_traits RhsProductTraits; +// typedef typename LhsProductTraits::ActualXprType ActualLhsType; +// typedef typename RhsProductTraits::ActualXprType ActualRhsType; +// const ActualLhsType& actualLhs = LhsProductTraits::extract(_expression()); +// ActualRhsType& actualRhs = const_cast(RhsProductTraits::extract(rhs)); + enum { copy = ei_traits::Flags & RowMajorBit }; +// std::cerr << typeid(MatrixType).name() << "\n"; typedef typename ei_meta_if::type, RhsDerived&>::ret RhsCopy; RhsCopy rhsCopy(rhs); - ei_triangular_solver_selector::type, Mode>::run(_expression(), rhsCopy); + ei_triangular_solver_selector::type, + Mode/*, LhsProductTraits::NeedToConjugate,RhsProductTraits::NeedToConjugate*/>::run(_expression(), rhsCopy); if (copy) rhs = rhsCopy; diff --git a/Eigen/src/Core/products/GeneralMatrixVector.h b/Eigen/src/Core/products/GeneralMatrixVector.h index 61b2cc67c..ccaafb8bd 100644 --- a/Eigen/src/Core/products/GeneralMatrixVector.h +++ b/Eigen/src/Core/products/GeneralMatrixVector.h @@ -307,8 +307,11 @@ static EIGEN_DONT_INLINE void ei_cache_friendly_product_rowmajor_times_vector( skipRows = std::min(skipRows,res.size()); // note that the skiped columns are processed later. } - ei_internal_assert((alignmentPattern==NoneAligned) || PacketSize==1 - || (size_t(lhs+alignedStart+lhsStride*skipRows)%sizeof(Packet))==0); + ei_internal_assert( alignmentPattern==NoneAligned + || PacketSize==1 + || (skipRows + rowsAtOnce >= res.size()) + || PacketSize > rhsSize + || (size_t(lhs+alignedStart+lhsStride*skipRows)%sizeof(Packet))==0); } int offset1 = (FirstAligned && alignmentStep==1?3:1);