Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper Update #1084

Merged
merged 2 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions wrap/gtwrap/matlab_wrapper/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,30 @@ def _has_serialization(self, cls):
return True
return False

def can_be_pointer(self, arg_type: parser.Type):
"""
Determine if the `arg_type` can have a pointer to it.

E.g. `Pose3` can have `Pose3*` but
`Matrix` should not have `Matrix*`.
"""
return (arg_type.typename.name not in self.not_ptr_type
and arg_type.typename.name not in self.ignore_namespace
and arg_type.typename.name != 'string')

def is_shared_ptr(self, arg_type: parser.Type):
"""
Determine if the `interface_parser.Type` should be treated as a
shared pointer in the wrapper.
"""
return arg_type.is_shared_ptr or (
arg_type.typename.name not in self.not_ptr_type
and arg_type.typename.name not in self.ignore_namespace
and arg_type.typename.name != 'string')
return arg_type.is_shared_ptr

def is_ptr(self, arg_type: parser.Type):
"""
Determine if the `interface_parser.Type` should be treated as a
raw pointer in the wrapper.
"""
return arg_type.is_ptr or (
arg_type.typename.name not in self.not_ptr_type
and arg_type.typename.name not in self.ignore_namespace
and arg_type.typename.name != 'string')
return arg_type.is_ptr

def is_ref(self, arg_type: parser.Type):
"""
Expand Down
96 changes: 53 additions & 43 deletions wrap/gtwrap/matlab_wrapper/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,13 @@ def _expand_default_arguments(method, save_backup=True):
"""
def args_copy(args):
return ArgumentList([copy.copy(arg) for arg in args.list()])

def method_copy(method):
method2 = copy.copy(method)
method2.args = args_copy(method.args)
method2.args.backup = method.args.backup
return method2

if save_backup:
method.args.backup = args_copy(method.args)
method = method_copy(method)
Expand All @@ -162,7 +164,8 @@ def method_copy(method):
method.args.list().remove(arg)
return [
methodWithArg,
*MatlabWrapper._expand_default_arguments(method, save_backup=False)
*MatlabWrapper._expand_default_arguments(method,
save_backup=False)
]
break
assert all(arg.default is None for arg in method.args.list()), \
Expand All @@ -180,9 +183,12 @@ def _group_methods(self, methods):

if method_index is None:
method_map[method.name] = len(method_out)
method_out.append(MatlabWrapper._expand_default_arguments(method))
method_out.append(
MatlabWrapper._expand_default_arguments(method))
else:
method_out[method_index] += MatlabWrapper._expand_default_arguments(method)
method_out[
method_index] += MatlabWrapper._expand_default_arguments(
method)

return method_out

Expand Down Expand Up @@ -337,43 +343,42 @@ def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False):
body_args = ''

for arg in args.list():
ctype_camel = self._format_type_name(arg.ctype.typename,
separator='')
ctype_sep = self._format_type_name(arg.ctype.typename)

if self.is_ref(arg.ctype): # and not constructor:
ctype_camel = self._format_type_name(arg.ctype.typename,
separator='')
body_args += textwrap.indent(textwrap.dedent('''\
{ctype}& {name} = *unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");
'''.format(ctype=self._format_type_name(arg.ctype.typename),
ctype_camel=ctype_camel,
name=arg.name,
id=arg_id)),
prefix=' ')

elif (self.is_shared_ptr(arg.ctype) or self.is_ptr(arg.ctype)) and \
arg_type = "{ctype}&".format(ctype=ctype_sep)
unwrap = '*unwrap_shared_ptr< {ctype} >(in[{id}], "ptr_{ctype_camel}");'.format(
ctype=ctype_sep, ctype_camel=ctype_camel, id=arg_id)

elif self.is_ptr(arg.ctype) and \
arg.ctype.typename.name not in self.ignore_namespace:
if arg.ctype.is_shared_ptr:
call_type = arg.ctype.is_shared_ptr
else:
call_type = arg.ctype.is_ptr

body_args += textwrap.indent(textwrap.dedent('''\
{std_boost}::shared_ptr<{ctype_sep}> {name} = unwrap_shared_ptr< {ctype_sep} >(in[{id}], "ptr_{ctype}");
'''.format(std_boost='boost' if constructor else 'boost',
ctype_sep=self._format_type_name(
arg.ctype.typename),
ctype=self._format_type_name(arg.ctype.typename,
separator=''),
name=arg.name,
id=arg_id)),
prefix=' ')
arg_type = "{ctype_sep}*".format(ctype_sep=ctype_sep)
unwrap = 'unwrap_ptr< {ctype_sep} >(in[{id}], "ptr_{ctype}");'.format(
ctype_sep=ctype_sep, ctype=ctype_camel, id=arg_id)

else:
body_args += textwrap.indent(textwrap.dedent('''\
{ctype} {name} = unwrap< {ctype} >(in[{id}]);
'''.format(ctype=arg.ctype.typename.name,
name=arg.name,
id=arg_id)),
prefix=' ')
elif (self.is_shared_ptr(arg.ctype) or self.can_be_pointer(arg.ctype)) and \
arg.ctype.typename.name not in self.ignore_namespace:
call_type = arg.ctype.is_shared_ptr

arg_type = "{std_boost}::shared_ptr<{ctype_sep}>".format(
std_boost='boost' if constructor else 'boost',
ctype_sep=ctype_sep)
unwrap = 'unwrap_shared_ptr< {ctype_sep} >(in[{id}], "ptr_{ctype}");'.format(
ctype_sep=ctype_sep, ctype=ctype_camel, id=arg_id)

else:
arg_type = "{ctype}".format(ctype=arg.ctype.typename.name)
unwrap = 'unwrap< {ctype} >(in[{id}]);'.format(
ctype=arg.ctype.typename.name, id=arg_id)

body_args += textwrap.indent(textwrap.dedent('''\
{arg_type} {name} = {unwrap}
'''.format(arg_type=arg_type, name=arg.name,
unwrap=unwrap)),
prefix=' ')
arg_id += 1

params = ''
Expand All @@ -383,12 +388,14 @@ def _wrapper_unwrap_arguments(self, args, arg_id=0, constructor=False):
if params != '':
params += ','

if (arg.default is not None) and (arg.name not in explicit_arg_names):
if (arg.default is not None) and (arg.name
not in explicit_arg_names):
params += arg.default
continue

if (not self.is_ref(arg.ctype)) and (self.is_shared_ptr(arg.ctype)) and (self.is_ptr(
arg.ctype)) and (arg.ctype.typename.name not in self.ignore_namespace):
if not self.is_ref(arg.ctype) and (self.is_shared_ptr(arg.ctype) or \
self.is_ptr(arg.ctype) or self.can_be_pointer(arg.ctype))and \
arg.ctype.typename.name not in self.ignore_namespace:
if arg.ctype.is_shared_ptr:
call_type = arg.ctype.is_shared_ptr
else:
Expand Down Expand Up @@ -601,7 +608,8 @@ def wrap_class_constructors(self, namespace_name, inst_class, parent_name,
if not isinstance(ctors, Iterable):
ctors = [ctors]

ctors = sum((MatlabWrapper._expand_default_arguments(ctor) for ctor in ctors), [])
ctors = sum((MatlabWrapper._expand_default_arguments(ctor)
for ctor in ctors), [])

methods_wrap = textwrap.indent(textwrap.dedent("""\
methods
Expand Down Expand Up @@ -885,10 +893,10 @@ def wrap_static_methods(self, namespace_name, instantiated_class,
wrapper=self._wrapper_name(),
id=self._update_wrapper_id(
(namespace_name, instantiated_class,
static_overload.name, static_overload)),
static_overload.name, static_overload)),
class_name=instantiated_class.name,
end_statement=end_statement),
prefix=' ')
prefix=' ')

# If the arguments don't match any of the checks above,
# throw an error with the class and method name.
Expand Down Expand Up @@ -1079,7 +1087,8 @@ def wrap_collector_function_return_types(self, return_type, func_id):
pair_value = 'first' if func_id == 0 else 'second'
new_line = '\n' if func_id == 0 else ''

if self.is_shared_ptr(return_type) or self.is_ptr(return_type):
if self.is_shared_ptr(return_type) or self.is_ptr(return_type) or \
self.can_be_pointer(return_type):
shared_obj = 'pairResult.' + pair_value

if not (return_type.is_shared_ptr or return_type.is_ptr):
Expand Down Expand Up @@ -1145,7 +1154,8 @@ def wrap_collector_function_return(self, method):

if return_1_name != 'void':
if return_count == 1:
if self.is_shared_ptr(return_1) or self.is_ptr(return_1):
if self.is_shared_ptr(return_1) or self.is_ptr(return_1) or \
self.can_be_pointer(return_1):
sep_method_name = partial(self._format_type_name,
return_1.typename,
include_namespace=True)
Expand Down
8 changes: 8 additions & 0 deletions wrap/matlab.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,14 @@ boost::shared_ptr<Class> unwrap_shared_ptr(const mxArray* obj, const string& pro
return *spp;
}

template <typename Class>
Class* unwrap_ptr(const mxArray* obj, const string& propertyName) {

mxArray* mxh = mxGetProperty(obj,0, propertyName.c_str());
Class* x = reinterpret_cast<Class*> (mxGetData(mxh));
return x;
}

//// throw an error if unwrap_shared_ptr is attempted for an Eigen Vector
//template <>
//Vector unwrap_shared_ptr<Vector>(const mxArray* obj, const string& propertyName) {
Expand Down
6 changes: 3 additions & 3 deletions wrap/tests/expected/matlab/ForwardKinematicsFactor.m
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
if nargin == 2
my_ptr = varargin{2};
else
my_ptr = inheritance_wrapper(36, varargin{2});
my_ptr = inheritance_wrapper(52, varargin{2});
end
base_ptr = inheritance_wrapper(35, my_ptr);
base_ptr = inheritance_wrapper(51, my_ptr);
else
error('Arguments do not match any overload of ForwardKinematicsFactor constructor');
end
Expand All @@ -22,7 +22,7 @@
end

function delete(obj)
inheritance_wrapper(37, obj.ptr_ForwardKinematicsFactor);
inheritance_wrapper(53, obj.ptr_ForwardKinematicsFactor);
end

function display(obj), obj.print(''); end
Expand Down
2 changes: 1 addition & 1 deletion wrap/tests/expected/matlab/functions_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ void load2D_2(int nargout, mxArray *out[], int nargin, const mxArray *in[])
{
checkArguments("load2D",nargout,nargin,2);
string filename = unwrap< string >(in[0]);
boost::shared_ptr<gtsam::noiseModel::Diagonal> model = unwrap_shared_ptr< gtsam::noiseModel::Diagonal >(in[1], "ptr_gtsamnoiseModelDiagonal");
gtsam::noiseModel::Diagonal* model = unwrap_ptr< gtsam::noiseModel::Diagonal >(in[1], "ptr_gtsamnoiseModelDiagonal");
auto pairResult = load2D(filename,model);
out[0] = wrap_shared_ptr(pairResult.first,"gtsam.NonlinearFactorGraph", false);
out[1] = wrap_shared_ptr(pairResult.second,"gtsam.Values", false);
Expand Down
4 changes: 2 additions & 2 deletions wrap/tests/expected/matlab/geometry_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void gtsamPoint2_argChar_7(int nargout, mxArray *out[], int nargin, const mxArra
{
checkArguments("argChar",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<gtsam::Point2>(in[0], "ptr_gtsamPoint2");
boost::shared_ptr<char> a = unwrap_shared_ptr< char >(in[1], "ptr_char");
char* a = unwrap_ptr< char >(in[1], "ptr_char");
obj->argChar(a);
}

Expand All @@ -175,7 +175,7 @@ void gtsamPoint2_argChar_10(int nargout, mxArray *out[], int nargin, const mxArr
{
checkArguments("argChar",nargout,nargin-1,1);
auto obj = unwrap_shared_ptr<gtsam::Point2>(in[0], "ptr_gtsamPoint2");
boost::shared_ptr<char> a = unwrap_shared_ptr< char >(in[1], "ptr_char");
char* a = unwrap_ptr< char >(in[1], "ptr_char");
obj->argChar(a);
}

Expand Down
Loading