mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-06-04 18:54:00 +08:00
Fix bug in a test + compilation errors
This commit is contained in:
parent
1c8b9e10a7
commit
cfaedb38cd
@ -115,6 +115,7 @@ class TensorExecutor<Expression, DefaultDevice, Vectorizable,
|
||||
const DefaultDevice& device = DefaultDevice()) {
|
||||
typedef TensorBlock<ScalarNoConst, StorageIndex, NumDims, Evaluator::Layout>
|
||||
TensorBlock;
|
||||
typedef typename TensorBlock::Dimensions TensorBlockDimensions;
|
||||
typedef TensorBlockMapper<ScalarNoConst, StorageIndex, NumDims,
|
||||
Evaluator::Layout>
|
||||
TensorBlockMapper;
|
||||
@ -141,8 +142,9 @@ class TensorExecutor<Expression, DefaultDevice, Vectorizable,
|
||||
evaluator.getResourceRequirements(&resources);
|
||||
MergeResourceRequirements(resources, &block_shape, &block_total_size);
|
||||
|
||||
TensorBlockMapper block_mapper(evaluator.dimensions(), block_shape,
|
||||
block_total_size);
|
||||
TensorBlockMapper block_mapper(
|
||||
TensorBlockDimensions(evaluator.dimensions()), block_shape,
|
||||
block_total_size);
|
||||
block_total_size = block_mapper.block_dims_total_size();
|
||||
|
||||
Scalar* data = static_cast<Scalar*>(
|
||||
|
@ -520,6 +520,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
|
||||
|
||||
typedef internal::TensorBlock<ScalarNoConst, Index, NumDims, Layout>
|
||||
TensorBlock;
|
||||
typedef typename TensorBlock::Dimensions TensorBlockDimensions;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||
: m_impl(op.expression(), device), m_device(device), m_dimensions(op.sizes()), m_offsets(op.startIndices())
|
||||
@ -687,7 +688,7 @@ struct TensorEvaluator<const TensorSlicingOp<StartIndices, Sizes, ArgType>, Devi
|
||||
TensorBlock input_block(srcCoeff(output_block->first_coeff_index()),
|
||||
output_block->block_sizes(),
|
||||
output_block->block_strides(),
|
||||
Dimensions(m_inputStrides),
|
||||
TensorBlockDimensions(m_inputStrides),
|
||||
output_block->data());
|
||||
m_impl.block(&input_block);
|
||||
}
|
||||
@ -796,6 +797,7 @@ struct TensorEvaluator<TensorSlicingOp<StartIndices, Sizes, ArgType>, Device>
|
||||
|
||||
typedef internal::TensorBlock<ScalarNoConst, Index, NumDims, Layout>
|
||||
TensorBlock;
|
||||
typedef typename TensorBlock::Dimensions TensorBlockDimensions;
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorEvaluator(const XprType& op, const Device& device)
|
||||
: Base(op, device)
|
||||
@ -862,13 +864,11 @@ struct TensorEvaluator<TensorSlicingOp<StartIndices, Sizes, ArgType>, Device>
|
||||
const TensorBlock& block) {
|
||||
this->m_impl.writeBlock(TensorBlock(
|
||||
this->srcCoeff(block.first_coeff_index()), block.block_sizes(),
|
||||
block.block_strides(), Dimensions(this->m_inputStrides),
|
||||
block.block_strides(), TensorBlockDimensions(this->m_inputStrides),
|
||||
const_cast<ScalarNoConst*>(block.data())));
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
namespace internal {
|
||||
template<typename StartIndices, typename StopIndices, typename Strides, typename XprType>
|
||||
struct traits<TensorStridingSlicingOp<StartIndices, StopIndices, Strides, XprType> > : public traits<XprType>
|
||||
|
@ -317,7 +317,7 @@ static void test_execute_reshape(Device d)
|
||||
|
||||
DSizes<Index, ReshapedDims> reshaped_dims;
|
||||
reshaped_dims[shuffle[0]] = dims[0] * dims[1];
|
||||
for (int i = 2; i < NumDims; ++i) reshaped_dims[shuffle[i]] = dims[i];
|
||||
for (int i = 1; i < ReshapedDims; ++i) reshaped_dims[shuffle[i]] = dims[i + 1];
|
||||
|
||||
Tensor<T, ReshapedDims, Options, Index> golden = src.reshape(reshaped_dims);
|
||||
|
||||
|
@ -83,10 +83,10 @@ static void test_expr_shuffling()
|
||||
|
||||
Tensor<float, 4, DataLayout> result(5,7,3,2);
|
||||
|
||||
array<int, 4> src_slice_dim{{2,3,1,7}};
|
||||
array<int, 4> src_slice_start{{0,0,0,0}};
|
||||
array<int, 4> dst_slice_dim{{1,7,3,2}};
|
||||
array<int, 4> dst_slice_start{{0,0,0,0}};
|
||||
array<ptrdiff_t, 4> src_slice_dim{{2,3,1,7}};
|
||||
array<ptrdiff_t, 4> src_slice_start{{0,0,0,0}};
|
||||
array<ptrdiff_t, 4> dst_slice_dim{{1,7,3,2}};
|
||||
array<ptrdiff_t, 4> dst_slice_start{{0,0,0,0}};
|
||||
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
result.slice(dst_slice_start, dst_slice_dim) =
|
||||
|
Loading…
x
Reference in New Issue
Block a user