From 382a5feabaf5d2901be73fdd98c4cb8fc96e22f9 Mon Sep 17 00:00:00 2001 From: Matthew Newville Date: Tue, 14 Jan 2025 09:33:40 -0600 Subject: [PATCH 1/2] make all procedure attributes private to curb access to AST nodes, which can be exploited --- asteval/astutils.py | 90 ++++++++++++++++++++++----------------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/asteval/astutils.py b/asteval/astutils.py index c5f9181..4dab4fc 100644 --- a/asteval/astutils.py +++ b/asteval/astutils.py @@ -506,59 +506,59 @@ def __init__(self, name, interp, doc=None, lineno=0, self.name = name self.__name__ = self.name self.__asteval__ = interp - self.raise_exc = self.__asteval__.raise_exception + self.__raise_exc__ = self.__asteval__.raise_exception self.__doc__ = doc - self.body = body - self.argnames = args - self.kwargs = kwargs - self.vararg = vararg - self.varkws = varkws + self.__body__ = body + self.__argnames__ = args + self.__kwargs__ = kwargs + self.__vararg__ = vararg + self.__varkws__ = varkws self.lineno = lineno self.__ininit__ = False def __setattr__(self, attr, val): if not getattr(self, '__ininit__', True): - self.raise_exc(None, exc=TypeError, + self.__raise_exc__(None, exc=TypeError, msg="procedure is read-only") self.__dict__[attr] = val def __dir__(self): - return ['_getdoc', 'argnames', 'kwargs', 'name', 'vararg', 'varkws'] + return ['__getdoc__', '__argnames__', 'kwargs', 'name', 'vararg', 'varkws'] - def _getdoc(self): + def __getdoc__(self): doc = self.__doc__ if isinstance(doc, ast.Constant): doc = doc.value return doc def __repr__(self): - """TODO: docstring in magic method.""" - sig = self._signature() + """Procedure repr""" + sig = self.__signature__() rep = f"" - doc = self._getdoc() + doc = self.__getdoc__() if doc is not None: rep = f"{rep}\n {doc}" return rep - def _signature(self): - "call signature" + def __signature__(self): + "return the procedure's call signature" sig = "" - if len(self.argnames) > 0: - sig = sig + ', '.join(self.argnames) - if self.vararg is not None: - sig = sig + f"*{self.vararg}" - if len(self.kwargs) > 0: + if len(self.__argnames__) > 0: + sig = sig + ', '.join(self.__argnames__) + if self.__vararg__ is not None: + sig = sig + f"*{self.__vararg__}" + if len(self.__kwargs__) > 0: if len(sig) > 0: sig = f"{sig}, " - _kw = [f"{k}={v}" for k, v in self.kwargs] + _kw = [f"{k}={v}" for k, v in self.__kwargs__] sig = f"{sig}{', '.join(_kw)}" - if self.varkws is not None: - sig = f"{sig}, **{self.varkws}" + if self.__varkws__ is not None: + sig = f"{sig}, **{self.__varkws__}" return f"{self.name}({sig})" def __call__(self, *args, **kwargs): - """TODO: docstring in public method.""" + """call the Procedure""" topsym = self.__asteval__.symtable if self.__asteval__.config.get('nested_symtable', False): sargs = {'_main': topsym} @@ -576,27 +576,27 @@ def __call__(self, *args, **kwargs): args = list(args) nargs = len(args) nkws = len(kwargs) - nargs_expected = len(self.argnames) + nargs_expected = len(self.__argnames__) # check for too few arguments, but the correct keyword given if (nargs < nargs_expected) and nkws > 0: - for name in self.argnames[nargs:]: + for name in self.__argnames__[nargs:]: if name in kwargs: args.append(kwargs.pop(name)) nargs = len(args) - nargs_expected = len(self.argnames) + nargs_expected = len(self.__argnames__) nkws = len(kwargs) if nargs < nargs_expected: msg = f"{self.name}() takes at least" msg = f"{msg} {nargs_expected} arguments, got {nargs}" - self.raise_exc(None, exc=TypeError, msg=msg) + self.__raise_exc__(None, exc=TypeError, msg=msg) # check for multiple values for named argument - if len(self.argnames) > 0 and kwargs is not None: + if len(self.__argnames__) > 0 and kwargs is not None: msg = "multiple values for keyword argument" - for targ in self.argnames: + for targ in self.__argnames__: if targ in kwargs: msg = f"{msg} '{targ}' in Procedure {self.name}" - self.raise_exc(None, exc=TypeError, msg=msg, lineno=self.lineno) + self.__raise_exc__(None, exc=TypeError, msg=msg, lineno=self.lineno) # check more args given than expected, varargs not given if nargs != nargs_expected: @@ -604,44 +604,44 @@ def __call__(self, *args, **kwargs): if nargs < nargs_expected: msg = f"not enough arguments for Procedure {self.name}()" msg = f"{msg} (expected {nargs_expected}, got {nargs}" - self.raise_exc(None, exc=TypeError, msg=msg) + self.__raise_exc__(None, exc=TypeError, msg=msg) - if nargs > nargs_expected and self.vararg is None: - if nargs - nargs_expected > len(self.kwargs): + if nargs > nargs_expected and self.__vararg__ is None: + if nargs - nargs_expected > len(self.__kwargs__): msg = f"too many arguments for {self.name}() expected at most" - msg = f"{msg} {len(self.kwargs)+nargs_expected}, got {nargs}" - self.raise_exc(None, exc=TypeError, msg=msg) + msg = f"{msg} {len(self.__kwargs__)+nargs_expected}, got {nargs}" + self.__raise_exc__(None, exc=TypeError, msg=msg) for i, xarg in enumerate(args[nargs_expected:]): - kw_name = self.kwargs[i][0] + kw_name = self.__kwargs__[i][0] if kw_name not in kwargs: kwargs[kw_name] = xarg - for argname in self.argnames: + for argname in self.__argnames__: symlocals[argname] = args.pop(0) try: - if self.vararg is not None: - symlocals[self.vararg] = tuple(args) + if self.__vararg__ is not None: + symlocals[self.__vararg__] = tuple(args) - for key, val in self.kwargs: + for key, val in self.__kwargs__: if key in kwargs: val = kwargs.pop(key) symlocals[key] = val - if self.varkws is not None: - symlocals[self.varkws] = kwargs + if self.__varkws__ is not None: + symlocals[self.__varkws__] = kwargs elif len(kwargs) > 0: msg = f"extra keyword arguments for Procedure {self.name}: " msg = msg + ','.join(list(kwargs.keys())) - self.raise_exc(None, msg=msg, exc=TypeError, + self.__raise_exc__(None, msg=msg, exc=TypeError, lineno=self.lineno) except (ValueError, LookupError, TypeError, NameError, AttributeError): msg = f"incorrect arguments for Procedure {self.name}" - self.raise_exc(None, msg=msg, lineno=self.lineno) + self.__raise_exc__(None, msg=msg, lineno=self.lineno) if self.__asteval__.config.get('nested_symtable', False): save_symtable = self.__asteval__.symtable @@ -655,7 +655,7 @@ def __call__(self, *args, **kwargs): retval = None # evaluate script of function - for node in self.body: + for node in self.__body__: self.__asteval__.run(node, expr='<>', lineno=self.lineno) if len(self.__asteval__.error) > 0: break From 3ba2d51e4287073ef615b3e1828302f7704ca897 Mon Sep 17 00:00:00 2001 From: Matthew Newville Date: Tue, 14 Jan 2025 09:34:01 -0600 Subject: [PATCH 2/2] add test of accessing procedure attributes --- tests/test_asteval.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_asteval.py b/tests/test_asteval.py index 75c3517..3519217 100644 --- a/tests/test_asteval.py +++ b/tests/test_asteval.py @@ -1568,5 +1568,25 @@ def test_delete_slice(nested): assert interp("g.dlist") == [1, 3, 5, 7, 15, 17, 19, 21] +@pytest.mark.parametrize("nested", [False, True]) +def test_unsafe_procedure_access(nested): + """ + addressing https://github.com/lmfit/asteval/security/advisories/GHSA-vp47-9734-prjw + """ + interp = make_interpreter(nested_symtable=nested) + interp(textwrap.dedent(""" + def my_func(x, y): + return x+y + + my_func.__body__[0] = 'something else' + + """), raise_errors=False) + + error = interp.error[0] + etype, fullmsg = error.get_error() + assert 'no safe attribute' in error.msg + assert etype == 'AttributeError' + + if __name__ == '__main__': pytest.main(['-v', '-x', '-s'])