test_virtual_functions.py
Go to the documentation of this file.
1 from __future__ import annotations
2 
3 import pytest
4 
5 import env # noqa: F401
6 
7 m = pytest.importorskip("pybind11_tests.virtual_functions")
8 from pybind11_tests import ConstructorStats # noqa: E402
9 
10 
11 def test_override(capture, msg):
12  class ExtendedExampleVirt(m.ExampleVirt):
13  def __init__(self, state):
14  super().__init__(state + 1)
15  self.data = "Hello world"
16 
17  def run(self, value):
18  print(f"ExtendedExampleVirt::run({value}), calling parent..")
19  return super().run(value + 1)
20 
21  def run_bool(self):
22  print("ExtendedExampleVirt::run_bool()")
23  return False
24 
25  def get_string1(self):
26  return "override1"
27 
28  def pure_virtual(self):
29  print(f"ExtendedExampleVirt::pure_virtual(): {self.data}")
30 
31  class ExtendedExampleVirt2(ExtendedExampleVirt):
32  def __init__(self, state):
33  super().__init__(state + 1)
34 
35  def get_string2(self):
36  return "override2"
37 
38  ex12 = m.ExampleVirt(10)
39  with capture:
40  assert m.runExampleVirt(ex12, 20) == 30
41  assert (
42  capture
43  == """
44  Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)
45  """
46  )
47 
48  with pytest.raises(RuntimeError) as excinfo:
49  m.runExampleVirtVirtual(ex12)
50  assert (
51  msg(excinfo.value)
52  == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"'
53  )
54 
55  ex12p = ExtendedExampleVirt(10)
56  with capture:
57  assert m.runExampleVirt(ex12p, 20) == 32
58  assert (
59  capture
60  == """
61  ExtendedExampleVirt::run(20), calling parent..
62  Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
63  """
64  )
65  with capture:
66  assert m.runExampleVirtBool(ex12p) is False
67  assert capture == "ExtendedExampleVirt::run_bool()"
68  with capture:
69  m.runExampleVirtVirtual(ex12p)
70  assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
71 
72  ex12p2 = ExtendedExampleVirt2(15)
73  with capture:
74  assert m.runExampleVirt(ex12p2, 50) == 68
75  assert (
76  capture
77  == """
78  ExtendedExampleVirt::run(50), calling parent..
79  Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
80  """
81  )
82 
83  cstats = ConstructorStats.get(m.ExampleVirt)
84  assert cstats.alive() == 3
85  del ex12, ex12p, ex12p2
86  assert cstats.alive() == 0
87  assert cstats.values() == ["10", "11", "17"]
88  assert cstats.copy_constructions == 0
89  assert cstats.move_constructions >= 0
90 
91 
93  """`A` only initializes its trampoline class when we inherit from it
94 
95  If we just create and use an A instance directly, the trampoline initialization is
96  bypassed and we only initialize an A() instead (for performance reasons).
97  """
98 
99  class B(m.A):
100  def __init__(self):
101  super().__init__()
102 
103  def f(self):
104  print("In python f()")
105 
106  # C++ version
107  with capture:
108  a = m.A()
109  m.call_f(a)
110  del a
111  pytest.gc_collect()
112  assert capture == "A.f()"
113 
114  # Python version
115  with capture:
116  b = B()
117  m.call_f(b)
118  del b
119  pytest.gc_collect()
120  assert (
121  capture
122  == """
123  PyA.PyA()
124  PyA.f()
125  In python f()
126  PyA.~PyA()
127  """
128  )
129 
130 
132  """`A2`, unlike the above, is configured to always initialize the alias
133 
134  While the extra initialization and extra class layer has small virtual dispatch
135  performance penalty, it also allows us to do more things with the trampoline
136  class such as defining local variables and performing construction/destruction.
137  """
138 
139  class B2(m.A2):
140  def __init__(self):
141  super().__init__()
142 
143  def f(self):
144  print("In python B2.f()")
145 
146  # No python subclass version
147  with capture:
148  a2 = m.A2()
149  m.call_f(a2)
150  del a2
151  pytest.gc_collect()
152  a3 = m.A2(1)
153  m.call_f(a3)
154  del a3
155  pytest.gc_collect()
156  assert (
157  capture
158  == """
159  PyA2.PyA2()
160  PyA2.f()
161  A2.f()
162  PyA2.~PyA2()
163  PyA2.PyA2()
164  PyA2.f()
165  A2.f()
166  PyA2.~PyA2()
167  """
168  )
169 
170  # Python subclass version
171  with capture:
172  b2 = B2()
173  m.call_f(b2)
174  del b2
175  pytest.gc_collect()
176  assert (
177  capture
178  == """
179  PyA2.PyA2()
180  PyA2.f()
181  In python B2.f()
182  PyA2.~PyA2()
183  """
184  )
185 
186 
187 # PyPy: Reference count > 1 causes call with noncopyable instance
188 # to fail in ncv1.print_nc()
189 @pytest.mark.xfail("env.PYPY")
190 @pytest.mark.skipif(
191  not hasattr(m, "NCVirt"), reason="NCVirt does not work on Intel/PGI/NVCC compilers"
192 )
194  class NCVirtExt(m.NCVirt):
195  def get_noncopyable(self, a, b):
196  # Constructs and returns a new instance:
197  return m.NonCopyable(a * a, b * b)
198 
199  def get_movable(self, a, b):
200  # Return a referenced copy
201  self.movable = m.Movable(a, b)
202  return self.movable
203 
204  class NCVirtExt2(m.NCVirt):
205  def get_noncopyable(self, a, b):
206  # Keep a reference: this is going to throw an exception
207  self.nc = m.NonCopyable(a, b)
208  return self.nc
209 
210  def get_movable(self, a, b):
211  # Return a new instance without storing it
212  return m.Movable(a, b)
213 
214  ncv1 = NCVirtExt()
215  assert ncv1.print_nc(2, 3) == "36"
216  assert ncv1.print_movable(4, 5) == "9"
217  ncv2 = NCVirtExt2()
218  assert ncv2.print_movable(7, 7) == "14"
219  # Don't check the exception message here because it differs under debug/non-debug mode
220  with pytest.raises(RuntimeError):
221  ncv2.print_nc(9, 9)
222 
223  nc_stats = ConstructorStats.get(m.NonCopyable)
224  mv_stats = ConstructorStats.get(m.Movable)
225  assert nc_stats.alive() == 1
226  assert mv_stats.alive() == 1
227  del ncv1, ncv2
228  assert nc_stats.alive() == 0
229  assert mv_stats.alive() == 0
230  assert nc_stats.values() == ["4", "9", "9", "9"]
231  assert mv_stats.values() == ["4", "5", "7", "7"]
232  assert nc_stats.copy_constructions == 0
233  assert mv_stats.copy_constructions == 1
234  assert nc_stats.move_constructions >= 0
235  assert mv_stats.move_constructions >= 0
236 
237 
239  """#159: virtual function dispatch has problems with similar-named functions"""
240 
241  class PyClass1(m.DispatchIssue):
242  def dispatch(self):
243  return "Yay.."
244 
245  class PyClass2(m.DispatchIssue):
246  def dispatch(self):
247  with pytest.raises(RuntimeError) as excinfo:
248  super().dispatch()
249  assert (
250  msg(excinfo.value)
251  == 'Tried to call pure virtual function "Base::dispatch"'
252  )
253 
254  return m.dispatch_issue_go(PyClass1())
255 
256  b = PyClass2()
257  assert m.dispatch_issue_go(b) == "Yay.."
258 
259 
261  """#3357: Recursive dispatch fails to find python function override"""
262 
263  class Data(m.Data):
264  def __init__(self, value):
265  super().__init__()
266  self.value = value
267 
268  class Adder(m.Adder):
269  def __call__(self, first, second, visitor):
270  # lambda is a workaround, which adds extra frame to the
271  # current CPython thread. Removing lambda reveals the bug
272  # [https://github.com/pybind/pybind11/issues/3357]
273  (lambda: visitor(Data(first.value + second.value)))() # noqa: PLC3002
274 
275  class StoreResultVisitor:
276  def __init__(self):
277  self.result = None
278 
279  def __call__(self, data):
280  self.result = data.value
281 
282  store = StoreResultVisitor()
283 
284  m.add2(Data(1), Data(2), Adder(), store)
285  assert store.result == 3
286 
287  # without lambda in Adder class, this function fails with
288  # RuntimeError: Tried to call pure virtual function "AdderBase::__call__"
289  m.add3(Data(1), Data(2), Data(3), Adder(), store)
290  assert store.result == 6
291 
292 
294  """#392/397: overriding reference-returning functions"""
295  o = m.OverrideTest("asdf")
296 
297  # Not allowed (see associated .cpp comment)
298  # i = o.str_ref()
299  # assert o.str_ref() == "asdf"
300  assert o.str_value() == "asdf"
301 
302  assert o.A_value().value == "hi"
303  a = o.A_ref()
304  assert a.value == "hi"
305  a.value = "bye"
306  assert a.value == "bye"
307 
308 
310  class AR(m.A_Repeat):
311  def unlucky_number(self):
312  return 99
313 
314  class AT(m.A_Tpl):
315  def unlucky_number(self):
316  return 999
317 
318  obj = AR()
319  assert obj.say_something(3) == "hihihi"
320  assert obj.unlucky_number() == 99
321  assert obj.say_everything() == "hi 99"
322 
323  obj = AT()
324  assert obj.say_something(3) == "hihihi"
325  assert obj.unlucky_number() == 999
326  assert obj.say_everything() == "hi 999"
327 
328  for obj in [m.B_Repeat(), m.B_Tpl()]:
329  assert obj.say_something(3) == "B says hi 3 times"
330  assert obj.unlucky_number() == 13
331  assert obj.lucky_number() == 7.0
332  assert obj.say_everything() == "B says hi 1 times 13"
333 
334  for obj in [m.C_Repeat(), m.C_Tpl()]:
335  assert obj.say_something(3) == "B says hi 3 times"
336  assert obj.unlucky_number() == 4444
337  assert obj.lucky_number() == 888.0
338  assert obj.say_everything() == "B says hi 1 times 4444"
339 
340  class CR(m.C_Repeat):
341  def lucky_number(self):
342  return m.C_Repeat.lucky_number(self) + 1.25
343 
344  obj = CR()
345  assert obj.say_something(3) == "B says hi 3 times"
346  assert obj.unlucky_number() == 4444
347  assert obj.lucky_number() == 889.25
348  assert obj.say_everything() == "B says hi 1 times 4444"
349 
350  class CT(m.C_Tpl):
351  pass
352 
353  obj = CT()
354  assert obj.say_something(3) == "B says hi 3 times"
355  assert obj.unlucky_number() == 4444
356  assert obj.lucky_number() == 888.0
357  assert obj.say_everything() == "B says hi 1 times 4444"
358 
359  class CCR(CR):
360  def lucky_number(self):
361  return CR.lucky_number(self) * 10
362 
363  obj = CCR()
364  assert obj.say_something(3) == "B says hi 3 times"
365  assert obj.unlucky_number() == 4444
366  assert obj.lucky_number() == 8892.5
367  assert obj.say_everything() == "B says hi 1 times 4444"
368 
369  class CCT(CT):
370  def lucky_number(self):
371  return CT.lucky_number(self) * 1000
372 
373  obj = CCT()
374  assert obj.say_something(3) == "B says hi 3 times"
375  assert obj.unlucky_number() == 4444
376  assert obj.lucky_number() == 888000.0
377  assert obj.say_everything() == "B says hi 1 times 4444"
378 
379  class DR(m.D_Repeat):
380  def unlucky_number(self):
381  return 123
382 
383  def lucky_number(self):
384  return 42.0
385 
386  for obj in [m.D_Repeat(), m.D_Tpl()]:
387  assert obj.say_something(3) == "B says hi 3 times"
388  assert obj.unlucky_number() == 4444
389  assert obj.lucky_number() == 888.0
390  assert obj.say_everything() == "B says hi 1 times 4444"
391 
392  obj = DR()
393  assert obj.say_something(3) == "B says hi 3 times"
394  assert obj.unlucky_number() == 123
395  assert obj.lucky_number() == 42.0
396  assert obj.say_everything() == "B says hi 1 times 123"
397 
398  class DT(m.D_Tpl):
399  def say_something(self, times):
400  return "DT says:" + (" quack" * times)
401 
402  def unlucky_number(self):
403  return 1234
404 
405  def lucky_number(self):
406  return -4.25
407 
408  obj = DT()
409  assert obj.say_something(3) == "DT says: quack quack quack"
410  assert obj.unlucky_number() == 1234
411  assert obj.lucky_number() == -4.25
412  assert obj.say_everything() == "DT says: quack 1234"
413 
414  class DT2(DT):
415  def say_something(self, times):
416  return "DT2: " + ("QUACK" * times)
417 
418  def unlucky_number(self):
419  return -3
420 
421  class BT(m.B_Tpl):
422  def say_something(self, times):
423  return "BT" * times
424 
425  def unlucky_number(self):
426  return -7
427 
428  def lucky_number(self):
429  return -1.375
430 
431  obj = BT()
432  assert obj.say_something(3) == "BTBTBT"
433  assert obj.unlucky_number() == -7
434  assert obj.lucky_number() == -1.375
435  assert obj.say_everything() == "BT -7"
436 
437 
439  # Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7)
440  m.test_gil()
441  m.test_gil_from_thread()
442 
443 
445  def func():
446  class Test(m.test_override_cache_helper):
447  def func(self):
448  return 42
449 
450  return Test()
451 
452  def func2():
453  class Test(m.test_override_cache_helper):
454  pass
455 
456  return Test()
457 
458  for _ in range(1500):
459  assert m.test_override_cache(func()) == 42
460  assert m.test_override_cache(func2()) == 0
Eigen::internal::print
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
Definition: NEON/PacketMath.h:3115
Adder
Definition: test_virtual_functions.cpp:183
test_virtual_functions.test_dispatch_issue
def test_dispatch_issue(msg)
Definition: test_virtual_functions.py:238
test_virtual_functions.test_alias_delay_initialization2
def test_alias_delay_initialization2(capture)
Definition: test_virtual_functions.py:131
test_virtual_functions.test_python_override
def test_python_override()
Definition: test_virtual_functions.py:444
B
Definition: test_numpy_dtypes.cpp:299
hasattr
bool hasattr(handle obj, handle name)
Definition: pytypes.h:870
test_virtual_functions.test_inherited_virtuals
def test_inherited_virtuals()
Definition: test_virtual_functions.py:309
test_trampoline.func2
def func2()
Definition: test_trampoline.py:14
test_virtual_functions.test_override
def test_override(capture, msg)
Definition: test_virtual_functions.py:11
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
test_virtual_functions.test_move_support
def test_move_support()
Definition: test_virtual_functions.py:193
test_virtual_functions.test_alias_delay_initialization1
def test_alias_delay_initialization1(capture)
Definition: test_virtual_functions.py:92
DT
Definition: testDecisionTree.cpp:96
test_virtual_functions.test_issue_1454
def test_issue_1454()
Definition: test_virtual_functions.py:438
func
int func(const int &a)
Definition: testDSF.cpp:221
gtsam.examples.DogLegOptimizerExample.run
def run(args)
Definition: DogLegOptimizerExample.py:21
gtwrap.interface_parser.function.__init__
def __init__(self, Union[Type, TemplatedType] ctype, str name, ParseResults default=None)
Definition: interface_parser/function.py:41
test_virtual_functions.test_override_ref
def test_override_ref()
Definition: test_virtual_functions.py:293
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
ConstructorStats::get
static ConstructorStats & get(std::type_index type)
Definition: constructor_stats.h:163
Test
Definition: Test.h:30
func
Definition: benchGeometry.cpp:23
test_virtual_functions.test_recursive_dispatch_issue
def test_recursive_dispatch_issue()
Definition: test_virtual_functions.py:260
pybind11.msg
msg
Definition: wrap/pybind11/pybind11/__init__.py:6


gtsam
Author(s):
autogenerated on Thu Dec 19 2024 04:06:00