Fix bug #496: generalize internal rank1_update implementation to accept uplo(A) += v * w and make A.triangularView() += v * w uses it.

Update unit tests and blas interface respectively.
This commit is contained in:
Gael Guennebaud 2013-02-24 23:05:42 +01:00
parent 08388cc712
commit 04367447ac
5 changed files with 138 additions and 37 deletions

View File

@ -12,6 +12,9 @@
namespace Eigen {
template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs>
struct selfadjoint_rank1_update;
namespace internal {
/**********************************************************************
@ -180,31 +183,93 @@ struct tribb_kernel
// high level API
template<typename MatrixType, typename ProductType, int UpLo, bool IsOuterProduct>
struct general_product_to_triangular_selector;
template<typename MatrixType, typename ProductType, int UpLo>
struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,true>
{
static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha)
{
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::Index Index;
typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs;
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
typedef typename internal::remove_all<ActualLhs>::type _ActualLhs;
typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
typedef typename internal::remove_all<ActualRhs>::type _ActualRhs;
typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
enum {
StorageOrder = (internal::traits<MatrixType>::Flags&RowMajorBit) ? RowMajor : ColMajor,
UseLhsDirectly = _ActualLhs::InnerStrideAtCompileTime==1,
UseRhsDirectly = _ActualRhs::InnerStrideAtCompileTime==1
};
internal::gemv_static_vector_if<Scalar,Lhs::SizeAtCompileTime,Lhs::MaxSizeAtCompileTime,!UseLhsDirectly> static_lhs;
ei_declare_aligned_stack_constructed_variable(Scalar, actualLhsPtr, actualLhs.size(),
(UseLhsDirectly ? const_cast<Scalar*>(actualLhs.data()) : static_lhs.data()));
if(!UseLhsDirectly) Map<typename _ActualLhs::PlainObject>(actualLhsPtr, actualLhs.size()) = actualLhs;
internal::gemv_static_vector_if<Scalar,Rhs::SizeAtCompileTime,Rhs::MaxSizeAtCompileTime,!UseRhsDirectly> static_rhs;
ei_declare_aligned_stack_constructed_variable(Scalar, actualRhsPtr, actualRhs.size(),
(UseRhsDirectly ? const_cast<Scalar*>(actualRhs.data()) : static_rhs.data()));
if(!UseRhsDirectly) Map<typename _ActualRhs::PlainObject>(actualRhsPtr, actualRhs.size()) = actualRhs;
selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo,
LhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
RhsBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex>
::run(actualLhs.size(), mat.data(), mat.outerStride(), actualLhsPtr, actualRhsPtr, actualAlpha);
}
};
template<typename MatrixType, typename ProductType, int UpLo>
struct general_product_to_triangular_selector<MatrixType,ProductType,UpLo,false>
{
static void run(MatrixType& mat, const ProductType& prod, const typename MatrixType::Scalar& alpha)
{
typedef typename MatrixType::Scalar Scalar;
typedef typename MatrixType::Index Index;
typedef typename internal::remove_all<typename ProductType::LhsNested>::type Lhs;
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
typedef typename internal::remove_all<ActualLhs>::type _ActualLhs;
typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
typedef typename internal::remove_all<typename ProductType::RhsNested>::type Rhs;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
typedef typename internal::remove_all<ActualRhs>::type _ActualRhs;
typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
typename ProductType::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
internal::general_matrix_matrix_triangular_product<Index,
typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo>
::run(mat.cols(), actualLhs.cols(),
&actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(),
mat.data(), mat.outerStride(), actualAlpha);
}
};
template<typename MatrixType, unsigned int UpLo>
template<typename ProductDerived, typename _Lhs, typename _Rhs>
TriangularView<MatrixType,UpLo>& TriangularView<MatrixType,UpLo>::assignProduct(const ProductBase<ProductDerived, _Lhs,_Rhs>& prod, const Scalar& alpha)
{
typedef typename internal::remove_all<typename ProductDerived::LhsNested>::type Lhs;
typedef internal::blas_traits<Lhs> LhsBlasTraits;
typedef typename LhsBlasTraits::DirectLinearAccessType ActualLhs;
typedef typename internal::remove_all<ActualLhs>::type _ActualLhs;
typename internal::add_const_on_value_type<ActualLhs>::type actualLhs = LhsBlasTraits::extract(prod.lhs());
typedef typename internal::remove_all<typename ProductDerived::RhsNested>::type Rhs;
typedef internal::blas_traits<Rhs> RhsBlasTraits;
typedef typename RhsBlasTraits::DirectLinearAccessType ActualRhs;
typedef typename internal::remove_all<ActualRhs>::type _ActualRhs;
typename internal::add_const_on_value_type<ActualRhs>::type actualRhs = RhsBlasTraits::extract(prod.rhs());
typename ProductDerived::Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(prod.lhs().derived()) * RhsBlasTraits::extractScalarFactor(prod.rhs().derived());
internal::general_matrix_matrix_triangular_product<Index,
typename Lhs::Scalar, _ActualLhs::Flags&RowMajorBit ? RowMajor : ColMajor, LhsBlasTraits::NeedToConjugate,
typename Rhs::Scalar, _ActualRhs::Flags&RowMajorBit ? RowMajor : ColMajor, RhsBlasTraits::NeedToConjugate,
MatrixType::Flags&RowMajorBit ? RowMajor : ColMajor, UpLo>
::run(m_matrix.cols(), actualLhs.cols(),
&actualLhs.coeffRef(0,0), actualLhs.outerStride(), &actualRhs.coeffRef(0,0), actualRhs.outerStride(),
const_cast<Scalar*>(m_matrix.data()), m_matrix.outerStride(), actualAlpha);
general_product_to_triangular_selector<MatrixType, ProductDerived, UpLo, (_Lhs::ColsAtCompileTime==1) || (_Rhs::RowsAtCompileTime==1)>::run(m_matrix.const_cast_derived(), prod.derived(), alpha);
return *this;
}

View File

@ -18,21 +18,19 @@
namespace Eigen {
template<typename Scalar, typename Index, int StorageOrder, int UpLo, bool ConjLhs, bool ConjRhs>
struct selfadjoint_rank1_update;
template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs>
{
static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha)
static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, Scalar alpha)
{
internal::conj_if<ConjRhs> cj;
typedef Map<const Matrix<Scalar,Dynamic,1> > OtherMap;
typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjRhsType;
typedef typename internal::conditional<ConjLhs,typename OtherMap::ConjugateReturnType,const OtherMap&>::type ConjLhsType;
for (Index i=0; i<size; ++i)
{
Map<Matrix<Scalar,Dynamic,1> >(mat+stride*i+(UpLo==Lower ? i : 0), (UpLo==Lower ? size-i : (i+1)))
+= (alpha * cj(vec[i])) * ConjRhsType(OtherMap(vec+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1)));
+= (alpha * cj(vecY[i])) * ConjLhsType(OtherMap(vecX+(UpLo==Lower ? i : 0),UpLo==Lower ? size-i : (i+1)));
}
}
};
@ -40,9 +38,9 @@ struct selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo,ConjLhs,ConjRhs>
template<typename Scalar, typename Index, int UpLo, bool ConjLhs, bool ConjRhs>
struct selfadjoint_rank1_update<Scalar,Index,RowMajor,UpLo,ConjLhs,ConjRhs>
{
static void run(Index size, Scalar* mat, Index stride, const Scalar* vec, Scalar alpha)
static void run(Index size, Scalar* mat, Index stride, const Scalar* vecX, const Scalar* vecY, Scalar alpha)
{
selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vec,alpha);
selfadjoint_rank1_update<Scalar,Index,ColMajor,UpLo==Lower?Upper:Lower,ConjRhs,ConjLhs>::run(size,mat,stride,vecY,vecX,alpha);
}
};
@ -78,7 +76,7 @@ struct selfadjoint_product_selector<MatrixType,OtherType,UpLo,true>
selfadjoint_rank1_update<Scalar,Index,StorageOrder,UpLo,
OtherBlasTraits::NeedToConjugate && NumTraits<Scalar>::IsComplex,
(!OtherBlasTraits::NeedToConjugate) && NumTraits<Scalar>::IsComplex>
::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualAlpha);
::run(other.size(), mat.data(), mat.outerStride(), actualOtherPtr, actualOtherPtr, actualAlpha);
}
};

View File

@ -216,7 +216,7 @@ int EIGEN_BLAS_FUNC(hpr2)(char *uplo, int *n, RealScalar *palpha, RealScalar *px
*/
int EIGEN_BLAS_FUNC(her)(char *uplo, int *n, RealScalar *palpha, RealScalar *px, int *incx, RealScalar *pa, int *lda)
{
typedef void (*functype)(int, Scalar*, int, const Scalar*, Scalar);
typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar);
static functype func[2];
static bool init = false;
@ -252,7 +252,7 @@ int EIGEN_BLAS_FUNC(her)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
if(code>=2 || func[code]==0)
return 0;
func[code](*n, a, *lda, x_cpy, alpha);
func[code](*n, a, *lda, x_cpy, x_cpy, alpha);
matrix(a,*n,*n,*lda).diagonal().imag().setZero();

View File

@ -85,7 +85,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
// init = true;
// }
typedef void (*functype)(int, Scalar*, int, const Scalar*, Scalar);
typedef void (*functype)(int, Scalar*, int, const Scalar*, const Scalar*, Scalar);
static functype func[2];
static bool init = false;
@ -121,7 +121,7 @@ int EIGEN_BLAS_FUNC(syr)(char *uplo, int *n, RealScalar *palpha, RealScalar *px,
if(code>=2 || func[code]==0)
return 0;
func[code](*n, c, *ldc, x_cpy, alpha);
func[code](*n, c, *ldc, x_cpy, x_cpy, alpha);
if(x_cpy!=x) delete[] x_cpy;

View File

@ -14,6 +14,7 @@ template<typename MatrixType> void syrk(const MatrixType& m)
typedef typename MatrixType::Index Index;
typedef typename MatrixType::Scalar Scalar;
typedef typename NumTraits<Scalar>::Real RealScalar;
typedef Matrix<Scalar, MatrixType::RowsAtCompileTime, MatrixType::ColsAtCompileTime, RowMajor> RMatrixType;
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, Dynamic> Rhs1;
typedef Matrix<Scalar, Dynamic, MatrixType::RowsAtCompileTime> Rhs2;
typedef Matrix<Scalar, MatrixType::ColsAtCompileTime, Dynamic,RowMajor> Rhs3;
@ -22,10 +23,12 @@ template<typename MatrixType> void syrk(const MatrixType& m)
Index cols = m.cols();
MatrixType m1 = MatrixType::Random(rows, cols),
m2 = MatrixType::Random(rows, cols);
m2 = MatrixType::Random(rows, cols),
m3 = MatrixType::Random(rows, cols);
RMatrixType rm2 = MatrixType::Random(rows, cols);
Rhs1 rhs1 = Rhs1::Random(internal::random<int>(1,320), cols);
Rhs2 rhs2 = Rhs2::Random(rows, internal::random<int>(1,320));
Rhs1 rhs1 = Rhs1::Random(internal::random<int>(1,320), cols); Rhs1 rhs11 = Rhs1::Random(rhs1.rows(), cols);
Rhs2 rhs2 = Rhs2::Random(rows, internal::random<int>(1,320)); Rhs2 rhs22 = Rhs2::Random(rows, rhs2.cols());
Rhs3 rhs3 = Rhs3::Random(internal::random<int>(1,320), rows);
Scalar s1 = internal::random<Scalar>();
@ -35,19 +38,34 @@ template<typename MatrixType> void syrk(const MatrixType& m)
m2.setZero();
VERIFY_IS_APPROX((m2.template selfadjointView<Lower>().rankUpdate(rhs2,s1)._expression()),
((s1 * rhs2 * rhs2.adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
m2.setZero();
VERIFY_IS_APPROX(((m2.template triangularView<Lower>() += s1 * rhs2 * rhs22.adjoint()).nestedExpression()),
((s1 * rhs2 * rhs22.adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
m2.setZero();
VERIFY_IS_APPROX(m2.template selfadjointView<Upper>().rankUpdate(rhs2,s1)._expression(),
(s1 * rhs2 * rhs2.adjoint()).eval().template triangularView<Upper>().toDenseMatrix());
m2.setZero();
VERIFY_IS_APPROX((m2.template triangularView<Upper>() += s1 * rhs22 * rhs2.adjoint()).nestedExpression(),
(s1 * rhs22 * rhs2.adjoint()).eval().template triangularView<Upper>().toDenseMatrix());
m2.setZero();
VERIFY_IS_APPROX(m2.template selfadjointView<Lower>().rankUpdate(rhs1.adjoint(),s1)._expression(),
(s1 * rhs1.adjoint() * rhs1).eval().template triangularView<Lower>().toDenseMatrix());
m2.setZero();
VERIFY_IS_APPROX((m2.template triangularView<Lower>() += s1 * rhs11.adjoint() * rhs1).nestedExpression(),
(s1 * rhs11.adjoint() * rhs1).eval().template triangularView<Lower>().toDenseMatrix());
m2.setZero();
VERIFY_IS_APPROX(m2.template selfadjointView<Upper>().rankUpdate(rhs1.adjoint(),s1)._expression(),
(s1 * rhs1.adjoint() * rhs1).eval().template triangularView<Upper>().toDenseMatrix());
VERIFY_IS_APPROX((m2.template triangularView<Upper>() = s1 * rhs1.adjoint() * rhs11).nestedExpression(),
(s1 * rhs1.adjoint() * rhs11).eval().template triangularView<Upper>().toDenseMatrix());
m2.setZero();
VERIFY_IS_APPROX(m2.template selfadjointView<Lower>().rankUpdate(rhs3.adjoint(),s1)._expression(),
(s1 * rhs3.adjoint() * rhs3).eval().template triangularView<Lower>().toDenseMatrix());
@ -63,6 +81,15 @@ template<typename MatrixType> void syrk(const MatrixType& m)
m2.setZero();
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.col(c),s1)._expression()),
((s1 * m1.col(c) * m1.col(c).adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
rm2.setZero();
VERIFY_IS_APPROX((rm2.template selfadjointView<Upper>().rankUpdate(m1.col(c),s1)._expression()),
((s1 * m1.col(c) * m1.col(c).adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
m2.setZero();
VERIFY_IS_APPROX((m2.template triangularView<Upper>() += s1 * m3.col(c) * m1.col(c).adjoint()).nestedExpression(),
((s1 * m3.col(c) * m1.col(c).adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
rm2.setZero();
VERIFY_IS_APPROX((rm2.template triangularView<Upper>() += s1 * m1.col(c) * m3.col(c).adjoint()).nestedExpression(),
((s1 * m1.col(c) * m3.col(c).adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
m2.setZero();
VERIFY_IS_APPROX((m2.template selfadjointView<Lower>().rankUpdate(m1.col(c).conjugate(),s1)._expression()),
@ -72,9 +99,20 @@ template<typename MatrixType> void syrk(const MatrixType& m)
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.col(c).conjugate(),s1)._expression()),
((s1 * m1.col(c).conjugate() * m1.col(c).conjugate().adjoint()).eval().template triangularView<Upper>().toDenseMatrix()));
m2.setZero();
VERIFY_IS_APPROX((m2.template selfadjointView<Lower>().rankUpdate(m1.row(c),s1)._expression()),
((s1 * m1.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
rm2.setZero();
VERIFY_IS_APPROX((rm2.template selfadjointView<Lower>().rankUpdate(m1.row(c),s1)._expression()),
((s1 * m1.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
m2.setZero();
VERIFY_IS_APPROX((m2.template triangularView<Lower>() += s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).nestedExpression(),
((s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
rm2.setZero();
VERIFY_IS_APPROX((rm2.template triangularView<Lower>() += s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).nestedExpression(),
((s1 * m3.row(c).transpose() * m1.row(c).transpose().adjoint()).eval().template triangularView<Lower>().toDenseMatrix()));
m2.setZero();
VERIFY_IS_APPROX((m2.template selfadjointView<Upper>().rankUpdate(m1.row(c).adjoint(),s1)._expression()),