3 GTSAM Copyright 2010-2020, Georgia Tech Research Corporation, 4 Atlanta, Georgia 30332-0415 7 See LICENSE for the license information 9 Code generator for wrapping a C++ module with Pybind11 10 Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert 23 Class to generate binding code for Pybind11 specifically. 28 top_module_namespaces=
'',
45 """Set the argument names in Pybind11 format.""" 46 names = args_list.args_names()
49 for arg
in args_list.args_list:
50 if isinstance(arg.default, str)
and arg.default
is not None:
52 arg.default =
' = "{arg.default}"'.format(arg=arg)
54 arg.default =
' = {arg.default}'.format(arg=arg)
57 argument =
'py::arg("{name}"){default}'.format(
58 name=arg.name, default=
'{0}'.format(arg.default))
59 py_args.append(argument)
60 return ", " +
", ".join(py_args)
65 """Define the method signature types with the argument names.""" 66 cpp_types = args_list.to_cpp(self.
use_boost)
67 names = args_list.args_names()
69 "{} {}".format(ctype, name)
70 for ctype, name
in zip(cpp_types, names)
73 return ', '.join(types_names)
76 """Wrap the constructors.""" 78 for ctor
in my_class.ctors:
81 '{py_args_names})'.format(
82 args_cpp_types=
", ".join(ctor.args.to_cpp(self.
use_boost)),
93 py_method = method.name + method_suffix
94 cpp_method = method.to_cpp()
96 if cpp_method
in [
"serialize",
"serializable"]:
98 self._serializing_classes.append(cpp_class)
100 ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".format(class_inst=cpp_class +
'*')
102 ".def(\"deserialize\", []({class_inst} self, string serialized){{ gtsam::deserialize(serialized, *self); }}, py::arg(\"serialized\"))" \
103 .format(class_inst=cpp_class +
'*')
104 return serialize_method + deserialize_method
106 if cpp_method ==
"pickle":
109 "Cannot pickle a class which is not serializable")
111 ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast<std::string>(), obj); return obj; }}))" 112 return pickle_method.format(cpp_class=cpp_class,
115 is_method =
isinstance(method, instantiator.InstantiatedMethod)
116 is_static =
isinstance(method, parser.StaticMethod)
117 return_void = method.return_type.is_void()
118 args_names = method.args.args_names()
123 caller = cpp_class +
"::" if not is_method
else "self->" 124 function_call = (
'{opt_return} {caller}{function_name}' 125 '({args_names});'.format(
126 opt_return=
'return' if not return_void
else '',
128 function_name=cpp_method,
129 args_names=
', '.join(args_names),
132 ret = (
'{prefix}.{cdef}("{py_method}",' 133 '[]({opt_self}{opt_comma}{args_signature_with_names}){{' 136 '{py_args_names}){suffix}'.format(
138 cdef=
"def_static" if is_static
else "def",
140 else py_method +
"_",
141 opt_self=
"{cpp_class}* self".format(
142 cpp_class=cpp_class)
if is_method
else "",
143 opt_comma=
', ' if is_method
and args_names
else '',
144 args_signature_with_names=args_signature_with_names,
145 function_call=function_call,
146 py_args_names=py_args_names,
152 if method.name ==
'print':
157 'py::scoped_ostream_redirect output; self->print')
160 ret +=
'''{prefix}.def("__repr__", 161 [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{ 162 gtsam::RedirectCout redirect; 163 self.{method_name}({method_args}); 164 return redirect.str(); 165 }}{py_args_names}){suffix}'''.format(
168 opt_comma=
', ' if args_names
else '',
169 args_signature_with_names=args_signature_with_names,
170 method_name=method.name,
171 method_args=
", ".join(args_names)
if args_names
else '',
172 py_args_names=py_args_names,
180 prefix=
'\n' +
' ' * 8,
183 Wrap all the methods in the `cpp_class`. 185 This function is also used to wrap global functions. 188 for method
in methods:
191 if method.name ==
'insert' and cpp_class ==
'gtsam::Values':
192 name_list = method.args.args_names()
193 type_list = method.args.to_cpp(self.
use_boost)
195 if type_list[0].strip() ==
'size_t':
196 method_suffix =
'_' + name_list[1].strip()
201 method_suffix=method_suffix)
216 prefix=
'\n' +
' ' * 8):
217 """Wrap a variable that's not part of a class (i.e. global) 220 if variable.default
is None:
221 variable_value = variable.name
223 variable_value = variable.default
225 return '{prefix}{module_var}.attr("{variable_name}") = {namespace}{variable_value};'.format(
227 module_var=module_var,
228 variable_name=variable.name,
230 variable_value=variable_value)
233 """Wrap all the properties in the `cpp_class`.""" 235 for prop
in properties:
236 res += (
'{prefix}.def_{property}("{property_name}", ' 237 '&{cpp_class}::{property_name})'.format(
240 if prop.ctype.is_const
else "readwrite",
242 property_name=prop.name,
247 """Wrap all the overloaded operators in the `cpp_class`.""" 249 template =
"{prefix}.def({{0}})".format(prefix=prefix)
251 if op.operator ==
"[]":
252 res +=
"{prefix}.def(\"__getitem__\", &{cpp_class}::operator[])".format(
253 prefix=prefix, cpp_class=cpp_class)
254 elif op.operator ==
"()":
255 res +=
"{prefix}.def(\"__call__\", &{cpp_class}::operator())".format(
256 prefix=prefix, cpp_class=cpp_class)
258 res += template.format(
"{0}py::self".format(op.operator))
260 res += template.format(
"py::self {0} py::self".format(
264 def wrap_enum(self, enum, class_name='', module=None, prefix=' ' * 4):
269 enum: The parsed enum to wrap. 270 class_name: The class under which the enum is defined. 271 prefix: The amount of indentation. 276 cpp_class = enum.cpp_typename().to_cpp()
279 cpp_class = class_name +
"::" + cpp_class
281 res =
'{prefix}py::enum_<{cpp_class}>({module}, "{enum.name}", py::arithmetic())'.format(
282 prefix=prefix, module=module, enum=enum, cpp_class=cpp_class)
283 for enumerator
in enum.enumerators:
284 res +=
'\n{prefix} .value("{enumerator.name}", {cpp_class}::{enumerator.name})'.format(
285 prefix=prefix, enumerator=enumerator, cpp_class=cpp_class)
289 def wrap_enums(self, enums, instantiated_class, prefix=' ' * 4):
290 """Wrap multiple enums defined in a class.""" 291 cpp_class = instantiated_class.cpp_class()
292 module_var = instantiated_class.name.lower()
298 class_name=cpp_class,
304 self, instantiated_class: instantiator.InstantiatedClass):
305 """Wrap the class.""" 307 cpp_class = instantiated_class.cpp_class()
310 if instantiated_class.parent_class:
311 class_parent =
"{instantiated_class.parent_class}, ".format(
312 instantiated_class=instantiated_class)
316 if instantiated_class.enums:
318 instance_name = instantiated_class.name.lower()
319 class_declaration = (
320 '\n py::class_<{cpp_class}, {class_parent}' 321 '{shared_ptr_type}::shared_ptr<{cpp_class}>> ' 322 '{instance_name}({module_var}, "{class_name}");' 323 '\n {instance_name}').format(
324 shared_ptr_type=(
'boost' if self.
use_boost else 'std'),
326 class_name=instantiated_class.name,
327 class_parent=class_parent,
328 instance_name=instance_name,
329 module_var=module_var)
330 module_var = instance_name
333 class_declaration = (
334 '\n py::class_<{cpp_class}, {class_parent}' 335 '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' 336 ).format(shared_ptr_type=(
'boost' if self.
use_boost else 'std'),
338 class_name=instantiated_class.name,
339 class_parent=class_parent,
340 module_var=module_var)
342 return (
'{class_declaration}' 345 '{wrapped_static_methods}' 346 '{wrapped_properties}' 347 '{wrapped_operators};\n'.format(
348 class_declaration=class_declaration,
349 wrapped_ctors=self.
wrap_ctors(instantiated_class),
351 instantiated_class.methods, cpp_class),
353 instantiated_class.static_methods, cpp_class),
355 instantiated_class.properties, cpp_class),
357 instantiated_class.operators, cpp_class)))
360 """Wrap STL containers.""" 362 cpp_class = stl_class.cpp_class()
367 '\n py::class_<{cpp_class}, {class_parent}' 368 '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")' 371 '{wrapped_static_methods}' 372 '{wrapped_properties};\n'.format(
373 shared_ptr_type=(
'boost' if self.
use_boost else 'std'),
375 class_name=stl_class.name,
376 class_parent=
str(stl_class.parent_class) +
377 (
', ' if stl_class.parent_class
else ''),
378 module_var=module_var,
383 stl_class.static_methods, cpp_class),
385 stl_class.properties, cpp_class),
389 for i
in range(
min(
len(namespaces1),
len(namespaces2))):
390 if namespaces1[i] != namespaces2[i]:
395 """Get the Pybind11 module name from the namespaces.""" 398 return "m_{}".format(
'_'.join(sub_module_namespaces))
403 idx = 1
if not namespaces[0]
else 0
404 return '::'.join(namespaces[idx:] + [name])
409 """Wrap the complete `namespace`.""" 413 namespaces = namespace.full_namespaces()
418 for element
in namespace.content:
420 include =
"{}\n".format(element)
422 include = include.replace(
'<',
'"').replace(
'>',
'"')
430 wrapped += wrapped_namespace
431 includes += includes_namespace
437 ' ' * 4 +
'pybind11::module {module_var} = ' 438 '{parent_module_var}.def_submodule("{namespace}", "' 439 '{namespace} submodule");\n'.format(
440 module_var=module_var,
441 namespace=namespace.name,
447 for element
in namespace.content:
449 include =
"{}\n".format(element)
451 include = include.replace(
'<',
'"').replace(
'>',
'"')
456 wrapped += wrapped_namespace
457 includes += includes_namespace
459 elif isinstance(element, instantiator.InstantiatedClass):
461 wrapped += self.
wrap_enums(element.enums, element)
466 module_var=module_var,
468 prefix=
'\n' +
' ' * 4)
475 func
for func
in namespace.content
477 instantiator.InstantiatedGlobalFunction))
482 prefix=
'\n' +
' ' * 4 + module_var,
485 return wrapped, includes
488 """Wrap the code in the interface file.""" 492 boost_class_export =
"" 497 new_name = re.sub(
"[,:<> ]",
"", cpp_class)
498 boost_class_export +=
"typedef {cpp_class} {new_name};\n".format(
502 boost_class_export +=
"BOOST_CLASS_EXPORT({new_name})\n".format(
505 holder_type =
"PYBIND11_DECLARE_HOLDER_TYPE(TYPE_PLACEHOLDER_DONOTUSE, " \
506 "{shared_ptr_type}::shared_ptr<TYPE_PLACEHOLDER_DONOTUSE>);" 507 include_boost =
"#include <boost/shared_ptr.hpp>" if self.
use_boost else "" 509 return self.module_template.format(
510 include_boost=include_boost,
513 holder_type=holder_type.format(
514 shared_ptr_type=(
'boost' if self.
use_boost else 'std'))
516 wrapped_namespace=wrapped_namespace,
517 boost_class_export=boost_class_export,
def wrap_methods(self, methods, cpp_class, prefix='\n'+ ' '*8, suffix='')
def wrap_instantiated_class
def __init__(self, module, module_name, top_module_namespaces='', use_boost=False, ignore_classes=(), module_template="")
def _gen_module_var(self, namespaces)
def wrap_ctors(self, my_class)
def _py_args_names(self, args_list)
bool isinstance(handle obj)
def _partial_match(self, namespaces1, namespaces2)
def wrap_stl_class(self, stl_class)
def wrap_namespace(self, namespace)
def _add_namespaces(self, name, namespaces)
def wrap_properties(self, properties, cpp_class, prefix='\n'+ ' '*8)
def wrap_variable(self, namespace, module_var, variable, prefix='\n'+ ' '*8)
def wrap_enums(self, enums, instantiated_class, prefix=' '*4)
def wrap_operators(self, operators, cpp_class, prefix='\n'+ ' '*8)
def wrap_enum(self, enum, class_name='', module=None, prefix=' '*4)
def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix="")
def _method_args_signature_with_names(self, args_list)