mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 02:33:59 +08:00
Fix bug #496: generalize internal rank1_update implementation to accept uplo(A) += v * w and make A.triangularView() += v * w uses it.
Update unit tests and blas interface respectively.
This commit is contained in:
parent
08388cc712
commit
04367447ac
@ -12,6 +12,9 @@
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs>
|
||||
struct selfadjoint_rank1_update;
|
||||
|
||||
namespace internal {
|
||||
|
||||
/**********************************************************************
|
||||
@ -180,31 +183,93 @@ struct tribb_kernel
|
||||
|
||||
// high level API
|
||||
|
||||
template<typename MatrixType, typename ProductType, int UpLo, bool IsOuterProduct>
|
||||
struct general_product_to_triangular_selector;
|
||||
|
||||
|
||||
template<typename MatrixType, typename ProductType, int UpLo>
|
||||
struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true>
|
||||
{
|
||||
static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha)
|
||||
{
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef typename MatrixType::Index Index;
|
||||
|
||||
typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs;
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
|
||||
typedef typename internal::remove_all<ActualLhs>::type _ActualLhs;
|
||||
typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
|
||||
|
||||
typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs;
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
|
||||
typedef typename internal::remove_all<ActualRhs>::type _ActualRhs;
|
||||
typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
|
||||
|
||||
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
|
||||
|
||||
enum {
|
||||
StorageOrder = (internal::traits<MatrixType>::Flags&RowMajorBit) ? RowMajor : ColMajor,
|
||||
UseLhsDirectly = _ActualLhs::InnerStrideAtCompileTime==1,
|
||||
UseRhsDirectly = _ActualRhs::InnerStrideAtCompileTime==1
|
||||
};
|
||||
|
||||
internal::gemv_static_vector_if<Scalar,Lhs::SizeAtCompileTime,Lhs::MaxSizeAtCompileTime,!UseLhsDirectly> static_lhs;
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, actualLhsPtr, actualLhs.size(),
|
||||
(UseLhsDirectly ? const_cast<Scalar*>(actualLhs.data()) : static_lhs.data()));
|
||||
if(!UseLhsDirectly) Map<typename _ActualLhs::PlainObject>(actualLhsPtr, actualLhs.size()) = actualLhs;
|
||||
|
||||
internal::gemv_static_vector_if<Scalar,Rhs::SizeAtCompileTime,Rhs::MaxSizeAtCompileTime,!UseRhsDirectly> static_rhs;
|
||||
ei_declare_aligned_stack_constructed_variable(Scalar, actualRhsPtr, actualRhs.size(),
|
||||
(UseRhsDirectly ? const_cast<Scalar*>(actualRhs.data()) : static_rhs.data()));
|
||||
if(!UseRhsDirectly) Map<typename _ActualRhs::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
|
||||
|
||||
|
||||
selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo,
|
||||
LhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
|
||||
RhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex>
|
||||
::run(actualLhs.size(), mat.data(), mat.outerStride(), actualLhsPtr, actualRhsPtr, actualAlpha);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename MatrixType, typename ProductType, int UpLo>
|
||||
struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
|
||||
{
|
||||
static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha)
|
||||
{
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef typename MatrixType::Index Index;
|
||||
|
||||
typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs;
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
|
||||
typedef typename internal::remove_all<ActualLhs>::type _ActualLhs;
|
||||
typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
|
||||
|
||||
typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs;
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
|
||||
typedef typename internal::remove_all<ActualRhs>::type _ActualRhs;
|
||||
typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
|
||||
|
||||
typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
|
||||
|
||||
internal::general_matrix_matrix_triangular_product<Index,
|
||||
typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
|
||||
typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
|
||||
MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo>
|
||||
::run(mat.cols(), actualLhs.cols(),
|
||||
&actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(),
|
||||
mat.data(), mat.outerStride(), actualAlpha);
|
||||
}
|
||||
};
|
||||
|
||||
template<typename MatrixType, unsigned int UpLo>
|
||||
template<typename ProductDerived, typename _Lhs, typename _Rhs>
|
||||
TriangularView<MatrixType,UpLo>& TriangularView<MatrixType,UpLo>::assignProduct(const ProductBase<ProductDerived, _Lhs,_Rhs>& prod, const Scalar& alpha)
|
||||
{
|
||||
typedef typename internal::remove_all<typename ProductDerived::LhsNested>::type Lhs;
|
||||
typedef internal::blas_traits<Lhs> LhsBlasTraits;
|
||||
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
|
||||
typedef typename internal::remove_all<ActualLhs>::type _ActualLhs;
|
||||
typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
|
||||
|
||||
typedef typename internal::remove_all<typename ProductDerived::RhsNested>::type Rhs;
|
||||
typedef internal::blas_traits<Rhs> RhsBlasTraits;
|
||||
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
|
||||
typedef typename internal::remove_all<ActualRhs>::type _ActualRhs;
|
||||
typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
|
||||
|
||||
typename ProductDerived::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
|
||||
|
||||
internal::general_matrix_matrix_triangular_product<Index,
|
||||
typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
|
||||
typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
|
||||
MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo>
|
||||
::run(m_matrix.cols(), actualLhs.cols(),
|
||||
&actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(),
|
||||
const_cast<Scalar*>(m_matrix.data()), m_matrix.outerStride(), actualAlpha);
|
||||
general_product_to_triangular_selector<MatrixType, ProductDerived, UpLo, (_Lhs::ColsAtCompileTime==1) || (_Rhs::RowsAtCompileTime==1)>::run(m_matrix.const_cast_derived(), prod.derived(), alpha);
|
||||
|
||||
return *this;
|
||||
}
|
||||
|
@ -18,21 +18,19 @@
|
||||
|
||||
namespace Eigen {
|
||||
|
||||
template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs>
|
||||
struct selfadjoint_rank1_update;
|
||||
|
||||
template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
|
||||
struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs>
|
||||
{
|
||||
static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha)
|
||||
static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, Scalar alpha)
|
||||
{
|
||||
internal::conj_if<ConjRhs> cj;
|
||||
typedef Map<const Matrix<Scalar,Dynamic,1> > OtherMap;
|
||||
typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjRhsType;
|
||||
typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjLhsType;
|
||||
for (Index i=0; i<size; ++i)
|
||||
{
|
||||
Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1)))
|
||||
+= (alpha * cj(vec[i])) * ConjRhsType(OtherMap(vec+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1)));
|
||||
+= (alpha * cj(vecY[i])) * ConjLhsType(OtherMap(vecX+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1)));
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -40,9 +38,9 @@ struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs>
|
||||
template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
|
||||
struct selfadjoint_rank1_update<Scalar,Index,RowMajor,UpLo,ConjLhs,ConjRhs>
|
||||
{
|
||||
static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha)
|
||||
static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, Scalar alpha)
|
||||
{
|
||||
selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vec,alpha);
|
||||
selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vecY,vecX,alpha);
|
||||
}
|
||||
};
|
||||
|
||||
@ -78,7 +76,7 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true>
|
||||
selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo,
|
||||
OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
|
||||
(!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex>
|
||||
::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualAlpha);
|
||||
::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualOtherPtr, actualAlpha);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -216,7 +216,7 @@ int EIGEN_BLAS_FUNC(hpr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px
|
||||
*/
|
||||
int EIGEN_BLAS_FUNC(her)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *pa, int *lda)
|
||||
{
|
||||
typedef void (*functype)(int, Scalar*, int, const Scalar*, Scalar);
|
||||
typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar);
|
||||
static functype func[2];
|
||||
|
||||
static bool init = false;
|
||||
@ -252,7 +252,7 @@ int EIGEN_BLAS_FUNC(her)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
|
||||
if(code>=2 || func[code]==0)
|
||||
return 0;
|
||||
|
||||
func[code](*n, a, *lda, x_cpy, alpha);
|
||||
func[code](*n, a, *lda, x_cpy, x_cpy, alpha);
|
||||
|
||||
matrix(a,*n,*n,*lda).diagonal().imag().setZero();
|
||||
|
||||
|
@ -85,7 +85,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
|
||||
|
||||
// init = true;
|
||||
// }
|
||||
typedef void (*functype)(int, Scalar*, int, const Scalar*, Scalar);
|
||||
typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar);
|
||||
static functype func[2];
|
||||
|
||||
static bool init = false;
|
||||
@ -121,7 +121,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
|
||||
if(code>=2 || func[code]==0)
|
||||
return 0;
|
||||
|
||||
func[code](*n, c, *ldc, x_cpy, alpha);
|
||||
func[code](*n, c, *ldc, x_cpy, x_cpy, alpha);
|
||||
|
||||
if(x_cpy!=x) delete[] x_cpy;
|
||||
|
||||
|
@ -14,6 +14,7 @@ template<typename MatrixType> void syrk(const MatrixType& m)
|
||||
typedef typename MatrixType::Index Index;
|
||||
typedef typename MatrixType::Scalar Scalar;
|
||||
typedef typename NumTraits<Scalar>::Real RealScalar;
|
||||
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime, RowMajor> RMatrixType;
|
||||
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, Dynamic> Rhs1;
|
||||
typedef Matrix<Scalar, Dynamic, MatrixType::RowsAtCompileTime> Rhs2;
|
||||
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, Dynamic,RowMajor> Rhs3;
|
||||
@ -22,10 +23,12 @@ template<typename MatrixType> void syrk(const MatrixType& m)
|
||||
Index cols = m.cols();
|
||||
|
||||
MatrixType m1 = MatrixType::Random(rows, cols),
|
||||
m2 = MatrixType::Random(rows, cols);
|
||||
m2 = MatrixType::Random(rows, cols),
|
||||
m3 = MatrixType::Random(rows, cols);
|
||||
RMatrixType rm2 = MatrixType::Random(rows, cols);
|
||||
|
||||
Rhs1 rhs1 = Rhs1::Random(internal::random<int>(1,320), cols);
|
||||
Rhs2 rhs2 = Rhs2::Random(rows, internal::random<int>(1,320));
|
||||
Rhs1 rhs1 = Rhs1::Random(internal::random<int>(1,320), cols); Rhs1 rhs11 = Rhs1::Random(rhs1.rows(), cols);
|
||||
Rhs2 rhs2 = Rhs2::Random(rows, internal::random<int>(1,320)); Rhs2 rhs22 = Rhs2::Random(rows, rhs2.cols());
|
||||
Rhs3 rhs3 = Rhs3::Random(internal::random<int>(1,320), rows);
|
||||
|
||||
Scalar s1 = internal::random<Scalar>();
|
||||
@ -35,19 +38,34 @@ template<typename MatrixType> void syrk(const MatrixType& m)
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Lower>().rankUpdate(rhs2,s1)._expression()),
|
||||
((s1 * rhs2 * rhs2.adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX(((m2.template triangularView<Lower>() += s1 * rhs2 * rhs22.adjoint()).nestedExpression()),
|
||||
((s1 * rhs2 * rhs22.adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
|
||||
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX(m2.template selfadjointView<Upper>().rankUpdate(rhs2,s1)._expression(),
|
||||
(s1 * rhs2 * rhs2.adjoint()).eval().template triangularView<Upper>().toDenseMatrix());
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template triangularView<Upper>() += s1 * rhs22 * rhs2.adjoint()).nestedExpression(),
|
||||
(s1 * rhs22 * rhs2.adjoint()).eval().template triangularView<Upper>().toDenseMatrix());
|
||||
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX(m2.template selfadjointView<Lower>().rankUpdate(rhs1.adjoint(),s1)._expression(),
|
||||
(s1 * rhs1.adjoint() * rhs1).eval().template triangularView<Lower>().toDenseMatrix());
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template triangularView<Lower>() += s1 * rhs11.adjoint() * rhs1).nestedExpression(),
|
||||
(s1 * rhs11.adjoint() * rhs1).eval().template triangularView<Lower>().toDenseMatrix());
|
||||
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX(m2.template selfadjointView<Upper>().rankUpdate(rhs1.adjoint(),s1)._expression(),
|
||||
(s1 * rhs1.adjoint() * rhs1).eval().template triangularView<Upper>().toDenseMatrix());
|
||||
VERIFY_IS_APPROX((m2.template triangularView<Upper>() = s1 * rhs1.adjoint() * rhs11).nestedExpression(),
|
||||
(s1 * rhs1.adjoint() * rhs11).eval().template triangularView<Upper>().toDenseMatrix());
|
||||
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX(m2.template selfadjointView<Lower>().rankUpdate(rhs3.adjoint(),s1)._expression(),
|
||||
(s1 * rhs3.adjoint() * rhs3).eval().template triangularView<Lower>().toDenseMatrix());
|
||||
@ -63,6 +81,15 @@ template<typename MatrixType> void syrk(const MatrixType& m)
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.col(c),s1)._expression()),
|
||||
((s1 * m1.col(c) * m1.col(c).adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
|
||||
rm2.setZero();
|
||||
VERIFY_IS_APPROX((rm2.template selfadjointView<Upper>().rankUpdate(m1.col(c),s1)._expression()),
|
||||
((s1 * m1.col(c) * m1.col(c).adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template triangularView<Upper>() += s1 * m3.col(c) * m1.col(c).adjoint()).nestedExpression(),
|
||||
((s1 * m3.col(c) * m1.col(c).adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
|
||||
rm2.setZero();
|
||||
VERIFY_IS_APPROX((rm2.template triangularView<Upper>() += s1 * m1.col(c) * m3.col(c).adjoint()).nestedExpression(),
|
||||
((s1 * m1.col(c) * m3.col(c).adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Lower>().rankUpdate(m1.col(c).conjugate(),s1)._expression()),
|
||||
@ -72,9 +99,20 @@ template<typename MatrixType> void syrk(const MatrixType& m)
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.col(c).conjugate(),s1)._expression()),
|
||||
((s1 * m1.col(c).conjugate() * m1.col(c).conjugate().adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
|
||||
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Lower>().rankUpdate(m1.row(c),s1)._expression()),
|
||||
((s1 * m1.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
|
||||
rm2.setZero();
|
||||
VERIFY_IS_APPROX((rm2.template selfadjointView<Lower>().rankUpdate(m1.row(c),s1)._expression()),
|
||||
((s1 * m1.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template triangularView<Lower>() += s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).nestedExpression(),
|
||||
((s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
|
||||
rm2.setZero();
|
||||
VERIFY_IS_APPROX((rm2.template triangularView<Lower>() += s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).nestedExpression(),
|
||||
((s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
|
||||
|
||||
|
||||
m2.setZero();
|
||||
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.row(c).adjoint(),s1)._expression()),
|
||||
|
Loading…
x
Reference in New Issue
Block a user