cxx11_tensor_casts.cpp
Go to the documentation of this file.
1 // This file is part of Eigen, a lightweight C++ template library
2 // for linear algebra.
3 //
4 // Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5 //
6 // This Source Code Form is subject to the terms of the Mozilla
7 // Public License v. 2.0. If a copy of the MPL was not distributed
8 // with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9 
10 #include "main.h"
12 
13 #include <Eigen/CXX11/Tensor>
14 
15 using Eigen::Tensor;
16 using Eigen::array;
17 
18 static void test_simple_cast()
19 {
20  Tensor<float, 2> ftensor(20,30);
21  ftensor = ftensor.random() * 100.f;
22  Tensor<char, 2> chartensor(20,30);
23  chartensor.setRandom();
24  Tensor<std::complex<float>, 2> cplextensor(20,30);
25  cplextensor.setRandom();
26 
27  chartensor = ftensor.cast<char>();
28  cplextensor = ftensor.cast<std::complex<float> >();
29 
30  for (int i = 0; i < 20; ++i) {
31  for (int j = 0; j < 30; ++j) {
32  VERIFY_IS_EQUAL(chartensor(i,j), static_cast<char>(ftensor(i,j)));
33  VERIFY_IS_EQUAL(cplextensor(i,j), static_cast<std::complex<float> >(ftensor(i,j)));
34  }
35  }
36 }
37 
38 
39 static void test_vectorized_cast()
40 {
41  Tensor<int, 2> itensor(20,30);
42  itensor = itensor.random() / 1000;
43  Tensor<float, 2> ftensor(20,30);
44  ftensor.setRandom();
45  Tensor<double, 2> dtensor(20,30);
46  dtensor.setRandom();
47 
48  ftensor = itensor.cast<float>();
49  dtensor = itensor.cast<double>();
50 
51  for (int i = 0; i < 20; ++i) {
52  for (int j = 0; j < 30; ++j) {
53  VERIFY_IS_EQUAL(itensor(i,j), static_cast<int>(ftensor(i,j)));
54  VERIFY_IS_EQUAL(dtensor(i,j), static_cast<double>(ftensor(i,j)));
55  }
56  }
57 }
58 
59 
61 {
62  Tensor<float, 2> ftensor(20,30);
63  ftensor = ftensor.random() * 1000.0f;
64  Tensor<double, 2> dtensor(20,30);
65  dtensor = dtensor.random() * 1000.0;
66 
67  Tensor<int, 2> i1tensor = ftensor.cast<int>();
68  Tensor<int, 2> i2tensor = dtensor.cast<int>();
69 
70  for (int i = 0; i < 20; ++i) {
71  for (int j = 0; j < 30; ++j) {
72  VERIFY_IS_EQUAL(i1tensor(i,j), static_cast<int>(ftensor(i,j)));
73  VERIFY_IS_EQUAL(i2tensor(i,j), static_cast<int>(dtensor(i,j)));
74  }
75  }
76 }
77 
78 
80 {
81  Tensor<double, 2> dtensor(20, 30);
82  dtensor.setRandom();
83  Tensor<float, 2> ftensor(20, 30);
84  ftensor = dtensor.cast<float>();
85 
86  for (int i = 0; i < 20; ++i) {
87  for (int j = 0; j < 30; ++j) {
88  VERIFY_IS_APPROX(dtensor(i,j), static_cast<double>(ftensor(i,j)));
89  }
90  }
91 }
92 
93 
95 {
96  Tensor<float, 2> ftensor(20, 30);
97  ftensor.setRandom();
98  Tensor<double, 2> dtensor(20, 30);
99  dtensor = ftensor.cast<double>();
100 
101  for (int i = 0; i < 20; ++i) {
102  for (int j = 0; j < 30; ++j) {
103  VERIFY_IS_APPROX(dtensor(i,j), static_cast<double>(ftensor(i,j)));
104  }
105  }
106 }
107 
108 template <typename FromType, typename ToType>
109 static void test_type_cast() {
110  Tensor<FromType, 2> ftensor(100, 200);
111  // Generate random values for a valid cast.
112  for (int i = 0; i < 100; ++i) {
113  for (int j = 0; j < 200; ++j) {
115  }
116  }
117 
118  Tensor<ToType, 2> ttensor(100, 200);
119  ttensor = ftensor.template cast<ToType>();
120 
121  for (int i = 0; i < 100; ++i) {
122  for (int j = 0; j < 200; ++j) {
123  const ToType ref = internal::cast<FromType,ToType>(ftensor(i, j));
124  VERIFY_IS_APPROX(ttensor(i, j), ref);
125  }
126  }
127 }
128 
129 template<typename Scalar, typename EnableIf = void>
130 struct test_cast_runner {
131  static void run() {
132  test_type_cast<Scalar, bool>();
143  test_type_cast<Scalar, float>();
144  test_type_cast<Scalar, double>();
145  test_type_cast<Scalar, std::complex<float>>();
146  test_type_cast<Scalar, std::complex<double>>();
147  }
148 };
149 
150 // Only certain types allow cast from std::complex<>.
151 template<typename Scalar>
152 struct test_cast_runner<Scalar, typename internal::enable_if<NumTraits<Scalar>::IsComplex>::type> {
153  static void run() {
156  test_type_cast<Scalar, std::complex<float>>();
157  test_type_cast<Scalar, std::complex<double>>();
158  }
159 };
160 
161 
162 EIGEN_DECLARE_TEST(cxx11_tensor_casts)
163 {
169 
183  CALL_SUBTEST(test_cast_runner<std::complex<float>>::run());
184  CALL_SUBTEST(test_cast_runner<std::complex<double>>::run());
185 
186 }
unsigned char uint8_t
Definition: ms_stdint.h:83
int array[24]
SCALAR Scalar
Definition: bench_gemm.cpp:46
EIGEN_DECLARE_TEST(cxx11_tensor_casts)
static void test_simple_cast()
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
Definition: TensorBase.h:996
unsigned short uint16_t
Definition: ms_stdint.h:84
static void test_vectorized_cast()
static void test_small_to_big_type_cast()
signed char int8_t
Definition: ms_stdint.h:80
signed short int16_t
Definition: ms_stdint.h:81
#define VERIFY_IS_APPROX(a, b)
#define VERIFY_IS_EQUAL(a, b)
Definition: main.h:386
signed __int64 int64_t
Definition: ms_stdint.h:94
unsigned __int64 uint64_t
Definition: ms_stdint.h:95
unsigned int uint32_t
Definition: ms_stdint.h:85
signed int int32_t
Definition: ms_stdint.h:82
static void test_big_to_small_type_cast()
Reference counting helper.
Definition: object.h:67
static void test_float_to_int_cast()
#define CALL_SUBTEST(FUNC)
Definition: main.h:399
static void test_type_cast()
std::ptrdiff_t j
The tensor class.
Definition: Tensor.h:63


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:34:07