mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
fix a couple a issue with blas (new TRMM api, and enforece column major)
This commit is contained in:
parent
f59226e901
commit
cbd6fe323c
@ -86,15 +86,15 @@ enum
|
||||
Conj = IsComplex
|
||||
};
|
||||
|
||||
typedef Map<Matrix<Scalar,Dynamic,Dynamic>, 0, OuterStride<Dynamic> > MatrixType;
|
||||
typedef Map<Matrix<Scalar,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> > MatrixType;
|
||||
typedef Map<Matrix<Scalar,Dynamic,1>, 0, InnerStride<Dynamic> > StridedVectorType;
|
||||
typedef Map<Matrix<Scalar,Dynamic,1> > CompactVectorType;
|
||||
|
||||
template<typename T>
|
||||
Map<Matrix<T,Dynamic,Dynamic>, 0, OuterStride<Dynamic> >
|
||||
Map<Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >
|
||||
matrix(T* data, int rows, int cols, int stride)
|
||||
{
|
||||
return Map<Matrix<T,Dynamic,Dynamic>, 0, OuterStride<Dynamic> >(data, rows, cols, OuterStride<Dynamic>(stride));
|
||||
return Map<Matrix<T,Dynamic,Dynamic,ColMajor>, 0, OuterStride<> >(data, rows, cols, OuterStride<>(stride));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
@ -63,10 +63,8 @@ int EIGEN_BLAS_FUNC(gemm)(char *opa, char *opb, int *m, int *n, int *k, RealScal
|
||||
|
||||
if(beta!=Scalar(1))
|
||||
{
|
||||
if(beta==Scalar(0))
|
||||
matrix(c, *m, *n, *ldc).setZero();
|
||||
else
|
||||
matrix(c, *m, *n, *ldc) *= beta;
|
||||
if(beta==Scalar(0)) matrix(c, *m, *n, *ldc).setZero();
|
||||
else matrix(c, *m, *n, *ldc) *= beta;
|
||||
}
|
||||
|
||||
ei_gemm_blocking_space<ColMajor,Scalar,Scalar,Dynamic,Dynamic,Dynamic> blocking(*m,*n,*k);
|
||||
@ -207,14 +205,17 @@ int EIGEN_BLAS_FUNC(trmm)(char *side, char *uplo, char *opa, char *diag, int *m,
|
||||
return 0;
|
||||
}
|
||||
|
||||
if(*m==0 || *n==0)
|
||||
return 1;
|
||||
|
||||
// FIXME find a way to avoid this copy
|
||||
Matrix<Scalar,Dynamic,Dynamic> tmp = matrix(b,*m,*n,*ldb);
|
||||
Matrix<Scalar,Dynamic,Dynamic,ColMajor> tmp = matrix(b,*m,*n,*ldb);
|
||||
matrix(b,*m,*n,*ldb).setZero();
|
||||
|
||||
if(SIDE(*side)==LEFT)
|
||||
func[code](*m, *n, *m, a, *lda, tmp.data(), tmp.outerStride(), b, *ldb, alpha);
|
||||
else
|
||||
func[code](*n, *m, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha);
|
||||
func[code](*m, *n, *n, tmp.data(), tmp.outerStride(), a, *lda, b, *ldb, alpha);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user