mirror of
https://gitlab.com/libeigen/eigen.git
synced 2025-07-22 12:54:26 +08:00
Add syntactic sugar to Eigen tensors to allow more natural syntax.
Specifically, this enables expressions involving: scalar + tensor scalar * tensor scalar / tensor scalar - tensor
This commit is contained in:
parent
6021c90fdf
commit
811aadbe00
@ -215,6 +215,13 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
return unaryExpr(internal::scalar_add_op<Scalar>(rhs));
|
return unaryExpr(internal::scalar_add_op<Scalar>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE friend
|
||||||
|
const TensorCwiseUnaryOp<internal::scalar_add_op<Scalar>, const Derived>
|
||||||
|
operator+ (Scalar lhs, const Derived& rhs) {
|
||||||
|
return rhs + lhs;
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sub_op<Scalar>, const Derived>
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_sub_op<Scalar>, const Derived>
|
||||||
operator- (Scalar rhs) const {
|
operator- (Scalar rhs) const {
|
||||||
@ -222,18 +229,41 @@ class TensorBase<Derived, ReadOnlyAccessors>
|
|||||||
return unaryExpr(internal::scalar_sub_op<Scalar>(rhs));
|
return unaryExpr(internal::scalar_sub_op<Scalar>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE friend
|
||||||
|
const TensorCwiseUnaryOp<internal::scalar_add_op<Scalar>,
|
||||||
|
const TensorCwiseUnaryOp<internal::scalar_opposite_op<Scalar>, const Derived> >
|
||||||
|
operator- (Scalar lhs, const Derived& rhs) {
|
||||||
|
return -rhs + lhs;
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Derived>
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Derived>
|
||||||
operator* (Scalar rhs) const {
|
operator* (Scalar rhs) const {
|
||||||
return unaryExpr(internal::scalar_multiple_op<Scalar>(rhs));
|
return unaryExpr(internal::scalar_multiple_op<Scalar>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE friend
|
||||||
|
const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>, const Derived>
|
||||||
|
operator* (Scalar lhs, const Derived& rhs) {
|
||||||
|
return rhs * lhs;
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_quotient1_op<Scalar>, const Derived>
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_quotient1_op<Scalar>, const Derived>
|
||||||
operator/ (Scalar rhs) const {
|
operator/ (Scalar rhs) const {
|
||||||
return unaryExpr(internal::scalar_quotient1_op<Scalar>(rhs));
|
return unaryExpr(internal::scalar_quotient1_op<Scalar>(rhs));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC
|
||||||
|
EIGEN_STRONG_INLINE friend
|
||||||
|
const TensorCwiseUnaryOp<internal::scalar_multiple_op<Scalar>,
|
||||||
|
const TensorCwiseUnaryOp<internal::scalar_inverse_op<Scalar>, const Derived> >
|
||||||
|
operator/ (Scalar lhs, const Derived& rhs) {
|
||||||
|
return rhs.inverse() * lhs;
|
||||||
|
}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC
|
EIGEN_DEVICE_FUNC
|
||||||
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_mod_op<Scalar>, const Derived>
|
EIGEN_STRONG_INLINE const TensorCwiseUnaryOp<internal::scalar_mod_op<Scalar>, const Derived>
|
||||||
operator% (Scalar rhs) const {
|
operator% (Scalar rhs) const {
|
||||||
|
@ -33,7 +33,7 @@ static void test_comparison_sugar() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void test_scalar_sugar() {
|
static void test_scalar_sugar_add_mul() {
|
||||||
Tensor<float, 3> A(6, 7, 5);
|
Tensor<float, 3> A(6, 7, 5);
|
||||||
Tensor<float, 3> B(6, 7, 5);
|
Tensor<float, 3> B(6, 7, 5);
|
||||||
A.setRandom();
|
A.setRandom();
|
||||||
@ -41,21 +41,41 @@ static void test_scalar_sugar() {
|
|||||||
|
|
||||||
const float alpha = 0.43f;
|
const float alpha = 0.43f;
|
||||||
const float beta = 0.21f;
|
const float beta = 0.21f;
|
||||||
|
const float gamma = 0.14f;
|
||||||
|
|
||||||
Tensor<float, 3> R = A * A.constant(alpha) + B * B.constant(beta);
|
Tensor<float, 3> R = A.constant(gamma) + A * A.constant(alpha) + B * B.constant(beta);
|
||||||
Tensor<float, 3> S = A * alpha + B * beta;
|
Tensor<float, 3> S = A * alpha + B * beta + gamma;
|
||||||
|
Tensor<float, 3> T = gamma + alpha * A + beta * B;
|
||||||
|
|
||||||
// TODO: add enough syntactic sugar to support this
|
for (int i = 0; i < 6*7*5; ++i) {
|
||||||
// Tensor<float, 3> T = alpha * A + beta * B;
|
VERIFY_IS_APPROX(R(i), S(i));
|
||||||
|
VERIFY_IS_APPROX(R(i), T(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void test_scalar_sugar_sub_div() {
|
||||||
|
Tensor<float, 3> A(6, 7, 5);
|
||||||
|
Tensor<float, 3> B(6, 7, 5);
|
||||||
|
A.setRandom();
|
||||||
|
B.setRandom();
|
||||||
|
|
||||||
|
const float alpha = 0.43f;
|
||||||
|
const float beta = 0.21f;
|
||||||
|
const float gamma = 0.14f;
|
||||||
|
const float delta = 0.32f;
|
||||||
|
|
||||||
|
Tensor<float, 3> R = A.constant(gamma) - A / A.constant(alpha)
|
||||||
|
- B.constant(beta) / B - A.constant(delta);
|
||||||
|
Tensor<float, 3> S = gamma - A / alpha - beta / B - delta;
|
||||||
|
|
||||||
for (int i = 0; i < 6*7*5; ++i) {
|
for (int i = 0; i < 6*7*5; ++i) {
|
||||||
VERIFY_IS_APPROX(R(i), S(i));
|
VERIFY_IS_APPROX(R(i), S(i));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void test_cxx11_tensor_sugar()
|
void test_cxx11_tensor_sugar()
|
||||||
{
|
{
|
||||||
CALL_SUBTEST(test_comparison_sugar());
|
CALL_SUBTEST(test_comparison_sugar());
|
||||||
CALL_SUBTEST(test_scalar_sugar());
|
CALL_SUBTEST(test_scalar_sugar_add_mul());
|
||||||
|
CALL_SUBTEST(test_scalar_sugar_sub_div());
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user