test_virtual_functions.py
Go to the documentation of this file.
1 import pytest
2 
3 import env # noqa: F401
4 
5 m = pytest.importorskip("pybind11_tests.virtual_functions")
6 from pybind11_tests import ConstructorStats # noqa: E402
7 
8 
9 def test_override(capture, msg):
10  class ExtendedExampleVirt(m.ExampleVirt):
11  def __init__(self, state):
12  super().__init__(state + 1)
13  self.data = "Hello world"
14 
15  def run(self, value):
16  print(f"ExtendedExampleVirt::run({value}), calling parent..")
17  return super().run(value + 1)
18 
19  def run_bool(self):
20  print("ExtendedExampleVirt::run_bool()")
21  return False
22 
23  def get_string1(self):
24  return "override1"
25 
26  def pure_virtual(self):
27  print(f"ExtendedExampleVirt::pure_virtual(): {self.data}")
28 
29  class ExtendedExampleVirt2(ExtendedExampleVirt):
30  def __init__(self, state):
31  super().__init__(state + 1)
32 
33  def get_string2(self):
34  return "override2"
35 
36  ex12 = m.ExampleVirt(10)
37  with capture:
38  assert m.runExampleVirt(ex12, 20) == 30
39  assert (
40  capture
41  == """
42  Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)
43  """
44  )
45 
46  with pytest.raises(RuntimeError) as excinfo:
47  m.runExampleVirtVirtual(ex12)
48  assert (
49  msg(excinfo.value)
50  == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"'
51  )
52 
53  ex12p = ExtendedExampleVirt(10)
54  with capture:
55  assert m.runExampleVirt(ex12p, 20) == 32
56  assert (
57  capture
58  == """
59  ExtendedExampleVirt::run(20), calling parent..
60  Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
61  """
62  )
63  with capture:
64  assert m.runExampleVirtBool(ex12p) is False
65  assert capture == "ExtendedExampleVirt::run_bool()"
66  with capture:
67  m.runExampleVirtVirtual(ex12p)
68  assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
69 
70  ex12p2 = ExtendedExampleVirt2(15)
71  with capture:
72  assert m.runExampleVirt(ex12p2, 50) == 68
73  assert (
74  capture
75  == """
76  ExtendedExampleVirt::run(50), calling parent..
77  Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
78  """
79  )
80 
81  cstats = ConstructorStats.get(m.ExampleVirt)
82  assert cstats.alive() == 3
83  del ex12, ex12p, ex12p2
84  assert cstats.alive() == 0
85  assert cstats.values() == ["10", "11", "17"]
86  assert cstats.copy_constructions == 0
87  assert cstats.move_constructions >= 0
88 
89 
91  """`A` only initializes its trampoline class when we inherit from it
92 
93  If we just create and use an A instance directly, the trampoline initialization is
94  bypassed and we only initialize an A() instead (for performance reasons).
95  """
96 
97  class B(m.A):
98  def __init__(self):
99  super().__init__()
100 
101  def f(self):
102  print("In python f()")
103 
104  # C++ version
105  with capture:
106  a = m.A()
107  m.call_f(a)
108  del a
109  pytest.gc_collect()
110  assert capture == "A.f()"
111 
112  # Python version
113  with capture:
114  b = B()
115  m.call_f(b)
116  del b
117  pytest.gc_collect()
118  assert (
119  capture
120  == """
121  PyA.PyA()
122  PyA.f()
123  In python f()
124  PyA.~PyA()
125  """
126  )
127 
128 
130  """`A2`, unlike the above, is configured to always initialize the alias
131 
132  While the extra initialization and extra class layer has small virtual dispatch
133  performance penalty, it also allows us to do more things with the trampoline
134  class such as defining local variables and performing construction/destruction.
135  """
136 
137  class B2(m.A2):
138  def __init__(self):
139  super().__init__()
140 
141  def f(self):
142  print("In python B2.f()")
143 
144  # No python subclass version
145  with capture:
146  a2 = m.A2()
147  m.call_f(a2)
148  del a2
149  pytest.gc_collect()
150  a3 = m.A2(1)
151  m.call_f(a3)
152  del a3
153  pytest.gc_collect()
154  assert (
155  capture
156  == """
157  PyA2.PyA2()
158  PyA2.f()
159  A2.f()
160  PyA2.~PyA2()
161  PyA2.PyA2()
162  PyA2.f()
163  A2.f()
164  PyA2.~PyA2()
165  """
166  )
167 
168  # Python subclass version
169  with capture:
170  b2 = B2()
171  m.call_f(b2)
172  del b2
173  pytest.gc_collect()
174  assert (
175  capture
176  == """
177  PyA2.PyA2()
178  PyA2.f()
179  In python B2.f()
180  PyA2.~PyA2()
181  """
182  )
183 
184 
185 # PyPy: Reference count > 1 causes call with noncopyable instance
186 # to fail in ncv1.print_nc()
187 @pytest.mark.xfail("env.PYPY")
188 @pytest.mark.skipif(
189  not hasattr(m, "NCVirt"), reason="NCVirt does not work on Intel/PGI/NVCC compilers"
190 )
192  class NCVirtExt(m.NCVirt):
193  def get_noncopyable(self, a, b):
194  # Constructs and returns a new instance:
195  return m.NonCopyable(a * a, b * b)
196 
197  def get_movable(self, a, b):
198  # Return a referenced copy
199  self.movable = m.Movable(a, b)
200  return self.movable
201 
202  class NCVirtExt2(m.NCVirt):
203  def get_noncopyable(self, a, b):
204  # Keep a reference: this is going to throw an exception
205  self.nc = m.NonCopyable(a, b)
206  return self.nc
207 
208  def get_movable(self, a, b):
209  # Return a new instance without storing it
210  return m.Movable(a, b)
211 
212  ncv1 = NCVirtExt()
213  assert ncv1.print_nc(2, 3) == "36"
214  assert ncv1.print_movable(4, 5) == "9"
215  ncv2 = NCVirtExt2()
216  assert ncv2.print_movable(7, 7) == "14"
217  # Don't check the exception message here because it differs under debug/non-debug mode
218  with pytest.raises(RuntimeError):
219  ncv2.print_nc(9, 9)
220 
221  nc_stats = ConstructorStats.get(m.NonCopyable)
222  mv_stats = ConstructorStats.get(m.Movable)
223  assert nc_stats.alive() == 1
224  assert mv_stats.alive() == 1
225  del ncv1, ncv2
226  assert nc_stats.alive() == 0
227  assert mv_stats.alive() == 0
228  assert nc_stats.values() == ["4", "9", "9", "9"]
229  assert mv_stats.values() == ["4", "5", "7", "7"]
230  assert nc_stats.copy_constructions == 0
231  assert mv_stats.copy_constructions == 1
232  assert nc_stats.move_constructions >= 0
233  assert mv_stats.move_constructions >= 0
234 
235 
237  """#159: virtual function dispatch has problems with similar-named functions"""
238 
239  class PyClass1(m.DispatchIssue):
240  def dispatch(self):
241  return "Yay.."
242 
243  class PyClass2(m.DispatchIssue):
244  def dispatch(self):
245  with pytest.raises(RuntimeError) as excinfo:
246  super().dispatch()
247  assert (
248  msg(excinfo.value)
249  == 'Tried to call pure virtual function "Base::dispatch"'
250  )
251 
252  return m.dispatch_issue_go(PyClass1())
253 
254  b = PyClass2()
255  assert m.dispatch_issue_go(b) == "Yay.."
256 
257 
259  """#3357: Recursive dispatch fails to find python function override"""
260 
261  class Data(m.Data):
262  def __init__(self, value):
263  super().__init__()
264  self.value = value
265 
266  class Adder(m.Adder):
267  def __call__(self, first, second, visitor):
268  # lambda is a workaround, which adds extra frame to the
269  # current CPython thread. Removing lambda reveals the bug
270  # [https://github.com/pybind/pybind11/issues/3357]
271  (lambda: visitor(Data(first.value + second.value)))() # noqa: PLC3002
272 
273  class StoreResultVisitor:
274  def __init__(self):
275  self.result = None
276 
277  def __call__(self, data):
278  self.result = data.value
279 
280  store = StoreResultVisitor()
281 
282  m.add2(Data(1), Data(2), Adder(), store)
283  assert store.result == 3
284 
285  # without lambda in Adder class, this function fails with
286  # RuntimeError: Tried to call pure virtual function "AdderBase::__call__"
287  m.add3(Data(1), Data(2), Data(3), Adder(), store)
288  assert store.result == 6
289 
290 
292  """#392/397: overriding reference-returning functions"""
293  o = m.OverrideTest("asdf")
294 
295  # Not allowed (see associated .cpp comment)
296  # i = o.str_ref()
297  # assert o.str_ref() == "asdf"
298  assert o.str_value() == "asdf"
299 
300  assert o.A_value().value == "hi"
301  a = o.A_ref()
302  assert a.value == "hi"
303  a.value = "bye"
304  assert a.value == "bye"
305 
306 
308  class AR(m.A_Repeat):
309  def unlucky_number(self):
310  return 99
311 
312  class AT(m.A_Tpl):
313  def unlucky_number(self):
314  return 999
315 
316  obj = AR()
317  assert obj.say_something(3) == "hihihi"
318  assert obj.unlucky_number() == 99
319  assert obj.say_everything() == "hi 99"
320 
321  obj = AT()
322  assert obj.say_something(3) == "hihihi"
323  assert obj.unlucky_number() == 999
324  assert obj.say_everything() == "hi 999"
325 
326  for obj in [m.B_Repeat(), m.B_Tpl()]:
327  assert obj.say_something(3) == "B says hi 3 times"
328  assert obj.unlucky_number() == 13
329  assert obj.lucky_number() == 7.0
330  assert obj.say_everything() == "B says hi 1 times 13"
331 
332  for obj in [m.C_Repeat(), m.C_Tpl()]:
333  assert obj.say_something(3) == "B says hi 3 times"
334  assert obj.unlucky_number() == 4444
335  assert obj.lucky_number() == 888.0
336  assert obj.say_everything() == "B says hi 1 times 4444"
337 
338  class CR(m.C_Repeat):
339  def lucky_number(self):
340  return m.C_Repeat.lucky_number(self) + 1.25
341 
342  obj = CR()
343  assert obj.say_something(3) == "B says hi 3 times"
344  assert obj.unlucky_number() == 4444
345  assert obj.lucky_number() == 889.25
346  assert obj.say_everything() == "B says hi 1 times 4444"
347 
348  class CT(m.C_Tpl):
349  pass
350 
351  obj = CT()
352  assert obj.say_something(3) == "B says hi 3 times"
353  assert obj.unlucky_number() == 4444
354  assert obj.lucky_number() == 888.0
355  assert obj.say_everything() == "B says hi 1 times 4444"
356 
357  class CCR(CR):
358  def lucky_number(self):
359  return CR.lucky_number(self) * 10
360 
361  obj = CCR()
362  assert obj.say_something(3) == "B says hi 3 times"
363  assert obj.unlucky_number() == 4444
364  assert obj.lucky_number() == 8892.5
365  assert obj.say_everything() == "B says hi 1 times 4444"
366 
367  class CCT(CT):
368  def lucky_number(self):
369  return CT.lucky_number(self) * 1000
370 
371  obj = CCT()
372  assert obj.say_something(3) == "B says hi 3 times"
373  assert obj.unlucky_number() == 4444
374  assert obj.lucky_number() == 888000.0
375  assert obj.say_everything() == "B says hi 1 times 4444"
376 
377  class DR(m.D_Repeat):
378  def unlucky_number(self):
379  return 123
380 
381  def lucky_number(self):
382  return 42.0
383 
384  for obj in [m.D_Repeat(), m.D_Tpl()]:
385  assert obj.say_something(3) == "B says hi 3 times"
386  assert obj.unlucky_number() == 4444
387  assert obj.lucky_number() == 888.0
388  assert obj.say_everything() == "B says hi 1 times 4444"
389 
390  obj = DR()
391  assert obj.say_something(3) == "B says hi 3 times"
392  assert obj.unlucky_number() == 123
393  assert obj.lucky_number() == 42.0
394  assert obj.say_everything() == "B says hi 1 times 123"
395 
396  class DT(m.D_Tpl):
397  def say_something(self, times):
398  return "DT says:" + (" quack" * times)
399 
400  def unlucky_number(self):
401  return 1234
402 
403  def lucky_number(self):
404  return -4.25
405 
406  obj = DT()
407  assert obj.say_something(3) == "DT says: quack quack quack"
408  assert obj.unlucky_number() == 1234
409  assert obj.lucky_number() == -4.25
410  assert obj.say_everything() == "DT says: quack 1234"
411 
412  class DT2(DT):
413  def say_something(self, times):
414  return "DT2: " + ("QUACK" * times)
415 
416  def unlucky_number(self):
417  return -3
418 
419  class BT(m.B_Tpl):
420  def say_something(self, times):
421  return "BT" * times
422 
423  def unlucky_number(self):
424  return -7
425 
426  def lucky_number(self):
427  return -1.375
428 
429  obj = BT()
430  assert obj.say_something(3) == "BTBTBT"
431  assert obj.unlucky_number() == -7
432  assert obj.lucky_number() == -1.375
433  assert obj.say_everything() == "BT -7"
434 
435 
437  # Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7)
438  m.test_gil()
439  m.test_gil_from_thread()
440 
441 
443  def func():
444  class Test(m.test_override_cache_helper):
445  def func(self):
446  return 42
447 
448  return Test()
449 
450  def func2():
451  class Test(m.test_override_cache_helper):
452  pass
453 
454  return Test()
455 
456  for _ in range(1500):
457  assert m.test_override_cache(func()) == 42
458  assert m.test_override_cache(func2()) == 0
Eigen::internal::print
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
Definition: NEON/PacketMath.h:3115
Adder
Definition: test_virtual_functions.cpp:183
test_virtual_functions.test_dispatch_issue
def test_dispatch_issue(msg)
Definition: test_virtual_functions.py:236
test_virtual_functions.test_alias_delay_initialization2
def test_alias_delay_initialization2(capture)
Definition: test_virtual_functions.py:129
test_virtual_functions.test_python_override
def test_python_override()
Definition: test_virtual_functions.py:442
B
Definition: test_numpy_dtypes.cpp:299
hasattr
bool hasattr(handle obj, handle name)
Definition: pytypes.h:853
test_virtual_functions.test_inherited_virtuals
def test_inherited_virtuals()
Definition: test_virtual_functions.py:307
test_trampoline.func2
def func2()
Definition: test_trampoline.py:12
test_virtual_functions.test_override
def test_override(capture, msg)
Definition: test_virtual_functions.py:9
gtsam::range
Double_ range(const Point2_ &p, const Point2_ &q)
Definition: slam/expressions.h:30
test_virtual_functions.test_move_support
def test_move_support()
Definition: test_virtual_functions.py:191
test_virtual_functions.test_alias_delay_initialization1
def test_alias_delay_initialization1(capture)
Definition: test_virtual_functions.py:90
DT
Definition: testDecisionTree.cpp:95
test_virtual_functions.test_issue_1454
def test_issue_1454()
Definition: test_virtual_functions.py:436
func
int func(const int &a)
Definition: testDSF.cpp:221
gtsam.examples.DogLegOptimizerExample.run
def run(args)
Definition: DogLegOptimizerExample.py:21
gtwrap.interface_parser.function.__init__
def __init__(self, Union[Type, TemplatedType] ctype, str name, ParseResults default=None)
Definition: interface_parser/function.py:41
test_virtual_functions.test_override_ref
def test_override_ref()
Definition: test_virtual_functions.py:291
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
ConstructorStats::get
static ConstructorStats & get(std::type_index type)
Definition: constructor_stats.h:163
Test
Definition: Test.h:30
func
Definition: benchGeometry.cpp:23
test_virtual_functions.test_recursive_dispatch_issue
def test_recursive_dispatch_issue()
Definition: test_virtual_functions.py:258
pybind11.msg
msg
Definition: wrap/pybind11/pybind11/__init__.py:4


gtsam
Author(s):
autogenerated on Thu Jun 13 2024 03:08:55