gtsam
3rdparty
Eigen
unsupported
test
cxx11_tensor_patch.cpp
Go to the documentation of this file.
1
// This file is part of Eigen, a lightweight C++ template library
2
// for linear algebra.
3
//
4
// Copyright (C) 2014 Benoit Steiner <benoit.steiner.goog@gmail.com>
5
//
6
// This Source Code Form is subject to the terms of the Mozilla
7
// Public License v. 2.0. If a copy of the MPL was not distributed
8
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10
#include "
main.h
"
11
12
#include <Eigen/CXX11/Tensor>
13
14
using
Eigen::Tensor
;
15
16
template
<
int
DataLayout>
17
static
void
test_simple_patch
()
18
{
19
Tensor<float, 4, DataLayout>
tensor(2,3,5,7);
20
tensor.
setRandom
();
21
array<ptrdiff_t, 4>
patch_dims;
22
23
patch_dims[0] = 1;
24
patch_dims[1] = 1;
25
patch_dims[2] = 1;
26
patch_dims[3] = 1;
27
28
Tensor<float, 5, DataLayout>
no_patch;
29
no_patch = tensor.extract_patches(patch_dims);
30
31
if
(
DataLayout
==
ColMajor
) {
32
VERIFY_IS_EQUAL
(no_patch.
dimension
(0), 1);
33
VERIFY_IS_EQUAL
(no_patch.
dimension
(1), 1);
34
VERIFY_IS_EQUAL
(no_patch.
dimension
(2), 1);
35
VERIFY_IS_EQUAL
(no_patch.
dimension
(3), 1);
36
VERIFY_IS_EQUAL
(no_patch.
dimension
(4), tensor.
size
());
37
}
else
{
38
VERIFY_IS_EQUAL
(no_patch.
dimension
(0), tensor.
size
());
39
VERIFY_IS_EQUAL
(no_patch.
dimension
(1), 1);
40
VERIFY_IS_EQUAL
(no_patch.
dimension
(2), 1);
41
VERIFY_IS_EQUAL
(no_patch.
dimension
(3), 1);
42
VERIFY_IS_EQUAL
(no_patch.
dimension
(4), 1);
43
}
44
45
for
(
int
i
= 0;
i
< tensor.
size
(); ++
i
) {
46
VERIFY_IS_EQUAL
(tensor.
data
()[
i
], no_patch.
data
()[
i
]);
47
}
48
49
patch_dims[0] = 2;
50
patch_dims[1] = 3;
51
patch_dims[2] = 5;
52
patch_dims[3] = 7;
53
Tensor<float, 5, DataLayout>
single_patch;
54
single_patch = tensor.extract_patches(patch_dims);
55
56
if
(
DataLayout
==
ColMajor
) {
57
VERIFY_IS_EQUAL
(single_patch.
dimension
(0), 2);
58
VERIFY_IS_EQUAL
(single_patch.
dimension
(1), 3);
59
VERIFY_IS_EQUAL
(single_patch.
dimension
(2), 5);
60
VERIFY_IS_EQUAL
(single_patch.
dimension
(3), 7);
61
VERIFY_IS_EQUAL
(single_patch.
dimension
(4), 1);
62
}
else
{
63
VERIFY_IS_EQUAL
(single_patch.
dimension
(0), 1);
64
VERIFY_IS_EQUAL
(single_patch.
dimension
(1), 2);
65
VERIFY_IS_EQUAL
(single_patch.
dimension
(2), 3);
66
VERIFY_IS_EQUAL
(single_patch.
dimension
(3), 5);
67
VERIFY_IS_EQUAL
(single_patch.
dimension
(4), 7);
68
}
69
70
for
(
int
i
= 0;
i
< tensor.
size
(); ++
i
) {
71
VERIFY_IS_EQUAL
(tensor.
data
()[
i
], single_patch.
data
()[
i
]);
72
}
73
74
patch_dims[0] = 1;
75
patch_dims[1] = 2;
76
patch_dims[2] = 2;
77
patch_dims[3] = 1;
78
Tensor<float, 5, DataLayout>
twod_patch;
79
twod_patch = tensor.extract_patches(patch_dims);
80
81
if
(
DataLayout
==
ColMajor
) {
82
VERIFY_IS_EQUAL
(twod_patch.
dimension
(0), 1);
83
VERIFY_IS_EQUAL
(twod_patch.
dimension
(1), 2);
84
VERIFY_IS_EQUAL
(twod_patch.
dimension
(2), 2);
85
VERIFY_IS_EQUAL
(twod_patch.
dimension
(3), 1);
86
VERIFY_IS_EQUAL
(twod_patch.
dimension
(4), 2*2*4*7);
87
}
else
{
88
VERIFY_IS_EQUAL
(twod_patch.
dimension
(0), 2*2*4*7);
89
VERIFY_IS_EQUAL
(twod_patch.
dimension
(1), 1);
90
VERIFY_IS_EQUAL
(twod_patch.
dimension
(2), 2);
91
VERIFY_IS_EQUAL
(twod_patch.
dimension
(3), 2);
92
VERIFY_IS_EQUAL
(twod_patch.
dimension
(4), 1);
93
}
94
95
for
(
int
i
= 0;
i
< 2; ++
i
) {
96
for
(
int
j
= 0;
j
< 2; ++
j
) {
97
for
(
int
k = 0; k < 4; ++k) {
98
for
(
int
l
= 0;
l
< 7; ++
l
) {
99
int
patch_loc;
100
if
(
DataLayout
==
ColMajor
) {
101
patch_loc =
i
+ 2 * (
j
+ 2 * (k + 4 *
l
));
102
}
else
{
103
patch_loc =
l
+ 7 * (k + 4 * (
j
+ 2 *
i
));
104
}
105
for
(
int
x
= 0;
x
< 2; ++
x
) {
106
for
(
int
y
= 0;
y
< 2; ++
y
) {
107
if
(
DataLayout
==
ColMajor
) {
108
VERIFY_IS_EQUAL
(tensor(
i
,
j
+
x
,k+
y
,
l
), twod_patch(0,
x
,
y
,0,patch_loc));
109
}
else
{
110
VERIFY_IS_EQUAL
(tensor(
i
,
j
+
x
,k+
y
,
l
), twod_patch(patch_loc,0,
x
,
y
,0));
111
}
112
}
113
}
114
}
115
}
116
}
117
}
118
119
patch_dims[0] = 1;
120
patch_dims[1] = 2;
121
patch_dims[2] = 3;
122
patch_dims[3] = 5;
123
Tensor<float, 5, DataLayout>
threed_patch;
124
threed_patch = tensor.extract_patches(patch_dims);
125
126
if
(
DataLayout
==
ColMajor
) {
127
VERIFY_IS_EQUAL
(threed_patch.
dimension
(0), 1);
128
VERIFY_IS_EQUAL
(threed_patch.
dimension
(1), 2);
129
VERIFY_IS_EQUAL
(threed_patch.
dimension
(2), 3);
130
VERIFY_IS_EQUAL
(threed_patch.
dimension
(3), 5);
131
VERIFY_IS_EQUAL
(threed_patch.
dimension
(4), 2*2*3*3);
132
}
else
{
133
VERIFY_IS_EQUAL
(threed_patch.
dimension
(0), 2*2*3*3);
134
VERIFY_IS_EQUAL
(threed_patch.
dimension
(1), 1);
135
VERIFY_IS_EQUAL
(threed_patch.
dimension
(2), 2);
136
VERIFY_IS_EQUAL
(threed_patch.
dimension
(3), 3);
137
VERIFY_IS_EQUAL
(threed_patch.
dimension
(4), 5);
138
}
139
140
for
(
int
i
= 0;
i
< 2; ++
i
) {
141
for
(
int
j
= 0;
j
< 2; ++
j
) {
142
for
(
int
k = 0; k < 3; ++k) {
143
for
(
int
l
= 0;
l
< 3; ++
l
) {
144
int
patch_loc;
145
if
(
DataLayout
==
ColMajor
) {
146
patch_loc =
i
+ 2 * (
j
+ 2 * (k + 3 *
l
));
147
}
else
{
148
patch_loc =
l
+ 3 * (k + 3 * (
j
+ 2 *
i
));
149
}
150
for
(
int
x
= 0;
x
< 2; ++
x
) {
151
for
(
int
y
= 0;
y
< 3; ++
y
) {
152
for
(
int
z
= 0;
z
< 5; ++
z
) {
153
if
(
DataLayout
==
ColMajor
) {
154
VERIFY_IS_EQUAL
(tensor(
i
,
j
+
x
,k+
y
,
l
+
z
), threed_patch(0,
x
,
y
,
z
,patch_loc));
155
}
else
{
156
VERIFY_IS_EQUAL
(tensor(
i
,
j
+
x
,k+
y
,
l
+
z
), threed_patch(patch_loc,0,
x
,
y
,
z
));
157
}
158
}
159
}
160
}
161
}
162
}
163
}
164
}
165
}
166
167
EIGEN_DECLARE_TEST
(cxx11_tensor_patch)
168
{
169
CALL_SUBTEST
(test_simple_patch<ColMajor>());
170
CALL_SUBTEST
(test_simple_patch<RowMajor>());
171
// CALL_SUBTEST(test_expr_shuffling());
172
}
Eigen::Tensor::dimension
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index dimension(std::size_t n) const
Definition:
Tensor.h:101
Eigen::Tensor
The tensor class.
Definition:
Tensor.h:63
Eigen::array
Definition:
EmulateArray.h:21
VERIFY_IS_EQUAL
#define VERIFY_IS_EQUAL(a, b)
Definition:
main.h:386
DataLayout
static const int DataLayout
Definition:
cxx11_tensor_image_patch_sycl.cpp:24
x
set noclip points set clip one set noclip two set bar set border lt lw set xdata set ydata set zdata set x2data set y2data set boxwidth set dummy x
Definition:
gnuplot_common_settings.hh:12
j
std::ptrdiff_t j
Definition:
tut_arithmetic_redux_minmax.cpp:2
l
static const Line3 l(Rot3(), 1, 1)
pybind_wrapper_test_script.z
z
Definition:
pybind_wrapper_test_script.py:61
y
Scalar * y
Definition:
level1_cplx_impl.h:124
Eigen::TensorBase< Tensor< Scalar_, NumIndices_, Options_, IndexType_ > >::setRandom
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Tensor< Scalar_, NumIndices_, Options_, IndexType_ > & setRandom()
Definition:
TensorBase.h:996
EIGEN_DECLARE_TEST
EIGEN_DECLARE_TEST(cxx11_tensor_patch)
Definition:
cxx11_tensor_patch.cpp:167
main.h
Eigen::Tensor::data
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar * data()
Definition:
Tensor.h:104
Eigen::Tensor::size
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index size() const
Definition:
Tensor.h:103
Eigen::ColMajor
@ ColMajor
Definition:
Constants.h:319
test_simple_patch
static void test_simple_patch()
Definition:
cxx11_tensor_patch.cpp:17
i
int i
Definition:
BiCGSTAB_step_by_step.cpp:9
CALL_SUBTEST
#define CALL_SUBTEST(FUNC)
Definition:
main.h:399
gtsam
Author(s):
autogenerated on Wed Jan 1 2025 04:01:24