test_interpreter.cpp
Go to the documentation of this file.
1 #include <pybind11/embed.h>
2 
3 #ifdef _MSC_VER
4 // Silence MSVC C++17 deprecation warning from Catch regarding std::uncaught_exceptions (up to catch
5 // 2.0.1; this should be fixed in the next catch release after 2.0.1).
6 # pragma warning(disable: 4996)
7 #endif
8 
9 #include <catch.hpp>
10 
11 #include <thread>
12 #include <fstream>
13 #include <functional>
14 
15 namespace py = pybind11;
16 using namespace py::literals;
17 
18 class Widget {
19 public:
20  Widget(std::string message) : message(message) { }
21  virtual ~Widget() = default;
22 
23  std::string the_message() const { return message; }
24  virtual int the_answer() const = 0;
25 
26 private:
27  std::string message;
28 };
29 
30 class PyWidget final : public Widget {
31  using Widget::Widget;
32 
33  int the_answer() const override { PYBIND11_OVERRIDE_PURE(int, Widget, the_answer); }
34 };
35 
36 PYBIND11_EMBEDDED_MODULE(widget_module, m) {
37  py::class_<Widget, PyWidget>(m, "Widget")
38  .def(py::init<std::string>())
39  .def_property_readonly("the_message", &Widget::the_message);
40 
41  m.def("add", [](int i, int j) { return i + j; });
42 }
43 
44 PYBIND11_EMBEDDED_MODULE(throw_exception, ) {
45  throw std::runtime_error("C++ Error");
46 }
47 
48 PYBIND11_EMBEDDED_MODULE(throw_error_already_set, ) {
49  auto d = py::dict();
50  d["missing"].cast<py::object>();
51 }
52 
53 TEST_CASE("Pass classes and data between modules defined in C++ and Python") {
54  auto module = py::module::import("test_interpreter");
55  REQUIRE(py::hasattr(module, "DerivedWidget"));
56 
57  auto locals = py::dict("hello"_a="Hello, World!", "x"_a=5, **module.attr("__dict__"));
58  py::exec(R"(
59  widget = DerivedWidget("{} - {}".format(hello, x))
60  message = widget.the_message
61  )", py::globals(), locals);
62  REQUIRE(locals["message"].cast<std::string>() == "Hello, World! - 5");
63 
64  auto py_widget = module.attr("DerivedWidget")("The question");
65  auto message = py_widget.attr("the_message");
66  REQUIRE(message.cast<std::string>() == "The question");
67 
68  const auto &cpp_widget = py_widget.cast<const Widget &>();
69  REQUIRE(cpp_widget.the_answer() == 42);
70 }
71 
72 TEST_CASE("Import error handling") {
73  REQUIRE_NOTHROW(py::module::import("widget_module"));
74  REQUIRE_THROWS_WITH(py::module::import("throw_exception"),
75  "ImportError: C++ Error");
76  REQUIRE_THROWS_WITH(py::module::import("throw_error_already_set"),
77  Catch::Contains("ImportError: KeyError"));
78 }
79 
80 TEST_CASE("There can be only one interpreter") {
85 
86  REQUIRE_THROWS_WITH(py::initialize_interpreter(), "The interpreter is already running");
87  REQUIRE_THROWS_WITH(py::scoped_interpreter(), "The interpreter is already running");
88 
90  REQUIRE_NOTHROW(py::scoped_interpreter());
91  {
92  auto pyi1 = py::scoped_interpreter();
93  auto pyi2 = std::move(pyi1);
94  }
96 }
97 
99  auto builtins = py::handle(PyEval_GetBuiltins());
100  return builtins.contains(PYBIND11_INTERNALS_ID);
101 };
102 
104  auto **&ipp = py::detail::get_internals_pp();
105  return ipp && *ipp;
106 }
107 
108 TEST_CASE("Restart the interpreter") {
109  // Verify pre-restart state.
110  REQUIRE(py::module::import("widget_module").attr("add")(1, 2).cast<int>() == 3);
113  REQUIRE(py::module::import("external_module").attr("A")(123).attr("value").cast<int>() == 123);
114 
115  // local and foreign module internals should point to the same internals:
116  REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) ==
117  py::module::import("external_module").attr("internals_at")().cast<uintptr_t>());
118 
119  // Restart the interpreter.
121  REQUIRE(Py_IsInitialized() == 0);
122 
124  REQUIRE(Py_IsInitialized() == 1);
125 
126  // Internals are deleted after a restart.
127  REQUIRE_FALSE(has_pybind11_internals_builtin());
128  REQUIRE_FALSE(has_pybind11_internals_static());
132  REQUIRE(reinterpret_cast<uintptr_t>(*py::detail::get_internals_pp()) ==
133  py::module::import("external_module").attr("internals_at")().cast<uintptr_t>());
134 
135  // Make sure that an interpreter with no get_internals() created until finalize still gets the
136  // internals destroyed
139  bool ran = false;
140  py::module::import("__main__").attr("internals_destroy_test") =
141  py::capsule(&ran, [](void *ran) { py::detail::get_internals(); *static_cast<bool *>(ran) = true; });
142  REQUIRE_FALSE(has_pybind11_internals_builtin());
143  REQUIRE_FALSE(has_pybind11_internals_static());
144  REQUIRE_FALSE(ran);
146  REQUIRE(ran);
148  REQUIRE_FALSE(has_pybind11_internals_builtin());
149  REQUIRE_FALSE(has_pybind11_internals_static());
150 
151  // C++ modules can be reloaded.
152  auto cpp_module = py::module::import("widget_module");
153  REQUIRE(cpp_module.attr("add")(1, 2).cast<int>() == 3);
154 
155  // C++ type information is reloaded and can be used in python modules.
156  auto py_module = py::module::import("test_interpreter");
157  auto py_widget = py_module.attr("DerivedWidget")("Hello after restart");
158  REQUIRE(py_widget.attr("the_message").cast<std::string>() == "Hello after restart");
159 }
160 
161 TEST_CASE("Subinterpreter") {
162  // Add tags to the modules in the main interpreter and test the basics.
163  py::module::import("__main__").attr("main_tag") = "main interpreter";
164  {
165  auto m = py::module::import("widget_module");
166  m.attr("extension_module_tag") = "added to module in main interpreter";
167 
168  REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
169  }
172 
174  auto main_tstate = PyThreadState_Get();
175  auto sub_tstate = Py_NewInterpreter();
176 
177  // Subinterpreters get their own copy of builtins. detail::get_internals() still
178  // works by returning from the static variable, i.e. all interpreters share a single
179  // global pybind11::internals;
180  REQUIRE_FALSE(has_pybind11_internals_builtin());
182 
183  // Modules tags should be gone.
184  REQUIRE_FALSE(py::hasattr(py::module::import("__main__"), "tag"));
185  {
186  auto m = py::module::import("widget_module");
187  REQUIRE_FALSE(py::hasattr(m, "extension_module_tag"));
188 
189  // Function bindings should still work.
190  REQUIRE(m.attr("add")(1, 2).cast<int>() == 3);
191  }
192 
193  // Restore main interpreter.
194  Py_EndInterpreter(sub_tstate);
195  PyThreadState_Swap(main_tstate);
196 
197  REQUIRE(py::hasattr(py::module::import("__main__"), "main_tag"));
198  REQUIRE(py::hasattr(py::module::import("widget_module"), "extension_module_tag"));
199 }
200 
201 TEST_CASE("Execution frame") {
202  // When the interpreter is embedded, there is no execution frame, but `py::exec`
203  // should still function by using reasonable globals: `__main__.__dict__`.
204  py::exec("var = dict(number=42)");
205  REQUIRE(py::globals()["var"]["number"].cast<int>() == 42);
206 }
207 
208 TEST_CASE("Threads") {
209  // Restart interpreter to ensure threads are not initialized
212  REQUIRE_FALSE(has_pybind11_internals_static());
213 
214  constexpr auto num_threads = 10;
215  auto locals = py::dict("count"_a=0);
216 
217  {
218  py::gil_scoped_release gil_release{};
220 
221  auto threads = std::vector<std::thread>();
222  for (auto i = 0; i < num_threads; ++i) {
223  threads.emplace_back([&]() {
224  py::gil_scoped_acquire gil{};
225  locals["count"] = locals["count"].cast<int>() + 1;
226  });
227  }
228 
229  for (auto &thread : threads) {
230  thread.join();
231  }
232  }
233 
234  REQUIRE(locals["count"].cast<int>() == num_threads);
235 }
236 
237 // Scope exit utility https://stackoverflow.com/a/36644501/7255855
238 struct scope_exit {
239  std::function<void()> f_;
240  explicit scope_exit(std::function<void()> f) noexcept : f_(std::move(f)) {}
241  ~scope_exit() { if (f_) f_(); }
242 };
243 
244 TEST_CASE("Reload module from file") {
245  // Disable generation of cached bytecode (.pyc files) for this test, otherwise
246  // Python might pick up an old version from the cache instead of the new versions
247  // of the .py files generated below
248  auto sys = py::module::import("sys");
249  bool dont_write_bytecode = sys.attr("dont_write_bytecode").cast<bool>();
250  sys.attr("dont_write_bytecode") = true;
251  // Reset the value at scope exit
252  scope_exit reset_dont_write_bytecode([&]() {
253  sys.attr("dont_write_bytecode") = dont_write_bytecode;
254  });
255 
256  std::string module_name = "test_module_reload";
257  std::string module_file = module_name + ".py";
258 
259  // Create the module .py file
260  std::ofstream test_module(module_file);
261  test_module << "def test():\n";
262  test_module << " return 1\n";
263  test_module.close();
264  // Delete the file at scope exit
265  scope_exit delete_module_file([&]() {
266  std::remove(module_file.c_str());
267  });
268 
269  // Import the module from file
270  auto module = py::module::import(module_name.c_str());
271  int result = module.attr("test")().cast<int>();
272  REQUIRE(result == 1);
273 
274  // Update the module .py file with a small change
275  test_module.open(module_file);
276  test_module << "def test():\n";
277  test_module << " return 2\n";
278  test_module.close();
279 
280  // Reload the module
281  module.reload();
282  result = module.attr("test")().cast<int>();
283  REQUIRE(result == 2);
284 }
int the_answer() const override
Matrix3f m
std::string message
void initialize_interpreter(bool init_signal_handlers=true)
Definition: embed.h:105
bool hasattr(handle obj, handle name)
Definition: pytypes.h:403
bool has_pybind11_internals_static()
PYBIND11_NOINLINE internals & get_internals()
Return a reference to the current internals data.
Definition: internals.h:245
void finalize_interpreter()
Definition: embed.h:150
Wrapper for Python extension modules.
Definition: pybind11.h:855
std::string the_message() const
TEST_CASE("Pass classes and data between modules defined in C++ and Python")
Definition: Half.h:150
PYBIND11_EMBEDDED_MODULE(widget_module, m)
bool has_pybind11_internals_builtin()
Values result
scope_exit(std::function< void()> f) noexcept
Widget(std::string message)
Point2(* f)(const Point3 &, OptionalJacobian< 2, 3 >)
#define PYBIND11_INTERNALS_ID
Definition: internals.h:202
#define PYBIND11_OVERRIDE_PURE(ret_type, cname, fn,...)
Definition: pybind11.h:2265
std::function< void()> f_
void exec(str expr, object global=globals(), object local=object())
Definition: eval.h:60
dict globals()
Definition: pybind11.h:948
internals **& get_internals_pp()
Definition: internals.h:210
std::ptrdiff_t j


gtsam
Author(s):
autogenerated on Sat May 8 2021 02:46:03