2 Code to use the parsed results and convert it to a format 3 that Matlab's MEX compiler can use. 12 from functools
import partial, reduce
13 from typing
import Dict, Iterable, List, Union
20 """ Wrap the given C++ code into Matlab. 23 module: the C++ module being wrapped 24 module_name: name of the C++ module being wrapped 25 top_module_namespace: C++ namespace for the top module (default '') 26 ignore_classes: A list of classes to ignore (default []) 33 'unsigned char':
'unsigned char',
45 'unsigned char':
'unsigned char',
56 whitelist = [
'serializable',
'serialize']
58 ignore_methods = [
'pickle']
60 not_check_type: list = []
62 not_ptr_type = [
'int',
'double',
'bool',
'char',
'unsigned char',
'size_t']
64 ignore_namespace = [
'Matrix',
'Vector',
'Point2',
'Point3']
68 wrapper_map: dict = {}
70 includes: Dict[parser.Include, int] = {}
72 classes: List[Union[parser.Class, instantiator.InstantiatedClass]] = []
73 classes_elems: Dict[Union[parser.Class, instantiator.InstantiatedClass], int] = {}
75 global_function_id = 0
77 content: List[str] = []
80 dir_path = osp.dirname(osp.realpath(__file__))
81 with open(osp.join(dir_path,
"matlab_wrapper.tpl"))
as f:
82 wrapper_file_header = f.read()
87 top_module_namespace=
'',
98 print(message, file=sys.stderr)
101 self.includes[include] = 0
104 if self.classes_elems.get(instantiated_class)
is None:
105 self.classes_elems[instantiated_class] = 0
106 self.classes.append(instantiated_class)
109 """Get and define wrapper ids. 111 Generates the map of id -> collector function. 114 collector_function: tuple storing info about the wrapper function 115 (namespace, class instance, function type, function name, 117 id_diff: constant to add to the id in the map 120 the current wrapper id 122 if collector_function
is not None:
123 is_instantiated_class =
isinstance(collector_function[1],
124 instantiator.InstantiatedClass)
126 if is_instantiated_class:
127 function_name = collector_function[0] + \
128 collector_function[1].name +
'_' + collector_function[2]
130 function_name = collector_function[1].name
133 collector_function[0], collector_function[1],
134 collector_function[2], function_name +
'_' +
142 return 'handle' if names ==
'' else names
145 """Insert spaces at the beginning of each line 148 x: the statement currently generated 149 y: the addition to add to the statement 151 return x +
'\n' + (
'' if y ==
'' else ' ') + y
155 Determine if the `interface_parser.Type` should be treated as a 156 shared pointer in the wrapper. 158 return arg_type.is_shared_ptr
or (
161 and arg_type.typename.name !=
'string')
165 Determine if the `interface_parser.Type` should be treated as a 166 raw pointer in the wrapper. 168 return arg_type.is_ptr
or (
171 and arg_type.typename.name !=
'string')
174 """Determine if the interface_parser.Type should be treated as a 175 reference in the wrapper. 182 """Group overloaded methods together""" 186 for method
in methods:
187 method_index = method_map.get(method.name)
189 if method_index
is None:
190 method_map[method.name] =
len(method_out)
191 method_out.append([method])
193 self.
_debug(
"[_group_methods] Merging {} with {}".format(
194 method_index, method.name))
195 method_out[method_index].
append(method)
200 """Reformatted the C++ class name to fit Matlab defined naming 203 if len(instantiated_class.ctors) != 0:
204 return instantiated_class.ctors[0].name
206 return instantiated_class.name
212 include_namespace=
True,
217 type_name: an interface_parser.Typename to reformat 218 separator: the statement to add between namespaces and typename 219 include_namespace: whether to include namespaces when reformatting 220 constructor: if the typename will be in a constructor 221 method: if the typename will be in a method 224 constructor and method cannot both be true 226 if constructor
and method:
228 'Constructor and method parameters cannot both be True')
230 formatted_type_name =
'' 231 name = type_name.name
233 if include_namespace:
234 for namespace
in type_name.namespaces:
236 formatted_type_name += namespace + separator
240 formatted_type_name += cls.data_type.get(name)
or name
242 formatted_type_name += cls.data_type_param.get(name)
or name
244 formatted_type_name += name
246 if separator ==
"::":
248 for idx
in range(
len(type_name.instantiations)):
249 template =
'{}'.format(
251 include_namespace=include_namespace,
252 constructor=constructor,
254 templates.append(template)
256 if len(templates) > 0:
257 formatted_type_name +=
'<{}>'.format(
','.join(templates))
260 for idx
in range(
len(type_name.instantiations)):
261 formatted_type_name +=
'{}'.format(
264 include_namespace=
False,
265 constructor=constructor,
268 return formatted_type_name
273 include_namespace=
False,
275 """Format return_type. 278 return_type: an interface_parser.ReturnType to reformat 279 include_namespace: whether to include namespaces when reformatting 285 return_type.type1.typename,
287 include_namespace=include_namespace)
289 return_wrap =
'pair< {type1}, {type2} >'.format(
291 return_type.type1.typename,
293 include_namespace=include_namespace),
295 return_type.type2.typename,
297 include_namespace=include_namespace))
302 """Format a template_instantiator.InstantiatedClass name.""" 303 if instantiated_class.parent ==
'':
304 parent_full_ns = [
'']
306 parent_full_ns = instantiated_class.parent.full_namespaces()
313 parentname =
"".join([separator + x
314 for x
in parent_full_ns]) + separator
316 class_name = parentname[2 *
len(separator):]
318 class_name += instantiated_class.name
325 gtsamPoint3.staticFunction 329 if isinstance(static_method, parser.StaticMethod):
330 method +=
"".join([separator + x
for x
in static_method.parent.namespaces()]) + \
331 separator + static_method.parent.name + separator
333 return method[2 *
len(separator):]
338 gtsamPoint3.staticFunction 342 if isinstance(instance_method, instantiator.InstantiatedMethod):
345 for x
in instance_method.parent.parent.full_namespaces()
347 method +=
"".join(method_list) + separator
349 method += instance_method.parent.name + separator
350 method += instance_method.original.name
351 method +=
"<" + instance_method.instantiations.to_cpp() +
">" 353 return method[2 *
len(separator):]
358 gtsamPoint3.staticFunction 362 if isinstance(static_method, parser.GlobalFunction):
363 method +=
"".join([separator + x
for x
in static_method.parent.full_namespaces()]) + \
366 return method[2 *
len(separator):]
369 """Wrap an interface_parser.ArgumentList into a list of arguments. 372 A string representation of the arguments. For example: 377 for i, arg
in enumerate(args.args_list, 1):
379 include_namespace=
False)
381 arg_wrap +=
'{c_type} {arg_name}{comma}'.format(
384 comma=
'' if i ==
len(args.args_list)
else ', ')
389 """ Wrap an interface_parser.ArgumentList into a statement of argument 393 A string representation of a variable arguments for an if 394 statement. For example: 395 ' && isa(varargin{1},'double') && isa(varargin{2},'numeric')' 399 for i, arg
in enumerate(args.args_list, 1):
400 name = arg.ctype.typename.name
401 if name
in self.not_check_type:
404 check_type = self.data_type_param.get(name)
406 if self.data_type.get(check_type):
409 if check_type
is None:
413 constructor=
not wrap_datatypes)
415 var_arg_wrap +=
" && isa(varargin{{{num}}},'{data_type}')".format(
416 num=i, data_type=check_type)
418 var_arg_wrap +=
' && size(varargin{{{num}}},2)==1'.format(
421 var_arg_wrap +=
' && size(varargin{{{num}}},1)==2'.format(
423 var_arg_wrap +=
' && size(varargin{{{num}}},2)==1'.format(
426 var_arg_wrap +=
' && size(varargin{{{num}}},1)==3'.format(
428 var_arg_wrap +=
' && size(varargin{{{num}}},2)==1'.format(
434 """ Wrap an interface_parser.ArgumentList into a list of argument 438 A string representation of a list of variable arguments. 440 'varargin{1}, varargin{2}, varargin{3}' 445 for i
in range(1,
len(args.args_list) + 1):
447 var_list_wrap +=
'varargin{{{}}}'.format(i)
450 var_list_wrap +=
', varargin{{{}}}'.format(i)
456 Wrap the given arguments into either just a varargout call or a 457 call in an if statement that checks if the parameters are accurate. 462 if check_statement ==
'':
464 'if length(varargin) == {param_count}'.format(
465 param_count=
len(args.args_list))
467 for _, arg
in enumerate(args.args_list):
468 name = arg.ctype.typename.name
470 if name
in self.not_check_type:
474 check_type = self.data_type_param.get(name)
476 if self.data_type.get(check_type):
479 if check_type
is None:
483 check_statement +=
" && isa(varargin{{{id}}},'{ctype}')".format(
484 id=arg_id, ctype=check_type)
487 check_statement +=
' && size(varargin{{{num}}},2)==1'.format(
490 check_statement +=
' && size(varargin{{{num}}},1)==2'.format(
492 check_statement +=
' && size(varargin{{{num}}},2)==1'.format(
495 check_statement +=
' && size(varargin{{{num}}},1)==3'.format(
497 check_statement +=
' && size(varargin{{{num}}},2)==1'.format(
502 check_statement = check_statement \
503 if check_statement ==
'' \
504 else check_statement +
'\n' 506 return check_statement
509 """Format the interface_parser.Arguments. 512 ((a), unsigned char a = unwrap< unsigned char >(in[1]);), 513 ((a), Test& t = *unwrap_shared_ptr< Test >(in[1], "ptr_Test");), 514 ((a), std::shared_ptr<Test> p1 = unwrap_shared_ptr< Test >(in[1], "ptr_Test");) 519 for arg
in args.args_list:
526 body_args += textwrap.indent(textwrap.dedent(
'''\ 527 {ctype}& {name} = *unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}"); 529 ctype_camel=ctype_camel,
536 if arg.ctype.is_shared_ptr:
537 call_type = arg.ctype.is_shared_ptr
539 call_type = arg.ctype.is_ptr
541 body_args += textwrap.indent(textwrap.dedent(
'''\ 542 {std_boost}::shared_ptr<{ctype_sep}> {name} = unwrap_shared_ptr< {ctype_sep} >(in[{id}], "ptr_{ctype}"); 543 '''.format(std_boost=
'boost' if constructor
else 'boost',
555 body_args += textwrap.indent(textwrap.dedent(
'''\ 556 {ctype} {name} = unwrap< {ctype} >(in[{id}]); 557 '''.format(ctype=arg.ctype.typename.name,
566 return params, body_args
570 """The amount of objects returned by the given 571 interface_parser.ReturnType. 573 return 1
if return_type.type2 ==
'' else 2
576 """Determine the name of wrapper function.""" 580 """Generate comments for serialize methods.""" 582 static_methods = sorted(static_methods, key=
lambda name: name.name)
584 for static_method
in static_methods:
585 if comment_wrap ==
'':
586 comment_wrap =
'%-------Static Methods-------\n' 588 comment_wrap +=
'%{name}({args}) : returns {return_type}\n'.format(
589 name=static_method.name,
592 include_namespace=
True))
594 comment_wrap += textwrap.dedent(
'''\ 596 %-------Serialization Interface------- 597 %string_serialize() : returns string 598 %string_deserialize(string serialized) : returns {class_name} 600 ''').format(class_name=class_name)
605 """Generate comments for the given class in Matlab. 608 instantiated_class: the class being wrapped 609 ctors: a list of the constructors in the class 610 methods: a list of the methods in the class 612 class_name = instantiated_class.name
613 ctors = instantiated_class.ctors
614 methods = instantiated_class.methods
615 static_methods = instantiated_class.static_methods
617 comment = textwrap.dedent(
'''\ 618 %class {class_name}, see Doxygen page for details 619 %at https://gtsam.org/doxygen/ 620 ''').format(class_name=class_name)
623 comment +=
'%\n%-------Constructors-------\n' 627 comment +=
'%{ctor_name}({args})\n'.format(ctor_name=ctor.name,
631 if len(methods) != 0:
633 '%-------Methods-------\n' 635 methods = sorted(methods, key=
lambda name: name.name)
638 for method
in methods:
644 comment +=
'%{name}({args})'.format(name=method.name,
648 if method.return_type.type2 ==
'':
650 method.return_type.type1.typename)
652 return_type =
'pair< {type1}, {type2} >'.format(
654 method.return_type.type1.typename),
656 method.return_type.type2.typename))
658 comment +=
' : returns {return_type}\n'.format(
659 return_type=return_type)
663 if len(static_methods) != 0:
669 """Generate the C++ file for the wrapper.""" 674 return file_name, wrapper_file
677 """Wrap methods in the body of a class.""" 688 Wrap a sequence of methods. Groups methods with the same names 690 If global_funcs is True then output every method into its own file. 695 for method
in methods:
700 self.
_debug(
"[wrap_methods] wrapping: {}..{}={}".format(
701 method[0].parent.name, method[0].name,
702 type(method[0].parent.name)))
705 self.content.append((
"".join([
706 '+' + x +
'/' for x
in global_ns.full_namespaces()[1:]
707 ])[:-1], [(method[0].name +
'.m', method_text)]))
715 """Wrap the given global function.""" 717 function = [function]
719 function_name = function[0].name
724 for i, overload
in enumerate(function):
725 param_wrap +=
' if' if i == 0
else ' elseif' 726 param_wrap +=
' length(varargin) == ' 728 if len(overload.args.args_list) == 0:
731 param_wrap +=
str(
len(overload.args.args_list)) \
736 overload.return_type, include_namespace=
True, separator=
".")
738 return_type_formatted)
740 param_wrap += textwrap.indent(textwrap.dedent(
'''\ 741 {varargout}{module_name}_wrapper({num}, varargin{{:}}); 742 ''').format(varargout=varargout,
745 collector_function=(function[0].parent.name,
746 function[i],
'global_function',
750 param_wrap += textwrap.indent(textwrap.dedent(
'''\ 752 error('Arguments do not match any overload of function {func_name}'); 753 ''').format(func_name=function_name),
756 global_function = textwrap.indent(textwrap.dedent(
'''\ 757 function varargout = {m_method}(varargin) 759 ''').format(m_method=function_name, statements=param_wrap),
762 return global_function
766 """Wrap class constructor. 769 namespace_name: the name of the namespace ('' if it does not exist) 770 inst_class: instance of the class 771 parent_name: the name of the parent class if it exists 772 ctors: the interface_parser.Constructor in the class 773 is_virtual: whether the class is part of a virtual inheritance 776 has_parent = parent_name !=
'' 777 class_name = inst_class.name
783 methods_wrap = textwrap.indent(textwrap.dedent(
"""\ 785 function obj = {class_name}(varargin) 786 """).format(class_name=class_name),
790 methods_wrap +=
" if (nargin == 2 || (nargin == 3 && strcmp(varargin{3}, 'void')))" 792 methods_wrap +=
' if nargin == 2' 794 methods_wrap +=
" && isa(varargin{1}, 'uint64')" 795 methods_wrap +=
" && varargin{1} == uint64(5139824614673773682)\n" 798 methods_wrap += textwrap.indent(textwrap.dedent(
'''\ 800 my_ptr = varargin{{2}}; 802 my_ptr = {wrapper_name}({id}, varargin{{2}}); 808 methods_wrap +=
' my_ptr = varargin{2};\n' 811 (namespace_name, inst_class,
'collectorInsertAndMakeBase',
None),
812 id_diff=-1
if is_virtual
else 0)
814 methods_wrap +=
' {ptr}{wrapper_name}({id}, my_ptr);\n' \
816 ptr=
'base_ptr = ' if has_parent
else '',
818 id=collector_base_id - (1
if is_virtual
else 0))
821 wrapper_return =
'[ my_ptr, base_ptr ] = ' \
825 methods_wrap += textwrap.indent(textwrap.dedent(
'''\ 826 elseif nargin == {len}{varargin} 827 {ptr}{wrapper}({num}{comma}{var_arg}); 828 ''').format(len=
len(ctor.args.args_list),
834 (namespace_name, inst_class,
'constructor', ctor)),
835 comma=
'' if len(ctor.args.args_list) == 0
else ', ',
842 self.
_debug(
"class: {} ns: {}".format(
847 base_obj =
' obj = obj@{parent_name}(uint64(5139824614673773682), base_ptr);'.format(
848 parent_name=parent_name)
851 base_obj =
'\n' + base_obj
853 self.
_debug(
"class: {}, name: {}".format(
856 methods_wrap += textwrap.indent(textwrap.dedent(
'''\ 858 error('Arguments do not match any overload of {class_name_doc} constructor'); 860 obj.ptr_{class_name} = my_ptr; 862 ''').format(namespace=namespace_name,
863 d=
'' if namespace_name ==
'' else '.',
874 """Generate properties of class.""" 875 return textwrap.dedent(
'''\ 879 ''').format(class_name)
882 """Generate the delete function for the Matlab class.""" 883 class_name = inst_class.name
885 methods_text = textwrap.indent(textwrap.dedent(
"""\ 887 {wrapper}({num}, obj.ptr_{class_name}); 890 (namespace_name, inst_class,
'deconstructor',
None)),
892 class_name=
"".join(inst_class.parent.full_namespaces()) +
899 """Generate the display function for the Matlab class.""" 900 return textwrap.indent(textwrap.dedent(
"""\ 901 function display(obj), obj.print(''); end 902 %DISPLAY Calls print on the object 903 function disp(obj), obj.display; end 904 %DISP Calls print on the object 909 """Group overloaded methods together""" 913 for method
in methods:
914 method_index = method_map.get(method.name)
916 if method_index
is None:
917 method_map[method.name] =
len(method_out)
918 method_out.append([method])
921 method_out[method_index].
append(method)
927 """Determine format of return and varargout statements""" 930 if return_type_formatted ==
'void' \
931 else 'varargout{1} = ' 933 varargout =
'[ varargout{1} varargout{2} ] = ' 942 """Wrap the methods in the class. 945 namespace_name: the name of the class's namespace 946 inst_class: the instantiated class whose methods to wrap 947 methods: the methods to wrap in the order to wrap them 948 serialize: mutable param storing if one of the methods is serialize 956 serialize =
list(serialize)
958 for method
in methods:
959 method_name = method[0].name
960 if method_name
in self.
whitelist and method_name !=
'serialize':
965 if method_name ==
'serialize':
968 namespace_name, inst_class)
971 method_text += textwrap.indent(textwrap.dedent(
"""\ 972 function varargout = {method_name}(this, varargin) 973 """).format(caps_name=method_name.upper(),
974 method_name=method_name),
977 for overload
in method:
978 method_text += textwrap.indent(textwrap.dedent(
"""\ 979 % {caps_name} usage: {method_name}(""").format(
980 caps_name=method_name.upper(),
981 method_name=method_name),
986 overload.return_type,
987 include_namespace=
True,
990 return_type_formatted)
994 class_name = namespace_name + (
'' if namespace_name ==
'' 995 else '.') + inst_class.name
998 if check_statement ==
'' \
999 else textwrap.indent(textwrap.dedent(
"""\ 1003 class_name=class_name,
1004 method_name=overload.original.name), prefix=
' ')
1006 method_text += textwrap.dedent(
"""\ 1007 {method_args}) : returns {return_type} 1008 % Doxygen can be found at https://gtsam.org/doxygen/ 1009 {check_statement}{spacing}{varargout}{wrapper}({num}, this, varargin{{:}}); 1010 {end_statement}""").format(
1012 return_type=return_type_formatted,
1014 (namespace_name, inst_class,
1015 overload.original.name, overload)),
1016 check_statement=check_statement,
1017 spacing=
'' if check_statement ==
'' else ' ',
1018 varargout=varargout,
1020 end_statement=end_statement)
1022 final_statement = textwrap.indent(textwrap.dedent(
"""\ 1023 error('Arguments do not match any overload of function {class_name}.{method_name}'); 1024 """.format(class_name=class_name, method_name=method_name)),
1026 method_text += final_statement +
'end\n\n' 1033 Wrap the static methods in the class. 1035 class_name = instantiated_class.name
1037 method_text =
'methods(Static = true)\n' 1038 static_methods = sorted(instantiated_class.static_methods,
1039 key=
lambda name: name.name)
1043 for static_method
in static_methods:
1044 format_name =
list(static_method[0].name)
1045 format_name[0] = format_name[0].upper()
1050 method_text += textwrap.indent(textwrap.dedent(
'''\ 1051 function varargout = {name}(varargin) 1052 '''.format(name=
''.join(format_name))),
1055 for static_overload
in static_method:
1057 static_overload.args)
1059 end_statement =
'' \
1060 if check_statement ==
'' \
1061 else textwrap.indent(textwrap.dedent(
""" 1065 method_text += textwrap.indent(textwrap.dedent(
'''\ 1066 % {name_caps} usage: {name_upper_case}({args}) : returns {return_type} 1067 % Doxygen can be found at https://gtsam.org/doxygen/ 1068 {check_statement}{spacing}varargout{{1}} = {wrapper}({id}, varargin{{:}});{end_statement} 1070 name=
''.join(format_name),
1071 name_caps=static_overload.name.upper(),
1072 name_upper_case=static_overload.name,
1075 static_overload.return_type,
1076 include_namespace=
True,
1078 length=
len(static_overload.args.args_list),
1080 static_overload.args),
1081 check_statement=check_statement,
1082 spacing=
'' if check_statement ==
'' else ' ',
1085 (namespace_name, instantiated_class,
1086 static_overload.name, static_overload)),
1087 class_name=instantiated_class.name,
1088 end_statement=end_statement),
1092 method_text += textwrap.indent(textwrap.dedent(
"""\ 1093 error('Arguments do not match any overload of function {class_name}.{method_name}'); 1094 """.format(class_name=class_name,
1095 method_name=static_overload.name)),
1098 method_text += textwrap.indent(textwrap.dedent(
"""\ 1103 method_text += textwrap.indent(textwrap.dedent(
"""\ 1104 function varargout = string_deserialize(varargin) 1105 % STRING_DESERIALIZE usage: string_deserialize() : returns {class_name} 1106 % Doxygen can be found at https://gtsam.org/doxygen/ 1107 if length(varargin) == 1 1108 varargout{{1}} = {wrapper}({id}, varargin{{:}}); 1110 error('Arguments do not match any overload of function {class_name}.string_deserialize'); 1113 function obj = loadobj(sobj) 1114 % LOADOBJ Saves the object to a matlab-readable format 1115 obj = {class_name}.string_deserialize(sobj); 1118 class_name=namespace_name +
'.' + instantiated_class.name,
1121 (namespace_name, instantiated_class,
'string_deserialize',
1128 """Generate comments and code for given class. 1131 instantiated_class: template_instantiator.InstantiatedClass 1132 instance storing the class to wrap 1133 namespace_name: the name of the namespace if there is one 1136 namespace_file_name = namespace_name + file_name
1138 uninstantiated_name =
"::".join(instantiated_class.namespaces()
1139 [1:]) +
"::" + instantiated_class.name
1145 content_text += self.
wrap_methods(instantiated_class.methods)
1153 content_text +=
'classdef {class_name} < {parent}\n'.format(
1154 class_name=file_name,
1156 instantiated_class.parent_class)).replace(
"::",
"."))
1159 content_text +=
' ' + reduce(
1162 namespace_file_name).splitlines()) +
'\n' 1165 content_text +=
' ' + reduce(
1170 instantiated_class.parent_class,
1171 instantiated_class.ctors,
1172 instantiated_class.is_virtual,
1173 ).splitlines()) +
'\n' 1176 content_text +=
' ' + reduce(
1179 namespace_name, instantiated_class).splitlines()) +
'\n' 1182 content_text +=
' ' + reduce(
1189 if len(instantiated_class.methods) != 0:
1190 methods = sorted(instantiated_class.methods,
1191 key=
lambda name: name.name)
1196 serialize=serialize).splitlines()
1197 if len(class_methods_wrapped) > 0:
1198 content_text +=
' ' + reduce(
1199 lambda x, y: x +
'\n' + (
'' if y ==
'' else ' ') + y,
1200 class_methods_wrapped) +
'\n' 1203 content_text +=
' end\n\n ' + reduce(
1206 serialize[0]).splitlines()) +
'\n' 1208 content_text += textwrap.dedent(
'''\ 1213 return file_name +
'.m', content_text
1216 """Wrap a namespace by wrapping all of its components. 1219 namespace: the interface_parser.namespace instance of the namespace 1220 parent: parent namespace 1223 namespaces = namespace.full_namespaces()
1224 inner_namespace = namespace.name !=
'' 1226 self.
_debug(
"wrapping ns: {}, parent: {}".format(
1227 namespace.full_namespaces(), parent))
1230 self.content.append((matlab_wrapper[0], matlab_wrapper[1]))
1233 namespace_scope = []
1235 for element
in namespace.content:
1240 elif isinstance(element, instantiator.InstantiatedClass):
1245 element,
"".join(namespace.full_namespaces()))
1247 if not class_text
is None:
1248 namespace_scope.append((
"".join([
1250 for x
in namespace.full_namespaces()[1:]
1251 ])[:-1], [(class_text[0], class_text[1])]))
1254 current_scope.append((class_text[0], class_text[1]))
1256 self.content.extend(current_scope)
1259 self.content.append(namespace_scope)
1263 func
for func
in namespace.content
1267 test_output += self.
wrap_methods(all_funcs,
True, global_ns=namespace)
1276 """Wrap the collector function which returns a shared pointer.""" 1277 new_line =
'\n' if new_line
else '' 1279 return textwrap.indent(textwrap.dedent(
'''\ 1281 boost::shared_ptr<{name}> shared({shared_obj}); 1282 out[{id}] = wrap_shared_ptr(shared,"{name}"); 1284 return_type_name, include_namespace=
False),
1285 shared_obj=shared_obj,
1292 Wrap the return type of the collector function. 1294 return_type_text =
' out[' +
str(func_id) +
'] = ' 1295 pair_value =
'first' if func_id == 0
else 'second' 1296 new_line =
'\n' if func_id == 0
else '' 1299 shared_obj =
'pairResult.' + pair_value
1301 if not (return_type.is_shared_ptr
or return_type.is_ptr):
1302 shared_obj =
'boost::make_shared<{name}>({shared_obj})' \
1304 shared_obj=
'pairResult.' + pair_value)
1308 return_type.typename, shared_obj, func_id, func_id == 0)
1310 return_type_text +=
'wrap_shared_ptr({0},"{1}", false);{new_line}' \
1316 return_type_text +=
'wrap< {0} >(pairResult.{1});{2}'.format(
1318 pair_value, new_line)
1320 return return_type_text
1324 Wrap the complete return type of the function. 1330 return_1 = method.return_type.type1
1332 return_1_name = method.return_type.type1.typename.name
1335 if isinstance(method, instantiator.InstantiatedMethod):
1337 method_name = method.to_cpp()
1340 if method.instantiations:
1344 method = method.to_cpp()
1346 elif isinstance(method, parser.GlobalFunction):
1348 method_name += method.name
1351 if isinstance(method.parent, instantiator.InstantiatedClass):
1352 method_name = method.parent.cpp_class() +
"::" 1355 method_name += method.name
1357 if "MeasureRange" in method_name:
1358 self.
_debug(
"method: {}, method: {}, inst: {}".format(
1359 method_name, method.name, method.parent.cpp_class()))
1361 obj =
' ' if return_1_name ==
'void' else '' 1362 obj +=
'{}{}({})'.format(obj_start, method_name, params)
1364 if return_1_name !=
'void':
1365 if return_count == 1:
1369 include_namespace=
True)
1373 return_1.typename, obj, 0, new_line=
False)
1375 if return_1.is_shared_ptr
or return_1.is_ptr:
1376 shared_obj =
'{obj},"{method_name_sep}"'.format(
1377 obj=obj, method_name_sep=sep_method_name(
'.'))
1379 self.
_debug(
"Non-PTR: {}, {}".format(
1380 return_1,
type(return_1)))
1381 self.
_debug(
"Inner type is: {}, {}".format(
1382 return_1.typename.name, sep_method_name(
'.')))
1383 self.
_debug(
"Inner type instantiations: {}".format(
1384 return_1.typename.instantiations))
1385 method_name_sep_dot = sep_method_name(
'.')
1386 shared_obj_template =
'boost::make_shared<{method_name_sep_col}>({obj}),' \
1387 '"{method_name_sep_dot}"' 1388 shared_obj = shared_obj_template \
1389 .format(method_name_sep_col=sep_method_name(),
1390 method_name_sep_dot=method_name_sep_dot,
1394 expanded += textwrap.indent(
1395 'out[0] = wrap_shared_ptr({}, false);'.format(
1399 expanded +=
' out[0] = wrap< {} >({});'.format(
1400 return_1.typename.name, obj)
1401 elif return_count == 2:
1402 return_2 = method.return_type.type2
1404 expanded +=
' auto pairResult = {};\n'.format(obj)
1410 expanded += obj +
';' 1417 Add function to upcast type from void type. 1419 return textwrap.dedent(
'''\ 1420 void {class_name}_upcastFromVoid_{id}(int nargout, mxArray *out[], int nargin, const mxArray *in[]) {{ 1421 mexAtExit(&_deleteAllObjects); 1422 typedef boost::shared_ptr<{cpp_name}> Shared; 1423 boost::shared_ptr<void> *asVoid = *reinterpret_cast<boost::shared_ptr<void>**> (mxGetData(in[0])); 1424 out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); 1425 Shared *self = new Shared(boost::static_pointer_cast<{cpp_name}>(*asVoid)); 1426 *reinterpret_cast<Shared**>(mxGetData(out[0])) = self; 1428 ''').format(class_name=class_name, cpp_name=cpp_name, id=func_id)
1432 Generate the complete collector function. 1434 collector_func = self.wrapper_map.get(func_id)
1436 if collector_func
is None:
1439 method_name = collector_func[3]
1441 collector_function =
"void {}" \
1442 "(int nargout, mxArray *out[], int nargin, const mxArray *in[])\n".format(method_name)
1444 if isinstance(collector_func[1], instantiator.InstantiatedClass):
1447 extra = collector_func[4]
1449 class_name = collector_func[0] + collector_func[1].name
1450 class_name_separated = collector_func[1].cpp_class()
1452 is_static_method =
isinstance(extra, parser.StaticMethod)
1454 if collector_func[2] ==
'collectorInsertAndMakeBase':
1455 body += textwrap.indent(textwrap.dedent(
'''\ 1456 mexAtExit(&_deleteAllObjects); 1457 typedef boost::shared_ptr<{class_name_sep}> Shared;\n 1458 Shared *self = *reinterpret_cast<Shared**> (mxGetData(in[0])); 1459 collector_{class_name}.insert(self); 1460 ''').format(class_name_sep=class_name_separated,
1461 class_name=class_name),
1464 if collector_func[1].parent_class:
1465 body += textwrap.indent(textwrap.dedent(
''' 1466 typedef boost::shared_ptr<{}> SharedBase; 1467 out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); 1468 *reinterpret_cast<SharedBase**>(mxGetData(out[0])) = new SharedBase(*self); 1469 ''').format(collector_func[1].parent_class),
1471 elif collector_func[2] ==
'constructor':
1474 extra.args, constructor=
True)
1476 if collector_func[1].parent_class:
1477 base += textwrap.indent(textwrap.dedent(
''' 1478 typedef boost::shared_ptr<{}> SharedBase; 1479 out[1] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); 1480 *reinterpret_cast<SharedBase**>(mxGetData(out[1])) = new SharedBase(*self); 1481 ''').format(collector_func[1].parent_class),
1484 body += textwrap.dedent(
'''\ 1485 mexAtExit(&_deleteAllObjects); 1486 typedef boost::shared_ptr<{class_name_sep}> Shared;\n 1487 {body_args} Shared *self = new Shared(new {class_name_sep}({params})); 1488 collector_{class_name}.insert(self); 1489 out[0] = mxCreateNumericMatrix(1, 1, mxUINT32OR64_CLASS, mxREAL); 1490 *reinterpret_cast<Shared**> (mxGetData(out[0])) = self; 1491 {base}''').format(class_name_sep=class_name_separated,
1492 body_args=body_args,
1494 class_name=class_name,
1496 elif collector_func[2] ==
'deconstructor':
1497 body += textwrap.indent(textwrap.dedent(
'''\ 1498 typedef boost::shared_ptr<{class_name_sep}> Shared; 1499 checkArguments("delete_{class_name}",nargout,nargin,1); 1500 Shared *self = *reinterpret_cast<Shared**>(mxGetData(in[0])); 1501 Collector_{class_name}::iterator item; 1502 item = collector_{class_name}.find(self); 1503 if(item != collector_{class_name}.end()) {{ 1505 collector_{class_name}.erase(item); 1507 ''').format(class_name_sep=class_name_separated,
1508 class_name=class_name),
1510 elif extra ==
'serialize':
1512 collector_func[1].name,
1513 full_name=collector_func[1].cpp_class(),
1514 namespace=collector_func[0])
1515 elif extra ==
'deserialize':
1517 collector_func[1].name,
1518 full_name=collector_func[1].cpp_class(),
1519 namespace=collector_func[0])
1520 elif is_method
or is_static_method:
1523 if is_static_method:
1526 method_name += extra.name
1533 extra.args, arg_id=1
if is_method
else 0)
1538 shared_obj =
' auto obj = unwrap_shared_ptr<{class_name_sep}>' \
1539 '(in[0], "ptr_{class_name}");\n'.format(
1540 class_name_sep=class_name_separated,
1541 class_name=class_name)
1543 body +=
' checkArguments("{method_name}",nargout,nargin{min1},' \
1547 '{return_body}\n'.format(
1548 min1=
'-1' if is_method
else '',
1549 shared_obj=shared_obj,
1550 method_name=method_name,
1551 num_args=
len(extra.args.args_list),
1552 body_args=body_args,
1553 return_body=return_body)
1557 if extra
not in [
'serialize',
'deserialize']:
1560 collector_function += body
1563 body = textwrap.dedent(
'''\ 1565 checkArguments("{function_name}",nargout,nargin,{len}); 1566 ''').format(function_name=collector_func[1].name,
1568 len=
len(collector_func[1].args.args_list))
1573 collector_function += body
1577 return collector_function
1581 Generate the wrapped MEX function. 1587 id_val = self.wrapper_map.get(wrapper_id)
1588 set_next_case =
False 1591 id_val = self.wrapper_map.get(wrapper_id + 1)
1596 set_next_case =
True 1598 cases += textwrap.indent(textwrap.dedent(
'''\ 1600 {}(nargout, out, nargin-1, in+1); 1602 ''').format(wrapper_id, next_case
if next_case
else id_val[3]),
1606 next_case =
'{}_upcastFromVoid_{}'.format(
1607 id_val[1].name, wrapper_id + 1)
1611 mex_function = textwrap.dedent(
''' 1612 void mexFunction(int nargout, mxArray *out[], int nargin, const mxArray *in[]) 1615 std::streambuf *outbuf = std::cout.rdbuf(&mout);\n 1616 _{module_name}_RTTIRegister();\n 1617 int id = unwrap<int>(in[0]);\n 1621 }} catch(const std::exception& e) {{ 1622 mexErrMsgTxt(("Exception from gtsam:\\n" + std::string(e.what()) + "\\n").c_str()); 1624 std::cout.rdbuf(outbuf); 1626 ''').format(module_name=self.
module_name, cases=cases)
1631 """Generate the c++ wrapper.""" 1634 #include <boost/archive/text_iarchive.hpp> 1635 #include <boost/archive/text_oarchive.hpp> 1636 #include <boost/serialization/export.hpp>\n 1641 includes_list = sorted(
list(self.includes.keys()),
1642 key=
lambda include: include.header)
1647 if len(includes_list) == 0:
1649 elif len(includes_list) == 1:
1650 wrapper_file += (
str(includes_list[0]) +
'\n')
1652 wrapper_file += reduce(
lambda x, y:
str(x) +
'\n' +
str(y),
1654 wrapper_file +=
'\n' 1656 typedef_instances =
'\n' 1657 typedef_collectors =
'' 1658 boost_class_export_guid =
'' 1659 delete_objs = textwrap.dedent(
'''\ 1660 void _deleteAllObjects() 1663 std::streambuf *outbuf = std::cout.rdbuf(&mout);\n 1664 bool anyDeleted = false; 1666 rtti_reg_start = textwrap.dedent(
'''\ 1667 void _{module_name}_RTTIRegister() {{ 1668 const mxArray *alreadyCreated = mexGetVariablePtr("global", "gtsam_{module_name}_rttiRegistry_created"); 1669 if(!alreadyCreated) {{ 1670 std::map<std::string, std::string> types; 1673 rtti_reg_end = textwrap.indent(
1675 mxArray *registry = mexGetVariable("global", "gtsamwrap_rttiRegistry"); 1677 registry = mxCreateStructMatrix(1, 1, 0, NULL); 1678 typedef std::pair<std::string, std::string> StringPair; 1679 for(const StringPair& rtti_matlab: types) { 1680 int fieldId = mxAddField(registry, rtti_matlab.first.c_str()); 1682 mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); 1683 mxArray *matlabName = mxCreateString(rtti_matlab.second.c_str()); 1684 mxSetFieldByNumber(registry, 0, fieldId, matlabName); 1686 if(mexPutVariable("global", "gtsamwrap_rttiRegistry", registry) != 0) 1687 mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); 1688 mxDestroyArray(registry); 1690 prefix=
' ') +
' \n' + textwrap.dedent(
'''\ 1691 mxArray *newAlreadyCreated = mxCreateNumericMatrix(0, 0, mxINT8_CLASS, mxREAL); 1692 if(mexPutVariable("global", "gtsam_geometry_rttiRegistry_created", newAlreadyCreated) != 0) 1693 mexErrMsgTxt("gtsam wrap: Error indexing RTTI types, inheritance will not work correctly"); 1694 mxDestroyArray(newAlreadyCreated); 1700 for cls
in self.classes:
1701 uninstantiated_name =
"::".join(
1702 cls.namespaces()[1:]) +
"::" + cls.name
1703 self.
_debug(
"Cls: {} -> {}".format(cls.name, uninstantiated_name))
1706 self.
_debug(
"Ignoring: {} -> {}".format(
1707 cls.name, uninstantiated_name))
1710 def _has_serialization(cls):
1711 for m
in cls.methods:
1716 if cls.instantiations:
1719 for i, inst
in enumerate(cls.instantiations):
1725 typedef_instances +=
'typedef {original_class_name} {class_name_sep};\n' \
1726 .format(original_class_name=cls.cpp_class(),
1727 class_name_sep=cls.name)
1729 class_name_sep = cls.name
1732 if len(cls.original.namespaces()) > 1
and _has_serialization(
1734 boost_class_export_guid +=
'BOOST_CLASS_EXPORT_GUID({}, "{}");\n'.format(
1735 class_name_sep, class_name)
1737 class_name_sep = cls.cpp_class()
1740 if len(cls.original.namespaces()) > 1
and _has_serialization(
1742 boost_class_export_guid +=
'BOOST_CLASS_EXPORT_GUID({}, "{}");\n'.format(
1743 class_name_sep, class_name)
1745 typedef_collectors += textwrap.dedent(
'''\ 1746 typedef std::set<boost::shared_ptr<{class_name_sep}>*> Collector_{class_name}; 1747 static Collector_{class_name} collector_{class_name}; 1748 ''').format(class_name_sep=class_name_sep, class_name=class_name)
1749 delete_objs += textwrap.indent(textwrap.dedent(
'''\ 1750 {{ for(Collector_{class_name}::iterator iter = collector_{class_name}.begin(); 1751 iter != collector_{class_name}.end(); ) {{ 1753 collector_{class_name}.erase(iter++); 1756 ''').format(class_name=class_name),
1760 rtti_reg_mid +=
' types.insert(std::make_pair(typeid({}).name(), "{}"));\n' \
1761 .format(class_name_sep, class_name)
1763 set_next_case =
False 1766 id_val = self.wrapper_map.get(idx)
1767 queue_set_next_case = set_next_case
1769 set_next_case =
False 1772 id_val = self.wrapper_map.get(idx + 1)
1777 set_next_case =
True 1781 if queue_set_next_case:
1783 id_val[1].name, idx, id_val[1].cpp_class())
1785 wrapper_file += textwrap.dedent(
'''\ 1787 {boost_class_export_guid} 1788 {typedefs_collectors} 1789 {delete_objs} if(anyDeleted) 1791 "WARNING: Wrap modules with variables in the workspace have been reloaded due to\\n" 1792 "calling destructors, call \'clear all\' again if you plan to now recompile a wrap\\n" 1793 "module, so that your recompiled module is used instead of the old one." << endl; 1794 std::cout.rdbuf(outbuf); 1797 {pointer_constructor_fragment}{mex_function}''') \
1798 .format(typedef_instances=typedef_instances,
1799 boost_class_export_guid=boost_class_export_guid,
1800 typedefs_collectors=typedef_collectors,
1801 delete_objs=delete_objs,
1802 rtti_register=rtti_reg_start + rtti_reg_mid + rtti_reg_end,
1803 pointer_constructor_fragment=ptr_ctor_frag,
1806 self.content.append((self.
_wrapper_name() +
'.cpp', wrapper_file))
1810 Wrap the serizalize method of the class. 1812 class_name = inst_class.name
1814 (namespace_name, inst_class,
'string_serialize',
'serialize'))
1816 return textwrap.dedent(
'''\ 1817 function varargout = string_serialize(this, varargin) 1818 % STRING_SERIALIZE usage: string_serialize() : returns string 1819 % Doxygen can be found at https://gtsam.org/doxygen/ 1820 if length(varargin) == 0 1821 varargout{{1}} = {wrapper}({wrapper_id}, this, varargin{{:}}); 1823 error('Arguments do not match any overload of function {class_name}.string_serialize'); 1826 function sobj = saveobj(obj) 1827 % SAVEOBJ Saves the object to a matlab-readable format 1828 sobj = obj.string_serialize(); 1831 wrapper_id=wrapper_id,
1832 class_name=namespace_name +
'.' + class_name)
1839 Wrap the serizalize collector function. 1841 return textwrap.indent(textwrap.dedent(
"""\ 1842 typedef boost::shared_ptr<{full_name}> Shared; 1843 checkArguments("string_serialize",nargout,nargin-1,0); 1844 Shared obj = unwrap_shared_ptr<{full_name}>(in[0], "ptr_{namespace}{class_name}"); 1845 ostringstream out_archive_stream; 1846 boost::archive::text_oarchive out_archive(out_archive_stream); 1847 out_archive << *obj; 1848 out[0] = wrap< string >(out_archive_stream.str()); 1849 """).format(class_name=class_name,
1850 full_name=full_name,
1851 namespace=namespace),
1859 Wrap the deserizalize collector function. 1861 return textwrap.indent(textwrap.dedent(
"""\ 1862 typedef boost::shared_ptr<{full_name}> Shared; 1863 checkArguments("{namespace}{class_name}.string_deserialize",nargout,nargin,1); 1864 string serialized = unwrap< string >(in[0]); 1865 istringstream in_archive_stream(serialized); 1866 boost::archive::text_iarchive in_archive(in_archive_stream); 1867 Shared output(new {full_name}()); 1868 in_archive >> *output; 1869 out[0] = wrap_shared_ptr(output,"{namespace}.{class_name}", false); 1870 """).format(class_name=class_name,
1871 full_name=full_name,
1872 namespace=namespace),
1876 """High level function to wrap the project.""" 1885 Generate files and folders from matlab wrapper content. 1888 cc_content: The content to generate formatted as 1889 (file_name, file_content) or 1890 (folder_name, [(file_name, file_content)]) 1891 path: The path to the files parent folder within the main folder 1893 def _debug(message):
1896 print(message, file=sys.stderr)
1898 for c
in cc_content:
1902 _debug(
"c object: {}".format(c[0][0]))
1903 path_to_folder = osp.join(path, c[0][0])
1905 if not os.path.isdir(path_to_folder):
1907 os.makedirs(path_to_folder, exist_ok=
True)
1911 for sub_content
in c:
1912 _debug(
"sub object: {}".format(sub_content[1][0][0]))
1916 path_to_folder = osp.join(path, c[0])
1918 _debug(
"[generate_content_global]: {}".format(path_to_folder))
1919 if not os.path.isdir(path_to_folder):
1921 os.makedirs(path_to_folder, exist_ok=
True)
1924 for sub_content
in c[1]:
1925 path_to_file = osp.join(path_to_folder, sub_content[0])
1926 _debug(
"[generate_global_method]: {}".format(path_to_file))
1927 with open(path_to_file,
'w')
as f:
1928 f.write(sub_content[1])
1930 path_to_file = osp.join(path, c[0])
1932 _debug(
"[generate_content]: {}".format(path_to_file))
1933 if not os.path.isdir(path_to_file):
1939 with open(path_to_file,
'w')
as f:
void print(const Matrix &A, const string &s, ostream &stream)
def _format_static_method(self, static_method, separator='')
def generate_matlab_wrapper(self)
def wrap_method(self, methods)
def _wrap_method_check_statement(self, args)
def wrap_methods(self, methods, global_funcs=False, global_ns=None)
def _update_wrapper_id(self, collector_function=None, id_diff=0)
def _qualified_name(self, names)
def _add_include(self, include)
def wrap_class_display(self)
def _is_ptr(self, arg_type)
def class_comment(self, instantiated_class)
def _add_class(self, instantiated_class)
def _format_global_method(self, static_method, separator='')
def class_serialize_comment(self, class_name, static_methods)
def wrap_collector_function_shared_return(self, return_type_name, shared_obj, func_id, new_line=True)
def wrap_collector_function_serialize(self, class_name, full_name='', namespace='')
Tuple< Args..., T > append(Tuple< Args... > t, T a)
the deduction function for append_base that automatically generate the IndexRange ...
bool isinstance(handle obj)
def wrap_class_serialize_method(self, namespace_name, inst_class)
def _wrap_variable_arguments(self, args, wrap_datatypes=True)
def _wrap_list_variable_arguments(self, args)
def wrap_global_function(self, function)
def wrap_collector_function_return_types(self, return_type, func_id)
def generate_wrapper(self, namespace)
def wrap_class_methods(self, namespace_name, inst_class, methods, serialize=(False,))
def wrap_class_constructors(self, namespace_name, inst_class, parent_name, ctors, is_virtual)
def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False)
def _is_shared_ptr(self, arg_type)
def _insert_spaces(self, x, y)
def _format_varargout(cls, return_type, return_type_formatted)
def wrap_collector_function_return(self, method)
def _format_class_name(self, instantiated_class, separator='')
def wrap_namespace(self, namespace, parent=())
def wrap_class_properties(self, class_name)
def _wrap_args(self, args)
def _group_class_methods(self, methods)
def wrap_instantiated_class(self, instantiated_class, namespace_name='')
def wrap_class_deconstructor(self, namespace_name, inst_class)
def wrap_static_methods(self, namespace_name, instantiated_class, serialize)
def _group_methods(self, methods)
def wrap_collector_function_upcast_from_void(self, class_name, func_id, cpp_name)
def _clean_class_name(self, instantiated_class)
def _return_count(return_type)
def _debug(self, message)
def _format_return_type(cls, return_type, include_namespace=False, separator="::")
def __init__(self, module, module_name, top_module_namespace='', ignore_classes=())
def generate_content(cc_content, path, verbose=False)
def _format_type_name(cls, type_name, separator='::', include_namespace=True, constructor=False, method=False)
def _is_ref(self, arg_type)
def wrap_collector_function_deserialize(self, class_name, full_name='', namespace='')
def generate_collector_function(self, func_id)
def _format_instance_method(self, instance_method, separator='')