3 GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
4 Atlanta, Georgia 30332-0415
7 See LICENSE for the license information
9 Code generator for wrapping a C++ module with Pybind11
10 Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert
16 from pathlib
import Path
17 from typing
import List
26 Class to generate binding code for Pybind11 specifically.
31 top_module_namespaces='',
32 use_boost_serialization=False,
43 'lambda',
'False',
'def',
'if',
'raise',
'None',
'del',
'import',
44 'return',
'True',
'elif',
'in',
'try',
'and',
'else',
'is',
45 'while',
'as',
'except',
'lambda',
'with',
'assert',
'finally',
46 'nonlocal',
'yield',
'break',
'for',
'not',
'class',
'from',
'or',
47 'continue',
'global',
'pass'
59 "svg",
"png",
"jpeg",
"html",
"javascript",
"markdown",
"latex"
63 """Set the argument names in Pybind11 format."""
67 for arg
in args.list():
68 if arg.default
is not None:
69 default =
' = {arg.default}'.
format(arg=arg)
72 argument =
'py::arg("{name}"){default}'.
format(
73 name=arg.name, default=
'{0}'.
format(default))
74 py_args.append(argument)
75 return ", " +
", ".join(py_args)
80 """Generate the argument types and names as per the method signature."""
81 cpp_types = args.to_cpp()
84 "{} {}".
format(ctype, name)
85 for ctype, name
in zip(cpp_types, names)
88 return ', '.join(types_names)
91 """Wrap the constructors."""
93 for ctor
in my_class.ctors:
94 res += (self.
method_indent +
'.def(py::init<{args_cpp_types}>()'
96 args_cpp_types=
", ".join(ctor.args.to_cpp()),
102 """Helper method to add serialize, deserialize and pickle methods to the wrapped class."""
107 ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".
format(class_inst=cpp_class +
'*')
110 '.def("deserialize", []({class_inst} self, string serialized)' \
111 '{{ gtsam::deserialize(serialized, *self); }}, py::arg("serialized"))' \
112 .
format(class_inst=cpp_class +
'*')
116 ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast<std::string>(), obj); return obj; }}))"
118 return serialize_method + deserialize_method + \
119 pickle_method.format(cpp_class=cpp_class, indent=self.
method_indent)
121 def _wrap_print(self, ret: str, method: parser.Method, cpp_class: str,
122 args_names: List[str], args_signature_with_names: str,
123 py_args_names: str, prefix: str, suffix: str):
125 Update the print method to print to the output stream and append a __repr__ method.
128 ret (str): The result of the parser.
129 method (parser.Method): The method to be wrapped.
130 cpp_class (str): The C++ name of the class to which the method belongs.
131 args_names (List[str]): List of argument variable names passed to the method.
132 args_signature_with_names (str): C++ arguments containing their names and type signatures.
133 py_args_names (str): The pybind11 formatted version of the argument list.
134 prefix (str): Prefix to add to the wrapped method when writing to the cpp file.
135 suffix (str): Suffix to add to the wrapped method when writing to the cpp file.
138 str: The wrapped print method.
142 ret = ret.replace(
'self->print',
143 'py::scoped_ostream_redirect output; self->print')
146 ret +=
'''{prefix}.def("__repr__",
147 [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{
148 gtsam::RedirectCout redirect;
149 self.{method_name}({method_args});
150 return redirect.str();
151 }}{py_args_names}){suffix}'''.
format(
154 opt_comma=
', ' if args_names
else '',
155 args_signature_with_names=args_signature_with_names,
156 method_name=method.name,
157 method_args=
", ".join(args_names)
if args_names
else '',
158 py_args_names=py_args_names,
169 Wrap a Python double-underscore (dunder) method.
171 E.g. __len__() gets wrapped as `.def("__len__", [](gtsam::KeySet* self) {return self->size();})`
173 Supported methods are:
178 py_method = method.name + method_suffix
179 args_names = method.args.names()
183 if method.name ==
'len':
184 function_call =
"return std::distance(self->begin(), self->end());"
185 elif method.name ==
'contains':
186 function_call = f
"return std::find(self->begin(), self->end(), {method.args.args_list[0].name}) != self->end();"
187 elif method.name ==
'iter':
188 function_call =
"return py::make_iterator(self->begin(), self->end());"
190 ret = (
'{prefix}.def("__{py_method}__",'
191 '[]({self}{opt_comma}{args_signature_with_names}){{'
194 '{py_args_names}){suffix}'.
format(
197 self=f
"{cpp_class}* self",
198 opt_comma=
', ' if args_names
else '',
199 args_signature_with_names=args_signature_with_names,
200 function_call=function_call,
201 py_args_names=py_args_names,
214 Wrap the `method` for the class specified by `cpp_class`.
217 method: The method to wrap.
218 cpp_class: The C++ name of the class to which the method belongs.
219 prefix: Prefix to add to the wrapped method when writing to the cpp file.
220 suffix: Suffix to add to the wrapped method when writing to the cpp file.
221 method_suffix: A string to append to the wrapped method name.
223 py_method = method.name + method_suffix
224 cpp_method = method.to_cpp()
226 args_names = method.args.names()
231 if cpp_method
in [
"serialize",
"serializable"]:
241 py_method = f
"_repr_{self._ipython_special_methods[idx]}_"
245 py_method = py_method +
"_"
248 method, (parser.Method, instantiator.InstantiatedMethod))
251 (parser.StaticMethod, instantiator.InstantiatedStaticMethod))
252 return_void = method.return_type.is_void()
254 caller = cpp_class +
"::" if not is_method
else "self->"
255 function_call = (
'{opt_return} {caller}{method_name}'
257 opt_return=
'return' if not return_void
else '',
259 method_name=cpp_method,
260 args_names=
', '.join(args_names),
263 ret = (
'{prefix}.{cdef}("{py_method}",'
264 '[]({opt_self}{opt_comma}{args_signature_with_names}){{'
267 '{py_args_names}{docstring}){suffix}'.
format(
269 cdef=
"def_static" if is_static
else "def",
271 opt_self=
"{cpp_class}* self".
format(
272 cpp_class=cpp_class)
if is_method
else "",
273 opt_comma=
', ' if is_method
and args_names
else '',
274 args_signature_with_names=args_signature_with_names,
275 function_call=function_call,
276 py_args_names=py_args_names,
282 docstring=
', "' +
repr(self.
xml_parser.extract_docstring(self.
xml_source, cpp_class, cpp_method, method.args.names()))[1:-1].replace(
'"',
r'\"') +
'"'
288 if method.name ==
'print':
289 ret = self.
_wrap_print(ret, method, cpp_class, args_names,
290 args_signature_with_names, py_args_names,
298 prefix='\n' + ' ' * 8,
301 for method
in methods:
312 prefix='\n' + ' ' * 8,
315 Wrap all the methods in the `cpp_class`.
318 for method
in methods:
321 if method.name ==
'insert' and cpp_class ==
'gtsam::Values':
322 name_list = method.args.names()
323 type_list = method.args.to_cpp()
325 if type_list[0].strip() ==
'size_t':
326 method_suffix =
'_' + name_list[1].strip()
331 method_suffix=method_suffix)
346 prefix='\n' + ' ' * 8):
348 Wrap a variable that's not part of a class (i.e. global)
351 if variable.default
is None:
352 variable_value = variable.name
354 variable_value = variable.default
356 return '{prefix}{module_var}.attr("{variable_name}") = {namespace}{variable_value};'.
format(
358 module_var=module_var,
359 variable_name=variable.name,
361 variable_value=variable_value)
364 """Wrap all the properties in the `cpp_class`."""
366 for prop
in properties:
367 res += (
'{prefix}.def_{property}("{property_name}", '
368 '&{cpp_class}::{property_name})'.
format(
371 if prop.ctype.is_const
else "readwrite",
373 property_name=prop.name,
378 """Wrap all the overloaded operators in the `cpp_class`."""
380 template =
"{prefix}.def({{0}})".
format(prefix=prefix)
382 if op.operator ==
"[]":
383 res +=
"{prefix}.def(\"__getitem__\", &{cpp_class}::operator[])".
format(
384 prefix=prefix, cpp_class=cpp_class)
385 elif op.operator ==
"()":
386 res +=
"{prefix}.def(\"__call__\", &{cpp_class}::operator())".
format(
387 prefix=prefix, cpp_class=cpp_class)
389 res += template.format(
"{0}py::self".
format(op.operator))
391 res += template.format(
"py::self {0} py::self".
format(
395 def wrap_enum(self, enum, class_name='', module=None, prefix=' ' * 4):
400 enum: The parsed enum to wrap.
401 class_name: The class under which the enum is defined.
402 prefix: The amount of indentation.
407 cpp_class = enum.cpp_typename().
to_cpp()
410 cpp_class = class_name +
"::" + cpp_class
412 res =
'{prefix}py::enum_<{cpp_class}>({module}, "{enum.name}", py::arithmetic())'.
format(
413 prefix=prefix, module=module, enum=enum, cpp_class=cpp_class)
414 for enumerator
in enum.enumerators:
415 res +=
'\n{prefix} .value("{enumerator.name}", {cpp_class}::{enumerator.name})'.
format(
416 prefix=prefix, enumerator=enumerator, cpp_class=cpp_class)
420 def wrap_enums(self, enums, instantiated_class, prefix=' ' * 4):
421 """Wrap multiple enums defined in a class."""
422 cpp_class = instantiated_class.to_cpp()
423 module_var = instantiated_class.name.lower()
428 enum, class_name=cpp_class, module=module_var, prefix=prefix)
432 self, instantiated_class: instantiator.InstantiatedClass):
433 """Wrap the class."""
435 cpp_class = instantiated_class.to_cpp()
438 if instantiated_class.parent_class:
439 class_parent =
"{instantiated_class.parent_class}, ".
format(
440 instantiated_class=instantiated_class)
444 if instantiated_class.enums:
446 instance_name = instantiated_class.name.lower()
447 class_declaration = (
448 '\n py::class_<{cpp_class}, {class_parent}'
449 'std::shared_ptr<{cpp_class}>> '
450 '{instance_name}({module_var}, "{class_name}");'
451 '\n {instance_name}').
format(
453 class_name=instantiated_class.name,
454 class_parent=class_parent,
455 instance_name=instance_name,
456 module_var=module_var)
457 module_var = instance_name
460 class_declaration = (
461 '\n py::class_<{cpp_class}, {class_parent}'
462 'std::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")'
463 ).
format(cpp_class=cpp_class,
464 class_name=instantiated_class.name,
465 class_parent=class_parent,
466 module_var=module_var)
468 return (
'{class_declaration}'
471 '{wrapped_static_methods}'
472 '{wrapped_dunder_methods}'
473 '{wrapped_properties}'
474 '{wrapped_operators};\n'.
format(
475 class_declaration=class_declaration,
476 wrapped_ctors=self.
wrap_ctors(instantiated_class),
478 instantiated_class.methods, cpp_class),
480 instantiated_class.static_methods, cpp_class),
482 instantiated_class.dunder_methods, cpp_class),
484 instantiated_class.properties, cpp_class),
486 instantiated_class.operators, cpp_class)))
489 self, instantiated_decl: instantiator.InstantiatedDeclaration):
490 """Wrap the forward declaration."""
492 cpp_class = instantiated_decl.to_cpp()
496 res = (
'\n py::class_<{cpp_class}, '
497 'std::shared_ptr<{cpp_class}>>({module_var}, "{class_name}");'
498 ).
format(cpp_class=cpp_class,
499 class_name=instantiated_decl.name,
500 module_var=module_var)
504 """Wrap STL containers."""
506 cpp_class = stl_class.to_cpp()
510 return (
'\n py::class_<{cpp_class}, {class_parent}'
511 'std::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")'
514 '{wrapped_static_methods}'
515 '{wrapped_properties};\n'.
format(
517 class_name=stl_class.name,
518 class_parent=
str(stl_class.parent_class) +
519 (
', ' if stl_class.parent_class
else ''),
520 module_var=module_var,
525 stl_class.static_methods, cpp_class),
527 stl_class.properties, cpp_class),
533 prefix='\n' + ' ' * 8,
536 Wrap all the global functions.
539 for function
in functions:
541 function_name = function.name
545 if function_name
in python_keywords:
546 function_name = function_name +
"_"
548 cpp_method = function.to_cpp()
550 is_static =
isinstance(function, parser.StaticMethod)
551 return_void = function.return_type.is_void()
552 args_names = function.args.names()
556 caller = namespace +
"::"
557 function_call = (
'{opt_return} {caller}{function_name}'
560 if not return_void
else '',
562 function_name=cpp_method,
563 args_names=
', '.join(args_names),
566 ret = (
'{prefix}.{cdef}("{function_name}",'
567 '[]({args_signature}){{'
570 '{py_args_names}){suffix}'.
format(
572 cdef=
"def_static" if is_static
else "def",
573 function_name=function_name,
574 args_signature=args_signature,
575 function_call=function_call,
576 py_args_names=py_args_names,
585 if namespaces1[i] != namespaces2[i]:
590 """Get the Pybind11 module name from the namespaces."""
593 return "m_{}".
format(
'_'.join(sub_module_namespaces))
598 idx = 1
if not namespaces[0]
else 0
599 return '::'.join(namespaces[idx:] + [name])
604 """Wrap the complete `namespace`."""
608 namespaces = namespace.full_namespaces()
613 for element
in namespace.content:
615 include =
"{}\n".
format(element)
617 include = include.replace(
'<',
'"').replace(
'>',
'"')
625 wrapped += wrapped_namespace
626 includes += includes_namespace
632 ' ' * 4 +
'pybind11::module {module_var} = '
633 '{parent_module_var}.def_submodule("{namespace}", "'
634 '{namespace} submodule");\n'.
format(
635 module_var=module_var,
636 namespace=namespace.name,
642 for element
in namespace.content:
644 include =
"{}\n".
format(element)
646 include = include.replace(
'<',
'"').replace(
'>',
'"')
651 wrapped += wrapped_namespace
652 includes += includes_namespace
654 elif isinstance(element, instantiator.InstantiatedClass):
656 wrapped += self.
wrap_enums(element.enums, element)
658 elif isinstance(element, instantiator.InstantiatedDeclaration):
664 module_var=module_var,
666 prefix=
'\n' +
' ' * 4)
673 func
for func
in namespace.content
675 instantiator.InstantiatedGlobalFunction))
680 prefix=
'\n' +
' ' * 4 + module_var,
684 return wrapped, includes
686 def wrap_file(self, content, module_name=None, submodules=None):
688 Wrap the code in the interface file.
691 content: The contents of the interface file.
692 module_name: The name of the module.
693 submodules: List of other interface file names that should be linked to.
696 module = parser.Module.parseString(content)
698 module = instantiator.instantiate_namespace(module)
703 includes +=
"#include <boost/serialization/export.hpp>"
706 boost_class_export =
""
711 new_name = re.sub(
"[,:<> ]",
"", cpp_class)
712 boost_class_export +=
"typedef {cpp_class} {new_name};\n".
format(
713 cpp_class=cpp_class, new_name=new_name)
715 boost_class_export +=
"BOOST_CLASS_EXPORT({new_name})\n".
format(
718 boost_class_export =
""
725 if submodules
is not None:
726 module_def =
"PYBIND11_MODULE({0}, m_)".
format(module_name)
728 for idx, submodule
in enumerate(submodules):
729 submodules[idx] =
"void {0}(py::module_ &);".
format(submodule)
730 submodules_init.append(
"{0}(m_);".
format(submodule))
733 module_def =
"void {0}(py::module_ &m_)".
format(module_name)
737 module_def=module_def,
738 module_name=module_name,
740 wrapped_namespace=wrapped_namespace,
741 boost_class_export=boost_class_export,
742 submodules=
"\n".join(submodules),
743 submodules_init=
"\n".join(submodules_init),
748 Wrap a list of submodule files, i.e. a set of interface files which are
749 in support of a larger wrapping project.
751 E.g. This is used in GTSAM where we have a main gtsam.i, but various smaller .i files
752 which are the submodules.
753 The benefit of this scheme is that it reduces compute and memory usage during compilation.
756 source: Interface file which forms the submodule.
758 filename = Path(source).name
759 module_name = Path(source).stem
762 with open(source,
"r", encoding=
"UTF-8")
as f:
765 cc_content = self.
wrap_file(content, module_name=module_name)
768 with open(filename.replace(
".i",
".cpp"),
"w", encoding=
"UTF-8")
as f:
771 def wrap(self, sources, main_module_name):
773 Wrap all the main interface file.
776 sources: List of all interface files.
777 The first file should be the main module.
778 main_module_name: The name for the main module.
780 main_module = sources[0]
784 for source
in sources[1:]:
785 module_name = Path(source).stem
786 submodules.append(module_name)
788 with open(main_module,
"r", encoding=
"UTF-8")
as f:
792 submodules=submodules)
795 with open(main_module_name,
"w", encoding=
"UTF-8")
as f: