test_interpreter.cpp
Go to the documentation of this file.
1 #include <pybind11/embed.h>
2 
3 // Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to
4 // catch 2.0.1; this should be fixed in the next catch release after 2.0.1).
6 
7 #include <catch.hpp>
8 #include <cstdlib>
9 #include <fstream>
10 #include <functional>
11 #include <thread>
12 #include <utility>
13 
14 namespace py = pybind11;
15 using namespace py::literals;
16 
18  auto sys_path = py::module::import("sys").attr("path");
19  return py::len(sys_path);
20 }
21 
22 class Widget {
23 public:
24  explicit Widget(std::string message) : message(std::move(message)) {}
25  virtual ~Widget() = default;
26 
27  std::string the_message() const { return message; }
28  virtual int the_answer() const = 0;
29  virtual std::string argv0() const = 0;
30 
31 private:
32  std::string message;
33 };
34 
35 class PyWidget final : public Widget {
36  using Widget::Widget;
37 
38  int the_answer() const override { PYBIND11_OVERRIDE_PURE(int, Widget, the_answer); }
39  std::string argv0() const override { PYBIND11_OVERRIDE_PURE(std::string, Widget, argv0); }
40 };
41 
43 
44 public:
45  virtual int func() { return 0; }
46 
47  test_override_cache_helper() = default;
48  virtual ~test_override_cache_helper() = default;
49  // Non-copyable
50  test_override_cache_helper &operator=(test_override_cache_helper const &Right) = delete;
52 };
53 
56 };
57 
58 PYBIND11_EMBEDDED_MODULE(widget_module, m) {
59  py::class_<Widget, PyWidget>(m, "Widget")
60  .def(py::init<std::string>())
61  .def_property_readonly("the_message", &Widget::the_message);
62 
63  m.def("add", [](int i, int j) { return i + j; });
64 }
65 
66 PYBIND11_EMBEDDED_MODULE(trampoline_module, m) {
67  py::class_<test_override_cache_helper,
69  std::shared_ptr<test_override_cache_helper>>(m, "test_override_cache_helper")
70  .def(py::init_alias<>())
71  .def("func", &test_override_cache_helper::func);
72 }
73 
74 PYBIND11_EMBEDDED_MODULE(throw_exception, ) { throw std::runtime_error("C++ Error"); }
75 
76 PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
77  auto d = py::dict();
78  d["missing"].cast<py::object>();
79 }
80 
81 TEST_CASE("PYTHONPATH is used to update sys.path") {
82  // The setup for this TEST_CASE is in catch.cpp!
83  auto sys_path = py::str(py::module_::import("sys").attr("path")).cast<std::string>();
84  REQUIRE_THAT(sys_path,
85  Catch::Matchers::Contains("pybind11_test_embed_PYTHONPATH_2099743835476552"));
86 }
87 
88 TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
89  auto module_ = py::module_::import("test_interpreter");
90  REQUIRE(py::hasattr(module_, "DerivedWidget"));
91 
92  auto locals = py::dict("hello"_a = "Hello, World!", "x"_a = 5, **module_.attr("__dict__"));
93  py::exec(R"(
94  widget = DerivedWidget("{} - {}".format(hello, x))
95  message = widget.the_message
96  )",
97  py::globals(),
98  locals);
99  REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");
100 
101  auto py_widget = module_.attr("DerivedWidget")("The question");
102  auto message = py_widget.attr("the_message");
103  REQUIRE(message.cast<std::string>() == "The question");
104 
105  const auto &cpp_widget = py_widget.cast<const Widget &>();
106  REQUIRE(cpp_widget.the_answer() == 42);
107 }
108 
109 TEST_CASE("Override cache") {
110  auto module_ = py::module_::import("test_trampoline");
111  REQUIRE(py::hasattr(module_, "func"));
112  REQUIRE(py::hasattr(module_, "func2"));
113 
114  auto locals = py::dict(**module_.attr("__dict__"));
115 
116  int i = 0;
117  for (; i < 1500; ++i) {
118  std::shared_ptr<test_override_cache_helper> p_obj;
119  std::shared_ptr<test_override_cache_helper> p_obj2;
120 
121  py::object loc_inst = locals["func"]();
122  p_obj = py::cast<std::shared_ptr<test_override_cache_helper>>(loc_inst);
123 
124  int ret = p_obj->func();
125 
126  REQUIRE(ret == 42);
127 
128  loc_inst = locals["func2"]();
129 
130  p_obj2 = py::cast<std::shared_ptr<test_override_cache_helper>>(loc_inst);
131 
132  p_obj2->func();
133  }
134 }
135 
136 TEST_CASE("Import error handling") {
137  REQUIRE_NOTHROW(py::module_::import("widget_module"));
138  REQUIRE_THROWS_WITH(py::module_::import("throw_exception"), "ImportError: C++ Error");
139  REQUIRE_THROWS_WITH(py::module_::import("throw_error_already_set"),
140  Catch::Contains("ImportError: initialization failed"));
141 
142  auto locals = py::dict("is_keyerror"_a = false, "message"_a = "not set");
143  py::exec(R"(
144  try:
145  import throw_error_already_set
146  except ImportError as e:
147  is_keyerror = type(e.__cause__) == KeyError
148  message = str(e.__cause__)
149  )",
150  py::globals(),
151  locals);
152  REQUIRE(locals["is_keyerror"].cast<bool>() == true);
153  REQUIRE(locals["message"].cast<std::string>() == "'missing'");
154 }
155 
156 TEST_CASE("There can be only one interpreter") {
161 
162  REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
163  REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");
164 
166  REQUIRE_NOTHROW(py::scoped_interpreter());
167  {
168  auto pyi1 = py::scoped_interpreter();
169  auto pyi2 = std::move(pyi1);
170  }
172 }
173 
174 #if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
175 TEST_CASE("Custom PyConfig") {
177  PyConfig config;
178  PyConfig_InitPythonConfig(&config);
179  REQUIRE_NOTHROW(py::scoped_interpreter{&config});
180  {
181  py::scoped_interpreter p{&config};
182  REQUIRE(py::module_::import("widget_module").attr("add")(1, 41).cast<int>() == 42);
183  }
185 }
186 
187 TEST_CASE("scoped_interpreter with PyConfig_InitIsolatedConfig and argv") {
189  {
190  PyConfig config;
191  PyConfig_InitIsolatedConfig(&config);
192  char *argv[] = {strdup("a.out")};
193  py::scoped_interpreter argv_scope{&config, 1, argv};
194  std::free(argv[0]);
195  auto module = py::module::import("test_interpreter");
196  auto py_widget = module.attr("DerivedWidget")("The question");
197  const auto &cpp_widget = py_widget.cast<const Widget &>();
198  REQUIRE(cpp_widget.argv0() == "a.out");
199  }
201 }
202 
203 TEST_CASE("scoped_interpreter with PyConfig_InitPythonConfig and argv") {
205  {
206  PyConfig config;
207  PyConfig_InitPythonConfig(&config);
208 
209  // `initialize_interpreter() overrides the default value for config.parse_argv (`1`) by
210  // changing it to `0`. This test exercises `scoped_interpreter` with the default config.
211  char *argv[] = {strdup("a.out"), strdup("arg1")};
212  py::scoped_interpreter argv_scope(&config, 2, argv);
213  std::free(argv[0]);
214  std::free(argv[1]);
215  auto module = py::module::import("test_interpreter");
216  auto py_widget = module.attr("DerivedWidget")("The question");
217  const auto &cpp_widget = py_widget.cast<const Widget &>();
218  REQUIRE(cpp_widget.argv0() == "arg1");
219  }
221 }
222 #endif
223 
224 TEST_CASE("Add program dir to path pre-PyConfig") {
226  size_t path_size_add_program_dir_to_path_false = 0;
227  {
228  py::scoped_interpreter scoped_interp{true, 0, nullptr, false};
229  path_size_add_program_dir_to_path_false = get_sys_path_size();
230  }
231  {
232  py::scoped_interpreter scoped_interp{};
233  REQUIRE(get_sys_path_size() == path_size_add_program_dir_to_path_false + 1);
234  }
236 }
237 
238 #if PY_VERSION_HEX >= PYBIND11_PYCONFIG_SUPPORT_PY_VERSION_HEX
239 TEST_CASE("Add program dir to path using PyConfig") {
241  size_t path_size_add_program_dir_to_path_false = 0;
242  {
243  PyConfig config;
244  PyConfig_InitPythonConfig(&config);
245  py::scoped_interpreter scoped_interp{&config, 0, nullptr, false};
246  path_size_add_program_dir_to_path_false = get_sys_path_size();
247  }
248  {
249  PyConfig config;
250  PyConfig_InitPythonConfig(&config);
251  py::scoped_interpreter scoped_interp{&config};
252  REQUIRE(get_sys_path_size() == path_size_add_program_dir_to_path_false + 1);
253  }
255 }
256 #endif
257 
259  return bool(
261 }
262 
264  auto **&ipp = py::detail::get_internals_pp();
265  return (ipp != nullptr) && (*ipp != nullptr);
266 }
267 
268 TEST_CASE("Restart the interpreter") {
269  // Verify pre-restart state.
270  REQUIRE(py::module_::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
271  REQUIRE(has_state_dict_internals_obj());
273  REQUIRE(py::module_::import("external_module").attr("A")(123).attr("value").cast<int>()
274  == 123);
275 
276  // local and foreign module internals should point to the same internals:
277  REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp())
278  == py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
279 
280  // Restart the interpreter.
282  REQUIRE(Py_IsInitialized() == 0);
283 
285  REQUIRE(Py_IsInitialized() == 1);
286 
287  // Internals are deleted after a restart.
288  REQUIRE_FALSE(has_state_dict_internals_obj());
289  REQUIRE_FALSE(has_pybind11_internals_static());
291  REQUIRE(has_state_dict_internals_obj());
293  REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp())
294  == py::module_::import("external_module").attr("internals_at")().cast<uintptr_t>());
295 
296  // Make sure that an interpreter with no get_internals() created until finalize still gets the
297  // internals destroyed
300  bool ran = false;
301  py::module_::import("__main__").attr("internals_destroy_test")
302  = py::capsule(&ran, [](void *ran) {
304  *static_cast<bool *>(ran) = true;
305  });
306  REQUIRE_FALSE(has_state_dict_internals_obj());
307  REQUIRE_FALSE(has_pybind11_internals_static());
308  REQUIRE_FALSE(ran);
310  REQUIRE(ran);
312  REQUIRE_FALSE(has_state_dict_internals_obj());
313  REQUIRE_FALSE(has_pybind11_internals_static());
314 
315  // C++ modules can be reloaded.
316  auto cpp_module = py::module_::import("widget_module");
317  REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);
318 
319  // C++ type information is reloaded and can be used in python modules.
320  auto py_module = py::module_::import("test_interpreter");
321  auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
322  REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
323 }
324 
325 TEST_CASE("Subinterpreter") {
326  // Add tags to the modules in the main interpreter and test the basics.
327  py::module_::import("__main__").attr("main_tag") = "main interpreter";
328  {
329  auto m = py::module_::import("widget_module");
330  m.attr("extension_module_tag") = "added to module in main interpreter";
331 
332  REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
333  }
334  REQUIRE(has_state_dict_internals_obj());
336 
338  auto *main_tstate = PyThreadState_Get();
339  auto *sub_tstate = Py_NewInterpreter();
340 
341  // Subinterpreters get their own copy of builtins. detail::get_internals() still
342  // works by returning from the static variable, i.e. all interpreters share a single
343  // global pybind11::internals;
344  REQUIRE_FALSE(has_state_dict_internals_obj());
346 
347  // Modules tags should be gone.
348  REQUIRE_FALSE(py::hasattr(py::module_::import("__main__"), "tag"));
349  {
350  auto m = py::module_::import("widget_module");
351  REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));
352 
353  // Function bindings should still work.
354  REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
355  }
356 
357  // Restore main interpreter.
358  Py_EndInterpreter(sub_tstate);
359  PyThreadState_Swap(main_tstate);
360 
361  REQUIRE(py::hasattr(py::module_::import("__main__"), "main_tag"));
362  REQUIRE(py::hasattr(py::module_::import("widget_module"), "extension_module_tag"));
363 }
364 
365 TEST_CASE("Execution frame") {
366  // When the interpreter is embedded, there is no execution frame, but `py::exec`
367  // should still function by using reasonable globals: `__main__.__dict__`.
368  py::exec("var = dict(number=42)");
369  REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
370 }
371 
372 TEST_CASE("Threads") {
373  // Restart interpreter to ensure threads are not initialized
376  REQUIRE_FALSE(has_pybind11_internals_static());
377 
378  constexpr auto num_threads = 10;
379  auto locals = py::dict("count"_a = 0);
380 
381  {
382  py::gil_scoped_release gil_release{};
383 
384  auto threads = std::vector<std::thread>();
385  for (auto i = 0; i < num_threads; ++i) {
386  threads.emplace_back([&]() {
387  py::gil_scoped_acquire gil{};
388  locals["count"] = locals["count"].cast<int>() + 1;
389  });
390  }
391 
392  for (auto &thread : threads) {
393  thread.join();
394  }
395  }
396 
397  REQUIRE(locals["count"].cast<int>() == num_threads);
398 }
399 
400 // Scope exit utility https://stackoverflow.com/a/36644501/7255855
401 struct scope_exit {
402  std::function<void()> f_;
403  explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
405  if (f_) {
406  f_();
407  }
408  }
409 };
410 
411 TEST_CASE("Reload module from file") {
412  // Disable generation of cached bytecode (.pyc files) for this test, otherwise
413  // Python might pick up an old version from the cache instead of the new versions
414  // of the .py files generated below
415  auto sys = py::module_::import("sys");
416  bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
417  sys.attr("dont_write_bytecode") = true;
418  // Reset the value at scope exit
419  scope_exit reset_dont_write_bytecode(
420  [&]() { sys.attr("dont_write_bytecode") = dont_write_bytecode; });
421 
422  std::string module_name = "test_module_reload";
423  std::string module_file = module_name + ".py";
424 
425  // Create the module .py file
426  std::ofstream test_module(module_file);
427  test_module << "def test():\n";
428  test_module << " return 1\n";
429  test_module.close();
430  // Delete the file at scope exit
431  scope_exit delete_module_file([&]() { std::remove(module_file.c_str()); });
432 
433  // Import the module from file
434  auto module_ = py::module_::import(module_name.c_str());
435  int result = module_.attr("test")().cast<int>();
436  REQUIRE(result == 1);
437 
438  // Update the module .py file with a small change
439  test_module.open(module_file);
440  test_module << "def test():\n";
441  test_module << " return 2\n";
442  test_module.close();
443 
444  // Reload the module
445  module_.reload();
446  result = module_.attr("test")().cast<int>();
447  REQUIRE(result == 2);
448 }
449 
450 TEST_CASE("sys.argv gets initialized properly") {
452  {
453  py::scoped_interpreter default_scope;
454  auto module = py::module::import("test_interpreter");
455  auto py_widget = module.attr("DerivedWidget")("The question");
456  const auto &cpp_widget = py_widget.cast<const Widget &>();
457  REQUIRE(cpp_widget.argv0().empty());
458  }
459 
460  {
461  char *argv[] = {strdup("a.out")};
462  py::scoped_interpreter argv_scope(true, 1, argv);
463  std::free(argv[0]);
464  auto module = py::module::import("test_interpreter");
465  auto py_widget = module.attr("DerivedWidget")("The question");
466  const auto &cpp_widget = py_widget.cast<const Widget &>();
467  REQUIRE(cpp_widget.argv0() == "a.out");
468  }
470 }
471 
472 TEST_CASE("make_iterator can be called before then after finalizing an interpreter") {
473  // Reproduction of issue #2101 (https://github.com/pybind/pybind11/issues/2101)
475 
476  std::vector<int> container;
477  {
478  pybind11::scoped_interpreter g;
479  auto iter = pybind11::make_iterator(container.begin(), container.end());
480  }
481 
482  REQUIRE_NOTHROW([&]() {
483  pybind11::scoped_interpreter g;
484  auto iter = pybind11::make_iterator(container.begin(), container.end());
485  }());
486 
488 }
has_state_dict_internals_obj
bool has_state_dict_internals_obj()
Definition: test_interpreter.cpp:258
Widget::the_message
std::string the_message() const
Definition: test_interpreter.cpp:27
return_value_policy::move
@ move
TEST_CASE
TEST_CASE("PYTHONPATH is used to update sys.path")
Definition: test_interpreter.cpp:81
scope_exit::f_
std::function< void()> f_
Definition: test_interpreter.cpp:402
PyWidget::argv0
std::string argv0() const override
Definition: test_interpreter.cpp:39
d
static const double d[K][N]
Definition: igam.h:11
finalize_interpreter
void finalize_interpreter()
Definition: embed.h:245
has_pybind11_internals_static
bool has_pybind11_internals_static()
Definition: test_interpreter.cpp:263
ret
DenseIndex ret
Definition: level1_cplx_impl.h:44
hasattr
bool hasattr(handle obj, handle name)
Definition: pytypes.h:853
get_internals
PYBIND11_NOINLINE internals & get_internals()
Return a reference to the current internals data.
Definition: internals.h:467
result
Values result
Definition: OdometryOptimize.cpp:8
PYBIND11_EMBEDDED_MODULE
PYBIND11_EMBEDDED_MODULE(widget_module, m)
Definition: test_interpreter.cpp:58
PYBIND11_OVERRIDE_PURE
#define PYBIND11_OVERRIDE_PURE(ret_type, cname, fn,...)
Definition: pybind11.h:2861
matlab_wrap.module_name
module_name
Definition: matlab_wrap.py:60
test_override_cache_helper_trampoline::func
int func() override
Definition: test_interpreter.cpp:55
j
std::ptrdiff_t j
Definition: tut_arithmetic_redux_minmax.cpp:2
make_iterator
iterator make_iterator(Iterator first, Sentinel last, Extra &&...extra)
Makes a python iterator from a first and past-the-end C++ InputIterator.
Definition: pybind11.h:2409
Widget::Widget
Widget(std::string message)
Definition: test_interpreter.cpp:24
Widget::message
std::string message
Definition: test_interpreter.cpp:32
initialize_interpreter
void initialize_interpreter(bool init_signal_handlers=true, int argc=0, const char *const *argv=nullptr, bool add_program_dir_to_path=true)
Definition: embed.h:192
embed.h
exec
void exec(const str &expr, object global=globals(), object local=object())
Definition: eval.h:88
test_override_cache_helper_trampoline
Definition: test_interpreter.cpp:54
m
Matrix3f m
Definition: AngleAxis_mimic_euler.cpp:1
PyWidget
Definition: test_interpreter.cpp:35
scope_exit::scope_exit
scope_exit(std::function< void()> f) noexcept
Definition: test_interpreter.cpp:403
g
void g(const string &key, int i)
Definition: testBTree.cpp:41
PYBIND11_OVERRIDE
#define PYBIND11_OVERRIDE(ret_type, cname, fn,...)
Definition: pybind11.h:2854
tree::f
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
Definition: testExpression.cpp:218
test_override_cache_helper
Definition: test_interpreter.cpp:42
get_python_state_dict
object get_python_state_dict()
Definition: internals.h:434
module_::reload
void reload()
Reload the module or throws error_already_set.
Definition: pybind11.h:1220
module_
Wrapper for Python extension modules.
Definition: pybind11.h:1153
PYBIND11_WARNING_DISABLE_MSVC
PYBIND11_WARNING_PUSH PYBIND11_WARNING_DISABLE_MSVC(5054) PYBIND11_WARNING_POP static_assert(EIGEN_VERSION_AT_LEAST(3
iter
iterator iter(handle obj)
Definition: pytypes.h:2428
get_internals_obj_from_state_dict
object get_internals_obj_from_state_dict(handle state_dict)
Definition: internals.h:454
std
Definition: BFloat16.h:88
p
float * p
Definition: Tutorial_Map_using.cpp:9
pybind11
Definition: wrap/pybind11/pybind11/__init__.py:1
len
size_t len(handle h)
Get the length of a Python object.
Definition: pytypes.h:2399
object::cast
T cast() const &
Definition: cast.h:1169
scope_exit
Definition: test_interpreter.cpp:401
get_internals_pp
internals **& get_internals_pp()
Definition: internals.h:322
globals
dict globals()
Definition: pybind11.h:1287
func
Definition: benchGeometry.cpp:23
test_override_cache_helper::func
virtual int func()
Definition: test_interpreter.cpp:45
gtsam.examples.ShonanAveragingCLI.str
str
Definition: ShonanAveragingCLI.py:115
test_callbacks.value
value
Definition: test_callbacks.py:158
PyWidget::the_answer
int the_answer() const override
Definition: test_interpreter.cpp:38
i
int i
Definition: BiCGSTAB_step_by_step.cpp:9
uintptr_t
_W64 unsigned int uintptr_t
Definition: ms_stdint.h:124
get_sys_path_size
size_t get_sys_path_size()
Definition: test_interpreter.cpp:17
scope_exit::~scope_exit
~scope_exit()
Definition: test_interpreter.cpp:404


gtsam
Author(s):
autogenerated on Tue Jun 25 2024 03:05:28