1 from __future__
import annotations
7 np = pytest.importorskip(
"numpy")
8 eigen_tensor = pytest.importorskip(
"pybind11_tests.eigen_tensor")
9 submodules = [eigen_tensor.c_style, eigen_tensor.f_style]
11 import eigen_tensor_avoid_stl_array
as avoid
13 submodules += [avoid.c_style, avoid.f_style]
14 except ImportError
as e:
17 "import eigen_tensor_avoid_stl_array FAILED, while "
18 "import pybind11_tests.eigen_tensor succeeded. "
20 "test_eigen_tensor.cpp & "
21 "eigen_tensor_avoid_stl_array.cpp "
22 "are built together (or both are not built if Eigen is not available)."
24 raise RuntimeError(msg)
from e
26 tensor_ref = np.empty((3, 5, 2), dtype=np.int64)
28 for i
in range(tensor_ref.shape[0]):
29 for j
in range(tensor_ref.shape[1]):
30 for k
in range(tensor_ref.shape[2]):
31 tensor_ref[i, j, k] = i * (5 * 2) + j * 2 + k
36 @pytest.fixture(autouse=
True)
38 for module
in submodules:
43 for module
in submodules:
48 pytest.importorskip(
"eigen_tensor_avoid_stl_array")
49 assert len(submodules) == 4
53 assert mat.flags.writeable == writeable
55 copy = np.array(tensor_ref)
56 if modified
is not None:
57 copy[indices] = modified
59 np.testing.assert_array_equal(mat, copy)
62 @pytest.mark.parametrize(
"m", submodules)
63 @pytest.mark.parametrize(
"member_name", [
"member",
"member_view"])
65 if not hasattr(sys,
"getrefcount"):
66 pytest.skip(
"No reference counting")
67 foo = m.CustomExample()
68 counts = sys.getrefcount(foo)
71 new_counts = sys.getrefcount(foo)
72 assert new_counts == counts + 1
75 assert sys.getrefcount(foo) == counts
78 assert_equal_funcs = [
83 "move_fixed_tensor_copy",
87 "reference_tensor_v2",
88 "reference_fixed_tensor",
89 "reference_view_of_tensor",
90 "reference_view_of_tensor_v3",
91 "reference_view_of_tensor_v5",
92 "reference_view_of_fixed_tensor",
95 assert_equal_const_funcs = [
96 "reference_view_of_tensor_v2",
97 "reference_view_of_tensor_v4",
98 "reference_view_of_tensor_v6",
99 "reference_const_tensor",
100 "reference_const_tensor_v2",
104 @pytest.mark.parametrize(
"m", submodules)
105 @pytest.mark.parametrize(
"func_name", assert_equal_funcs + assert_equal_const_funcs)
107 writeable = func_name
in assert_equal_funcs
111 @pytest.mark.parametrize(
"m", submodules)
114 RuntimeError, match=
"Cannot use reference internal when there is no parent"
116 m.reference_tensor_internal()
118 with pytest.raises(RuntimeError, match=
"Cannot move from a constant reference"):
119 m.move_const_tensor()
122 RuntimeError, match=
"Cannot take ownership of a const reference"
124 m.take_const_tensor()
128 match=
"Invalid return_value_policy for Eigen Map type, must be either reference or reference_internal",
133 @pytest.mark.parametrize(
"m", submodules)
136 TypeError, match=
r"^round_trip_tensor\(\): incompatible function arguments"
138 m.round_trip_tensor(np.zeros((2, 3)))
140 with pytest.raises(TypeError, match=
r"^Cannot cast array data from dtype"):
141 m.round_trip_tensor(np.zeros(dtype=np.str_, shape=(2, 3, 1)))
145 match=
r"^round_trip_tensor_noconvert\(\): incompatible function arguments",
147 m.round_trip_tensor_noconvert(tensor_ref)
150 m.round_trip_tensor_noconvert(tensor_ref.astype(np.float64))
153 bad_options =
"C" if m.needed_options ==
"F" else "F"
156 TypeError, match=
r"^round_trip_view_tensor\(\): incompatible function arguments"
158 m.round_trip_view_tensor(
159 np.zeros((3, 5, 2), dtype=np.float64, order=bad_options)
163 TypeError, match=
r"^round_trip_view_tensor\(\): incompatible function arguments"
165 m.round_trip_view_tensor(
166 np.zeros((3, 5, 2), dtype=np.float32, order=m.needed_options)
170 TypeError, match=
r"^round_trip_view_tensor\(\): incompatible function arguments"
172 m.round_trip_view_tensor(
173 np.zeros((3, 5), dtype=np.float64, order=m.needed_options)
176 temp = np.zeros((3, 5, 2), dtype=np.float64, order=m.needed_options)
178 TypeError, match=
r"^round_trip_view_tensor\(\): incompatible function arguments"
180 m.round_trip_view_tensor(
184 temp = np.zeros((3, 5, 2), dtype=np.float64, order=m.needed_options)
185 temp.setflags(write=
False)
187 TypeError, match=
r"^round_trip_view_tensor\(\): incompatible function arguments"
189 m.round_trip_view_tensor(temp)
192 @pytest.mark.parametrize(
"m", submodules)
194 a = m.reference_tensor()
201 a = m.reference_view_of_tensor()
208 @pytest.mark.parametrize(
"m", submodules)
212 with pytest.raises(TypeError, match=
"^Cannot cast array data from"):
219 copy = np.array(tensor_ref, dtype=np.float64, order=m.needed_options)
223 copy.setflags(write=
False)
226 np.testing.assert_array_equal(
227 tensor_ref[:, ::-1, :], m.round_trip_tensor(tensor_ref[:, ::-1, :])
230 assert m.round_trip_rank_0(np.float64(3.5)) == 3.5
231 assert m.round_trip_rank_0(3.5) == 3.5
235 match=
r"^round_trip_rank_0_noconvert\(\): incompatible function arguments",
237 m.round_trip_rank_0_noconvert(np.float64(3.5))
241 match=
r"^round_trip_rank_0_noconvert\(\): incompatible function arguments",
243 m.round_trip_rank_0_noconvert(3.5)
246 TypeError, match=
r"^round_trip_rank_0_view\(\): incompatible function arguments"
248 m.round_trip_rank_0_view(np.float64(3.5))
251 TypeError, match=
r"^round_trip_rank_0_view\(\): incompatible function arguments"
253 m.round_trip_rank_0_view(3.5)
256 @pytest.mark.parametrize(
"m", submodules)
259 copy = np.array(tensor_ref, dtype=np.float64, order=m.needed_options)
260 a = m.round_trip_view_tensor(copy)
268 @pytest.mark.parametrize(
"m", submodules)
271 doc(m.copy_tensor) ==
"copy_tensor() -> numpy.ndarray[numpy.float64[?, ?, ?]]"
274 doc(m.copy_fixed_tensor)
275 ==
"copy_fixed_tensor() -> numpy.ndarray[numpy.float64[3, 5, 2]]"
278 doc(m.reference_const_tensor)
279 ==
"reference_const_tensor() -> numpy.ndarray[numpy.float64[?, ?, ?]]"
282 order_flag = f
"flags.{m.needed_options.lower()}_contiguous"
283 assert doc(m.round_trip_view_tensor) == (
284 f
"round_trip_view_tensor(arg0: numpy.ndarray[numpy.float64[?, ?, ?], flags.writeable, {order_flag}])"
285 f
" -> numpy.ndarray[numpy.float64[?, ?, ?], flags.writeable, {order_flag}]"
287 assert doc(m.round_trip_const_view_tensor) == (
288 f
"round_trip_const_view_tensor(arg0: numpy.ndarray[numpy.float64[?, ?, ?], {order_flag}])"
289 " -> numpy.ndarray[numpy.float64[?, ?, ?]]"