5 np = pytest.importorskip(
"numpy")
6 eigen_tensor = pytest.importorskip(
"pybind11_tests.eigen_tensor")
7 submodules = [eigen_tensor.c_style, eigen_tensor.f_style]
9 import eigen_tensor_avoid_stl_array
as avoid
11 submodules += [avoid.c_style, avoid.f_style]
12 except ImportError
as e:
15 "import eigen_tensor_avoid_stl_array FAILED, while "
16 "import pybind11_tests.eigen_tensor succeeded. "
18 "test_eigen_tensor.cpp & "
19 "eigen_tensor_avoid_stl_array.cpp "
20 "are built together (or both are not built if Eigen is not available)."
22 raise RuntimeError(msg)
from e
24 tensor_ref = np.empty((3, 5, 2), dtype=np.int64)
26 for i
in range(tensor_ref.shape[0]):
27 for j
in range(tensor_ref.shape[1]):
28 for k
in range(tensor_ref.shape[2]):
29 tensor_ref[i, j, k] = i * (5 * 2) + j * 2 + k
34 @pytest.fixture(autouse=
True)
36 for module
in submodules:
41 for module
in submodules:
46 pytest.importorskip(
"eigen_tensor_avoid_stl_array")
47 assert len(submodules) == 4
51 assert mat.flags.writeable == writeable
53 copy = np.array(tensor_ref)
54 if modified
is not None:
55 copy[indices] = modified
57 np.testing.assert_array_equal(mat, copy)
60 @pytest.mark.parametrize(
"m", submodules)
61 @pytest.mark.parametrize(
"member_name", [
"member",
"member_view"])
63 if not hasattr(sys,
"getrefcount"):
64 pytest.skip(
"No reference counting")
65 foo = m.CustomExample()
66 counts = sys.getrefcount(foo)
69 new_counts = sys.getrefcount(foo)
70 assert new_counts == counts + 1
73 assert sys.getrefcount(foo) == counts
76 assert_equal_funcs = [
81 "move_fixed_tensor_copy",
85 "reference_tensor_v2",
86 "reference_fixed_tensor",
87 "reference_view_of_tensor",
88 "reference_view_of_tensor_v3",
89 "reference_view_of_tensor_v5",
90 "reference_view_of_fixed_tensor",
93 assert_equal_const_funcs = [
94 "reference_view_of_tensor_v2",
95 "reference_view_of_tensor_v4",
96 "reference_view_of_tensor_v6",
97 "reference_const_tensor",
98 "reference_const_tensor_v2",
102 @pytest.mark.parametrize(
"m", submodules)
103 @pytest.mark.parametrize(
"func_name", assert_equal_funcs + assert_equal_const_funcs)
105 writeable = func_name
in assert_equal_funcs
109 @pytest.mark.parametrize(
"m", submodules)
112 RuntimeError, match=
"Cannot use reference internal when there is no parent"
114 m.reference_tensor_internal()
116 with pytest.raises(RuntimeError, match=
"Cannot move from a constant reference"):
117 m.move_const_tensor()
120 RuntimeError, match=
"Cannot take ownership of a const reference"
122 m.take_const_tensor()
126 match=
"Invalid return_value_policy for Eigen Map type, must be either reference or reference_internal",
131 @pytest.mark.parametrize(
"m", submodules)
134 TypeError, match=
r"^round_trip_tensor\(\): incompatible function arguments"
136 m.round_trip_tensor(np.zeros((2, 3)))
138 with pytest.raises(TypeError, match=
r"^Cannot cast array data from dtype"):
139 m.round_trip_tensor(np.zeros(dtype=np.str_, shape=(2, 3, 1)))
143 match=
r"^round_trip_tensor_noconvert\(\): incompatible function arguments",
145 m.round_trip_tensor_noconvert(tensor_ref)
148 m.round_trip_tensor_noconvert(tensor_ref.astype(np.float64))
151 bad_options =
"C" if m.needed_options ==
"F" else "F"
154 TypeError, match=
r"^round_trip_view_tensor\(\): incompatible function arguments"
156 m.round_trip_view_tensor(
157 np.zeros((3, 5, 2), dtype=np.float64, order=bad_options)
161 TypeError, match=
r"^round_trip_view_tensor\(\): incompatible function arguments"
163 m.round_trip_view_tensor(
164 np.zeros((3, 5, 2), dtype=np.float32, order=m.needed_options)
168 TypeError, match=
r"^round_trip_view_tensor\(\): incompatible function arguments"
170 m.round_trip_view_tensor(
171 np.zeros((3, 5), dtype=np.float64, order=m.needed_options)
174 temp = np.zeros((3, 5, 2), dtype=np.float64, order=m.needed_options)
176 TypeError, match=
r"^round_trip_view_tensor\(\): incompatible function arguments"
178 m.round_trip_view_tensor(
182 temp = np.zeros((3, 5, 2), dtype=np.float64, order=m.needed_options)
183 temp.setflags(write=
False)
185 TypeError, match=
r"^round_trip_view_tensor\(\): incompatible function arguments"
187 m.round_trip_view_tensor(temp)
190 @pytest.mark.parametrize(
"m", submodules)
192 a = m.reference_tensor()
199 a = m.reference_view_of_tensor()
206 @pytest.mark.parametrize(
"m", submodules)
210 with pytest.raises(TypeError, match=
"^Cannot cast array data from"):
217 copy = np.array(tensor_ref, dtype=np.float64, order=m.needed_options)
221 copy.setflags(write=
False)
224 np.testing.assert_array_equal(
225 tensor_ref[:, ::-1, :], m.round_trip_tensor(tensor_ref[:, ::-1, :])
228 assert m.round_trip_rank_0(np.float64(3.5)) == 3.5
229 assert m.round_trip_rank_0(3.5) == 3.5
233 match=
r"^round_trip_rank_0_noconvert\(\): incompatible function arguments",
235 m.round_trip_rank_0_noconvert(np.float64(3.5))
239 match=
r"^round_trip_rank_0_noconvert\(\): incompatible function arguments",
241 m.round_trip_rank_0_noconvert(3.5)
244 TypeError, match=
r"^round_trip_rank_0_view\(\): incompatible function arguments"
246 m.round_trip_rank_0_view(np.float64(3.5))
249 TypeError, match=
r"^round_trip_rank_0_view\(\): incompatible function arguments"
251 m.round_trip_rank_0_view(3.5)
254 @pytest.mark.parametrize(
"m", submodules)
257 copy = np.array(tensor_ref, dtype=np.float64, order=m.needed_options)
258 a = m.round_trip_view_tensor(copy)
266 @pytest.mark.parametrize(
"m", submodules)
269 doc(m.copy_tensor) ==
"copy_tensor() -> numpy.ndarray[numpy.float64[?, ?, ?]]"
272 doc(m.copy_fixed_tensor)
273 ==
"copy_fixed_tensor() -> numpy.ndarray[numpy.float64[3, 5, 2]]"
276 doc(m.reference_const_tensor)
277 ==
"reference_const_tensor() -> numpy.ndarray[numpy.float64[?, ?, ?]]"
280 order_flag = f
"flags.{m.needed_options.lower()}_contiguous"
281 assert doc(m.round_trip_view_tensor) == (
282 f
"round_trip_view_tensor(arg0: numpy.ndarray[numpy.float64[?, ?, ?], flags.writeable, {order_flag}])"
283 f
" -> numpy.ndarray[numpy.float64[?, ?, ?], flags.writeable, {order_flag}]"
285 assert doc(m.round_trip_const_view_tensor) == (
286 f
"round_trip_const_view_tensor(arg0: numpy.ndarray[numpy.float64[?, ?, ?], {order_flag}])"
287 " -> numpy.ndarray[numpy.float64[?, ?, ?]]"