pybind_wrapper.py
Go to the documentation of this file.
1 #!/usr/bin/env python3
2 """
3 GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
4 Atlanta, Georgia 30332-0415
5 All Rights Reserved
6 
7 See LICENSE for the license information
8 
9 Code generator for wrapping a C++ module with Pybind11
10 Author: Duy Nguyen Ta, Fan Jiang, Matthew Sklar, Varun Agrawal, and Frank Dellaert
11 """
12 
13 # pylint: disable=too-many-arguments, too-many-instance-attributes, no-self-use, no-else-return, too-many-arguments, unused-format-string-argument, line-too-long
14 
15 import re
16 
17 import gtwrap.interface_parser as parser
18 import gtwrap.template_instantiator as instantiator
19 
20 
22  """
23  Class to generate binding code for Pybind11 specifically.
24  """
25  def __init__(self,
26  module,
27  module_name,
28  top_module_namespaces='',
29  use_boost=False,
30  ignore_classes=(),
31  module_template=""):
32  self.module = module
33  self.module_name = module_name
34  self.top_module_namespaces = top_module_namespaces
35  self.use_boost = use_boost
36  self.ignore_classes = ignore_classes
38  self.module_template = module_template
39  self.python_keywords = ['print', 'lambda']
40 
41  # amount of indentation to add before each function/method declaration.
42  self.method_indent = '\n' + (' ' * 8)
43 
44  def _py_args_names(self, args_list):
45  """Set the argument names in Pybind11 format."""
46  names = args_list.args_names()
47  if names:
48  py_args = []
49  for arg in args_list.args_list:
50  if isinstance(arg.default, str) and arg.default is not None:
51  # string default arg
52  arg.default = ' = "{arg.default}"'.format(arg=arg)
53  elif arg.default: # Other types
54  arg.default = ' = {arg.default}'.format(arg=arg)
55  else:
56  arg.default = ''
57  argument = 'py::arg("{name}"){default}'.format(
58  name=arg.name, default='{0}'.format(arg.default))
59  py_args.append(argument)
60  return ", " + ", ".join(py_args)
61  else:
62  return ''
63 
64  def _method_args_signature_with_names(self, args_list):
65  """Define the method signature types with the argument names."""
66  cpp_types = args_list.to_cpp(self.use_boost)
67  names = args_list.args_names()
68  types_names = [
69  "{} {}".format(ctype, name)
70  for ctype, name in zip(cpp_types, names)
71  ]
72 
73  return ', '.join(types_names)
74 
75  def wrap_ctors(self, my_class):
76  """Wrap the constructors."""
77  res = ""
78  for ctor in my_class.ctors:
79  res += (
80  self.method_indent + '.def(py::init<{args_cpp_types}>()'
81  '{py_args_names})'.format(
82  args_cpp_types=", ".join(ctor.args.to_cpp(self.use_boost)),
83  py_args_names=self._py_args_names(ctor.args),
84  ))
85  return res
86 
87  def _wrap_method(self,
88  method,
89  cpp_class,
90  prefix,
91  suffix,
92  method_suffix=""):
93  py_method = method.name + method_suffix
94  cpp_method = method.to_cpp()
95 
96  if cpp_method in ["serialize", "serializable"]:
97  if not cpp_class in self._serializing_classes:
98  self._serializing_classes.append(cpp_class)
99  serialize_method = self.method_indent + \
100  ".def(\"serialize\", []({class_inst} self){{ return gtsam::serialize(*self); }})".format(class_inst=cpp_class + '*')
101  deserialize_method = self.method_indent + \
102  ".def(\"deserialize\", []({class_inst} self, string serialized){{ gtsam::deserialize(serialized, *self); }}, py::arg(\"serialized\"))" \
103  .format(class_inst=cpp_class + '*')
104  return serialize_method + deserialize_method
105 
106  if cpp_method == "pickle":
107  if not cpp_class in self._serializing_classes:
108  raise ValueError(
109  "Cannot pickle a class which is not serializable")
110  pickle_method = self.method_indent + \
111  ".def(py::pickle({indent} [](const {cpp_class} &a){{ /* __getstate__: Returns a string that encodes the state of the object */ return py::make_tuple(gtsam::serialize(a)); }},{indent} [](py::tuple t){{ /* __setstate__ */ {cpp_class} obj; gtsam::deserialize(t[0].cast<std::string>(), obj); return obj; }}))"
112  return pickle_method.format(cpp_class=cpp_class,
113  indent=self.method_indent)
114 
115  is_method = isinstance(method, instantiator.InstantiatedMethod)
116  is_static = isinstance(method, parser.StaticMethod)
117  return_void = method.return_type.is_void()
118  args_names = method.args.args_names()
119  py_args_names = self._py_args_names(method.args)
120  args_signature_with_names = self._method_args_signature_with_names(
121  method.args)
122 
123  caller = cpp_class + "::" if not is_method else "self->"
124  function_call = ('{opt_return} {caller}{function_name}'
125  '({args_names});'.format(
126  opt_return='return' if not return_void else '',
127  caller=caller,
128  function_name=cpp_method,
129  args_names=', '.join(args_names),
130  ))
131 
132  ret = ('{prefix}.{cdef}("{py_method}",'
133  '[]({opt_self}{opt_comma}{args_signature_with_names}){{'
134  '{function_call}'
135  '}}'
136  '{py_args_names}){suffix}'.format(
137  prefix=prefix,
138  cdef="def_static" if is_static else "def",
139  py_method=py_method if not py_method in self.python_keywords
140  else py_method + "_",
141  opt_self="{cpp_class}* self".format(
142  cpp_class=cpp_class) if is_method else "",
143  opt_comma=', ' if is_method and args_names else '',
144  args_signature_with_names=args_signature_with_names,
145  function_call=function_call,
146  py_args_names=py_args_names,
147  suffix=suffix,
148  ))
149 
150  # Create __repr__ override
151  # We allow all arguments to .print() and let the compiler handle type mismatches.
152  if method.name == 'print':
153  # Redirect stdout - see pybind docs for why this is a good idea:
154  # https://pybind11.readthedocs.io/en/stable/advanced/pycpp/utilities.html#capturing-standard-output-from-ostream
155  ret = ret.replace(
156  'self->print',
157  'py::scoped_ostream_redirect output; self->print')
158 
159  # Make __repr__() call print() internally
160  ret += '''{prefix}.def("__repr__",
161  [](const {cpp_class}& self{opt_comma}{args_signature_with_names}){{
162  gtsam::RedirectCout redirect;
163  self.{method_name}({method_args});
164  return redirect.str();
165  }}{py_args_names}){suffix}'''.format(
166  prefix=prefix,
167  cpp_class=cpp_class,
168  opt_comma=', ' if args_names else '',
169  args_signature_with_names=args_signature_with_names,
170  method_name=method.name,
171  method_args=", ".join(args_names) if args_names else '',
172  py_args_names=py_args_names,
173  suffix=suffix)
174 
175  return ret
176 
177  def wrap_methods(self,
178  methods,
179  cpp_class,
180  prefix='\n' + ' ' * 8,
181  suffix=''):
182  """
183  Wrap all the methods in the `cpp_class`.
184 
185  This function is also used to wrap global functions.
186  """
187  res = ""
188  for method in methods:
189 
190  # To avoid type confusion for insert, currently unused
191  if method.name == 'insert' and cpp_class == 'gtsam::Values':
192  name_list = method.args.args_names()
193  type_list = method.args.to_cpp(self.use_boost)
194  # inserting non-wrapped value types
195  if type_list[0].strip() == 'size_t':
196  method_suffix = '_' + name_list[1].strip()
197  res += self._wrap_method(method=method,
198  cpp_class=cpp_class,
199  prefix=prefix,
200  suffix=suffix,
201  method_suffix=method_suffix)
202 
203  res += self._wrap_method(
204  method=method,
205  cpp_class=cpp_class,
206  prefix=prefix,
207  suffix=suffix,
208  )
209 
210  return res
211 
212  def wrap_variable(self,
213  namespace,
214  module_var,
215  variable,
216  prefix='\n' + ' ' * 8):
217  """Wrap a variable that's not part of a class (i.e. global)
218  """
219  variable_value = ""
220  if variable.default is None:
221  variable_value = variable.name
222  else:
223  variable_value = variable.default
224 
225  return '{prefix}{module_var}.attr("{variable_name}") = {namespace}{variable_value};'.format(
226  prefix=prefix,
227  module_var=module_var,
228  variable_name=variable.name,
229  namespace=namespace,
230  variable_value=variable_value)
231 
232  def wrap_properties(self, properties, cpp_class, prefix='\n' + ' ' * 8):
233  """Wrap all the properties in the `cpp_class`."""
234  res = ""
235  for prop in properties:
236  res += ('{prefix}.def_{property}("{property_name}", '
237  '&{cpp_class}::{property_name})'.format(
238  prefix=prefix,
239  property="readonly"
240  if prop.ctype.is_const else "readwrite",
241  cpp_class=cpp_class,
242  property_name=prop.name,
243  ))
244  return res
245 
246  def wrap_operators(self, operators, cpp_class, prefix='\n' + ' ' * 8):
247  """Wrap all the overloaded operators in the `cpp_class`."""
248  res = ""
249  template = "{prefix}.def({{0}})".format(prefix=prefix)
250  for op in operators:
251  if op.operator == "[]": # __getitem__
252  res += "{prefix}.def(\"__getitem__\", &{cpp_class}::operator[])".format(
253  prefix=prefix, cpp_class=cpp_class)
254  elif op.operator == "()": # __call__
255  res += "{prefix}.def(\"__call__\", &{cpp_class}::operator())".format(
256  prefix=prefix, cpp_class=cpp_class)
257  elif op.is_unary:
258  res += template.format("{0}py::self".format(op.operator))
259  else:
260  res += template.format("py::self {0} py::self".format(
261  op.operator))
262  return res
263 
264  def wrap_enum(self, enum, class_name='', module=None, prefix=' ' * 4):
265  """
266  Wrap an enum.
267 
268  Args:
269  enum: The parsed enum to wrap.
270  class_name: The class under which the enum is defined.
271  prefix: The amount of indentation.
272  """
273  if module is None:
274  module = self._gen_module_var(enum.namespaces())
275 
276  cpp_class = enum.cpp_typename().to_cpp()
277  if class_name:
278  # If class_name is provided, add that as the namespace
279  cpp_class = class_name + "::" + cpp_class
280 
281  res = '{prefix}py::enum_<{cpp_class}>({module}, "{enum.name}", py::arithmetic())'.format(
282  prefix=prefix, module=module, enum=enum, cpp_class=cpp_class)
283  for enumerator in enum.enumerators:
284  res += '\n{prefix} .value("{enumerator.name}", {cpp_class}::{enumerator.name})'.format(
285  prefix=prefix, enumerator=enumerator, cpp_class=cpp_class)
286  res += ";\n\n"
287  return res
288 
289  def wrap_enums(self, enums, instantiated_class, prefix=' ' * 4):
290  """Wrap multiple enums defined in a class."""
291  cpp_class = instantiated_class.cpp_class()
292  module_var = instantiated_class.name.lower()
293  res = ''
294 
295  for enum in enums:
296  res += "\n" + self.wrap_enum(
297  enum,
298  class_name=cpp_class,
299  module=module_var,
300  prefix=prefix)
301  return res
302 
304  self, instantiated_class: instantiator.InstantiatedClass):
305  """Wrap the class."""
306  module_var = self._gen_module_var(instantiated_class.namespaces())
307  cpp_class = instantiated_class.cpp_class()
308  if cpp_class in self.ignore_classes:
309  return ""
310  if instantiated_class.parent_class:
311  class_parent = "{instantiated_class.parent_class}, ".format(
312  instantiated_class=instantiated_class)
313  else:
314  class_parent = ''
315 
316  if instantiated_class.enums:
317  # If class has enums, define an instance and set module_var to the instance
318  instance_name = instantiated_class.name.lower()
319  class_declaration = (
320  '\n py::class_<{cpp_class}, {class_parent}'
321  '{shared_ptr_type}::shared_ptr<{cpp_class}>> '
322  '{instance_name}({module_var}, "{class_name}");'
323  '\n {instance_name}').format(
324  shared_ptr_type=('boost' if self.use_boost else 'std'),
325  cpp_class=cpp_class,
326  class_name=instantiated_class.name,
327  class_parent=class_parent,
328  instance_name=instance_name,
329  module_var=module_var)
330  module_var = instance_name
331 
332  else:
333  class_declaration = (
334  '\n py::class_<{cpp_class}, {class_parent}'
335  '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")'
336  ).format(shared_ptr_type=('boost' if self.use_boost else 'std'),
337  cpp_class=cpp_class,
338  class_name=instantiated_class.name,
339  class_parent=class_parent,
340  module_var=module_var)
341 
342  return ('{class_declaration}'
343  '{wrapped_ctors}'
344  '{wrapped_methods}'
345  '{wrapped_static_methods}'
346  '{wrapped_properties}'
347  '{wrapped_operators};\n'.format(
348  class_declaration=class_declaration,
349  wrapped_ctors=self.wrap_ctors(instantiated_class),
350  wrapped_methods=self.wrap_methods(
351  instantiated_class.methods, cpp_class),
352  wrapped_static_methods=self.wrap_methods(
353  instantiated_class.static_methods, cpp_class),
354  wrapped_properties=self.wrap_properties(
355  instantiated_class.properties, cpp_class),
356  wrapped_operators=self.wrap_operators(
357  instantiated_class.operators, cpp_class)))
358 
359  def wrap_stl_class(self, stl_class):
360  """Wrap STL containers."""
361  module_var = self._gen_module_var(stl_class.namespaces())
362  cpp_class = stl_class.cpp_class()
363  if cpp_class in self.ignore_classes:
364  return ""
365 
366  return (
367  '\n py::class_<{cpp_class}, {class_parent}'
368  '{shared_ptr_type}::shared_ptr<{cpp_class}>>({module_var}, "{class_name}")'
369  '{wrapped_ctors}'
370  '{wrapped_methods}'
371  '{wrapped_static_methods}'
372  '{wrapped_properties};\n'.format(
373  shared_ptr_type=('boost' if self.use_boost else 'std'),
374  cpp_class=cpp_class,
375  class_name=stl_class.name,
376  class_parent=str(stl_class.parent_class) +
377  (', ' if stl_class.parent_class else ''),
378  module_var=module_var,
379  wrapped_ctors=self.wrap_ctors(stl_class),
380  wrapped_methods=self.wrap_methods(stl_class.methods,
381  cpp_class),
382  wrapped_static_methods=self.wrap_methods(
383  stl_class.static_methods, cpp_class),
384  wrapped_properties=self.wrap_properties(
385  stl_class.properties, cpp_class),
386  ))
387 
388  def _partial_match(self, namespaces1, namespaces2):
389  for i in range(min(len(namespaces1), len(namespaces2))):
390  if namespaces1[i] != namespaces2[i]:
391  return False
392  return True
393 
394  def _gen_module_var(self, namespaces):
395  """Get the Pybind11 module name from the namespaces."""
396  # We skip the first value in namespaces since it is empty
397  sub_module_namespaces = namespaces[len(self.top_module_namespaces):]
398  return "m_{}".format('_'.join(sub_module_namespaces))
399 
400  def _add_namespaces(self, name, namespaces):
401  if namespaces:
402  # Ignore the first empty global namespace.
403  idx = 1 if not namespaces[0] else 0
404  return '::'.join(namespaces[idx:] + [name])
405  else:
406  return name
407 
408  def wrap_namespace(self, namespace):
409  """Wrap the complete `namespace`."""
410  wrapped = ""
411  includes = ""
412 
413  namespaces = namespace.full_namespaces()
414  if not self._partial_match(namespaces, self.top_module_namespaces):
415  return "", ""
416 
417  if len(namespaces) < len(self.top_module_namespaces):
418  for element in namespace.content:
419  if isinstance(element, parser.Include):
420  include = "{}\n".format(element)
421  # replace the angle brackets with quotes
422  include = include.replace('<', '"').replace('>', '"')
423  includes += include
424  if isinstance(element, parser.Namespace):
425  (
426  wrapped_namespace,
427  includes_namespace,
428  ) = self.wrap_namespace( # noqa
429  element)
430  wrapped += wrapped_namespace
431  includes += includes_namespace
432  else:
433  module_var = self._gen_module_var(namespaces)
434 
435  if len(namespaces) > len(self.top_module_namespaces):
436  wrapped += (
437  ' ' * 4 + 'pybind11::module {module_var} = '
438  '{parent_module_var}.def_submodule("{namespace}", "'
439  '{namespace} submodule");\n'.format(
440  module_var=module_var,
441  namespace=namespace.name,
442  parent_module_var=self._gen_module_var(
443  namespaces[:-1]),
444  ))
445 
446  # Wrap an include statement, namespace, class or enum
447  for element in namespace.content:
448  if isinstance(element, parser.Include):
449  include = "{}\n".format(element)
450  # replace the angle brackets with quotes
451  include = include.replace('<', '"').replace('>', '"')
452  includes += include
453  elif isinstance(element, parser.Namespace):
454  wrapped_namespace, includes_namespace = self.wrap_namespace(
455  element)
456  wrapped += wrapped_namespace
457  includes += includes_namespace
458 
459  elif isinstance(element, instantiator.InstantiatedClass):
460  wrapped += self.wrap_instantiated_class(element)
461  wrapped += self.wrap_enums(element.enums, element)
462 
463  elif isinstance(element, parser.Variable):
464  variable_namespace = self._add_namespaces('', namespaces)
465  wrapped += self.wrap_variable(namespace=variable_namespace,
466  module_var=module_var,
467  variable=element,
468  prefix='\n' + ' ' * 4)
469 
470  elif isinstance(element, parser.Enum):
471  wrapped += self.wrap_enum(element)
472 
473  # Global functions.
474  all_funcs = [
475  func for func in namespace.content
476  if isinstance(func, (parser.GlobalFunction,
477  instantiator.InstantiatedGlobalFunction))
478  ]
479  wrapped += self.wrap_methods(
480  all_funcs,
481  self._add_namespaces('', namespaces)[:-2],
482  prefix='\n' + ' ' * 4 + module_var,
483  suffix=';',
484  )
485  return wrapped, includes
486 
487  def wrap(self):
488  """Wrap the code in the interface file."""
489  wrapped_namespace, includes = self.wrap_namespace(self.module)
490 
491  # Export classes for serialization.
492  boost_class_export = ""
493  for cpp_class in self._serializing_classes:
494  new_name = cpp_class
495  # The boost's macro doesn't like commas, so we have to typedef.
496  if ',' in cpp_class:
497  new_name = re.sub("[,:<> ]", "", cpp_class)
498  boost_class_export += "typedef {cpp_class} {new_name};\n".format( # noqa
499  cpp_class=cpp_class,
500  new_name=new_name,
501  )
502  boost_class_export += "BOOST_CLASS_EXPORT({new_name})\n".format(
503  new_name=new_name, )
504 
505  holder_type = "PYBIND11_DECLARE_HOLDER_TYPE(TYPE_PLACEHOLDER_DONOTUSE, " \
506  "{shared_ptr_type}::shared_ptr<TYPE_PLACEHOLDER_DONOTUSE>);"
507  include_boost = "#include <boost/shared_ptr.hpp>" if self.use_boost else ""
508 
509  return self.module_template.format(
510  include_boost=include_boost,
511  module_name=self.module_name,
512  includes=includes,
513  holder_type=holder_type.format(
514  shared_ptr_type=('boost' if self.use_boost else 'std'))
515  if self.use_boost else "",
516  wrapped_namespace=wrapped_namespace,
517  boost_class_export=boost_class_export,
518  )
def wrap_methods(self, methods, cpp_class, prefix='\n'+ ' '*8, suffix='')
#define min(a, b)
Definition: datatypes.h:19
def __init__(self, module, module_name, top_module_namespaces='', use_boost=False, ignore_classes=(), module_template="")
def _gen_module_var(self, namespaces)
def _py_args_names(self, args_list)
bool isinstance(handle obj)
Definition: pytypes.h:384
def _partial_match(self, namespaces1, namespaces2)
Definition: pytypes.h:928
def _add_namespaces(self, name, namespaces)
def wrap_properties(self, properties, cpp_class, prefix='\n'+ ' '*8)
def wrap_variable(self, namespace, module_var, variable, prefix='\n'+ ' '*8)
Definition: pytypes.h:1301
def wrap_enums(self, enums, instantiated_class, prefix=' '*4)
def wrap_operators(self, operators, cpp_class, prefix='\n'+ ' '*8)
size_t len(handle h)
Definition: pytypes.h:1514
def wrap_enum(self, enum, class_name='', module=None, prefix=' '*4)
def _wrap_method(self, method, cpp_class, prefix, suffix, method_suffix="")
def _method_args_signature_with_names(self, args_list)


gtsam
Author(s):
autogenerated on Sat May 8 2021 02:43:44