matlab_wrapper.py
Go to the documentation of this file.
1 """
2 Code to use the parsed results and convert it to a format
3 that Matlab's MEX compiler can use.
4 """
5 
6 # pylint: disable=too-many-lines, no-self-use, too-many-arguments, too-many-branches, too-many-statements
7 
8 import os
9 import os.path as osp
10 import sys
11 import textwrap
12 from functools import partial, reduce
13 from typing import Dict, Iterable, List, Union
14 
15 import gtwrap.interface_parser as parser
16 import gtwrap.template_instantiator as instantiator
17 
18 
20  """ Wrap the given C++ code into Matlab.
21 
22  Attributes
23  module: the C++ module being wrapped
24  module_name: name of the C++ module being wrapped
25  top_module_namespace: C++ namespace for the top module (default '')
26  ignore_classes: A list of classes to ignore (default [])
27  """
28  # Map the data type to its Matlab class.
29  # Found in Argument.cpp in old wrapper
30  data_type = {
31  'string': 'char',
32  'char': 'char',
33  'unsigned char': 'unsigned char',
34  'Vector': 'double',
35  'Matrix': 'double',
36  'int': 'numeric',
37  'size_t': 'numeric',
38  'bool': 'logical'
39  }
40  # Map the data type into the type used in Matlab methods.
41  # Found in matlab.h in old wrapper
42  data_type_param = {
43  'string': 'char',
44  'char': 'char',
45  'unsigned char': 'unsigned char',
46  'size_t': 'int',
47  'int': 'int',
48  'double': 'double',
49  'Point2': 'double',
50  'Point3': 'double',
51  'Vector': 'double',
52  'Matrix': 'double',
53  'bool': 'bool'
54  }
55  # Methods that should not be wrapped directly
56  whitelist = ['serializable', 'serialize']
57  # Methods that should be ignored
58  ignore_methods = ['pickle']
59  # Datatypes that do not need to be checked in methods
60  not_check_type: list = []
61  # Data types that are primitive types
62  not_ptr_type = ['int', 'double', 'bool', 'char', 'unsigned char', 'size_t']
63  # Ignore the namespace for these datatypes
64  ignore_namespace = ['Matrix', 'Vector', 'Point2', 'Point3']
65  # The amount of times the wrapper has created a call to geometry_wrapper
66  wrapper_id = 0
67  # Map each wrapper id to what its collector function namespace, class, type, and string format
68  wrapper_map: dict = {}
69  # Set of all the includes in the namespace
70  includes: Dict[parser.Include, int] = {}
71  # Set of all classes in the namespace
72  classes: List[Union[parser.Class, instantiator.InstantiatedClass]] = []
73  classes_elems: Dict[Union[parser.Class, instantiator.InstantiatedClass], int] = {}
74  # Id for ordering global functions in the wrapper
75  global_function_id = 0
76  # Files and their content
77  content: List[str] = []
78 
79  # Ensure the template file is always picked up from the correct directory.
80  dir_path = osp.dirname(osp.realpath(__file__))
81  with open(osp.join(dir_path, "matlab_wrapper.tpl")) as f:
82  wrapper_file_header = f.read()
83 
84  def __init__(self,
85  module,
86  module_name,
87  top_module_namespace='',
88  ignore_classes=()):
89  self.module = module
90  self.module_name = module_name
91  self.top_module_namespace = top_module_namespace
92  self.ignore_classes = ignore_classes
93  self.verbose = False
94 
95  def _debug(self, message):
96  if not self.verbose:
97  return
98  print(message, file=sys.stderr)
99 
100  def _add_include(self, include):
101  self.includes[include] = 0
102 
103  def _add_class(self, instantiated_class):
104  if self.classes_elems.get(instantiated_class) is None:
105  self.classes_elems[instantiated_class] = 0
106  self.classes.append(instantiated_class)
107 
108  def _update_wrapper_id(self, collector_function=None, id_diff=0):
109  """Get and define wrapper ids.
110 
111  Generates the map of id -> collector function.
112 
113  Args:
114  collector_function: tuple storing info about the wrapper function
115  (namespace, class instance, function type, function name,
116  extra)
117  id_diff: constant to add to the id in the map
118 
119  Returns:
120  the current wrapper id
121  """
122  if collector_function is not None:
123  is_instantiated_class = isinstance(collector_function[1],
124  instantiator.InstantiatedClass)
125 
126  if is_instantiated_class:
127  function_name = collector_function[0] + \
128  collector_function[1].name + '_' + collector_function[2]
129  else:
130  function_name = collector_function[1].name
131 
132  self.wrapper_map[self.wrapper_id] = (
133  collector_function[0], collector_function[1],
134  collector_function[2], function_name + '_' +
135  str(self.wrapper_id + id_diff), collector_function[3])
136 
137  self.wrapper_id += 1
138 
139  return self.wrapper_id - 1
140 
141  def _qualified_name(self, names):
142  return 'handle' if names == '' else names
143 
144  def _insert_spaces(self, x, y):
145  """Insert spaces at the beginning of each line
146 
147  Args:
148  x: the statement currently generated
149  y: the addition to add to the statement
150  """
151  return x + '\n' + ('' if y == '' else ' ') + y
152 
153  def _is_shared_ptr(self, arg_type):
154  """
155  Determine if the `interface_parser.Type` should be treated as a
156  shared pointer in the wrapper.
157  """
158  return arg_type.is_shared_ptr or (
159  arg_type.typename.name not in self.not_ptr_type
160  and arg_type.typename.name not in self.ignore_namespace
161  and arg_type.typename.name != 'string')
162 
163  def _is_ptr(self, arg_type):
164  """
165  Determine if the `interface_parser.Type` should be treated as a
166  raw pointer in the wrapper.
167  """
168  return arg_type.is_ptr or (
169  arg_type.typename.name not in self.not_ptr_type
170  and arg_type.typename.name not in self.ignore_namespace
171  and arg_type.typename.name != 'string')
172 
173  def _is_ref(self, arg_type):
174  """Determine if the interface_parser.Type should be treated as a
175  reference in the wrapper.
176  """
177  return arg_type.typename.name not in self.ignore_namespace and \
178  arg_type.typename.name not in self.not_ptr_type and \
179  arg_type.is_ref
180 
181  def _group_methods(self, methods):
182  """Group overloaded methods together"""
183  method_map = {}
184  method_out = []
185 
186  for method in methods:
187  method_index = method_map.get(method.name)
188 
189  if method_index is None:
190  method_map[method.name] = len(method_out)
191  method_out.append([method])
192  else:
193  self._debug("[_group_methods] Merging {} with {}".format(
194  method_index, method.name))
195  method_out[method_index].append(method)
196 
197  return method_out
198 
199  def _clean_class_name(self, instantiated_class):
200  """Reformatted the C++ class name to fit Matlab defined naming
201  standards
202  """
203  if len(instantiated_class.ctors) != 0:
204  return instantiated_class.ctors[0].name
205 
206  return instantiated_class.name
207 
208  @classmethod
209  def _format_type_name(cls,
210  type_name,
211  separator='::',
212  include_namespace=True,
213  constructor=False,
214  method=False):
215  """
216  Args:
217  type_name: an interface_parser.Typename to reformat
218  separator: the statement to add between namespaces and typename
219  include_namespace: whether to include namespaces when reformatting
220  constructor: if the typename will be in a constructor
221  method: if the typename will be in a method
222 
223  Raises:
224  constructor and method cannot both be true
225  """
226  if constructor and method:
227  raise Exception(
228  'Constructor and method parameters cannot both be True')
229 
230  formatted_type_name = ''
231  name = type_name.name
232 
233  if include_namespace:
234  for namespace in type_name.namespaces:
235  if name not in cls.ignore_namespace and namespace != '':
236  formatted_type_name += namespace + separator
237 
238  #self._debug("formatted_ns: {}, ns: {}".format(formatted_type_name, type_name.namespaces))
239  if constructor:
240  formatted_type_name += cls.data_type.get(name) or name
241  elif method:
242  formatted_type_name += cls.data_type_param.get(name) or name
243  else:
244  formatted_type_name += name
245 
246  if separator == "::": # C++
247  templates = []
248  for idx in range(len(type_name.instantiations)):
249  template = '{}'.format(
250  cls._format_type_name(type_name.instantiations[idx],
251  include_namespace=include_namespace,
252  constructor=constructor,
253  method=method))
254  templates.append(template)
255 
256  if len(templates) > 0: # If there are no templates
257  formatted_type_name += '<{}>'.format(','.join(templates))
258 
259  else:
260  for idx in range(len(type_name.instantiations)):
261  formatted_type_name += '{}'.format(
262  cls._format_type_name(type_name.instantiations[idx],
263  separator=separator,
264  include_namespace=False,
265  constructor=constructor,
266  method=method))
267 
268  return formatted_type_name
269 
270  @classmethod
271  def _format_return_type(cls,
272  return_type,
273  include_namespace=False,
274  separator="::"):
275  """Format return_type.
276 
277  Args:
278  return_type: an interface_parser.ReturnType to reformat
279  include_namespace: whether to include namespaces when reformatting
280  """
281  return_wrap = ''
282 
283  if cls._return_count(return_type) == 1:
284  return_wrap = cls._format_type_name(
285  return_type.type1.typename,
286  separator=separator,
287  include_namespace=include_namespace)
288  else:
289  return_wrap = 'pair< {type1}, {type2} >'.format(
290  type1=cls._format_type_name(
291  return_type.type1.typename,
292  separator=separator,
293  include_namespace=include_namespace),
294  type2=cls._format_type_name(
295  return_type.type2.typename,
296  separator=separator,
297  include_namespace=include_namespace))
298 
299  return return_wrap
300 
301  def _format_class_name(self, instantiated_class, separator=''):
302  """Format a template_instantiator.InstantiatedClass name."""
303  if instantiated_class.parent == '':
304  parent_full_ns = ['']
305  else:
306  parent_full_ns = instantiated_class.parent.full_namespaces()
307  # class_name = instantiated_class.parent.name
308  #
309  # if class_name != '':
310  # class_name += separator
311  #
312  # class_name += instantiated_class.name
313  parentname = "".join([separator + x
314  for x in parent_full_ns]) + separator
315 
316  class_name = parentname[2 * len(separator):]
317 
318  class_name += instantiated_class.name
319 
320  return class_name
321 
322  def _format_static_method(self, static_method, separator=''):
323  """Example:
324 
325  gtsamPoint3.staticFunction
326  """
327  method = ''
328 
329  if isinstance(static_method, parser.StaticMethod):
330  method += "".join([separator + x for x in static_method.parent.namespaces()]) + \
331  separator + static_method.parent.name + separator
332 
333  return method[2 * len(separator):]
334 
335  def _format_instance_method(self, instance_method, separator=''):
336  """Example:
337 
338  gtsamPoint3.staticFunction
339  """
340  method = ''
341 
342  if isinstance(instance_method, instantiator.InstantiatedMethod):
343  method_list = [
344  separator + x
345  for x in instance_method.parent.parent.full_namespaces()
346  ]
347  method += "".join(method_list) + separator
348 
349  method += instance_method.parent.name + separator
350  method += instance_method.original.name
351  method += "<" + instance_method.instantiations.to_cpp() + ">"
352 
353  return method[2 * len(separator):]
354 
355  def _format_global_method(self, static_method, separator=''):
356  """Example:
357 
358  gtsamPoint3.staticFunction
359  """
360  method = ''
361 
362  if isinstance(static_method, parser.GlobalFunction):
363  method += "".join([separator + x for x in static_method.parent.full_namespaces()]) + \
364  separator
365 
366  return method[2 * len(separator):]
367 
368  def _wrap_args(self, args):
369  """Wrap an interface_parser.ArgumentList into a list of arguments.
370 
371  Returns:
372  A string representation of the arguments. For example:
373  'int x, double y'
374  """
375  arg_wrap = ''
376 
377  for i, arg in enumerate(args.args_list, 1):
378  c_type = self._format_type_name(arg.ctype.typename,
379  include_namespace=False)
380 
381  arg_wrap += '{c_type} {arg_name}{comma}'.format(
382  c_type=c_type,
383  arg_name=arg.name,
384  comma='' if i == len(args.args_list) else ', ')
385 
386  return arg_wrap
387 
388  def _wrap_variable_arguments(self, args, wrap_datatypes=True):
389  """ Wrap an interface_parser.ArgumentList into a statement of argument
390  checks.
391 
392  Returns:
393  A string representation of a variable arguments for an if
394  statement. For example:
395  ' && isa(varargin{1},'double') && isa(varargin{2},'numeric')'
396  """
397  var_arg_wrap = ''
398 
399  for i, arg in enumerate(args.args_list, 1):
400  name = arg.ctype.typename.name
401  if name in self.not_check_type:
402  continue
403 
404  check_type = self.data_type_param.get(name)
405 
406  if self.data_type.get(check_type):
407  check_type = self.data_type[check_type]
408 
409  if check_type is None:
410  check_type = self._format_type_name(
411  arg.ctype.typename,
412  separator='.',
413  constructor=not wrap_datatypes)
414 
415  var_arg_wrap += " && isa(varargin{{{num}}},'{data_type}')".format(
416  num=i, data_type=check_type)
417  if name == 'Vector':
418  var_arg_wrap += ' && size(varargin{{{num}}},2)==1'.format(
419  num=i)
420  if name == 'Point2':
421  var_arg_wrap += ' && size(varargin{{{num}}},1)==2'.format(
422  num=i)
423  var_arg_wrap += ' && size(varargin{{{num}}},2)==1'.format(
424  num=i)
425  if name == 'Point3':
426  var_arg_wrap += ' && size(varargin{{{num}}},1)==3'.format(
427  num=i)
428  var_arg_wrap += ' && size(varargin{{{num}}},2)==1'.format(
429  num=i)
430 
431  return var_arg_wrap
432 
434  """ Wrap an interface_parser.ArgumentList into a list of argument
435  variables.
436 
437  Returns:
438  A string representation of a list of variable arguments.
439  For example:
440  'varargin{1}, varargin{2}, varargin{3}'
441  """
442  var_list_wrap = ''
443  first = True
444 
445  for i in range(1, len(args.args_list) + 1):
446  if first:
447  var_list_wrap += 'varargin{{{}}}'.format(i)
448  first = False
449  else:
450  var_list_wrap += ', varargin{{{}}}'.format(i)
451 
452  return var_list_wrap
453 
455  """
456  Wrap the given arguments into either just a varargout call or a
457  call in an if statement that checks if the parameters are accurate.
458  """
459  check_statement = ''
460  arg_id = 1
461 
462  if check_statement == '':
463  check_statement = \
464  'if length(varargin) == {param_count}'.format(
465  param_count=len(args.args_list))
466 
467  for _, arg in enumerate(args.args_list):
468  name = arg.ctype.typename.name
469 
470  if name in self.not_check_type:
471  arg_id += 1
472  continue
473 
474  check_type = self.data_type_param.get(name)
475 
476  if self.data_type.get(check_type):
477  check_type = self.data_type[check_type]
478 
479  if check_type is None:
480  check_type = self._format_type_name(arg.ctype.typename,
481  separator='.')
482 
483  check_statement += " && isa(varargin{{{id}}},'{ctype}')".format(
484  id=arg_id, ctype=check_type)
485 
486  if name == 'Vector':
487  check_statement += ' && size(varargin{{{num}}},2)==1'.format(
488  num=arg_id)
489  if name == 'Point2':
490  check_statement += ' && size(varargin{{{num}}},1)==2'.format(
491  num=arg_id)
492  check_statement += ' && size(varargin{{{num}}},2)==1'.format(
493  num=arg_id)
494  if name == 'Point3':
495  check_statement += ' && size(varargin{{{num}}},1)==3'.format(
496  num=arg_id)
497  check_statement += ' && size(varargin{{{num}}},2)==1'.format(
498  num=arg_id)
499 
500  arg_id += 1
501 
502  check_statement = check_statement \
503  if check_statement == '' \
504  else check_statement + '\n'
505 
506  return check_statement
507 
508  def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False):
509  """Format the interface_parser.Arguments.
510 
511  Examples:
512  ((a), unsigned char a = unwrap< unsigned char >(in[1]);),
513  ((a), Test& t = *unwrap_shared_ptr< Test >(in[1], "ptr_Test");),
514  ((a), std::shared_ptr<Test> p1 = unwrap_shared_ptr< Test >(in[1], "ptr_Test");)
515  """
516  params = ''
517  body_args = ''
518 
519  for arg in args.args_list:
520  if params != '':
521  params += ','
522 
523  if self._is_ref(arg.ctype): # and not constructor:
524  ctype_camel = self._format_type_name(arg.ctype.typename,
525  separator='')
526  body_args += textwrap.indent(textwrap.dedent('''\
527  {ctype}& {name} = *unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");
528  '''.format(ctype=self._format_type_name(arg.ctype.typename),
529  ctype_camel=ctype_camel,
530  name=arg.name,
531  id=arg_id)),
532  prefix=' ')
533 
534  elif (self._is_shared_ptr(arg.ctype) or self._is_ptr(arg.ctype)) and \
535  arg.ctype.typename.name not in self.ignore_namespace:
536  if arg.ctype.is_shared_ptr:
537  call_type = arg.ctype.is_shared_ptr
538  else:
539  call_type = arg.ctype.is_ptr
540 
541  body_args += textwrap.indent(textwrap.dedent('''\
542  {std_boost}::shared_ptr<{ctype_sep}> {name} = unwrap_shared_ptr< {ctype_sep} >(in[{id}], "ptr_{ctype}");
543  '''.format(std_boost='boost' if constructor else 'boost',
544  ctype_sep=self._format_type_name(
545  arg.ctype.typename),
546  ctype=self._format_type_name(arg.ctype.typename,
547  separator=''),
548  name=arg.name,
549  id=arg_id)),
550  prefix=' ')
551  if call_type == "":
552  params += "*"
553 
554  else:
555  body_args += textwrap.indent(textwrap.dedent('''\
556  {ctype} {name} = unwrap< {ctype} >(in[{id}]);
557  '''.format(ctype=arg.ctype.typename.name,
558  name=arg.name,
559  id=arg_id)),
560  prefix=' ')
561 
562  params += arg.name
563 
564  arg_id += 1
565 
566  return params, body_args
567 
568  @staticmethod
569  def _return_count(return_type):
570  """The amount of objects returned by the given
571  interface_parser.ReturnType.
572  """
573  return 1 if return_type.type2 == '' else 2
574 
575  def _wrapper_name(self):
576  """Determine the name of wrapper function."""
577  return self.module_name + '_wrapper'
578 
579  def class_serialize_comment(self, class_name, static_methods):
580  """Generate comments for serialize methods."""
581  comment_wrap = ''
582  static_methods = sorted(static_methods, key=lambda name: name.name)
583 
584  for static_method in static_methods:
585  if comment_wrap == '':
586  comment_wrap = '%-------Static Methods-------\n'
587 
588  comment_wrap += '%{name}({args}) : returns {return_type}\n'.format(
589  name=static_method.name,
590  args=self._wrap_args(static_method.args),
591  return_type=self._format_return_type(static_method.return_type,
592  include_namespace=True))
593 
594  comment_wrap += textwrap.dedent('''\
595  %
596  %-------Serialization Interface-------
597  %string_serialize() : returns string
598  %string_deserialize(string serialized) : returns {class_name}
599  %
600  ''').format(class_name=class_name)
601 
602  return comment_wrap
603 
604  def class_comment(self, instantiated_class):
605  """Generate comments for the given class in Matlab.
606 
607  Args
608  instantiated_class: the class being wrapped
609  ctors: a list of the constructors in the class
610  methods: a list of the methods in the class
611  """
612  class_name = instantiated_class.name
613  ctors = instantiated_class.ctors
614  methods = instantiated_class.methods
615  static_methods = instantiated_class.static_methods
616 
617  comment = textwrap.dedent('''\
618  %class {class_name}, see Doxygen page for details
619  %at https://gtsam.org/doxygen/
620  ''').format(class_name=class_name)
621 
622  if len(ctors) != 0:
623  comment += '%\n%-------Constructors-------\n'
624 
625  # Write constructors
626  for ctor in ctors:
627  comment += '%{ctor_name}({args})\n'.format(ctor_name=ctor.name,
628  args=self._wrap_args(
629  ctor.args))
630 
631  if len(methods) != 0:
632  comment += '%\n' \
633  '%-------Methods-------\n'
634 
635  methods = sorted(methods, key=lambda name: name.name)
636 
637  # Write methods
638  for method in methods:
639  if method.name in self.whitelist:
640  continue
641  if method.name in self.ignore_methods:
642  continue
643 
644  comment += '%{name}({args})'.format(name=method.name,
645  args=self._wrap_args(
646  method.args))
647 
648  if method.return_type.type2 == '':
649  return_type = self._format_type_name(
650  method.return_type.type1.typename)
651  else:
652  return_type = 'pair< {type1}, {type2} >'.format(
653  type1=self._format_type_name(
654  method.return_type.type1.typename),
655  type2=self._format_type_name(
656  method.return_type.type2.typename))
657 
658  comment += ' : returns {return_type}\n'.format(
659  return_type=return_type)
660 
661  comment += '%\n'
662 
663  if len(static_methods) != 0:
664  comment += self.class_serialize_comment(class_name, static_methods)
665 
666  return comment
667 
669  """Generate the C++ file for the wrapper."""
670  file_name = self._wrapper_name() + '.cpp'
671 
672  wrapper_file = self.wrapper_file_header
673 
674  return file_name, wrapper_file
675 
676  def wrap_method(self, methods):
677  """Wrap methods in the body of a class."""
678  if not isinstance(methods, list):
679  methods = [methods]
680 
681  # for method in methods:
682  # output = ''
683 
684  return ''
685 
686  def wrap_methods(self, methods, global_funcs=False, global_ns=None):
687  """
688  Wrap a sequence of methods. Groups methods with the same names
689  together.
690  If global_funcs is True then output every method into its own file.
691  """
692  output = ''
693  methods = self._group_methods(methods)
694 
695  for method in methods:
696  if method in self.ignore_methods:
697  continue
698 
699  if global_funcs:
700  self._debug("[wrap_methods] wrapping: {}..{}={}".format(
701  method[0].parent.name, method[0].name,
702  type(method[0].parent.name)))
703 
704  method_text = self.wrap_global_function(method)
705  self.content.append(("".join([
706  '+' + x + '/' for x in global_ns.full_namespaces()[1:]
707  ])[:-1], [(method[0].name + '.m', method_text)]))
708  else:
709  method_text = self.wrap_method(method)
710  output += ''
711 
712  return output
713 
714  def wrap_global_function(self, function):
715  """Wrap the given global function."""
716  if not isinstance(function, list):
717  function = [function]
718 
719  function_name = function[0].name
720 
721  # Get all combinations of parameters
722  param_wrap = ''
723 
724  for i, overload in enumerate(function):
725  param_wrap += ' if' if i == 0 else ' elseif'
726  param_wrap += ' length(varargin) == '
727 
728  if len(overload.args.args_list) == 0:
729  param_wrap += '0\n'
730  else:
731  param_wrap += str(len(overload.args.args_list)) \
732  + self._wrap_variable_arguments(overload.args, False) + '\n'
733 
734  # Determine format of return and varargout statements
735  return_type_formatted = self._format_return_type(
736  overload.return_type, include_namespace=True, separator=".")
737  varargout = self._format_varargout(overload.return_type,
738  return_type_formatted)
739 
740  param_wrap += textwrap.indent(textwrap.dedent('''\
741  {varargout}{module_name}_wrapper({num}, varargin{{:}});
742  ''').format(varargout=varargout,
743  module_name=self.module_name,
744  num=self._update_wrapper_id(
745  collector_function=(function[0].parent.name,
746  function[i], 'global_function',
747  None))),
748  prefix=' ')
749 
750  param_wrap += textwrap.indent(textwrap.dedent('''\
751  else
752  error('Arguments do not match any overload of function {func_name}');
753  ''').format(func_name=function_name),
754  prefix=' ')
755 
756  global_function = textwrap.indent(textwrap.dedent('''\
757  function varargout = {m_method}(varargin)
758  {statements} end
759  ''').format(m_method=function_name, statements=param_wrap),
760  prefix='')
761 
762  return global_function
763 
764  def wrap_class_constructors(self, namespace_name, inst_class, parent_name,
765  ctors, is_virtual):
766  """Wrap class constructor.
767 
768  Args:
769  namespace_name: the name of the namespace ('' if it does not exist)
770  inst_class: instance of the class
771  parent_name: the name of the parent class if it exists
772  ctors: the interface_parser.Constructor in the class
773  is_virtual: whether the class is part of a virtual inheritance
774  chain
775  """
776  has_parent = parent_name != ''
777  class_name = inst_class.name
778  if has_parent:
779  parent_name = self._format_type_name(parent_name, separator=".")
780  if not isinstance(ctors, Iterable):
781  ctors = [ctors]
782 
783  methods_wrap = textwrap.indent(textwrap.dedent("""\
784  methods
785  function obj = {class_name}(varargin)
786  """).format(class_name=class_name),
787  prefix='')
788 
789  if is_virtual:
790  methods_wrap += " if (nargin == 2 || (nargin == 3 && strcmp(varargin{3}, 'void')))"
791  else:
792  methods_wrap += ' if nargin == 2'
793 
794  methods_wrap += " && isa(varargin{1}, 'uint64')"
795  methods_wrap += " && varargin{1} == uint64(5139824614673773682)\n"
796 
797  if is_virtual:
798  methods_wrap += textwrap.indent(textwrap.dedent('''\
799  if nargin == 2
800  my_ptr = varargin{{2}};
801  else
802  my_ptr = {wrapper_name}({id}, varargin{{2}});
803  end
804  ''').format(wrapper_name=self._wrapper_name(),
805  id=self._update_wrapper_id() + 1),
806  prefix=' ')
807  else:
808  methods_wrap += ' my_ptr = varargin{2};\n'
809 
810  collector_base_id = self._update_wrapper_id(
811  (namespace_name, inst_class, 'collectorInsertAndMakeBase', None),
812  id_diff=-1 if is_virtual else 0)
813 
814  methods_wrap += ' {ptr}{wrapper_name}({id}, my_ptr);\n' \
815  .format(
816  ptr='base_ptr = ' if has_parent else '',
817  wrapper_name=self._wrapper_name(),
818  id=collector_base_id - (1 if is_virtual else 0))
819 
820  for ctor in ctors:
821  wrapper_return = '[ my_ptr, base_ptr ] = ' \
822  if has_parent \
823  else 'my_ptr = '
824 
825  methods_wrap += textwrap.indent(textwrap.dedent('''\
826  elseif nargin == {len}{varargin}
827  {ptr}{wrapper}({num}{comma}{var_arg});
828  ''').format(len=len(ctor.args.args_list),
829  varargin=self._wrap_variable_arguments(
830  ctor.args, False),
831  ptr=wrapper_return,
832  wrapper=self._wrapper_name(),
833  num=self._update_wrapper_id(
834  (namespace_name, inst_class, 'constructor', ctor)),
835  comma='' if len(ctor.args.args_list) == 0 else ', ',
836  var_arg=self._wrap_list_variable_arguments(ctor.args)),
837  prefix=' ')
838 
839  base_obj = ''
840 
841  if has_parent:
842  self._debug("class: {} ns: {}".format(
843  parent_name,
844  self._format_class_name(inst_class.parent, separator=".")))
845 
846  if has_parent:
847  base_obj = ' obj = obj@{parent_name}(uint64(5139824614673773682), base_ptr);'.format(
848  parent_name=parent_name)
849 
850  if base_obj:
851  base_obj = '\n' + base_obj
852 
853  self._debug("class: {}, name: {}".format(
854  inst_class.name, self._format_class_name(inst_class,
855  separator=".")))
856  methods_wrap += textwrap.indent(textwrap.dedent('''\
857  else
858  error('Arguments do not match any overload of {class_name_doc} constructor');
859  end{base_obj}
860  obj.ptr_{class_name} = my_ptr;
861  end\n
862  ''').format(namespace=namespace_name,
863  d='' if namespace_name == '' else '.',
864  class_name_doc=self._format_class_name(inst_class,
865  separator="."),
866  class_name=self._format_class_name(inst_class,
867  separator=""),
868  base_obj=base_obj),
869  prefix=' ')
870 
871  return methods_wrap
872 
873  def wrap_class_properties(self, class_name):
874  """Generate properties of class."""
875  return textwrap.dedent('''\
876  properties
877  ptr_{} = 0
878  end
879  ''').format(class_name)
880 
881  def wrap_class_deconstructor(self, namespace_name, inst_class):
882  """Generate the delete function for the Matlab class."""
883  class_name = inst_class.name
884 
885  methods_text = textwrap.indent(textwrap.dedent("""\
886  function delete(obj)
887  {wrapper}({num}, obj.ptr_{class_name});
888  end\n
889  """).format(num=self._update_wrapper_id(
890  (namespace_name, inst_class, 'deconstructor', None)),
891  wrapper=self._wrapper_name(),
892  class_name="".join(inst_class.parent.full_namespaces()) +
893  class_name),
894  prefix=' ')
895 
896  return methods_text
897 
899  """Generate the display function for the Matlab class."""
900  return textwrap.indent(textwrap.dedent("""\
901  function display(obj), obj.print(''); end
902  %DISPLAY Calls print on the object
903  function disp(obj), obj.display; end
904  %DISP Calls print on the object
905  """),
906  prefix=' ')
907 
908  def _group_class_methods(self, methods):
909  """Group overloaded methods together"""
910  method_map = {}
911  method_out = []
912 
913  for method in methods:
914  method_index = method_map.get(method.name)
915 
916  if method_index is None:
917  method_map[method.name] = len(method_out)
918  method_out.append([method])
919  else:
920  # print("[_group_methods] Merging {} with {}".format(method_index, method.name))
921  method_out[method_index].append(method)
922 
923  return method_out
924 
925  @classmethod
926  def _format_varargout(cls, return_type, return_type_formatted):
927  """Determine format of return and varargout statements"""
928  if cls._return_count(return_type) == 1:
929  varargout = '' \
930  if return_type_formatted == 'void' \
931  else 'varargout{1} = '
932  else:
933  varargout = '[ varargout{1} varargout{2} ] = '
934 
935  return varargout
936 
937  def wrap_class_methods(self,
938  namespace_name,
939  inst_class,
940  methods,
941  serialize=(False,)):
942  """Wrap the methods in the class.
943 
944  Args:
945  namespace_name: the name of the class's namespace
946  inst_class: the instantiated class whose methods to wrap
947  methods: the methods to wrap in the order to wrap them
948  serialize: mutable param storing if one of the methods is serialize
949  """
950  method_text = ''
951 
952  methods = self._group_class_methods(methods)
953 
954  # Convert to list so that it is mutable
955  if isinstance(serialize, tuple):
956  serialize = list(serialize)
957 
958  for method in methods:
959  method_name = method[0].name
960  if method_name in self.whitelist and method_name != 'serialize':
961  continue
962  if method_name in self.ignore_methods:
963  continue
964 
965  if method_name == 'serialize':
966  serialize[0] = True
967  method_text += self.wrap_class_serialize_method(
968  namespace_name, inst_class)
969  else:
970  # Generate method code
971  method_text += textwrap.indent(textwrap.dedent("""\
972  function varargout = {method_name}(this, varargin)
973  """).format(caps_name=method_name.upper(),
974  method_name=method_name),
975  prefix='')
976 
977  for overload in method:
978  method_text += textwrap.indent(textwrap.dedent("""\
979  % {caps_name} usage: {method_name}(""").format(
980  caps_name=method_name.upper(),
981  method_name=method_name),
982  prefix=' ')
983 
984  # Determine format of return and varargout statements
985  return_type_formatted = self._format_return_type(
986  overload.return_type,
987  include_namespace=True,
988  separator=".")
989  varargout = self._format_varargout(overload.return_type,
990  return_type_formatted)
991 
992  check_statement = self._wrap_method_check_statement(
993  overload.args)
994  class_name = namespace_name + ('' if namespace_name == ''
995  else '.') + inst_class.name
996 
997  end_statement = '' \
998  if check_statement == '' \
999  else textwrap.indent(textwrap.dedent("""\
1000  return
1001  end
1002  """).format(
1003  class_name=class_name,
1004  method_name=overload.original.name), prefix=' ')
1005 
1006  method_text += textwrap.dedent("""\
1007  {method_args}) : returns {return_type}
1008  % Doxygen can be found at https://gtsam.org/doxygen/
1009  {check_statement}{spacing}{varargout}{wrapper}({num}, this, varargin{{:}});
1010  {end_statement}""").format(
1011  method_args=self._wrap_args(overload.args),
1012  return_type=return_type_formatted,
1013  num=self._update_wrapper_id(
1014  (namespace_name, inst_class,
1015  overload.original.name, overload)),
1016  check_statement=check_statement,
1017  spacing='' if check_statement == '' else ' ',
1018  varargout=varargout,
1019  wrapper=self._wrapper_name(),
1020  end_statement=end_statement)
1021 
1022  final_statement = textwrap.indent(textwrap.dedent("""\
1023  error('Arguments do not match any overload of function {class_name}.{method_name}');
1024  """.format(class_name=class_name, method_name=method_name)),
1025  prefix=' ')
1026  method_text += final_statement + 'end\n\n'
1027 
1028  return method_text
1029 
1030  def wrap_static_methods(self, namespace_name, instantiated_class,
1031  serialize):
1032  """
1033  Wrap the static methods in the class.
1034  """
1035  class_name = instantiated_class.name
1036 
1037  method_text = 'methods(Static = true)\n'
1038  static_methods = sorted(instantiated_class.static_methods,
1039  key=lambda name: name.name)
1040 
1041  static_methods = self._group_class_methods(static_methods)
1042 
1043  for static_method in static_methods:
1044  format_name = list(static_method[0].name)
1045  format_name[0] = format_name[0].upper()
1046 
1047  if static_method[0].name in self.ignore_methods:
1048  continue
1049 
1050  method_text += textwrap.indent(textwrap.dedent('''\
1051  function varargout = {name}(varargin)
1052  '''.format(name=''.join(format_name))),
1053  prefix=" ")
1054 
1055  for static_overload in static_method:
1056  check_statement = self._wrap_method_check_statement(
1057  static_overload.args)
1058 
1059  end_statement = '' \
1060  if check_statement == '' \
1061  else textwrap.indent(textwrap.dedent("""
1062  return
1063  end
1064  """), prefix='')
1065  method_text += textwrap.indent(textwrap.dedent('''\
1066  % {name_caps} usage: {name_upper_case}({args}) : returns {return_type}
1067  % Doxygen can be found at https://gtsam.org/doxygen/
1068  {check_statement}{spacing}varargout{{1}} = {wrapper}({id}, varargin{{:}});{end_statement}
1069  ''').format(
1070  name=''.join(format_name),
1071  name_caps=static_overload.name.upper(),
1072  name_upper_case=static_overload.name,
1073  args=self._wrap_args(static_overload.args),
1074  return_type=self._format_return_type(
1075  static_overload.return_type,
1076  include_namespace=True,
1077  separator="."),
1078  length=len(static_overload.args.args_list),
1079  var_args_list=self._wrap_variable_arguments(
1080  static_overload.args),
1081  check_statement=check_statement,
1082  spacing='' if check_statement == '' else ' ',
1083  wrapper=self._wrapper_name(),
1084  id=self._update_wrapper_id(
1085  (namespace_name, instantiated_class,
1086  static_overload.name, static_overload)),
1087  class_name=instantiated_class.name,
1088  end_statement=end_statement),
1089  prefix=' ')
1090 
1091  #TODO Figure out what is static_overload doing here.
1092  method_text += textwrap.indent(textwrap.dedent("""\
1093  error('Arguments do not match any overload of function {class_name}.{method_name}');
1094  """.format(class_name=class_name,
1095  method_name=static_overload.name)),
1096  prefix=' ')
1097 
1098  method_text += textwrap.indent(textwrap.dedent("""\
1099  end\n
1100  """), prefix=" ")
1101 
1102  if serialize:
1103  method_text += textwrap.indent(textwrap.dedent("""\
1104  function varargout = string_deserialize(varargin)
1105  % STRING_DESERIALIZE usage: string_deserialize() : returns {class_name}
1106  % Doxygen can be found at https://gtsam.org/doxygen/
1107  if length(varargin) == 1
1108  varargout{{1}} = {wrapper}({id}, varargin{{:}});
1109  else
1110  error('Arguments do not match any overload of function {class_name}.string_deserialize');
1111  end
1112  end\n
1113  function obj = loadobj(sobj)
1114  % LOADOBJ Saves the object to a matlab-readable format
1115  obj = {class_name}.string_deserialize(sobj);
1116  end
1117  """).format(
1118  class_name=namespace_name + '.' + instantiated_class.name,
1119  wrapper=self._wrapper_name(),
1120  id=self._update_wrapper_id(
1121  (namespace_name, instantiated_class, 'string_deserialize',
1122  'deserialize'))),
1123  prefix=' ')
1124 
1125  return method_text
1126 
1127  def wrap_instantiated_class(self, instantiated_class, namespace_name=''):
1128  """Generate comments and code for given class.
1129 
1130  Args:
1131  instantiated_class: template_instantiator.InstantiatedClass
1132  instance storing the class to wrap
1133  namespace_name: the name of the namespace if there is one
1134  """
1135  file_name = self._clean_class_name(instantiated_class)
1136  namespace_file_name = namespace_name + file_name
1137 
1138  uninstantiated_name = "::".join(instantiated_class.namespaces()
1139  [1:]) + "::" + instantiated_class.name
1140  if uninstantiated_name in self.ignore_classes:
1141  return None
1142 
1143  # Class comment
1144  content_text = self.class_comment(instantiated_class)
1145  content_text += self.wrap_methods(instantiated_class.methods)
1146 
1147  # Class definition
1148  # if namespace_name:
1149  # print("nsname: {}, file_name_: {}, filename: {}"
1150  # .format(namespace_name,
1151  # self._clean_class_name(instantiated_class), file_name)
1152  # , file=sys.stderr)
1153  content_text += 'classdef {class_name} < {parent}\n'.format(
1154  class_name=file_name,
1155  parent=str(self._qualified_name(
1156  instantiated_class.parent_class)).replace("::", "."))
1157 
1158  # Class properties
1159  content_text += ' ' + reduce(
1160  self._insert_spaces,
1161  self.wrap_class_properties(
1162  namespace_file_name).splitlines()) + '\n'
1163 
1164  # Class constructor
1165  content_text += ' ' + reduce(
1166  self._insert_spaces,
1168  namespace_name,
1169  instantiated_class,
1170  instantiated_class.parent_class,
1171  instantiated_class.ctors,
1172  instantiated_class.is_virtual,
1173  ).splitlines()) + '\n'
1174 
1175  # Delete function
1176  content_text += ' ' + reduce(
1177  self._insert_spaces,
1179  namespace_name, instantiated_class).splitlines()) + '\n'
1180 
1181  # Display function
1182  content_text += ' ' + reduce(
1183  self._insert_spaces,
1184  self.wrap_class_display().splitlines()) + '\n'
1185 
1186  # Class methods
1187  serialize = [False]
1188 
1189  if len(instantiated_class.methods) != 0:
1190  methods = sorted(instantiated_class.methods,
1191  key=lambda name: name.name)
1192  class_methods_wrapped = self.wrap_class_methods(
1193  namespace_name,
1194  instantiated_class,
1195  methods,
1196  serialize=serialize).splitlines()
1197  if len(class_methods_wrapped) > 0:
1198  content_text += ' ' + reduce(
1199  lambda x, y: x + '\n' + ('' if y == '' else ' ') + y,
1200  class_methods_wrapped) + '\n'
1201 
1202  # Static class methods
1203  content_text += ' end\n\n ' + reduce(
1204  self._insert_spaces,
1205  self.wrap_static_methods(namespace_name, instantiated_class,
1206  serialize[0]).splitlines()) + '\n'
1207 
1208  content_text += textwrap.dedent('''\
1209  end
1210  end
1211  ''')
1212 
1213  return file_name + '.m', content_text
1214 
1215  def wrap_namespace(self, namespace, parent=()):
1216  """Wrap a namespace by wrapping all of its components.
1217 
1218  Args:
1219  namespace: the interface_parser.namespace instance of the namespace
1220  parent: parent namespace
1221  """
1222  test_output = ''
1223  namespaces = namespace.full_namespaces()
1224  inner_namespace = namespace.name != ''
1225  wrapped = []
1226  self._debug("wrapping ns: {}, parent: {}".format(
1227  namespace.full_namespaces(), parent))
1228 
1229  matlab_wrapper = self.generate_matlab_wrapper()
1230  self.content.append((matlab_wrapper[0], matlab_wrapper[1]))
1231 
1232  current_scope = []
1233  namespace_scope = []
1234 
1235  for element in namespace.content:
1236  if isinstance(element, parser.Include):
1237  self._add_include(element)
1238  elif isinstance(element, parser.Namespace):
1239  self.wrap_namespace(element, namespaces)
1240  elif isinstance(element, instantiator.InstantiatedClass):
1241  self._add_class(element)
1242 
1243  if inner_namespace:
1244  class_text = self.wrap_instantiated_class(
1245  element, "".join(namespace.full_namespaces()))
1246 
1247  if not class_text is None:
1248  namespace_scope.append(("".join([
1249  '+' + x + '/'
1250  for x in namespace.full_namespaces()[1:]
1251  ])[:-1], [(class_text[0], class_text[1])]))
1252  else:
1253  class_text = self.wrap_instantiated_class(element)
1254  current_scope.append((class_text[0], class_text[1]))
1255 
1256  self.content.extend(current_scope)
1257 
1258  if inner_namespace:
1259  self.content.append(namespace_scope)
1260 
1261  # Global functions
1262  all_funcs = [
1263  func for func in namespace.content
1264  if isinstance(func, parser.GlobalFunction)
1265  ]
1266 
1267  test_output += self.wrap_methods(all_funcs, True, global_ns=namespace)
1268 
1269  return wrapped
1270 
1272  return_type_name,
1273  shared_obj,
1274  func_id,
1275  new_line=True):
1276  """Wrap the collector function which returns a shared pointer."""
1277  new_line = '\n' if new_line else ''
1278 
1279  return textwrap.indent(textwrap.dedent('''\
1280  {{
1281  boost::shared_ptr<{name}> shared({shared_obj});
1282  out[{id}] = wrap_shared_ptr(shared,"{name}");
1283  }}{new_line}''').format(name=self._format_type_name(
1284  return_type_name, include_namespace=False),
1285  shared_obj=shared_obj,
1286  id=func_id,
1287  new_line=new_line),
1288  prefix=' ')
1289 
1290  def wrap_collector_function_return_types(self, return_type, func_id):
1291  """
1292  Wrap the return type of the collector function.
1293  """
1294  return_type_text = ' out[' + str(func_id) + '] = '
1295  pair_value = 'first' if func_id == 0 else 'second'
1296  new_line = '\n' if func_id == 0 else ''
1297 
1298  if self._is_shared_ptr(return_type) or self._is_ptr(return_type):
1299  shared_obj = 'pairResult.' + pair_value
1300 
1301  if not (return_type.is_shared_ptr or return_type.is_ptr):
1302  shared_obj = 'boost::make_shared<{name}>({shared_obj})' \
1303  .format(name=self._format_type_name(return_type.typename),
1304  shared_obj='pairResult.' + pair_value)
1305 
1306  if return_type.typename.name in self.ignore_namespace:
1307  return_type_text = self.wrap_collector_function_shared_return(
1308  return_type.typename, shared_obj, func_id, func_id == 0)
1309  else:
1310  return_type_text += 'wrap_shared_ptr({0},"{1}", false);{new_line}' \
1311  .format(shared_obj,
1312  self._format_type_name(return_type.typename,
1313  separator='.'),
1314  new_line=new_line)
1315  else:
1316  return_type_text += 'wrap< {0} >(pairResult.{1});{2}'.format(
1317  self._format_type_name(return_type.typename, separator='.'),
1318  pair_value, new_line)
1319 
1320  return return_type_text
1321 
1323  """
1324  Wrap the complete return type of the function.
1325  """
1326  expanded = ''
1327 
1328  params = self._wrapper_unwrap_arguments(method.args, arg_id=1)[0]
1329 
1330  return_1 = method.return_type.type1
1331  return_count = self._return_count(method.return_type)
1332  return_1_name = method.return_type.type1.typename.name
1333  obj_start = ''
1334 
1335  if isinstance(method, instantiator.InstantiatedMethod):
1336  # method_name = method.original.name
1337  method_name = method.to_cpp()
1338  obj_start = 'obj->'
1339 
1340  if method.instantiations:
1341  # method_name += '<{}>'.format(
1342  # self._format_type_name(method.instantiations))
1343  # method_name = self._format_instance_method(method, '::')
1344  method = method.to_cpp()
1345 
1346  elif isinstance(method, parser.GlobalFunction):
1347  method_name = self._format_global_method(method, '::')
1348  method_name += method.name
1349 
1350  else:
1351  if isinstance(method.parent, instantiator.InstantiatedClass):
1352  method_name = method.parent.cpp_class() + "::"
1353  else:
1354  method_name = self._format_static_method(method, '::')
1355  method_name += method.name
1356 
1357  if "MeasureRange" in method_name:
1358  self._debug("method: {}, method: {}, inst: {}".format(
1359  method_name, method.name, method.parent.cpp_class()))
1360 
1361  obj = ' ' if return_1_name == 'void' else ''
1362  obj += '{}{}({})'.format(obj_start, method_name, params)
1363 
1364  if return_1_name != 'void':
1365  if return_count == 1:
1366  if self._is_shared_ptr(return_1) or self._is_ptr(return_1):
1367  sep_method_name = partial(self._format_type_name,
1368  return_1.typename,
1369  include_namespace=True)
1370 
1371  if return_1.typename.name in self.ignore_namespace:
1372  expanded += self.wrap_collector_function_shared_return(
1373  return_1.typename, obj, 0, new_line=False)
1374 
1375  if return_1.is_shared_ptr or return_1.is_ptr:
1376  shared_obj = '{obj},"{method_name_sep}"'.format(
1377  obj=obj, method_name_sep=sep_method_name('.'))
1378  else:
1379  self._debug("Non-PTR: {}, {}".format(
1380  return_1, type(return_1)))
1381  self._debug("Inner type is: {}, {}".format(
1382  return_1.typename.name, sep_method_name('.')))
1383  self._debug("Inner type instantiations: {}".format(
1384  return_1.typename.instantiations))
1385  method_name_sep_dot = sep_method_name('.')
1386  shared_obj_template = 'boost::make_shared<{method_name_sep_col}>({obj}),' \
1387  '"{method_name_sep_dot}"'
1388  shared_obj = shared_obj_template \
1389  .format(method_name_sep_col=sep_method_name(),
1390  method_name_sep_dot=method_name_sep_dot,
1391  obj=obj)
1392 
1393  if return_1.typename.name not in self.ignore_namespace:
1394  expanded += textwrap.indent(
1395  'out[0] = wrap_shared_ptr({}, false);'.format(
1396  shared_obj),
1397  prefix=' ')
1398  else:
1399  expanded += ' out[0] = wrap< {} >({});'.format(
1400  return_1.typename.name, obj)
1401  elif return_count == 2:
1402  return_2 = method.return_type.type2
1403 
1404  expanded += ' auto pairResult = {};\n'.format(obj)
1405  expanded += self.wrap_collector_function_return_types(
1406  return_1, 0)
1407  expanded += self.wrap_collector_function_return_types(
1408  return_2, 1)
1409  else:
1410  expanded += obj + ';'
1411 
1412  return expanded
1413 
1414  def wrap_collector_function_upcast_from_void(self, class_name, func_id,
1415  cpp_name):
1416  """
1417  Add function to upcast type from void type.
1418  """
1419  return textwrap.dedent('''\
1420  void {class_name}_upcastFromVoid_{id}(int nargout, mxArray *out[], int nargin, const mxArray *in[]) {{
1421  mexAtExit(&_deleteAllObjects);
1422  typedef boost::shared_ptr<{cpp_name}> Shared;
1423  boost::shared_ptr<void> *asVoid = *reinterpret_cast<boost::shared_ptr<void>**> (mxGetData(in[0]));
1424  out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
1425  Shared *self = new Shared(boost::static_pointer_cast<{cpp_name}>(*asVoid));
1426  *reinterpret_cast<Shared**>(mxGetData(out[0])) = self;
1427  }}\n
1428  ''').format(class_name=class_name, cpp_name=cpp_name, id=func_id)
1429 
1430  def generate_collector_function(self, func_id):
1431  """
1432  Generate the complete collector function.
1433  """
1434  collector_func = self.wrapper_map.get(func_id)
1435 
1436  if collector_func is None:
1437  return ''
1438 
1439  method_name = collector_func[3]
1440 
1441  collector_function = "void {}" \
1442  "(int nargout, mxArray *out[], int nargin, const mxArray *in[])\n".format(method_name)
1443 
1444  if isinstance(collector_func[1], instantiator.InstantiatedClass):
1445  body = '{\n'
1446 
1447  extra = collector_func[4]
1448 
1449  class_name = collector_func[0] + collector_func[1].name
1450  class_name_separated = collector_func[1].cpp_class()
1451  is_method = isinstance(extra, parser.Method)
1452  is_static_method = isinstance(extra, parser.StaticMethod)
1453 
1454  if collector_func[2] == 'collectorInsertAndMakeBase':
1455  body += textwrap.indent(textwrap.dedent('''\
1456  mexAtExit(&_deleteAllObjects);
1457  typedef boost::shared_ptr<{class_name_sep}> Shared;\n
1458  Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0]));
1459  collector_{class_name}.insert(self);
1460  ''').format(class_name_sep=class_name_separated,
1461  class_name=class_name),
1462  prefix=' ')
1463 
1464  if collector_func[1].parent_class:
1465  body += textwrap.indent(textwrap.dedent('''
1466  typedef boost::shared_ptr<{}> SharedBase;
1467  out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
1468  *reinterpret_cast<SharedBase**>(mxGetData(out[0])) = new SharedBase(*self);
1469  ''').format(collector_func[1].parent_class),
1470  prefix=' ')
1471  elif collector_func[2] == 'constructor':
1472  base = ''
1473  params, body_args = self._wrapper_unwrap_arguments(
1474  extra.args, constructor=True)
1475 
1476  if collector_func[1].parent_class:
1477  base += textwrap.indent(textwrap.dedent('''
1478  typedef boost::shared_ptr<{}> SharedBase;
1479  out[1] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
1480  *reinterpret_cast<SharedBase**>(mxGetData(out[1])) = new SharedBase(*self);
1481  ''').format(collector_func[1].parent_class),
1482  prefix=' ')
1483 
1484  body += textwrap.dedent('''\
1485  mexAtExit(&_deleteAllObjects);
1486  typedef boost::shared_ptr<{class_name_sep}> Shared;\n
1487  {body_args} Shared *self = new Shared(new {class_name_sep}({params}));
1488  collector_{class_name}.insert(self);
1489  out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL);
1490  *reinterpret_cast<Shared**> (mxGetData(out[0])) = self;
1491  {base}''').format(class_name_sep=class_name_separated,
1492  body_args=body_args,
1493  params=params,
1494  class_name=class_name,
1495  base=base)
1496  elif collector_func[2] == 'deconstructor':
1497  body += textwrap.indent(textwrap.dedent('''\
1498  typedef boost::shared_ptr<{class_name_sep}> Shared;
1499  checkArguments("delete_{class_name}",nargout,nargin,1);
1500  Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0]));
1501  Collector_{class_name}::iterator item;
1502  item = collector_{class_name}.find(self);
1503  if(item != collector_{class_name}.end()) {{
1504  delete self;
1505  collector_{class_name}.erase(item);
1506  }}
1507  ''').format(class_name_sep=class_name_separated,
1508  class_name=class_name),
1509  prefix=' ')
1510  elif extra == 'serialize':
1511  body += self.wrap_collector_function_serialize(
1512  collector_func[1].name,
1513  full_name=collector_func[1].cpp_class(),
1514  namespace=collector_func[0])
1515  elif extra == 'deserialize':
1517  collector_func[1].name,
1518  full_name=collector_func[1].cpp_class(),
1519  namespace=collector_func[0])
1520  elif is_method or is_static_method:
1521  method_name = ''
1522 
1523  if is_static_method:
1524  method_name = self._format_static_method(extra) + '.'
1525 
1526  method_name += extra.name
1527 
1528  # return_type = extra.return_type
1529  # return_count = self._return_count(return_type)
1530 
1531  return_body = self.wrap_collector_function_return(extra)
1532  params, body_args = self._wrapper_unwrap_arguments(
1533  extra.args, arg_id=1 if is_method else 0)
1534 
1535  shared_obj = ''
1536 
1537  if is_method:
1538  shared_obj = ' auto obj = unwrap_shared_ptr<{class_name_sep}>' \
1539  '(in[0], "ptr_{class_name}");\n'.format(
1540  class_name_sep=class_name_separated,
1541  class_name=class_name)
1542 
1543  body += ' checkArguments("{method_name}",nargout,nargin{min1},' \
1544  '{num_args});\n' \
1545  '{shared_obj}' \
1546  '{body_args}' \
1547  '{return_body}\n'.format(
1548  min1='-1' if is_method else '',
1549  shared_obj=shared_obj,
1550  method_name=method_name,
1551  num_args=len(extra.args.args_list),
1552  body_args=body_args,
1553  return_body=return_body)
1554 
1555  body += '}\n'
1556 
1557  if extra not in ['serialize', 'deserialize']:
1558  body += '\n'
1559 
1560  collector_function += body
1561 
1562  else:
1563  body = textwrap.dedent('''\
1564  {{
1565  checkArguments("{function_name}",nargout,nargin,{len});
1566  ''').format(function_name=collector_func[1].name,
1567  id=self.global_function_id,
1568  len=len(collector_func[1].args.args_list))
1569 
1570  body += self._wrapper_unwrap_arguments(collector_func[1].args)[1]
1571  body += self.wrap_collector_function_return(collector_func[1]) + '\n}\n'
1572 
1573  collector_function += body
1574 
1575  self.global_function_id += 1
1576 
1577  return collector_function
1578 
1579  def mex_function(self):
1580  """
1581  Generate the wrapped MEX function.
1582  """
1583  cases = ''
1584  next_case = None
1585 
1586  for wrapper_id in range(self.wrapper_id):
1587  id_val = self.wrapper_map.get(wrapper_id)
1588  set_next_case = False
1589 
1590  if id_val is None:
1591  id_val = self.wrapper_map.get(wrapper_id + 1)
1592 
1593  if id_val is None:
1594  continue
1595 
1596  set_next_case = True
1597 
1598  cases += textwrap.indent(textwrap.dedent('''\
1599  case {}:
1600  {}(nargout, out, nargin-1, in+1);
1601  break;
1602  ''').format(wrapper_id, next_case if next_case else id_val[3]),
1603  prefix=' ')
1604 
1605  if set_next_case:
1606  next_case = '{}_upcastFromVoid_{}'.format(
1607  id_val[1].name, wrapper_id + 1)
1608  else:
1609  next_case = None
1610 
1611  mex_function = textwrap.dedent('''
1612  void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[])
1613  {{
1614  mstream mout;
1615  std::streambuf *outbuf = std::cout.rdbuf(&mout);\n
1616  _{module_name}_RTTIRegister();\n
1617  int id = unwrap<int>(in[0]);\n
1618  try {{
1619  switch(id) {{
1620  {cases} }}
1621  }} catch(const std::exception& e) {{
1622  mexErrMsgTxt(("Exception from gtsam:\\n" + std::string(e.what()) + "\\n").c_str());
1623  }}\n
1624  std::cout.rdbuf(outbuf);
1625  }}
1626  ''').format(module_name=self.module_name, cases=cases)
1627 
1628  return mex_function
1629 
1630  def generate_wrapper(self, namespace):
1631  """Generate the c++ wrapper."""
1632  # Includes
1633  wrapper_file = self.wrapper_file_header + textwrap.dedent("""
1634  #include <boost/archive/text_iarchive.hpp>
1635  #include <boost/archive/text_oarchive.hpp>
1636  #include <boost/serialization/export.hpp>\n
1637  """)
1638 
1639  assert namespace
1640 
1641  includes_list = sorted(list(self.includes.keys()),
1642  key=lambda include: include.header)
1643 
1644  # Check the number of includes.
1645  # If no includes, do nothing, if 1 then just append newline.
1646  # if more than one, concatenate them with newlines.
1647  if len(includes_list) == 0:
1648  pass
1649  elif len(includes_list) == 1:
1650  wrapper_file += (str(includes_list[0]) + '\n')
1651  else:
1652  wrapper_file += reduce(lambda x, y: str(x) + '\n' + str(y),
1653  includes_list)
1654  wrapper_file += '\n'
1655 
1656  typedef_instances = '\n'
1657  typedef_collectors = ''
1658  boost_class_export_guid = ''
1659  delete_objs = textwrap.dedent('''\
1660  void _deleteAllObjects()
1661  {
1662  mstream mout;
1663  std::streambuf *outbuf = std::cout.rdbuf(&mout);\n
1664  bool anyDeleted = false;
1665  ''')
1666  rtti_reg_start = textwrap.dedent('''\
1667  void _{module_name}_RTTIRegister() {{
1668  const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_{module_name}_rttiRegistry_created");
1669  if(!alreadyCreated) {{
1670  std::map<std::string, std::string> types;
1671  ''').format(module_name=self.module_name)
1672  rtti_reg_mid = ''
1673  rtti_reg_end = textwrap.indent(
1674  textwrap.dedent('''
1675  mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry");
1676  if(!registry)
1677  registry = mxCreateStructMatrix(1, 1, 0, NULL);
1678  typedef std::pair<std::string, std::string> StringPair;
1679  for(const StringPair& rtti_matlab: types) {
1680  int fieldId = mxAddField(registry, rtti_matlab.first.c_str());
1681  if(fieldId < 0)
1682  mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
1683  mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str());
1684  mxSetFieldByNumber(registry, 0, fieldId, matlabName);
1685  }
1686  if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0)
1687  mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
1688  mxDestroyArray(registry);
1689  '''),
1690  prefix=' ') + ' \n' + textwrap.dedent('''\
1691  mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL);
1692  if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0)
1693  mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly");
1694  mxDestroyArray(newAlreadyCreated);
1695  }
1696  }
1697  ''')
1698  ptr_ctor_frag = ''
1699 
1700  for cls in self.classes:
1701  uninstantiated_name = "::".join(
1702  cls.namespaces()[1:]) + "::" + cls.name
1703  self._debug("Cls: {} -> {}".format(cls.name, uninstantiated_name))
1704 
1705  if uninstantiated_name in self.ignore_classes:
1706  self._debug("Ignoring: {} -> {}".format(
1707  cls.name, uninstantiated_name))
1708  continue
1709 
1710  def _has_serialization(cls):
1711  for m in cls.methods:
1712  if m.name in self.whitelist:
1713  return True
1714  return False
1715 
1716  if cls.instantiations:
1717  cls_insts = ''
1718 
1719  for i, inst in enumerate(cls.instantiations):
1720  if i != 0:
1721  cls_insts += ', '
1722 
1723  cls_insts += self._format_type_name(inst)
1724 
1725  typedef_instances += 'typedef {original_class_name} {class_name_sep};\n' \
1726  .format(original_class_name=cls.cpp_class(),
1727  class_name_sep=cls.name)
1728 
1729  class_name_sep = cls.name
1730  class_name = self._format_class_name(cls)
1731 
1732  if len(cls.original.namespaces()) > 1 and _has_serialization(
1733  cls):
1734  boost_class_export_guid += 'BOOST_CLASS_EXPORT_GUID({}, "{}");\n'.format(
1735  class_name_sep, class_name)
1736  else:
1737  class_name_sep = cls.cpp_class()
1738  class_name = self._format_class_name(cls)
1739 
1740  if len(cls.original.namespaces()) > 1 and _has_serialization(
1741  cls):
1742  boost_class_export_guid += 'BOOST_CLASS_EXPORT_GUID({}, "{}");\n'.format(
1743  class_name_sep, class_name)
1744 
1745  typedef_collectors += textwrap.dedent('''\
1746  typedef std::set<boost::shared_ptr<{class_name_sep}>*> Collector_{class_name};
1747  static Collector_{class_name} collector_{class_name};
1748  ''').format(class_name_sep=class_name_sep, class_name=class_name)
1749  delete_objs += textwrap.indent(textwrap.dedent('''\
1750  {{ for(Collector_{class_name}::iterator iter = collector_{class_name}.begin();
1751  iter != collector_{class_name}.end(); ) {{
1752  delete *iter;
1753  collector_{class_name}.erase(iter++);
1754  anyDeleted = true;
1755  }} }}
1756  ''').format(class_name=class_name),
1757  prefix=' ')
1758 
1759  if cls.is_virtual:
1760  rtti_reg_mid += ' types.insert(std::make_pair(typeid({}).name(), "{}"));\n' \
1761  .format(class_name_sep, class_name)
1762 
1763  set_next_case = False
1764 
1765  for idx in range(self.wrapper_id):
1766  id_val = self.wrapper_map.get(idx)
1767  queue_set_next_case = set_next_case
1768 
1769  set_next_case = False
1770 
1771  if id_val is None:
1772  id_val = self.wrapper_map.get(idx + 1)
1773 
1774  if id_val is None:
1775  continue
1776 
1777  set_next_case = True
1778 
1779  ptr_ctor_frag += self.generate_collector_function(idx)
1780 
1781  if queue_set_next_case:
1782  ptr_ctor_frag += self.wrap_collector_function_upcast_from_void(
1783  id_val[1].name, idx, id_val[1].cpp_class())
1784 
1785  wrapper_file += textwrap.dedent('''\
1786  {typedef_instances}
1787  {boost_class_export_guid}
1788  {typedefs_collectors}
1789  {delete_objs} if(anyDeleted)
1790  cout <<
1791  "WARNING: Wrap modules with variables in the workspace have been reloaded due to\\n"
1792  "calling destructors, call \'clear all\' again if you plan to now recompile a wrap\\n"
1793  "module, so that your recompiled module is used instead of the old one." << endl;
1794  std::cout.rdbuf(outbuf);
1795  }}\n
1796  {rtti_register}
1797  {pointer_constructor_fragment}{mex_function}''') \
1798  .format(typedef_instances=typedef_instances,
1799  boost_class_export_guid=boost_class_export_guid,
1800  typedefs_collectors=typedef_collectors,
1801  delete_objs=delete_objs,
1802  rtti_register=rtti_reg_start + rtti_reg_mid + rtti_reg_end,
1803  pointer_constructor_fragment=ptr_ctor_frag,
1804  mex_function=self.mex_function())
1805 
1806  self.content.append((self._wrapper_name() + '.cpp', wrapper_file))
1807 
1808  def wrap_class_serialize_method(self, namespace_name, inst_class):
1809  """
1810  Wrap the serizalize method of the class.
1811  """
1812  class_name = inst_class.name
1813  wrapper_id = self._update_wrapper_id(
1814  (namespace_name, inst_class, 'string_serialize', 'serialize'))
1815 
1816  return textwrap.dedent('''\
1817  function varargout = string_serialize(this, varargin)
1818  % STRING_SERIALIZE usage: string_serialize() : returns string
1819  % Doxygen can be found at https://gtsam.org/doxygen/
1820  if length(varargin) == 0
1821  varargout{{1}} = {wrapper}({wrapper_id}, this, varargin{{:}});
1822  else
1823  error('Arguments do not match any overload of function {class_name}.string_serialize');
1824  end
1825  end\n
1826  function sobj = saveobj(obj)
1827  % SAVEOBJ Saves the object to a matlab-readable format
1828  sobj = obj.string_serialize();
1829  end
1830  ''').format(wrapper=self._wrapper_name(),
1831  wrapper_id=wrapper_id,
1832  class_name=namespace_name + '.' + class_name)
1833 
1835  class_name,
1836  full_name='',
1837  namespace=''):
1838  """
1839  Wrap the serizalize collector function.
1840  """
1841  return textwrap.indent(textwrap.dedent("""\
1842  typedef boost::shared_ptr<{full_name}> Shared;
1843  checkArguments("string_serialize",nargout,nargin-1,0);
1844  Shared obj = unwrap_shared_ptr<{full_name}>(in[0], "ptr_{namespace}{class_name}");
1845  ostringstream out_archive_stream;
1846  boost::archive::text_oarchive out_archive(out_archive_stream);
1847  out_archive << *obj;
1848  out[0] = wrap< string >(out_archive_stream.str());
1849  """).format(class_name=class_name,
1850  full_name=full_name,
1851  namespace=namespace),
1852  prefix=' ')
1853 
1855  class_name,
1856  full_name='',
1857  namespace=''):
1858  """
1859  Wrap the deserizalize collector function.
1860  """
1861  return textwrap.indent(textwrap.dedent("""\
1862  typedef boost::shared_ptr<{full_name}> Shared;
1863  checkArguments("{namespace}{class_name}.string_deserialize",nargout,nargin,1);
1864  string serialized = unwrap< string >(in[0]);
1865  istringstream in_archive_stream(serialized);
1866  boost::archive::text_iarchive in_archive(in_archive_stream);
1867  Shared output(new {full_name}());
1868  in_archive >> *output;
1869  out[0] = wrap_shared_ptr(output,"{namespace}.{class_name}", false);
1870  """).format(class_name=class_name,
1871  full_name=full_name,
1872  namespace=namespace),
1873  prefix=' ')
1874 
1875  def wrap(self):
1876  """High level function to wrap the project."""
1877  self.wrap_namespace(self.module)
1878  self.generate_wrapper(self.module)
1879 
1880  return self.content
1881 
1882 
1883 def generate_content(cc_content, path, verbose=False):
1884  """
1885  Generate files and folders from matlab wrapper content.
1886 
1887  Args:
1888  cc_content: The content to generate formatted as
1889  (file_name, file_content) or
1890  (folder_name, [(file_name, file_content)])
1891  path: The path to the files parent folder within the main folder
1892  """
1893  def _debug(message):
1894  if not verbose:
1895  return
1896  print(message, file=sys.stderr)
1897 
1898  for c in cc_content:
1899  if isinstance(c, list):
1900  if len(c) == 0:
1901  continue
1902  _debug("c object: {}".format(c[0][0]))
1903  path_to_folder = osp.join(path, c[0][0])
1904 
1905  if not os.path.isdir(path_to_folder):
1906  try:
1907  os.makedirs(path_to_folder, exist_ok=True)
1908  except OSError:
1909  pass
1910 
1911  for sub_content in c:
1912  _debug("sub object: {}".format(sub_content[1][0][0]))
1913  generate_content(sub_content[1], path_to_folder)
1914 
1915  elif isinstance(c[1], list):
1916  path_to_folder = osp.join(path, c[0])
1917 
1918  _debug("[generate_content_global]: {}".format(path_to_folder))
1919  if not os.path.isdir(path_to_folder):
1920  try:
1921  os.makedirs(path_to_folder, exist_ok=True)
1922  except OSError:
1923  pass
1924  for sub_content in c[1]:
1925  path_to_file = osp.join(path_to_folder, sub_content[0])
1926  _debug("[generate_global_method]: {}".format(path_to_file))
1927  with open(path_to_file, 'w') as f:
1928  f.write(sub_content[1])
1929  else:
1930  path_to_file = osp.join(path, c[0])
1931 
1932  _debug("[generate_content]: {}".format(path_to_file))
1933  if not os.path.isdir(path_to_file):
1934  try:
1935  os.mkdir(path)
1936  except OSError:
1937  pass
1938 
1939  with open(path_to_file, 'w') as f:
1940  f.write(c[1])
void print(const Matrix &A, const string &s, ostream &stream)
Definition: Matrix.cpp:155
def _format_static_method(self, static_method, separator='')
def wrap_methods(self, methods, global_funcs=False, global_ns=None)
def _update_wrapper_id(self, collector_function=None, id_diff=0)
def class_comment(self, instantiated_class)
def _add_class(self, instantiated_class)
def _format_global_method(self, static_method, separator='')
def class_serialize_comment(self, class_name, static_methods)
def wrap_collector_function_shared_return(self, return_type_name, shared_obj, func_id, new_line=True)
def wrap_collector_function_serialize(self, class_name, full_name='', namespace='')
Tuple< Args..., T > append(Tuple< Args... > t, T a)
the deduction function for append_base that automatically generate the IndexRange ...
bool isinstance(handle obj)
Definition: pytypes.h:384
def wrap_class_serialize_method(self, namespace_name, inst_class)
def _wrap_variable_arguments(self, args, wrap_datatypes=True)
def wrap_global_function(self, function)
def wrap_collector_function_return_types(self, return_type, func_id)
def wrap_class_methods(self, namespace_name, inst_class, methods, serialize=(False,))
def wrap_class_constructors(self, namespace_name, inst_class, parent_name, ctors, is_virtual)
def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False)
Definition: pytypes.h:928
def _format_varargout(cls, return_type, return_type_formatted)
def wrap_collector_function_return(self, method)
def _format_class_name(self, instantiated_class, separator='')
def wrap_namespace(self, namespace, parent=())
def wrap_class_properties(self, class_name)
Definition: pytypes.h:1301
def wrap_instantiated_class(self, instantiated_class, namespace_name='')
def wrap_class_deconstructor(self, namespace_name, inst_class)
def wrap_static_methods(self, namespace_name, instantiated_class, serialize)
def wrap_collector_function_upcast_from_void(self, class_name, func_id, cpp_name)
def _clean_class_name(self, instantiated_class)
def _format_return_type(cls, return_type, include_namespace=False, separator="::")
def __init__(self, module, module_name, top_module_namespace='', ignore_classes=())
def generate_content(cc_content, path, verbose=False)
def _format_type_name(cls, type_name, separator='::', include_namespace=True, constructor=False, method=False)
size_t len(handle h)
Definition: pytypes.h:1514
def wrap_collector_function_deserialize(self, class_name, full_name='', namespace='')
def generate_collector_function(self, func_id)
Definition: pytypes.h:897
def _format_instance_method(self, instance_method, separator='')


gtsam
Author(s):
autogenerated on Sat May 8 2021 02:42:47