Skip to content

Commit

Permalink
Merge pull request #135 from lmfit/fix_unsafe_procedures
Browse files Browse the repository at this point in the history
Fix unsafe procedures
  • Loading branch information
newville authored Jan 14, 2025
2 parents 8d7326d + 3ba2d51 commit aeea1b7
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 45 deletions.
90 changes: 45 additions & 45 deletions asteval/astutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,59 +506,59 @@ def __init__(self, name, interp, doc=None, lineno=0,
self.name = name
self.__name__ = self.name
self.__asteval__ = interp
self.raise_exc = self.__asteval__.raise_exception
self.__raise_exc__ = self.__asteval__.raise_exception
self.__doc__ = doc
self.body = body
self.argnames = args
self.kwargs = kwargs
self.vararg = vararg
self.varkws = varkws
self.__body__ = body
self.__argnames__ = args
self.__kwargs__ = kwargs
self.__vararg__ = vararg
self.__varkws__ = varkws
self.lineno = lineno
self.__ininit__ = False

def __setattr__(self, attr, val):
if not getattr(self, '__ininit__', True):
self.raise_exc(None, exc=TypeError,
self.__raise_exc__(None, exc=TypeError,
msg="procedure is read-only")
self.__dict__[attr] = val

def __dir__(self):
return ['_getdoc', 'argnames', 'kwargs', 'name', 'vararg', 'varkws']
return ['__getdoc__', '__argnames__', 'kwargs', 'name', 'vararg', 'varkws']

def _getdoc(self):
def __getdoc__(self):
doc = self.__doc__
if isinstance(doc, ast.Constant):
doc = doc.value
return doc

def __repr__(self):
"""TODO: docstring in magic method."""
sig = self._signature()
"""Procedure repr"""
sig = self.__signature__()
rep = f"<Procedure {sig}>"
doc = self._getdoc()
doc = self.__getdoc__()
if doc is not None:
rep = f"{rep}\n {doc}"
return rep

def _signature(self):
"call signature"
def __signature__(self):
"return the procedure's call signature"
sig = ""
if len(self.argnames) > 0:
sig = sig + ', '.join(self.argnames)
if self.vararg is not None:
sig = sig + f"*{self.vararg}"
if len(self.kwargs) > 0:
if len(self.__argnames__) > 0:
sig = sig + ', '.join(self.__argnames__)
if self.__vararg__ is not None:
sig = sig + f"*{self.__vararg__}"
if len(self.__kwargs__) > 0:
if len(sig) > 0:
sig = f"{sig}, "
_kw = [f"{k}={v}" for k, v in self.kwargs]
_kw = [f"{k}={v}" for k, v in self.__kwargs__]
sig = f"{sig}{', '.join(_kw)}"

if self.varkws is not None:
sig = f"{sig}, **{self.varkws}"
if self.__varkws__ is not None:
sig = f"{sig}, **{self.__varkws__}"
return f"{self.name}({sig})"

def __call__(self, *args, **kwargs):
"""TODO: docstring in public method."""
"""call the Procedure"""
topsym = self.__asteval__.symtable
if self.__asteval__.config.get('nested_symtable', False):
sargs = {'_main': topsym}
Expand All @@ -576,72 +576,72 @@ def __call__(self, *args, **kwargs):
args = list(args)
nargs = len(args)
nkws = len(kwargs)
nargs_expected = len(self.argnames)
nargs_expected = len(self.__argnames__)

# check for too few arguments, but the correct keyword given
if (nargs < nargs_expected) and nkws > 0:
for name in self.argnames[nargs:]:
for name in self.__argnames__[nargs:]:
if name in kwargs:
args.append(kwargs.pop(name))
nargs = len(args)
nargs_expected = len(self.argnames)
nargs_expected = len(self.__argnames__)
nkws = len(kwargs)
if nargs < nargs_expected:
msg = f"{self.name}() takes at least"
msg = f"{msg} {nargs_expected} arguments, got {nargs}"
self.raise_exc(None, exc=TypeError, msg=msg)
self.__raise_exc__(None, exc=TypeError, msg=msg)
# check for multiple values for named argument
if len(self.argnames) > 0 and kwargs is not None:
if len(self.__argnames__) > 0 and kwargs is not None:
msg = "multiple values for keyword argument"
for targ in self.argnames:
for targ in self.__argnames__:
if targ in kwargs:
msg = f"{msg} '{targ}' in Procedure {self.name}"
self.raise_exc(None, exc=TypeError, msg=msg, lineno=self.lineno)
self.__raise_exc__(None, exc=TypeError, msg=msg, lineno=self.lineno)

# check more args given than expected, varargs not given
if nargs != nargs_expected:
msg = None
if nargs < nargs_expected:
msg = f"not enough arguments for Procedure {self.name}()"
msg = f"{msg} (expected {nargs_expected}, got {nargs}"
self.raise_exc(None, exc=TypeError, msg=msg)
self.__raise_exc__(None, exc=TypeError, msg=msg)

if nargs > nargs_expected and self.vararg is None:
if nargs - nargs_expected > len(self.kwargs):
if nargs > nargs_expected and self.__vararg__ is None:
if nargs - nargs_expected > len(self.__kwargs__):
msg = f"too many arguments for {self.name}() expected at most"
msg = f"{msg} {len(self.kwargs)+nargs_expected}, got {nargs}"
self.raise_exc(None, exc=TypeError, msg=msg)
msg = f"{msg} {len(self.__kwargs__)+nargs_expected}, got {nargs}"
self.__raise_exc__(None, exc=TypeError, msg=msg)

for i, xarg in enumerate(args[nargs_expected:]):
kw_name = self.kwargs[i][0]
kw_name = self.__kwargs__[i][0]
if kw_name not in kwargs:
kwargs[kw_name] = xarg

for argname in self.argnames:
for argname in self.__argnames__:
symlocals[argname] = args.pop(0)

try:
if self.vararg is not None:
symlocals[self.vararg] = tuple(args)
if self.__vararg__ is not None:
symlocals[self.__vararg__] = tuple(args)

for key, val in self.kwargs:
for key, val in self.__kwargs__:
if key in kwargs:
val = kwargs.pop(key)
symlocals[key] = val

if self.varkws is not None:
symlocals[self.varkws] = kwargs
if self.__varkws__ is not None:
symlocals[self.__varkws__] = kwargs

elif len(kwargs) > 0:
msg = f"extra keyword arguments for Procedure {self.name}: "
msg = msg + ','.join(list(kwargs.keys()))
self.raise_exc(None, msg=msg, exc=TypeError,
self.__raise_exc__(None, msg=msg, exc=TypeError,
lineno=self.lineno)

except (ValueError, LookupError, TypeError,
NameError, AttributeError):
msg = f"incorrect arguments for Procedure {self.name}"
self.raise_exc(None, msg=msg, lineno=self.lineno)
self.__raise_exc__(None, msg=msg, lineno=self.lineno)

if self.__asteval__.config.get('nested_symtable', False):
save_symtable = self.__asteval__.symtable
Expand All @@ -655,7 +655,7 @@ def __call__(self, *args, **kwargs):
retval = None

# evaluate script of function
for node in self.body:
for node in self.__body__:
self.__asteval__.run(node, expr='<>', lineno=self.lineno)
if len(self.__asteval__.error) > 0:
break
Expand Down
20 changes: 20 additions & 0 deletions tests/test_asteval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,5 +1568,25 @@ def test_delete_slice(nested):
assert interp("g.dlist") == [1, 3, 5, 7, 15, 17, 19, 21]


@pytest.mark.parametrize("nested", [False, True])
def test_unsafe_procedure_access(nested):
"""
addressing https://github.com/lmfit/asteval/security/advisories/GHSA-vp47-9734-prjw
"""
interp = make_interpreter(nested_symtable=nested)
interp(textwrap.dedent("""
def my_func(x, y):
return x+y
my_func.__body__[0] = 'something else'
"""), raise_errors=False)

error = interp.error[0]
etype, fullmsg = error.get_error()
assert 'no safe attribute' in error.msg
assert etype == 'AttributeError'


if __name__ == '__main__':
pytest.main(['-v', '-x', '-s'])

0 comments on commit aeea1b7

Please sign in to comment.