diff --git a/salt/loader.py b/salt/loader.py index 939edca45dd1..2b949cafc478 100644 --- a/salt/loader.py +++ b/salt/loader.py @@ -35,6 +35,7 @@ import salt.utils.odict import salt.utils.platform import salt.utils.stringutils +import salt.utils.thread_local_proxy as thread_local_proxy import salt.utils.versions from salt.exceptions import LoaderError @@ -1089,6 +1090,75 @@ def _mod_type(module_path): return "ext" +def _inject_into_mod(mod, name, value, force_lock=False): + """ + Inject a variable into a module. This is used to inject "globals" like + ``__salt__``, ``__pillar``, or ``grains``. + + Instead of injecting the value directly, a ``ThreadLocalProxy`` is created. + If such a proxy is already present under the specified name, it is updated + with the new value. This update only affects the current thread, so that + the same name can refer to different values depending on the thread of + execution. + + This is important for data that is not truly global. For example, pillar + data might be dynamically overriden through function parameters and thus + the actual values available in pillar might depend on the thread that is + calling a module. + + mod: + module object into which the value is going to be injected. + + name: + name of the variable that is injected into the module. + + value: + value that is injected into the variable. The value is not injected + directly, but instead set as the new reference of the proxy that has + been created for the variable. + + force_lock: + whether the lock should be acquired before checking whether a proxy + object for the specified name has already been injected into the + module. If ``False`` (the default), this function checks for the + module's variable without acquiring the lock and only acquires the lock + if a new proxy has to be created and injected. + """ + old_value = getattr(mod, name, None) + # We use a double-checked locking scheme in order to avoid taking the lock + # when a proxy object has already been injected. + # In most programming languages, double-checked locking is considered + # unsafe when used without explicit memory barriers because one might read + # an uninitialized value. In CPython it is safe due to the global + # interpreter lock (GIL). In Python implementations that do not have the + # GIL, it could be unsafe, but at least Jython also guarantees that (for + # Python objects) memory is not corrupted when writing and reading without + # explicit synchronization + # (http://www.jython.org/jythonbook/en/1.0/Concurrency.html). + # Please note that in order to make this code safe in a runtime environment + # that does not make this guarantees, it is not sufficient. The + # ThreadLocalProxy must also be created with fallback_to_shared set to + # False or a lock must be added to the ThreadLocalProxy. + if force_lock: + with _inject_into_mod.lock: + if isinstance(old_value, thread_local_proxy.ThreadLocalProxy): + thread_local_proxy.ThreadLocalProxy.set_reference(old_value, value) + else: + setattr(mod, name, thread_local_proxy.ThreadLocalProxy(value, True)) + else: + if isinstance(old_value, thread_local_proxy.ThreadLocalProxy): + thread_local_proxy.ThreadLocalProxy.set_reference(old_value, value) + else: + _inject_into_mod(mod, name, value, True) + + +# Lock used when injecting globals. This is needed to avoid a race condition +# when two threads try to load the same module concurrently. This must be +# outside the loader because there might be more than one loader for the same +# namespace. +_inject_into_mod.lock = threading.RLock() + + # TODO: move somewhere else? class FilterDictWrapper(MutableMapping): """ @@ -1185,7 +1255,11 @@ def __init__( for k, v in six.iteritems(self.pack): if v is None: # if the value of a pack is None, lets make an empty dict - self.context_dict.setdefault(k, {}) + value = thread_local_proxy.ThreadLocalProxy.unproxy( + self.context_dict.get(k, {}) + ) + + self.context_dict[k] = value self.pack[k] = salt.utils.context.NamespacedDictWrapper( self.context_dict, k ) @@ -1468,13 +1542,19 @@ def __prep_mod_opts(self, opts): Strip out of the opts any logger instance """ if "__grains__" not in self.pack: - self.context_dict["grains"] = opts.get("grains", {}) + _grains = thread_local_proxy.ThreadLocalProxy.unproxy( + opts.get("grains", {}) + ) + + self.context_dict["grains"] = _grains self.pack["__grains__"] = salt.utils.context.NamespacedDictWrapper( self.context_dict, "grains" ) if "__pillar__" not in self.pack: - self.context_dict["pillar"] = opts.get("pillar", {}) + pillar = thread_local_proxy.ThreadLocalProxy.unproxy(opts.get("pillar", {})) + + self.context_dict["pillar"] = pillar self.pack["__pillar__"] = salt.utils.context.NamespacedDictWrapper( self.context_dict, "pillar" ) @@ -1670,7 +1750,7 @@ def _load_module(self, name): # pack whatever other globals we were asked to for p_name, p_value in six.iteritems(self.pack): - setattr(mod, p_name, p_value) + _inject_into_mod(mod, p_name, p_value) module_name = mod.__name__.rsplit(".", 1)[-1] diff --git a/salt/utils/json.py b/salt/utils/json.py index 18cbac8a759e..6a596c2f7548 100644 --- a/salt/utils/json.py +++ b/salt/utils/json.py @@ -13,6 +13,7 @@ # Import Salt libs import salt.utils.data import salt.utils.stringutils +import salt.utils.thread_local_proxy as thread_local_proxy # Import 3rd-party libs from salt.ext import six @@ -119,11 +120,18 @@ def dump(obj, fp, **kwargs): using the _json_module argument) """ json_module = kwargs.pop("_json_module", json) + orig_enc_func = kwargs.pop("default", lambda x: x) + + def _enc_func(_obj): + return orig_enc_func(thread_local_proxy.ThreadLocalProxy.unproxy(_obj)) + if "ensure_ascii" not in kwargs: kwargs["ensure_ascii"] = False if six.PY2: obj = salt.utils.data.encode(obj) - return json_module.dump(obj, fp, **kwargs) # future lint: blacklisted-function + return json_module.dump( + obj, fp, default=_enc_func, **kwargs + ) # future lint: blacklisted-function def dumps(obj, **kwargs): @@ -142,8 +150,15 @@ def dumps(obj, **kwargs): using the _json_module argument) """ json_module = kwargs.pop("_json_module", json) + orig_enc_func = kwargs.pop("default", lambda x: x) + + def _enc_func(_obj): + return orig_enc_func(thread_local_proxy.ThreadLocalProxy.unproxy(_obj)) + if "ensure_ascii" not in kwargs: kwargs["ensure_ascii"] = False if six.PY2: obj = salt.utils.data.encode(obj) - return json_module.dumps(obj, **kwargs) # future lint: blacklisted-function + return json_module.dumps( + obj, default=_enc_func, **kwargs + ) # future lint: blacklisted-function diff --git a/salt/utils/msgpack.py b/salt/utils/msgpack.py index a93763ffb7fa..f6a2c4250c64 100644 --- a/salt/utils/msgpack.py +++ b/salt/utils/msgpack.py @@ -8,6 +8,9 @@ import logging +# Import Salt libs +import salt.utils.thread_local_proxy as thread_local_proxy + log = logging.getLogger(__name__) # Import 3rd party libs @@ -94,8 +97,13 @@ def pack(o, stream, **kwargs): By default, this function uses the msgpack module and falls back to msgpack_pure, if the msgpack is not available. """ + orig_enc_func = kwargs.pop("default", lambda x: x) + + def _enc_func(obj): + return orig_enc_func(thread_local_proxy.ThreadLocalProxy.unproxy(obj)) + # Writes to a stream, there is no return - msgpack.pack(o, stream, **_sanitize_msgpack_kwargs(kwargs)) + msgpack.pack(o, stream, default=_enc_func, **_sanitize_msgpack_kwargs(kwargs)) def packb(o, **kwargs): @@ -108,7 +116,12 @@ def packb(o, **kwargs): By default, this function uses the msgpack module and falls back to msgpack_pure, if the msgpack is not available. """ - return msgpack.packb(o, **_sanitize_msgpack_kwargs(kwargs)) + orig_enc_func = kwargs.pop("default", lambda x: x) + + def _enc_func(obj): + return orig_enc_func(thread_local_proxy.ThreadLocalProxy.unproxy(obj)) + + return msgpack.packb(o, default=_enc_func, **_sanitize_msgpack_kwargs(kwargs)) def unpack(stream, **kwargs): diff --git a/salt/utils/thread_local_proxy.py b/salt/utils/thread_local_proxy.py new file mode 100644 index 000000000000..df2db08cc91e --- /dev/null +++ b/salt/utils/thread_local_proxy.py @@ -0,0 +1,601 @@ +# -*- coding: utf-8 -*- +""" +Proxy object that can reference different values depending on the current +thread of execution. + +..versionadded:: 2018.3.4 + +""" + +# Import python libs +from __future__ import absolute_import + +import threading + +# Import 3rd-party libs +from salt.ext import six + + +class ThreadLocalProxy(object): + """ + Proxy that delegates all operations to its referenced object. The referenced + object is hold through a thread-local variable, so that this proxy may refer + to different objects in different threads of execution. + + For all practical purposes (operators, attributes, `isinstance`), the proxy + acts like the referenced object. Thus, code receiving the proxy object + instead of the reference object typically does not have to be changed. The + only exception is code that explicitly uses the ``type()`` function for + checking the proxy's type. While `isinstance(proxy, ...)` will yield the + expected results (based on the actual type of the referenced object), using + something like ``issubclass(type(proxy), ...)`` will not work, because + these tests will be made on the type of the proxy object instead of the + type of the referenced object. In order to avoid this, such code must be + changed to use ``issubclass(type(ThreadLocalProxy.unproxy(proxy)), ...)``. + + If an instance of this class is created with the ``fallback_to_shared`` flag + set and a thread uses the instance without setting the reference explicitly, + the reference for this thread is initialized with the latest reference set + by any thread. + + This class has primarily been designed for use by the Salt loader, but it + might also be useful in other places. + """ + + __slots__ = ["_thread_local", "_last_reference", "_fallback_to_shared"] + + @staticmethod + def get_reference(proxy): + """ + Return the object that is referenced by the specified proxy. + + If the proxy has not been bound to a reference for the current thread, + the behavior depends on th the ``fallback_to_shared`` flag that has + been specified when creating the proxy. If the flag has been set, the + last reference that has been set by any thread is returned (and + silently set as the reference for the current thread). If the flag has + not been set, an ``AttributeError`` is raised. + + If the object references by this proxy is itself a proxy, that proxy is + returned. Use ``unproxy`` for unwrapping the referenced object until it + is not a proxy. + + proxy: + proxy object for which the reference shall be returned. If the + specified object is not an instance of `ThreadLocalProxy`, the + behavior is unspecified. Typically, an ``AttributeError`` is + going to be raised. + """ + thread_local = object.__getattribute__(proxy, "_thread_local") + try: + return thread_local.reference + except AttributeError: + fallback_to_shared = object.__getattribute__(proxy, "_fallback_to_shared") + if fallback_to_shared: + # If the reference has never been set in the current thread of + # execution, we use the reference that has been last set by any + # thread. + reference = object.__getattribute__(proxy, "_last_reference") + # We save the reference in the thread local so that future + # calls to get_reference will have consistent results. + ThreadLocalProxy.set_reference(proxy, reference) + return reference + else: + # We could simply return None, but this would make it hard to + # debug situations where the reference has not been set (the + # problem might go unnoticed until some code tries to do + # something with the returned object and it might not be easy to + # find out from where the None value originates). + # For this reason, we raise an AttributeError with an error + # message explaining the problem. + raise AttributeError( + "The proxy object has not been bound to a reference in this thread of execution." + ) + + @staticmethod + def set_reference(proxy, new_reference): + """ + Set the reference to be used the current thread of execution. + + After calling this function, the specified proxy will act like it was + the referenced object. + + proxy: + proxy object for which the reference shall be set. If the specified + object is not an instance of `ThreadLocalProxy`, the behavior is + unspecified. Typically, an ``AttributeError`` is going to be + raised. + + new_reference: + reference the proxy should point to for the current thread after + calling this function. + """ + # If the new reference is itself a proxy, we have to ensure that it does + # not refer to this proxy. If it does, we simply return because updating + # the reference would result in an inifite loop when trying to use the + # proxy. + possible_proxy = new_reference + while isinstance(possible_proxy, ThreadLocalProxy): + if possible_proxy is proxy: + return + possible_proxy = ThreadLocalProxy.get_reference(possible_proxy) + thread_local = object.__getattribute__(proxy, "_thread_local") + thread_local.reference = new_reference + object.__setattr__(proxy, "_last_reference", new_reference) + + @staticmethod + def unset_reference(proxy): + """ + Unset the reference to be used by the current thread of execution. + + After calling this function, the specified proxy will act like the + reference had never been set for the current thread. + + proxy: + proxy object for which the reference shall be unset. If the + specified object is not an instance of `ThreadLocalProxy`, the + behavior is unspecified. Typically, an ``AttributeError`` is going + to be raised. + """ + thread_local = object.__getattribute__(proxy, "_thread_local") + del thread_local.reference + + @staticmethod + def unproxy(possible_proxy): + """ + Unwrap and return the object referenced by a proxy. + + This function is very similar to :func:`get_reference`, but works for + both proxies and regular objects. If the specified object is a proxy, + its reference is extracted with ``get_reference`` and returned. If it + is not a proxy, it is returned as is. + + If the object references by the proxy is itself a proxy, the unwrapping + is repeated until a regular (non-proxy) object is found. + + possible_proxy: + object that might or might not be a proxy. + """ + while isinstance(possible_proxy, ThreadLocalProxy): + possible_proxy = ThreadLocalProxy.get_reference(possible_proxy) + return possible_proxy + + def __init__(self, initial_reference, fallback_to_shared=False): + """ + Create a proxy object that references the specified object. + + initial_reference: + object this proxy should initially reference (for the current + thread of execution). The :func:`set_reference` function is called + for the newly created proxy, passing this object. + + fallback_to_shared: + flag indicating what should happen when the proxy is used in a + thread where the reference has not been set explicitly. If + ``True``, the thread's reference is silently initialized to use the + reference last set by any thread. If ``False`` (the default), an + exception is raised when the proxy is used in a thread without + first initializing the reference in this thread. + """ + object.__setattr__(self, "_thread_local", threading.local()) + object.__setattr__(self, "_fallback_to_shared", fallback_to_shared) + ThreadLocalProxy.set_reference(self, initial_reference) + + def __repr__(self): + reference = ThreadLocalProxy.get_reference(self) + return repr(reference) + + def __str__(self): + reference = ThreadLocalProxy.get_reference(self) + return str(reference) + + def __lt__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference < other + + def __le__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference <= other + + def __eq__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference == other + + def __ne__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference != other + + def __gt__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference > other + + def __ge__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference >= other + + def __hash__(self): + reference = ThreadLocalProxy.get_reference(self) + return hash(reference) + + def __nonzero__(self): + reference = ThreadLocalProxy.get_reference(self) + return bool(reference) + + def __getattr__(self, name): + reference = ThreadLocalProxy.get_reference(self) + # Old-style classes might not have a __getattr__ method, but using + # getattr(...) will still work. + try: + original_method = reference.__getattr__ + except AttributeError: + return getattr(reference, name) + return reference.__getattr__(name) + + def __setattr__(self, name, value): + reference = ThreadLocalProxy.get_reference(self) + reference.__setattr__(name, value) + + def __delattr__(self, name): + reference = ThreadLocalProxy.get_reference(self) + reference.__delattr__(name) + + def __getattribute__(self, name): + reference = ThreadLocalProxy.get_reference(self) + return reference.__getattribute__(name) + + def __call__(self, *args, **kwargs): + reference = ThreadLocalProxy.get_reference(self) + return reference(*args, **kwargs) + + def __len__(self): + reference = ThreadLocalProxy.get_reference(self) + return len(reference) + + def __getitem__(self, key): + reference = ThreadLocalProxy.get_reference(self) + return reference[key] + + def __setitem__(self, key, value): + reference = ThreadLocalProxy.get_reference(self) + reference[key] = value + + def __delitem__(self, key): + reference = ThreadLocalProxy.get_reference(self) + del reference[key] + + def __iter__(self): + reference = ThreadLocalProxy.get_reference(self) + return reference.__iter__() + + def __reversed__(self): + reference = ThreadLocalProxy.get_reference(self) + return reversed(reference) + + def __contains__(self, item): + reference = ThreadLocalProxy.get_reference(self) + return item in reference + + def __add__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference + other + + def __sub__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference - other + + def __mul__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference * other + + def __floordiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference // other + + def __mod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference % other + + def __divmod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return divmod(reference, other) + + def __pow__(self, other, modulo=None): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + modulo = ThreadLocalProxy.unproxy(modulo) + if modulo is None: + return pow(reference, other) + else: + return pow(reference, other, modulo) + + def __lshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference << other + + def __rshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference >> other + + def __and__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference & other + + def __xor__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference ^ other + + def __or__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference | other + + def __div__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__div__ + except AttributeError: + return NotImplemented + return func(other) + + def __truediv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__truediv__ + except AttributeError: + return NotImplemented + return func(other) + + def __radd__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other + reference + + def __rsub__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other - reference + + def __rmul__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other * reference + + def __rdiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__rdiv__ + except AttributeError: + return NotImplemented + return func(other) + + def __rtruediv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__rtruediv__ + except AttributeError: + return NotImplemented + return func(other) + + def __rfloordiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other // reference + + def __rmod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other % reference + + def __rdivmod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return divmod(other, reference) + + def __rpow__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other ** reference + + def __rlshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other << reference + + def __rrshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other >> reference + + def __rand__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other & reference + + def __rxor__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other ^ reference + + def __ror__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other | reference + + def __iadd__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference += other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __isub__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference -= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __imul__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference *= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __idiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__idiv__ + except AttributeError: + return NotImplemented + reference = func(other) + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __itruediv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__itruediv__ + except AttributeError: + return NotImplemented + reference = func(other) + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ifloordiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference //= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __imod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference %= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ipow__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference **= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ilshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference <<= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __irshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference >>= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __iand__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference &= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ixor__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference ^= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ior__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference |= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __neg__(self): + reference = ThreadLocalProxy.get_reference(self) + return -reference + + def __pos__(self): + reference = ThreadLocalProxy.get_reference(self) + return +reference + + def __abs__(self): + reference = ThreadLocalProxy.get_reference(self) + return abs(reference) + + def __invert__(self): + reference = ThreadLocalProxy.get_reference(self) + return ~reference + + def __complex__(self): + reference = ThreadLocalProxy.get_reference(self) + return complex(reference) + + def __int__(self): + reference = ThreadLocalProxy.get_reference(self) + return int(reference) + + def __float__(self): + reference = ThreadLocalProxy.get_reference(self) + return float(reference) + + def __oct__(self): + reference = ThreadLocalProxy.get_reference(self) + return oct(reference) + + def __hex__(self): + reference = ThreadLocalProxy.get_reference(self) + return hex(reference) + + def __index__(self): + reference = ThreadLocalProxy.get_reference(self) + try: + func = reference.__index__ + except AttributeError: + return NotImplemented + return func() + + def __coerce__(self, other): + # `coerce` isn't available on python 3.6, 3.7, or 3.8 + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return coerce(reference, other) # pylint: disable=undefined-variable + + if six.PY2: + # pylint: disable=incompatible-py3-code + def __unicode__(self): + reference = ThreadLocalProxy.get_reference(self) + return unicode(reference) # pylint: disable=undefined-variable + + def __long__(self): + reference = ThreadLocalProxy.get_reference(self) + return long(reference) # pylint: disable=undefined-variable diff --git a/tests/unit/test_loader.py b/tests/unit/test_loader.py index bbe633a321dc..ed3648eebf08 100644 --- a/tests/unit/test_loader.py +++ b/tests/unit/test_loader.py @@ -26,6 +26,7 @@ import salt.loader import salt.utils.files import salt.utils.stringutils +import salt.utils.thread_local_proxy as thread_local_proxy # pylint: disable=import-error,no-name-in-module,redefined-builtin from salt.ext import six @@ -1488,3 +1489,34 @@ def test_osrelease_info_has_correct_type(self): grains = salt.loader.grains(self.opts) osrelease_info = grains["osrelease_info"] assert isinstance(osrelease_info, tuple), osrelease_info + + +class ThreadLocalProxyLoaderTest(TestCase): + @classmethod + def setUpClass(cls): + cls.opts = salt.config.minion_config(None) + + def test__inject_into_mod(self): + class test_module(object): + name = "inject_into_mod.test.module" + + # First path, Force is not true, proxy doesn't exist -- also takes the path of Force True and proxy not exist + salt.loader._inject_into_mod(test_module, "__opts__", self.opts) + self.assertTrue(hasattr(test_module, "__opts__")) + self.assertIsInstance(test_module.__opts__, thread_local_proxy.ThreadLocalProxy) + self.assertEqual(self.opts, test_module.__opts__) + foo = test_module.__opts__ + + # Second path, Force is not true, proxy exists + salt.loader._inject_into_mod(test_module, "__opts__", self.opts) + self.assertIsInstance(test_module.__opts__, thread_local_proxy.ThreadLocalProxy) + self.assertEqual(self.opts, test_module.__opts__) + bar = test_module.__opts__ + + self.assertIs(foo, bar) + self.assertEqual(foo, bar) + foo["yes"] = "no" + self.assertIn("yes", bar) + + # Final path, Force is true, proxy exists + salt.loader._inject_into_mod(test_module, "__opts__", self.opts, True) diff --git a/tests/unit/utils/test_thread_local_proxy.py b/tests/unit/utils/test_thread_local_proxy.py new file mode 100644 index 000000000000..ebcfeb07d147 --- /dev/null +++ b/tests/unit/utils/test_thread_local_proxy.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- + +# Import python libs +from __future__ import absolute_import + +# Import Salt Libs +from salt.utils import thread_local_proxy + +# Import Salt Testing Libs +from tests.support.unit import TestCase + + +class ThreadLocalProxyTestCase(TestCase): + """ + Test case for salt.utils.thread_local_proxy module. + """ + + def test_set_reference_avoid_loop(self): + """ + Test that passing another proxy (or the same proxy) to set_reference + does not results in a recursive proxy loop. + """ + test_obj1 = 1 + test_obj2 = 2 + proxy1 = thread_local_proxy.ThreadLocalProxy(test_obj1) + proxy2 = thread_local_proxy.ThreadLocalProxy(proxy1) + self.assertEqual(test_obj1, proxy1) + self.assertEqual(test_obj1, proxy2) + self.assertEqual(proxy1, proxy2) + thread_local_proxy.ThreadLocalProxy.set_reference(proxy1, test_obj2) + self.assertEqual(test_obj2, proxy1) + self.assertEqual(test_obj2, proxy2) + self.assertEqual(proxy1, proxy2) + thread_local_proxy.ThreadLocalProxy.set_reference(proxy1, proxy2) + self.assertEqual(test_obj2, proxy1) + self.assertEqual(test_obj2, proxy2) + self.assertEqual(proxy1, proxy2)