test_virtual_functions.py
Go to the documentation of this file.
1 import pytest
2 
3 import env # noqa: F401
4 
5 m = pytest.importorskip("pybind11_tests.virtual_functions")
6 from pybind11_tests import ConstructorStats # noqa: E402
7 
8 
9 def test_override(capture, msg):
10  class ExtendedExampleVirt(m.ExampleVirt):
11  def __init__(self, state):
12  super().__init__(state + 1)
13  self.data = "Hello world"
14 
15  def run(self, value):
16  print(f"ExtendedExampleVirt::run({value}), calling parent..")
17  return super().run(value + 1)
18 
19  def run_bool(self):
20  print("ExtendedExampleVirt::run_bool()")
21  return False
22 
23  def get_string1(self):
24  return "override1"
25 
26  def pure_virtual(self):
27  print(f"ExtendedExampleVirt::pure_virtual(): {self.data}")
28 
29  class ExtendedExampleVirt2(ExtendedExampleVirt):
30  def __init__(self, state):
31  super().__init__(state + 1)
32 
33  def get_string2(self):
34  return "override2"
35 
36  ex12 = m.ExampleVirt(10)
37  with capture:
38  assert m.runExampleVirt(ex12, 20) == 30
39  assert (
40  capture
41  == """
42  Original implementation of ExampleVirt::run(state=10, value=20, str1=default1, str2=default2)
43  """
44  )
45 
46  with pytest.raises(RuntimeError) as excinfo:
47  m.runExampleVirtVirtual(ex12)
48  assert (
49  msg(excinfo.value)
50  == 'Tried to call pure virtual function "ExampleVirt::pure_virtual"'
51  )
52 
53  ex12p = ExtendedExampleVirt(10)
54  with capture:
55  assert m.runExampleVirt(ex12p, 20) == 32
56  assert (
57  capture
58  == """
59  ExtendedExampleVirt::run(20), calling parent..
60  Original implementation of ExampleVirt::run(state=11, value=21, str1=override1, str2=default2)
61  """
62  )
63  with capture:
64  assert m.runExampleVirtBool(ex12p) is False
65  assert capture == "ExtendedExampleVirt::run_bool()"
66  with capture:
67  m.runExampleVirtVirtual(ex12p)
68  assert capture == "ExtendedExampleVirt::pure_virtual(): Hello world"
69 
70  ex12p2 = ExtendedExampleVirt2(15)
71  with capture:
72  assert m.runExampleVirt(ex12p2, 50) == 68
73  assert (
74  capture
75  == """
76  ExtendedExampleVirt::run(50), calling parent..
77  Original implementation of ExampleVirt::run(state=17, value=51, str1=override1, str2=override2)
78  """
79  )
80 
81  cstats = ConstructorStats.get(m.ExampleVirt)
82  assert cstats.alive() == 3
83  del ex12, ex12p, ex12p2
84  assert cstats.alive() == 0
85  assert cstats.values() == ["10", "11", "17"]
86  assert cstats.copy_constructions == 0
87  assert cstats.move_constructions >= 0
88 
89 
91  """`A` only initializes its trampoline class when we inherit from it
92 
93  If we just create and use an A instance directly, the trampoline initialization is
94  bypassed and we only initialize an A() instead (for performance reasons).
95  """
96 
97  class B(m.A):
98  def __init__(self):
99  super().__init__()
100 
101  def f(self):
102  print("In python f()")
103 
104  # C++ version
105  with capture:
106  a = m.A()
107  m.call_f(a)
108  del a
109  pytest.gc_collect()
110  assert capture == "A.f()"
111 
112  # Python version
113  with capture:
114  b = B()
115  m.call_f(b)
116  del b
117  pytest.gc_collect()
118  assert (
119  capture
120  == """
121  PyA.PyA()
122  PyA.f()
123  In python f()
124  PyA.~PyA()
125  """
126  )
127 
128 
130  """`A2`, unlike the above, is configured to always initialize the alias
131 
132  While the extra initialization and extra class layer has small virtual dispatch
133  performance penalty, it also allows us to do more things with the trampoline
134  class such as defining local variables and performing construction/destruction.
135  """
136 
137  class B2(m.A2):
138  def __init__(self):
139  super().__init__()
140 
141  def f(self):
142  print("In python B2.f()")
143 
144  # No python subclass version
145  with capture:
146  a2 = m.A2()
147  m.call_f(a2)
148  del a2
149  pytest.gc_collect()
150  a3 = m.A2(1)
151  m.call_f(a3)
152  del a3
153  pytest.gc_collect()
154  assert (
155  capture
156  == """
157  PyA2.PyA2()
158  PyA2.f()
159  A2.f()
160  PyA2.~PyA2()
161  PyA2.PyA2()
162  PyA2.f()
163  A2.f()
164  PyA2.~PyA2()
165  """
166  )
167 
168  # Python subclass version
169  with capture:
170  b2 = B2()
171  m.call_f(b2)
172  del b2
173  pytest.gc_collect()
174  assert (
175  capture
176  == """
177  PyA2.PyA2()
178  PyA2.f()
179  In python B2.f()
180  PyA2.~PyA2()
181  """
182  )
183 
184 
185 # PyPy: Reference count > 1 causes call with noncopyable instance
186 # to fail in ncv1.print_nc()
187 @pytest.mark.xfail("env.PYPY")
188 @pytest.mark.skipif(
189  not hasattr(m, "NCVirt"), reason="NCVirt does not work on Intel/PGI/NVCC compilers"
190 )
192  class NCVirtExt(m.NCVirt):
193  def get_noncopyable(self, a, b):
194  # Constructs and returns a new instance:
195  nc = m.NonCopyable(a * a, b * b)
196  return nc
197 
198  def get_movable(self, a, b):
199  # Return a referenced copy
200  self.movable = m.Movable(a, b)
201  return self.movable
202 
203  class NCVirtExt2(m.NCVirt):
204  def get_noncopyable(self, a, b):
205  # Keep a reference: this is going to throw an exception
206  self.nc = m.NonCopyable(a, b)
207  return self.nc
208 
209  def get_movable(self, a, b):
210  # Return a new instance without storing it
211  return m.Movable(a, b)
212 
213  ncv1 = NCVirtExt()
214  assert ncv1.print_nc(2, 3) == "36"
215  assert ncv1.print_movable(4, 5) == "9"
216  ncv2 = NCVirtExt2()
217  assert ncv2.print_movable(7, 7) == "14"
218  # Don't check the exception message here because it differs under debug/non-debug mode
219  with pytest.raises(RuntimeError):
220  ncv2.print_nc(9, 9)
221 
222  nc_stats = ConstructorStats.get(m.NonCopyable)
223  mv_stats = ConstructorStats.get(m.Movable)
224  assert nc_stats.alive() == 1
225  assert mv_stats.alive() == 1
226  del ncv1, ncv2
227  assert nc_stats.alive() == 0
228  assert mv_stats.alive() == 0
229  assert nc_stats.values() == ["4", "9", "9", "9"]
230  assert mv_stats.values() == ["4", "5", "7", "7"]
231  assert nc_stats.copy_constructions == 0
232  assert mv_stats.copy_constructions == 1
233  assert nc_stats.move_constructions >= 0
234  assert mv_stats.move_constructions >= 0
235 
236 
238  """#159: virtual function dispatch has problems with similar-named functions"""
239 
240  class PyClass1(m.DispatchIssue):
241  def dispatch(self):
242  return "Yay.."
243 
244  class PyClass2(m.DispatchIssue):
245  def dispatch(self):
246  with pytest.raises(RuntimeError) as excinfo:
247  super().dispatch()
248  assert (
249  msg(excinfo.value)
250  == 'Tried to call pure virtual function "Base::dispatch"'
251  )
252 
253  return m.dispatch_issue_go(PyClass1())
254 
255  b = PyClass2()
256  assert m.dispatch_issue_go(b) == "Yay.."
257 
258 
260  """#3357: Recursive dispatch fails to find python function override"""
261 
262  class Data(m.Data):
263  def __init__(self, value):
264  super().__init__()
265  self.value = value
266 
267  class Adder(m.Adder):
268  def __call__(self, first, second, visitor):
269  # lambda is a workaround, which adds extra frame to the
270  # current CPython thread. Removing lambda reveals the bug
271  # [https://github.com/pybind/pybind11/issues/3357]
272  (lambda: visitor(Data(first.value + second.value)))()
273 
274  class StoreResultVisitor:
275  def __init__(self):
276  self.result = None
277 
278  def __call__(self, data):
279  self.result = data.value
280 
281  store = StoreResultVisitor()
282 
283  m.add2(Data(1), Data(2), Adder(), store)
284  assert store.result == 3
285 
286  # without lambda in Adder class, this function fails with
287  # RuntimeError: Tried to call pure virtual function "AdderBase::__call__"
288  m.add3(Data(1), Data(2), Data(3), Adder(), store)
289  assert store.result == 6
290 
291 
293  """#392/397: overriding reference-returning functions"""
294  o = m.OverrideTest("asdf")
295 
296  # Not allowed (see associated .cpp comment)
297  # i = o.str_ref()
298  # assert o.str_ref() == "asdf"
299  assert o.str_value() == "asdf"
300 
301  assert o.A_value().value == "hi"
302  a = o.A_ref()
303  assert a.value == "hi"
304  a.value = "bye"
305  assert a.value == "bye"
306 
307 
309  class AR(m.A_Repeat):
310  def unlucky_number(self):
311  return 99
312 
313  class AT(m.A_Tpl):
314  def unlucky_number(self):
315  return 999
316 
317  obj = AR()
318  assert obj.say_something(3) == "hihihi"
319  assert obj.unlucky_number() == 99
320  assert obj.say_everything() == "hi 99"
321 
322  obj = AT()
323  assert obj.say_something(3) == "hihihi"
324  assert obj.unlucky_number() == 999
325  assert obj.say_everything() == "hi 999"
326 
327  for obj in [m.B_Repeat(), m.B_Tpl()]:
328  assert obj.say_something(3) == "B says hi 3 times"
329  assert obj.unlucky_number() == 13
330  assert obj.lucky_number() == 7.0
331  assert obj.say_everything() == "B says hi 1 times 13"
332 
333  for obj in [m.C_Repeat(), m.C_Tpl()]:
334  assert obj.say_something(3) == "B says hi 3 times"
335  assert obj.unlucky_number() == 4444
336  assert obj.lucky_number() == 888.0
337  assert obj.say_everything() == "B says hi 1 times 4444"
338 
339  class CR(m.C_Repeat):
340  def lucky_number(self):
341  return m.C_Repeat.lucky_number(self) + 1.25
342 
343  obj = CR()
344  assert obj.say_something(3) == "B says hi 3 times"
345  assert obj.unlucky_number() == 4444
346  assert obj.lucky_number() == 889.25
347  assert obj.say_everything() == "B says hi 1 times 4444"
348 
349  class CT(m.C_Tpl):
350  pass
351 
352  obj = CT()
353  assert obj.say_something(3) == "B says hi 3 times"
354  assert obj.unlucky_number() == 4444
355  assert obj.lucky_number() == 888.0
356  assert obj.say_everything() == "B says hi 1 times 4444"
357 
358  class CCR(CR):
359  def lucky_number(self):
360  return CR.lucky_number(self) * 10
361 
362  obj = CCR()
363  assert obj.say_something(3) == "B says hi 3 times"
364  assert obj.unlucky_number() == 4444
365  assert obj.lucky_number() == 8892.5
366  assert obj.say_everything() == "B says hi 1 times 4444"
367 
368  class CCT(CT):
369  def lucky_number(self):
370  return CT.lucky_number(self) * 1000
371 
372  obj = CCT()
373  assert obj.say_something(3) == "B says hi 3 times"
374  assert obj.unlucky_number() == 4444
375  assert obj.lucky_number() == 888000.0
376  assert obj.say_everything() == "B says hi 1 times 4444"
377 
378  class DR(m.D_Repeat):
379  def unlucky_number(self):
380  return 123
381 
382  def lucky_number(self):
383  return 42.0
384 
385  for obj in [m.D_Repeat(), m.D_Tpl()]:
386  assert obj.say_something(3) == "B says hi 3 times"
387  assert obj.unlucky_number() == 4444
388  assert obj.lucky_number() == 888.0
389  assert obj.say_everything() == "B says hi 1 times 4444"
390 
391  obj = DR()
392  assert obj.say_something(3) == "B says hi 3 times"
393  assert obj.unlucky_number() == 123
394  assert obj.lucky_number() == 42.0
395  assert obj.say_everything() == "B says hi 1 times 123"
396 
397  class DT(m.D_Tpl):
398  def say_something(self, times):
399  return "DT says:" + (" quack" * times)
400 
401  def unlucky_number(self):
402  return 1234
403 
404  def lucky_number(self):
405  return -4.25
406 
407  obj = DT()
408  assert obj.say_something(3) == "DT says: quack quack quack"
409  assert obj.unlucky_number() == 1234
410  assert obj.lucky_number() == -4.25
411  assert obj.say_everything() == "DT says: quack 1234"
412 
413  class DT2(DT):
414  def say_something(self, times):
415  return "DT2: " + ("QUACK" * times)
416 
417  def unlucky_number(self):
418  return -3
419 
420  class BT(m.B_Tpl):
421  def say_something(self, times):
422  return "BT" * times
423 
424  def unlucky_number(self):
425  return -7
426 
427  def lucky_number(self):
428  return -1.375
429 
430  obj = BT()
431  assert obj.say_something(3) == "BTBTBT"
432  assert obj.unlucky_number() == -7
433  assert obj.lucky_number() == -1.375
434  assert obj.say_everything() == "BT -7"
435 
436 
438  # Fix issue #1454 (crash when acquiring/releasing GIL on another thread in Python 2.7)
439  m.test_gil()
440  m.test_gil_from_thread()
441 
442 
444  def func():
445  class Test(m.test_override_cache_helper):
446  def func(self):
447  return 42
448 
449  return Test()
450 
451  def func2():
452  class Test(m.test_override_cache_helper):
453  pass
454 
455  return Test()
456 
457  for _ in range(1500):
458  assert m.test_override_cache(func()) == 42
459  assert m.test_override_cache(func2()) == 0
def test_alias_delay_initialization1(capture)
bool hasattr(handle obj, handle name)
Definition: pytypes.h:728
Definition: Test.h:30
EIGEN_STRONG_INLINE Packet4f print(const Packet4f &a)
static ConstructorStats & get(std::type_index type)
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
def test_alias_delay_initialization2(capture)
int func(const int &a)
Definition: testDSF.cpp:221
Double_ range(const Point2_ &p, const Point2_ &q)
def test_override(capture, msg)


gtsam
Author(s):
autogenerated on Tue Jul 4 2023 02:37:46