cxx11_tensor_casts.cpp
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
12 
13 #include <Eigen/CXX11/Tensor>
14 
15 using Eigen::Tensor;
16 using Eigen::array;
17 
18 static void test_simple_cast()
19 {
20  Tensor<float, 2> ftensor(20,30);
21  ftensor = ftensor.random() * 100.f;
22  Tensor<char, 2> chartensor(20,30);
23  chartensor.setRandom();
24  Tensor<std::complex<float>, 2> cplextensor(20,30);
25  cplextensor.setRandom();
26 
27  chartensor = ftensor.cast<char>();
28  cplextensor = ftensor.cast<std::complex<float> >();
29 
30  for (int i = 0; i < 20; ++i) {
31  for (int j = 0; j < 30; ++j) {
32  VERIFY_IS_EQUAL(chartensor(i,j), static_cast<char>(ftensor(i,j)));
33  VERIFY_IS_EQUAL(cplextensor(i,j), static_cast<std::complex<float> >(ftensor(i,j)));
34  }
35  }
36 }
37 
38 
39 static void test_vectorized_cast()
40 {
41  Tensor<int, 2> itensor(20,30);
42  itensor = itensor.random() / 1000;
43  Tensor<float, 2> ftensor(20,30);
44  ftensor.setRandom();
45  Tensor<double, 2> dtensor(20,30);
46  dtensor.setRandom();
47 
48  ftensor = itensor.cast<float>();
49  dtensor = itensor.cast<double>();
50 
51  for (int i = 0; i < 20; ++i) {
52  for (int j = 0; j < 30; ++j) {
53  VERIFY_IS_EQUAL(itensor(i,j), static_cast<int>(ftensor(i,j)));
54  VERIFY_IS_EQUAL(dtensor(i,j), static_cast<double>(ftensor(i,j)));
55  }
56  }
57 }
58 
59 
61 {
62  Tensor<float, 2> ftensor(20,30);
63  ftensor = ftensor.random() * 1000.0f;
64  Tensor<double, 2> dtensor(20,30);
65  dtensor = dtensor.random() * 1000.0;
66 
67  Tensor<int, 2> i1tensor = ftensor.cast<int>();
68  Tensor<int, 2> i2tensor = dtensor.cast<int>();
69 
70  for (int i = 0; i < 20; ++i) {
71  for (int j = 0; j < 30; ++j) {
72  VERIFY_IS_EQUAL(i1tensor(i,j), static_cast<int>(ftensor(i,j)));
73  VERIFY_IS_EQUAL(i2tensor(i,j), static_cast<int>(dtensor(i,j)));
74  }
75  }
76 }
77 
78 
80 {
81  Tensor<double, 2> dtensor(20, 30);
82  dtensor.setRandom();
83  Tensor<float, 2> ftensor(20, 30);
84  ftensor = dtensor.cast<float>();
85 
86  for (int i = 0; i < 20; ++i) {
87  for (int j = 0; j < 30; ++j) {
88  VERIFY_IS_APPROX(dtensor(i,j), static_cast<double>(ftensor(i,j)));
89  }
90  }
91 }
92 
93 
95 {
96  Tensor<float, 2> ftensor(20, 30);
97  ftensor.setRandom();
98  Tensor<double, 2> dtensor(20, 30);
99  dtensor = ftensor.cast<double>();
100 
101  for (int i = 0; i < 20; ++i) {
102  for (int j = 0; j < 30; ++j) {
103  VERIFY_IS_APPROX(dtensor(i,j), static_cast<double>(ftensor(i,j)));
104  }
105  }
106 }
107 
108 template <typename FromType, typename ToType>
109 static void test_type_cast() {
110  Tensor<FromType, 2> ftensor(100, 200);
111  // Generate random values for a valid cast.
112  for (int i = 0; i < 100; ++i) {
113  for (int j = 0; j < 200; ++j) {
115  }
116  }
117 
118  Tensor<ToType, 2> ttensor(100, 200);
119  ttensor = ftensor.template cast<ToType>();
120 
121  for (int i = 0; i < 100; ++i) {
122  for (int j = 0; j < 200; ++j) {
123  const ToType ref = internal::cast<FromType,ToType>(ftensor(i, j));
124  VERIFY_IS_APPROX(ttensor(i, j), ref);
125  }
126  }
127 }
128 
129 template<typename Scalar, typename EnableIf = void>
130 struct test_cast_runner {
131  static void run() {
132  test_type_cast<Scalar, bool>();
133  test_type_cast<Scalar, int8_t>();
134  test_type_cast<Scalar, int16_t>();
135  test_type_cast<Scalar, int32_t>();
136  test_type_cast<Scalar, int64_t>();
137  test_type_cast<Scalar, uint8_t>();
138  test_type_cast<Scalar, uint16_t>();
139  test_type_cast<Scalar, uint32_t>();
140  test_type_cast<Scalar, uint64_t>();
141  test_type_cast<Scalar, half>();
142  test_type_cast<Scalar, bfloat16>();
143  test_type_cast<Scalar, float>();
144  test_type_cast<Scalar, double>();
145  test_type_cast<Scalar, std::complex<float>>();
146  test_type_cast<Scalar, std::complex<double>>();
147  }
148 };
149 
150 // Only certain types allow cast from std::complex<>.
151 template<typename Scalar>
152 struct test_cast_runner<Scalar, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> {
153  static void run() {
154  test_type_cast<Scalar, half>();
155  test_type_cast<Scalar, bfloat16>();
156  test_type_cast<Scalar, std::complex<float>>();
157  test_type_cast<Scalar, std::complex<double>>();
158  }
159 };
160 
161 
162 EIGEN_DECLARE_TEST(cxx11_tensor_casts)
163 {
169 
183  CALL_SUBTEST(test_cast_runner<std::complex<float>>::run());
184  CALL_SUBTEST(test_cast_runner<std::complex<double>>::run());
185 
186 }
Eigen::Tensor
The tensor class.
Definition: Tensor.h:63
array
int array[24]
Definition: Map_general_stride.cpp:1
gtsam.examples.DogLegOptimizerExample.type
type
Definition: DogLegOptimizerExample.py:111
VERIFY_IS_EQUAL
#define VERIFY_IS_EQUAL(a, b)
Definition: main.h:386
test_small_to_big_type_cast
static void test_small_to_big_type_cast()
Definition: cxx11_tensor_casts.cpp:94
test_cast_runner::run
static void run()
Definition: cxx11_tensor_casts.cpp:131
test_vectorized_cast
static void test_vectorized_cast()
Definition: cxx11_tensor_casts.cpp:39
test_cast_runner< Scalar, typename internal::enable_if< NumTraits< Scalar >::IsComplex >::type >::run
static void run()
Definition: cxx11_tensor_casts.cpp:153
test_cast_runner
Definition: packetmath.cpp:195
test_big_to_small_type_cast
static void test_big_to_small_type_cast()
Definition: cxx11_tensor_casts.cpp:79
j
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
test_float_to_int_cast
static void test_float_to_int_cast()
Definition: cxx11_tensor_casts.cpp:60
gtsam.examples.DogLegOptimizerExample.run
def run(args)
Definition: DogLegOptimizerExample.py:21
test_simple_cast
static void test_simple_cast()
Definition: cxx11_tensor_casts.cpp:18
VERIFY_IS_APPROX
#define VERIFY_IS_APPROX(a, b)
Definition: integer_types.cpp:15
Eigen::TensorBase< Tensor< Scalar_, NumIndices_, Options_, IndexType_ > >::setRandom
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
Definition: TensorBase.h:996
test_type_cast
static void test_type_cast()
Definition: cxx11_tensor_casts.cpp:109
main.h
EIGEN_DECLARE_TEST
EIGEN_DECLARE_TEST(cxx11_tensor_casts)
Definition: cxx11_tensor_casts.cpp:162
random_without_cast_overflow.h
ref
Reference counting helper.
Definition: object.h:67
internal
Definition: BandTriangularSolver.h:13
test_callbacks.value
value
Definition: test_callbacks.py:160
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
Scalar
SCALAR Scalar
Definition: bench_gemm.cpp:46
CALL_SUBTEST
#define CALL_SUBTEST(FUNC)
Definition: main.h:399


gtsam
Author(s):
autogenerated on Wed Jan 1 2025 04:01:23