mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-08-11 19:29:02 +08:00
Add support for replicate in CUDA
This commit is contained in:
parent
6799c26cd6
commit
88e352adac
@ -517,7 +517,9 @@ template<typename Derived> class DenseBase
|
|||||||
template<int p> RealScalar lpNorm() const;
|
template<int p> RealScalar lpNorm() const;
|
||||||
|
|
||||||
template<int RowFactor, int ColFactor>
|
template<int RowFactor, int ColFactor>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
const Replicate<Derived,RowFactor,ColFactor> replicate() const;
|
const Replicate<Derived,RowFactor,ColFactor> replicate() const;
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
const Replicate<Derived,Dynamic,Dynamic> replicate(Index rowFacor,Index colFactor) const;
|
const Replicate<Derived,Dynamic,Dynamic> replicate(Index rowFacor,Index colFactor) const;
|
||||||
|
|
||||||
typedef Reverse<Derived, BothDirections> ReverseReturnType;
|
typedef Reverse<Derived, BothDirections> ReverseReturnType;
|
||||||
|
@ -69,6 +69,7 @@ template<typename MatrixType,int RowFactor,int ColFactor> class Replicate
|
|||||||
typedef typename internal::remove_all<MatrixType>::type NestedExpression;
|
typedef typename internal::remove_all<MatrixType>::type NestedExpression;
|
||||||
|
|
||||||
template<typename OriginalMatrixType>
|
template<typename OriginalMatrixType>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
inline explicit Replicate(const OriginalMatrixType& matrix)
|
inline explicit Replicate(const OriginalMatrixType& matrix)
|
||||||
: m_matrix(matrix), m_rowFactor(RowFactor), m_colFactor(ColFactor)
|
: m_matrix(matrix), m_rowFactor(RowFactor), m_colFactor(ColFactor)
|
||||||
{
|
{
|
||||||
@ -78,6 +79,7 @@ template<typename MatrixType,int RowFactor,int ColFactor> class Replicate
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<typename OriginalMatrixType>
|
template<typename OriginalMatrixType>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
inline Replicate(const OriginalMatrixType& matrix, Index rowFactor, Index colFactor)
|
inline Replicate(const OriginalMatrixType& matrix, Index rowFactor, Index colFactor)
|
||||||
: m_matrix(matrix), m_rowFactor(rowFactor), m_colFactor(colFactor)
|
: m_matrix(matrix), m_rowFactor(rowFactor), m_colFactor(colFactor)
|
||||||
{
|
{
|
||||||
@ -85,9 +87,12 @@ template<typename MatrixType,int RowFactor,int ColFactor> class Replicate
|
|||||||
THE_MATRIX_OR_EXPRESSION_THAT_YOU_PASSED_DOES_NOT_HAVE_THE_EXPECTED_TYPE)
|
THE_MATRIX_OR_EXPRESSION_THAT_YOU_PASSED_DOES_NOT_HAVE_THE_EXPECTED_TYPE)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
inline Index rows() const { return m_matrix.rows() * m_rowFactor.value(); }
|
inline Index rows() const { return m_matrix.rows() * m_rowFactor.value(); }
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
inline Index cols() const { return m_matrix.cols() * m_colFactor.value(); }
|
inline Index cols() const { return m_matrix.cols() * m_colFactor.value(); }
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
const _MatrixTypeNested& nestedExpression() const
|
const _MatrixTypeNested& nestedExpression() const
|
||||||
{
|
{
|
||||||
return m_matrix;
|
return m_matrix;
|
||||||
|
@ -461,6 +461,7 @@ template<typename ExpressionType, int Direction> class VectorwiseOp
|
|||||||
*/
|
*/
|
||||||
// NOTE implemented here because of sunstudio's compilation errors
|
// NOTE implemented here because of sunstudio's compilation errors
|
||||||
template<int Factor> const Replicate<ExpressionType,(IsVertical?Factor:1),(IsHorizontal?Factor:1)>
|
template<int Factor> const Replicate<ExpressionType,(IsVertical?Factor:1),(IsHorizontal?Factor:1)>
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
replicate(Index factor = Factor) const
|
replicate(Index factor = Factor) const
|
||||||
{
|
{
|
||||||
return Replicate<ExpressionType,Direction==Vertical?Factor:1,Direction==Horizontal?Factor:1>
|
return Replicate<ExpressionType,Direction==Vertical?Factor:1,Direction==Horizontal?Factor:1>
|
||||||
|
@ -47,6 +47,23 @@ struct coeff_wise {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
struct replicate {
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
void operator()(int i, const typename T::Scalar* in, typename T::Scalar* out) const
|
||||||
|
{
|
||||||
|
using namespace Eigen;
|
||||||
|
T x1(in+i);
|
||||||
|
int step = x1.size() * 4;
|
||||||
|
int stride = 3 * step;
|
||||||
|
|
||||||
|
typedef Map<Array<typename T::Scalar,Dynamic,Dynamic> > MapType;
|
||||||
|
MapType(out+i*stride+0*step, x1.rows()*2, x1.cols()*2) = x1.replicate(2,2);
|
||||||
|
MapType(out+i*stride+1*step, x1.rows()*3, x1.cols()) = in[i] * x1.colwise().replicate(3);
|
||||||
|
MapType(out+i*stride+2*step, x1.rows(), x1.cols()*3) = in[i] * x1.rowwise().replicate(3);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
struct redux {
|
struct redux {
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
@ -117,7 +134,7 @@ void test_cuda_basic()
|
|||||||
Eigen::VectorXf in, out;
|
Eigen::VectorXf in, out;
|
||||||
|
|
||||||
#ifndef __CUDA_ARCH__
|
#ifndef __CUDA_ARCH__
|
||||||
int data_size = nthreads * 16;
|
int data_size = nthreads * 512;
|
||||||
in.setRandom(data_size);
|
in.setRandom(data_size);
|
||||||
out.setRandom(data_size);
|
out.setRandom(data_size);
|
||||||
#endif
|
#endif
|
||||||
@ -125,6 +142,9 @@ void test_cuda_basic()
|
|||||||
CALL_SUBTEST( run_and_compare_to_cuda(coeff_wise<Vector3f>(), nthreads, in, out) );
|
CALL_SUBTEST( run_and_compare_to_cuda(coeff_wise<Vector3f>(), nthreads, in, out) );
|
||||||
CALL_SUBTEST( run_and_compare_to_cuda(coeff_wise<Array44f>(), nthreads, in, out) );
|
CALL_SUBTEST( run_and_compare_to_cuda(coeff_wise<Array44f>(), nthreads, in, out) );
|
||||||
|
|
||||||
|
CALL_SUBTEST( run_and_compare_to_cuda(replicate<Array4f>(), nthreads, in, out) );
|
||||||
|
CALL_SUBTEST( run_and_compare_to_cuda(replicate<Array33f>(), nthreads, in, out) );
|
||||||
|
|
||||||
CALL_SUBTEST( run_and_compare_to_cuda(redux<Array4f>(), nthreads, in, out) );
|
CALL_SUBTEST( run_and_compare_to_cuda(redux<Array4f>(), nthreads, in, out) );
|
||||||
CALL_SUBTEST( run_and_compare_to_cuda(redux<Matrix3f>(), nthreads, in, out) );
|
CALL_SUBTEST( run_and_compare_to_cuda(redux<Matrix3f>(), nthreads, in, out) );
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user