mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
bug #1517: fix triangular product with unit diagonal and nested scaling factor: (s*A).triangularView<UpperUnit>()*B
This commit is contained in:
parent
12efc7d41b
commit
5deeb19e7b
@ -400,7 +400,9 @@ struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
||||
{
|
||||
template<typename Dest> static void run(Dest& dst, const Lhs &a_lhs, const Rhs &a_rhs, const typename Dest::Scalar& alpha)
|
||||
{
|
||||
typedef typename Dest::Scalar Scalar;
|
||||
typedef typename Lhs::Scalar LhsScalar;
|
||||
typedef typename Rhs::Scalar RhsScalar;
|
||||
typedef typename Dest::Scalar Scalar;
|
||||
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhsType;
|
||||
@ -412,8 +414,9 @@ struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
||||
typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(a_lhs);
|
||||
typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(a_rhs);
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(a_lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(a_rhs);
|
||||
LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(a_lhs);
|
||||
RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(a_rhs);
|
||||
Scalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
|
||||
|
||||
typedef internal::gemm_blocking_space<(Dest::Flags&RowMajorBit) ? RowMajor : ColMajor,Scalar,Scalar,
|
||||
Lhs::MaxRowsAtCompileTime, Rhs::MaxColsAtCompileTime, Lhs::MaxColsAtCompileTime,4> BlockingType;
|
||||
@ -438,6 +441,21 @@ struct triangular_product_impl<Mode,LhsIsTriangular,Lhs,false,Rhs,false>
|
||||
&dst.coeffRef(0,0), dst.outerStride(), // result info
|
||||
actualAlpha, blocking
|
||||
);
|
||||
|
||||
// Apply correction if the diagonal is unit and a scalar factor was nested:
|
||||
if ((Mode&UnitDiag)==UnitDiag)
|
||||
{
|
||||
if (LhsIsTriangular && lhs_alpha!=LhsScalar(1))
|
||||
{
|
||||
Index diagSize = (std::min)(lhs.rows(),lhs.cols());
|
||||
dst.topRows(diagSize) -= ((lhs_alpha-LhsScalar(1))*a_rhs).topRows(diagSize);
|
||||
}
|
||||
else if ((!LhsIsTriangular) && rhs_alpha!=RhsScalar(1))
|
||||
{
|
||||
Index diagSize = (std::min)(rhs.rows(),rhs.cols());
|
||||
dst.leftCols(diagSize) -= (rhs_alpha-RhsScalar(1))*a_lhs.leftCols(diagSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -221,8 +221,9 @@ template<int Mode> struct trmv_selector<Mode,ColMajor>
|
||||
typename internal::add_const_on_value_type<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
|
||||
typename internal::add_const_on_value_type<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
|
||||
|
||||
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(rhs);
|
||||
LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
|
||||
RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
|
||||
ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
|
||||
|
||||
enum {
|
||||
// FIXME find a way to allow an inner stride on the result if packet_traits<Scalar>::size==1
|
||||
@ -274,6 +275,12 @@ template<int Mode> struct trmv_selector<Mode,ColMajor>
|
||||
else
|
||||
dest = MappedDest(actualDestPtr, dest.size());
|
||||
}
|
||||
|
||||
if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
|
||||
{
|
||||
Index diagSize = (std::min)(lhs.rows(),lhs.cols());
|
||||
dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@ -295,8 +302,9 @@ template<int Mode> struct trmv_selector<Mode,RowMajor>
|
||||
typename add_const<ActualLhsType>::type actualLhs = LhsBlasTraits::extract(lhs);
|
||||
typename add_const<ActualRhsType>::type actualRhs = RhsBlasTraits::extract(rhs);
|
||||
|
||||
ResScalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(lhs)
|
||||
* RhsBlasTraits::extractScalarFactor(rhs);
|
||||
LhsScalar lhs_alpha = LhsBlasTraits::extractScalarFactor(lhs);
|
||||
RhsScalar rhs_alpha = RhsBlasTraits::extractScalarFactor(rhs);
|
||||
ResScalar actualAlpha = alpha * lhs_alpha * rhs_alpha;
|
||||
|
||||
enum {
|
||||
DirectlyUseRhs = ActualRhsTypeCleaned::InnerStrideAtCompileTime==1
|
||||
@ -326,6 +334,12 @@ template<int Mode> struct trmv_selector<Mode,RowMajor>
|
||||
actualRhsPtr,1,
|
||||
dest.data(),dest.innerStride(),
|
||||
actualAlpha);
|
||||
|
||||
if ( ((Mode&UnitDiag)==UnitDiag) && (lhs_alpha!=LhsScalar(1)) )
|
||||
{
|
||||
Index diagSize = (std::min)(lhs.rows(),lhs.cols());
|
||||
dest.head(diagSize) -= (lhs_alpha-LhsScalar(1))*rhs.head(diagSize);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -29,7 +29,7 @@ void trmm(int rows=get_random_size<Scalar>(),
|
||||
typedef Matrix<Scalar,Dynamic,OtherCols,OtherCols==1?ColMajor:ResOrder> ResXS;
|
||||
typedef Matrix<Scalar,OtherCols,Dynamic,OtherCols==1?RowMajor:ResOrder> ResSX;
|
||||
|
||||
TriMatrix mat(rows,cols), tri(rows,cols), triTr(cols,rows);
|
||||
TriMatrix mat(rows,cols), tri(rows,cols), triTr(cols,rows), s1tri(rows,cols), s1triTr(cols,rows);
|
||||
|
||||
OnTheRight ge_right(cols,otherCols);
|
||||
OnTheLeft ge_left(otherCols,rows);
|
||||
@ -42,6 +42,8 @@ void trmm(int rows=get_random_size<Scalar>(),
|
||||
mat.setRandom();
|
||||
tri = mat.template triangularView<Mode>();
|
||||
triTr = mat.transpose().template triangularView<Mode>();
|
||||
s1tri = (s1*mat).template triangularView<Mode>();
|
||||
s1triTr = (s1*mat).transpose().template triangularView<Mode>();
|
||||
ge_right.setRandom();
|
||||
ge_left.setRandom();
|
||||
|
||||
@ -51,19 +53,29 @@ void trmm(int rows=get_random_size<Scalar>(),
|
||||
VERIFY_IS_APPROX( ge_xs.noalias() = mat.template triangularView<Mode>() * ge_right, tri * ge_right);
|
||||
VERIFY_IS_APPROX( ge_sx.noalias() = ge_left * mat.template triangularView<Mode>(), ge_left * tri);
|
||||
|
||||
VERIFY_IS_APPROX( ge_xs.noalias() = (s1*mat.adjoint()).template triangularView<Mode>() * (s2*ge_left.transpose()), s1*triTr.conjugate() * (s2*ge_left.transpose()));
|
||||
VERIFY_IS_APPROX( ge_sx.noalias() = ge_right.transpose() * mat.adjoint().template triangularView<Mode>(), ge_right.transpose() * triTr.conjugate());
|
||||
if((Mode&UnitDiag)==0)
|
||||
VERIFY_IS_APPROX( ge_xs.noalias() = (s1*mat.adjoint()).template triangularView<Mode>() * (s2*ge_left.transpose()), s1*triTr.conjugate() * (s2*ge_left.transpose()));
|
||||
|
||||
VERIFY_IS_APPROX( ge_xs.noalias() = (s1*mat.adjoint()).template triangularView<Mode>() * (s2*ge_left.adjoint()), s1*triTr.conjugate() * (s2*ge_left.adjoint()));
|
||||
VERIFY_IS_APPROX( ge_sx.noalias() = ge_right.adjoint() * mat.adjoint().template triangularView<Mode>(), ge_right.adjoint() * triTr.conjugate());
|
||||
VERIFY_IS_APPROX( ge_xs.noalias() = (s1*mat.transpose()).template triangularView<Mode>() * (s2*ge_left.transpose()), s1triTr * (s2*ge_left.transpose()));
|
||||
VERIFY_IS_APPROX( ge_sx.noalias() = (s2*ge_left) * (s1*mat).template triangularView<Mode>(), (s2*ge_left)*s1tri);
|
||||
|
||||
VERIFY_IS_APPROX( ge_sx.noalias() = ge_right.transpose() * mat.adjoint().template triangularView<Mode>(), ge_right.transpose() * triTr.conjugate());
|
||||
VERIFY_IS_APPROX( ge_sx.noalias() = ge_right.adjoint() * mat.adjoint().template triangularView<Mode>(), ge_right.adjoint() * triTr.conjugate());
|
||||
|
||||
ge_xs_save = ge_xs;
|
||||
VERIFY_IS_APPROX( (ge_xs_save + s1*triTr.conjugate() * (s2*ge_left.adjoint())).eval(), ge_xs.noalias() += (s1*mat.adjoint()).template triangularView<Mode>() * (s2*ge_left.adjoint()) );
|
||||
if((Mode&UnitDiag)==0)
|
||||
VERIFY_IS_APPROX( (ge_xs_save + s1*triTr.conjugate() * (s2*ge_left.adjoint())).eval(), ge_xs.noalias() += (s1*mat.adjoint()).template triangularView<Mode>() * (s2*ge_left.adjoint()) );
|
||||
ge_xs_save = ge_xs;
|
||||
VERIFY_IS_APPROX( (ge_xs_save + s1triTr * (s2*ge_left.adjoint())).eval(), ge_xs.noalias() += (s1*mat.transpose()).template triangularView<Mode>() * (s2*ge_left.adjoint()) );
|
||||
ge_sx.setRandom();
|
||||
ge_sx_save = ge_sx;
|
||||
VERIFY_IS_APPROX( ge_sx_save - (ge_right.adjoint() * (-s1 * triTr).conjugate()).eval(), ge_sx.noalias() -= (ge_right.adjoint() * (-s1 * mat).adjoint().template triangularView<Mode>()).eval());
|
||||
if((Mode&UnitDiag)==0)
|
||||
VERIFY_IS_APPROX( ge_sx_save - (ge_right.adjoint() * (-s1 * triTr).conjugate()).eval(), ge_sx.noalias() -= (ge_right.adjoint() * (-s1 * mat).adjoint().template triangularView<Mode>()).eval());
|
||||
|
||||
VERIFY_IS_APPROX( ge_xs = (s1*mat).adjoint().template triangularView<Mode>() * ge_left.adjoint(), numext::conj(s1) * triTr.conjugate() * ge_left.adjoint());
|
||||
if((Mode&UnitDiag)==0)
|
||||
VERIFY_IS_APPROX( ge_xs = (s1*mat).adjoint().template triangularView<Mode>() * ge_left.adjoint(), numext::conj(s1) * triTr.conjugate() * ge_left.adjoint());
|
||||
VERIFY_IS_APPROX( ge_xs = (s1*mat).transpose().template triangularView<Mode>() * ge_left.adjoint(), s1triTr * ge_left.adjoint());
|
||||
|
||||
|
||||
// TODO check with sub-matrix expressions ?
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user