test_virtual_functions.py
Go to the documentation of this file.
1 from __future__ import annotations
2 
3 import sys
4 
5 import pytest
6 
7 import env # noqa: F401
8 
9 m = pytest.importorskip("pybind11_tests.virtual_functions")
10 from pybind11_tests import ConstructorStats # noqa: E402
11 
12 
13 def test_override(capture, msg):
14  class ExtendedExampleVirt(m.ExampleVirt):
15  def __init__(self, state):
16  super().__init__(state + 1)
17  self.data = "Hello world"
18 
19  def run(self, value):
20  print(f"ExtendedExampleVirt::run({value}), calling parent..")
21  return super().run(value + 1)
22 
23  def run_bool(self):
24  print("ExtendedExampleVirt::run_bool()")
25  return False
26 
27  def get_string1(self):
28  return "override1"
29 
30  def pure_virtual(self):
31  print(f"ExtendedExampleVirt::pure_virtual(): {self.data}")
32 
33  class ExtendedExampleVirt2(ExtendedExampleVirt):
34  def __init__(self, state):
35  super().__init__(state + 1)
36 
37  def get_string2(self):
38  return "override2"
39 
40  ex12 = m.ExampleVirt(10)
41  with capture:
42  assert m.runExampleVirt(ex12, 20) == 30
43  assert (
44  capture
45  == """
46  Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)
47  """
48  )
49 
50  with pytest.raises(RuntimeError) as excinfo:
51  m.runExampleVirtVirtual(ex12)
52  assert (
53  msg(excinfo.value)
54  == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"'
55  )
56 
57  ex12p = ExtendedExampleVirt(10)
58  with capture:
59  assert m.runExampleVirt(ex12p, 20) == 32
60  assert (
61  capture
62  == """
63  ExtendedExampleVirt::run(20), calling parent..
64  Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
65  """
66  )
67  with capture:
68  assert m.runExampleVirtBool(ex12p) is False
69  assert capture == "ExtendedExampleVirt::run_bool()"
70  with capture:
71  m.runExampleVirtVirtual(ex12p)
72  assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
73 
74  ex12p2 = ExtendedExampleVirt2(15)
75  with capture:
76  assert m.runExampleVirt(ex12p2, 50) == 68
77  assert (
78  capture
79  == """
80  ExtendedExampleVirt::run(50), calling parent..
81  Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
82  """
83  )
84 
85  cstats = ConstructorStats.get(m.ExampleVirt)
86  assert cstats.alive() == 3
87  del ex12, ex12p, ex12p2
88  assert cstats.alive() == 0
89  assert cstats.values() == ["10", "11", "17"]
90  assert cstats.copy_constructions == 0
91  assert cstats.move_constructions >= 0
92 
93 
95  """`A` only initializes its trampoline class when we inherit from it
96 
97  If we just create and use an A instance directly, the trampoline initialization is
98  bypassed and we only initialize an A() instead (for performance reasons).
99  """
100 
101  class B(m.A):
102  def __init__(self):
103  super().__init__()
104 
105  def f(self):
106  print("In python f()")
107 
108  # C++ version
109  with capture:
110  a = m.A()
111  m.call_f(a)
112  del a
113  pytest.gc_collect()
114  assert capture == "A.f()"
115 
116  # Python version
117  with capture:
118  b = B()
119  m.call_f(b)
120  del b
121  pytest.gc_collect()
122  assert (
123  capture
124  == """
125  PyA.PyA()
126  PyA.f()
127  In python f()
128  PyA.~PyA()
129  """
130  )
131 
132 
134  """`A2`, unlike the above, is configured to always initialize the alias
135 
136  While the extra initialization and extra class layer has small virtual dispatch
137  performance penalty, it also allows us to do more things with the trampoline
138  class such as defining local variables and performing construction/destruction.
139  """
140 
141  class B2(m.A2):
142  def __init__(self):
143  super().__init__()
144 
145  def f(self):
146  print("In python B2.f()")
147 
148  # No python subclass version
149  with capture:
150  a2 = m.A2()
151  m.call_f(a2)
152  del a2
153  pytest.gc_collect()
154  a3 = m.A2(1)
155  m.call_f(a3)
156  del a3
157  pytest.gc_collect()
158  assert (
159  capture
160  == """
161  PyA2.PyA2()
162  PyA2.f()
163  A2.f()
164  PyA2.~PyA2()
165  PyA2.PyA2()
166  PyA2.f()
167  A2.f()
168  PyA2.~PyA2()
169  """
170  )
171 
172  # Python subclass version
173  with capture:
174  b2 = B2()
175  m.call_f(b2)
176  del b2
177  pytest.gc_collect()
178  assert (
179  capture
180  == """
181  PyA2.PyA2()
182  PyA2.f()
183  In python B2.f()
184  PyA2.~PyA2()
185  """
186  )
187 
188 
189 # PyPy: Reference count > 1 causes call with noncopyable instance
190 # to fail in ncv1.print_nc()
191 @pytest.mark.xfail("env.PYPY")
192 @pytest.mark.skipif(
193  not hasattr(m, "NCVirt"), reason="NCVirt does not work on Intel/PGI/NVCC compilers"
194 )
196  class NCVirtExt(m.NCVirt):
197  def get_noncopyable(self, a, b):
198  # Constructs and returns a new instance:
199  return m.NonCopyable(a * a, b * b)
200 
201  def get_movable(self, a, b):
202  # Return a referenced copy
203  self.movable = m.Movable(a, b)
204  return self.movable
205 
206  class NCVirtExt2(m.NCVirt):
207  def get_noncopyable(self, a, b):
208  # Keep a reference: this is going to throw an exception
209  self.nc = m.NonCopyable(a, b)
210  return self.nc
211 
212  def get_movable(self, a, b):
213  # Return a new instance without storing it
214  return m.Movable(a, b)
215 
216  ncv1 = NCVirtExt()
217  assert ncv1.print_nc(2, 3) == "36"
218  assert ncv1.print_movable(4, 5) == "9"
219  ncv2 = NCVirtExt2()
220  assert ncv2.print_movable(7, 7) == "14"
221  # Don't check the exception message here because it differs under debug/non-debug mode
222  with pytest.raises(RuntimeError):
223  ncv2.print_nc(9, 9)
224 
225  nc_stats = ConstructorStats.get(m.NonCopyable)
226  mv_stats = ConstructorStats.get(m.Movable)
227  assert nc_stats.alive() == 1
228  assert mv_stats.alive() == 1
229  del ncv1, ncv2
230  assert nc_stats.alive() == 0
231  assert mv_stats.alive() == 0
232  assert nc_stats.values() == ["4", "9", "9", "9"]
233  assert mv_stats.values() == ["4", "5", "7", "7"]
234  assert nc_stats.copy_constructions == 0
235  assert mv_stats.copy_constructions == 1
236  assert nc_stats.move_constructions >= 0
237  assert mv_stats.move_constructions >= 0
238 
239 
241  """#159: virtual function dispatch has problems with similar-named functions"""
242 
243  class PyClass1(m.DispatchIssue):
244  def dispatch(self):
245  return "Yay.."
246 
247  class PyClass2(m.DispatchIssue):
248  def dispatch(self):
249  with pytest.raises(RuntimeError) as excinfo:
250  super().dispatch()
251  assert (
252  msg(excinfo.value)
253  == 'Tried to call pure virtual function "Base::dispatch"'
254  )
255 
256  return m.dispatch_issue_go(PyClass1())
257 
258  b = PyClass2()
259  assert m.dispatch_issue_go(b) == "Yay.."
260 
261 
263  """#3357: Recursive dispatch fails to find python function override"""
264 
265  class Data(m.Data):
266  def __init__(self, value):
267  super().__init__()
268  self.value = value
269 
270  class Adder(m.Adder):
271  def __call__(self, first, second, visitor):
272  # lambda is a workaround, which adds extra frame to the
273  # current CPython thread. Removing lambda reveals the bug
274  # [https://github.com/pybind/pybind11/issues/3357]
275  (lambda: visitor(Data(first.value + second.value)))() # noqa: PLC3002
276 
277  class StoreResultVisitor:
278  def __init__(self):
279  self.result = None
280 
281  def __call__(self, data):
282  self.result = data.value
283 
284  store = StoreResultVisitor()
285 
286  m.add2(Data(1), Data(2), Adder(), store)
287  assert store.result == 3
288 
289  # without lambda in Adder class, this function fails with
290  # RuntimeError: Tried to call pure virtual function "AdderBase::__call__"
291  m.add3(Data(1), Data(2), Data(3), Adder(), store)
292  assert store.result == 6
293 
294 
296  """#392/397: overriding reference-returning functions"""
297  o = m.OverrideTest("asdf")
298 
299  # Not allowed (see associated .cpp comment)
300  # i = o.str_ref()
301  # assert o.str_ref() == "asdf"
302  assert o.str_value() == "asdf"
303 
304  assert o.A_value().value == "hi"
305  a = o.A_ref()
306  assert a.value == "hi"
307  a.value = "bye"
308  assert a.value == "bye"
309 
310 
312  class AR(m.A_Repeat):
313  def unlucky_number(self):
314  return 99
315 
316  class AT(m.A_Tpl):
317  def unlucky_number(self):
318  return 999
319 
320  obj = AR()
321  assert obj.say_something(3) == "hihihi"
322  assert obj.unlucky_number() == 99
323  assert obj.say_everything() == "hi 99"
324 
325  obj = AT()
326  assert obj.say_something(3) == "hihihi"
327  assert obj.unlucky_number() == 999
328  assert obj.say_everything() == "hi 999"
329 
330  for obj in [m.B_Repeat(), m.B_Tpl()]:
331  assert obj.say_something(3) == "B says hi 3 times"
332  assert obj.unlucky_number() == 13
333  assert obj.lucky_number() == 7.0
334  assert obj.say_everything() == "B says hi 1 times 13"
335 
336  for obj in [m.C_Repeat(), m.C_Tpl()]:
337  assert obj.say_something(3) == "B says hi 3 times"
338  assert obj.unlucky_number() == 4444
339  assert obj.lucky_number() == 888.0
340  assert obj.say_everything() == "B says hi 1 times 4444"
341 
342  class CR(m.C_Repeat):
343  def lucky_number(self):
344  return m.C_Repeat.lucky_number(self) + 1.25
345 
346  obj = CR()
347  assert obj.say_something(3) == "B says hi 3 times"
348  assert obj.unlucky_number() == 4444
349  assert obj.lucky_number() == 889.25
350  assert obj.say_everything() == "B says hi 1 times 4444"
351 
352  class CT(m.C_Tpl):
353  pass
354 
355  obj = CT()
356  assert obj.say_something(3) == "B says hi 3 times"
357  assert obj.unlucky_number() == 4444
358  assert obj.lucky_number() == 888.0
359  assert obj.say_everything() == "B says hi 1 times 4444"
360 
361  class CCR(CR):
362  def lucky_number(self):
363  return CR.lucky_number(self) * 10
364 
365  obj = CCR()
366  assert obj.say_something(3) == "B says hi 3 times"
367  assert obj.unlucky_number() == 4444
368  assert obj.lucky_number() == 8892.5
369  assert obj.say_everything() == "B says hi 1 times 4444"
370 
371  class CCT(CT):
372  def lucky_number(self):
373  return CT.lucky_number(self) * 1000
374 
375  obj = CCT()
376  assert obj.say_something(3) == "B says hi 3 times"
377  assert obj.unlucky_number() == 4444
378  assert obj.lucky_number() == 888000.0
379  assert obj.say_everything() == "B says hi 1 times 4444"
380 
381  class DR(m.D_Repeat):
382  def unlucky_number(self):
383  return 123
384 
385  def lucky_number(self):
386  return 42.0
387 
388  for obj in [m.D_Repeat(), m.D_Tpl()]:
389  assert obj.say_something(3) == "B says hi 3 times"
390  assert obj.unlucky_number() == 4444
391  assert obj.lucky_number() == 888.0
392  assert obj.say_everything() == "B says hi 1 times 4444"
393 
394  obj = DR()
395  assert obj.say_something(3) == "B says hi 3 times"
396  assert obj.unlucky_number() == 123
397  assert obj.lucky_number() == 42.0
398  assert obj.say_everything() == "B says hi 1 times 123"
399 
400  class DT(m.D_Tpl):
401  def say_something(self, times):
402  return "DT says:" + (" quack" * times)
403 
404  def unlucky_number(self):
405  return 1234
406 
407  def lucky_number(self):
408  return -4.25
409 
410  obj = DT()
411  assert obj.say_something(3) == "DT says: quack quack quack"
412  assert obj.unlucky_number() == 1234
413  assert obj.lucky_number() == -4.25
414  assert obj.say_everything() == "DT says: quack 1234"
415 
416  class DT2(DT):
417  def say_something(self, times):
418  return "DT2: " + ("QUACK" * times)
419 
420  def unlucky_number(self):
421  return -3
422 
423  class BT(m.B_Tpl):
424  def say_something(self, times):
425  return "BT" * times
426 
427  def unlucky_number(self):
428  return -7
429 
430  def lucky_number(self):
431  return -1.375
432 
433  obj = BT()
434  assert obj.say_something(3) == "BTBTBT"
435  assert obj.unlucky_number() == -7
436  assert obj.lucky_number() == -1.375
437  assert obj.say_everything() == "BT -7"
438 
439 
440 @pytest.mark.skipif(sys.platform.startswith("emscripten"), reason="Requires threads")
442  # Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7)
443  m.test_gil()
444  m.test_gil_from_thread()
445 
446 
448  def func():
449  class Test(m.test_override_cache_helper):
450  def func(self):
451  return 42
452 
453  return Test()
454 
455  def func2():
456  class Test(m.test_override_cache_helper):
457  pass
458 
459  return Test()
460 
461  for _ in range(1500):
462  assert m.test_override_cache(func()) == 42
463  assert m.test_override_cache(func2()) == 0
Eigen::internal::print
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
Definition: NEON/PacketMath.h:3115
Adder
Definition: test_virtual_functions.cpp:183
test_virtual_functions.test_dispatch_issue
def test_dispatch_issue(msg)
Definition: test_virtual_functions.py:240
test_virtual_functions.test_alias_delay_initialization2
def test_alias_delay_initialization2(capture)
Definition: test_virtual_functions.py:133
test_virtual_functions.test_python_override
def test_python_override()
Definition: test_virtual_functions.py:447
B
Definition: test_numpy_dtypes.cpp:301
hasattr
bool hasattr(handle obj, handle name)
Definition: pytypes.h:870
test_virtual_functions.test_inherited_virtuals
def test_inherited_virtuals()
Definition: test_virtual_functions.py:311
test_trampoline.func2
def func2()
Definition: test_trampoline.py:14
test_virtual_functions.test_override
def test_override(capture, msg)
Definition: test_virtual_functions.py:13
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
test_virtual_functions.test_move_support
def test_move_support()
Definition: test_virtual_functions.py:195
test_virtual_functions.test_alias_delay_initialization1
def test_alias_delay_initialization1(capture)
Definition: test_virtual_functions.py:94
DT
Definition: testDecisionTree.cpp:96
test_virtual_functions.test_issue_1454
def test_issue_1454()
Definition: test_virtual_functions.py:441
func
int func(const int &a)
Definition: testDSF.cpp:221
gtsam.examples.DogLegOptimizerExample.run
def run(args)
Definition: DogLegOptimizerExample.py:21
gtwrap.interface_parser.function.__init__
def __init__(self, Union[Type, TemplatedType] ctype, str name, ParseResults default=None)
Definition: interface_parser/function.py:41
test_virtual_functions.test_override_ref
def test_override_ref()
Definition: test_virtual_functions.py:295
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
ConstructorStats::get
static ConstructorStats & get(std::type_index type)
Definition: constructor_stats.h:163
Test
Definition: Test.h:30
func
Definition: benchGeometry.cpp:23
test_virtual_functions.test_recursive_dispatch_issue
def test_recursive_dispatch_issue()
Definition: test_virtual_functions.py:262
pybind11.msg
msg
Definition: wrap/pybind11/pybind11/__init__.py:6


gtsam
Author(s):
autogenerated on Wed Mar 19 2025 03:06:21