From 76e969893e65c37fef0c82e053dc8ff27191a13a Mon Sep 17 00:00:00 2001 From: Sebastian Marsching Date: Wed, 17 Jan 2018 17:16:15 +0100 Subject: [PATCH 1/3] Fix race condition in Salt loader. There was a race condition in the salt loader when injecting global values (e.g. "__pillar__" or "__salt__") into modules. One effect of this race condition was that in a setup with multiple threads, some threads may see pillar data intended for other threads or the pillar data seen by a thread might even change spuriously. There have been earlier attempts to fix this problem (#27937, #29397). These patches tried to fix the problem by storing the dictionary that keeps the relevant data in a thread-local variable and referencing this thread-local variable from the variables that are injected into the modules. These patches did not fix the problem completely because they only work when a module is loaded through a single loader instance only. When there is more than one loader, there is more than one thread-local variable and the variable injected into a module is changed to point to another thread-local variable when the module is loaded again. Thus, the problem resurfaced while working on #39670. This patch attempts to solve the problem from a slightly different angle, complementing the earlier patches: The value injected into the modules now is a proxy that internally uses a thread-local variable to decide to which object it points. This means that when loading a module again through a different loader (possibly passing different pillar data), the data is actually only changed in the thread in which the loader is used. Other threads are not affected by such a change. This means that it will work correctly in the current situation where loaders are possibly created by many different modules and these modules do not necessary know in which context they are executed. Thus it is much more flexible and reliable than the more explicit approach used by the two earlier patches. --- salt/loader.py | 74 +++- salt/utils/thread_local_proxy.py | 691 +++++++++++++++++++++++++++++++ 2 files changed, 764 insertions(+), 1 deletion(-) create mode 100644 salt/utils/thread_local_proxy.py diff --git a/salt/loader.py b/salt/loader.py index 7329f8f9506e..2a1d54796747 100644 --- a/salt/loader.py +++ b/salt/loader.py @@ -13,6 +13,7 @@ import logging import inspect import tempfile +import threading import functools import types from collections import MutableMapping @@ -28,6 +29,7 @@ import salt.utils.lazy import salt.utils.odict import salt.utils.platform +import salt.utils.thread_local_proxy import salt.utils.versions from salt.exceptions import LoaderError from salt.template import check_render_pipe_str @@ -1003,6 +1005,76 @@ def _mod_type(module_path): return 'ext' +def _inject_into_mod(mod, name, value, force_lock=False): + ''' + Inject a variable into a module. This is used to inject "globals" like + ``__salt__``, ``__pillar``, or ``grains``. + + Instead of injecting the value directly, a ``ThreadLocalProxy`` is created. + If such a proxy is already present under the specified name, it is updated + with the new value. This update only affects the current thread, so that + the same name can refer to different values depending on the thread of + execution. + + This is important for data that is not truly global. For example, pillar + data might be dynamically overriden through function parameters and thus + the actual values available in pillar might depend on the thread that is + calling a module. + + mod: + module object into which the value is going to be injected. + + name: + name of the variable that is injected into the module. + + value: + value that is injected into the variable. The value is not injected + directly, but instead set as the new reference of the proxy that has + been created for the variable. + + force_lock: + whether the lock should be acquired before checking whether a proxy + object for the specified name has already been injected into the + module. If ``False`` (the default), this function checks for the + module's variable without acquiring the lock and only acquires the lock + if a new proxy has to be created and injected. + ''' + from salt.utils.thread_local_proxy import ThreadLocalProxy + old_value = getattr(mod, name, None) + # We use a double-checked locking scheme in order to avoid taking the lock + # when a proxy object has already been injected. + # In most programming languages, double-checked locking is considered + # unsafe when used without explicit memory barries because one might read + # an uninitialized value. In CPython it is safe due to the global + # interpreter lock (GIL). In Python implementations that do not have the + # GIL, it could be unsafe, but at least Jython also guarantees that (for + # Python objects) memory is not corrupted when writing and reading without + # explicit synchronization + # (http://www.jython.org/jythonbook/en/1.0/Concurrency.html). + # Please note that in order to make this code safe in a runtime environment + # that does not make this guarantees, it is not sufficient. The + # ThreadLocalProxy must also be created with fallback_to_shared set to + # False or a lock must be added to the ThreadLocalProxy. + if force_lock: + with _inject_into_mod.lock: + if isinstance(old_value, ThreadLocalProxy): + ThreadLocalProxy.set_reference(old_value, value) + else: + setattr(mod, name, ThreadLocalProxy(value, True)) + else: + if isinstance(old_value, ThreadLocalProxy): + ThreadLocalProxy.set_reference(old_value, value) + else: + _inject_into_mod(mod, name, value, True) + + +# Lock used when injecting globals. This is needed to avoid a race condition +# when two threads try to load the same module concurrently. This must be +# outside the loader because there might be more than one loader for the same +# namespace. +_inject_into_mod.lock = threading.RLock() + + # TODO: move somewhere else? class FilterDictWrapper(MutableMapping): ''' @@ -1493,7 +1565,7 @@ def _load_module(self, name): # pack whatever other globals we were asked to for p_name, p_value in six.iteritems(self.pack): - setattr(mod, p_name, p_value) + _inject_into_mod(mod, p_name, p_value) module_name = mod.__name__.rsplit('.', 1)[-1] diff --git a/salt/utils/thread_local_proxy.py b/salt/utils/thread_local_proxy.py new file mode 100644 index 000000000000..71ff656f489d --- /dev/null +++ b/salt/utils/thread_local_proxy.py @@ -0,0 +1,691 @@ +# -*- coding: utf-8 -*- +''' +Proxy object that can reference different values depending on the current +thread of execution. + +..versionadded:: Nitrogen + +''' + +# Import python libs +from __future__ import absolute_import +import threading + +# Import 3rd-party libs +from salt.ext import six + + +# There are certain types which are sequences, but actually represent string +# like objects. We need a list of these types for the recursive unproxy code. +_STRING_LIKE_TYPES = (six.binary_type, six.string_types, six.text_type) + + +class ThreadLocalProxy(object): + ''' + Proxy that delegates all operations to its referenced object. The referenced + object is hold through a thread-local variable, so that this proxy may refer + to different objects in different threads of execution. + + For all practical purposes (operators, attributes, `isinstance`), the proxy + acts like the referenced object. Thus, code receiving the proxy object + instead of the reference object typically does not have to be changed. The + only exception is code that explicitly uses the ``type()`` function for + checking the proxy's type. While `isinstance(proxy, ...)` will yield the + expected results (based on the actual type of the referenced object), using + something like ``issubclass(type(proxy), ...)`` will not work, because + these tests will be made on the type of the proxy object instead of the + type of the referenced object. In order to avoid this, such code must be + changed to use ``issubclass(type(ThreadLocalProxy.unproxy(proxy)), ...)``. + + If an instance of this class is created with the ``fallback_to_shared`` flag + set and a thread uses the instance without setting the reference explicitly, + the reference for this thread is initialized with the latest reference set + by any thread. + + This class has primarily been designed for use by the Salt loader, but it + might also be useful in other places. + ''' + + __slots__ = ['_thread_local', '_last_reference', '_fallback_to_shared'] + + @staticmethod + def get_reference(proxy): + ''' + Return the object that is referenced by the specified proxy. + + If the proxy has not been bound to a reference for the current thread, + the behavior depends on th the ``fallback_to_shared`` flag that has + been specified when creating the proxy. If the flag has been set, the + last reference that has been set by any thread is returned (and + silently set as the reference for the current thread). If the flag has + not been set, an ``AttributeError`` is raised. + + proxy: + proxy object for which the reference shall be returned. If the + specified object is not an instance of `ThreadLocalProxy`, the + behavior is unspecified. Typically, an ``AttributeError`` is + going to be raised. + ''' + thread_local = object.__getattribute__(proxy, '_thread_local') + try: + return thread_local.reference + except AttributeError: + fallback_to_shared = object.__getattribute__( + proxy, '_fallback_to_shared') + if fallback_to_shared: + # If the reference has never been set in the current thread of + # execution, we use the reference that has been last set by any + # thread. + reference = object.__getattribute__(proxy, '_last_reference') + # We save the reference in the thread local so that future + # calls to get_reference will have consistent results. + ThreadLocalProxy.set_reference(proxy, reference) + return reference + else: + # We could simply return None, but this would make it hard to + # debug situations where the reference has not been set (the + # problem might go unnoticed until some code tries to do + # something with the returned object and it might not be easy to + # find out from where the None value originates). + # For this reason, we raise an AttributeError with an error + # message explaining the problem. + raise AttributeError( + 'The proxy object has not been bound to a reference in this thread of execution.') + + @staticmethod + def set_reference(proxy, new_reference): + ''' + Set the reference to be used the current thread of execution. + + After calling this function, the specified proxy will act like it was + the referenced object. + + proxy: + proxy object for which the reference shall be set. If the specified + object is not an instance of `ThreadLocalProxy`, the behavior is + unspecified. Typically, an ``AttributeError`` is going to be + raised. + + new_reference: + reference the proxy should point to for the current thread after + calling this function. + ''' + thread_local = object.__getattribute__(proxy, '_thread_local') + thread_local.reference = new_reference + object.__setattr__(proxy, '_last_reference', new_reference) + + @staticmethod + def unset_reference(proxy): + ''' + Unset the reference to be used by the current thread of execution. + + After calling this function, the specified proxy will act like the + reference had never been set for the current thread. + + proxy: + proxy object for which the reference shall be unset. If the + specified object is not an instance of `ThreadLocalProxy`, the + behavior is unspecified. Typically, an ``AttributeError`` is going + to be raised. + ''' + thread_local = object.__getattribute__(proxy, '_thread_local') + del thread_local.reference + + @staticmethod + def unproxy(possible_proxy): + ''' + Unwrap and return the object referenced by a proxy. + + This function is very similar to :func:`get_reference`, but works for + both proxies and regular objects. If the specified object is a proxy, + its reference is extracted with ``get_reference`` and returned. If it + is not a proxy, it is returned as is. + + possible_proxy: + object that might or might not be a proxy. + ''' + if isinstance(possible_proxy, ThreadLocalProxy): + return ThreadLocalProxy.get_reference(possible_proxy) + else: + return possible_proxy + + @staticmethod + def unproxy_recursive(obj): + ''' + Recursively check an object for proxied members and convert it so that + it does not contain any proxies. This is mainly intended for code that + wants to serialize an object that might potentially be a proxy (or + contain proxies) using json or msgpack. + + The passed object is not modified. Instead, a new object is created if + conversion is needed. + + :param obj: object that shall be converted. + ''' + import collections + # If the object is a well-known proxy, we simply unwrap it. We still + # process the unwrapped object like a regular object because the wrapped + # object might actually be of a type that also requires conversion. + # Although unlikely, a proxy might actually wrap another proxy, so we + # unwrap until we find a non-proxy object. + unwrapped_obj = ThreadLocalProxy.unproxy(obj) + if obj is not unwrapped_obj: + return ThreadLocalProxy.unproxy_recursive(unwrapped_obj) + # msgpack's C code does (some) checks on the class of the object instead of + # doing them on the object itself. In addition to that, it only supports + # the actual dict and list types (or sub-classes if not in strict mode). + # This means that we have to convert objects which are mappings but not + # dicts and objects that are sequences but not lists or tuples to a + # supported type. + obj_type = type(obj) + if issubclass(obj_type, memoryview): + # msgpack has special handling for memoryview objects, so we never + # convert such objects. + return obj + elif isinstance(obj, collections.Mapping): + if not issubclass(obj_type, dict): + return { + ThreadLocalProxy.unproxy_recursive(key): + ThreadLocalProxy.unproxy_recursive(value) + for key, value in six.iteritems(obj) + } + else: + # We prefer using the original object. This way we can avoid + # duplicating data structures in memory. However, if we have to + # convert one of the elements, we also need a new instance for the + # object that contained the converted element so that it can + # reference the converted element. + key_value_pairs = {} + conversion_happened = False + for key, value in six.iteritems(obj): + converted_key = ThreadLocalProxy.unproxy_recursive(key) + converted_value = ThreadLocalProxy.unproxy_recursive(value) + if ((key is not converted_key) + or (value is not converted_value)): + conversion_happened = True + key_value_pairs[converted_key] = converted_value + if conversion_happened: + return key_value_pairs + else: + return obj + elif isinstance(obj, _STRING_LIKE_TYPES): + # Strings (both unicode and raw) also are sequences, but we do not want + # to handle them as such. If the object is an instance of a string + # type, but its type is not a subclass, it might be a proxy. + if not issubclass(obj_type, _STRING_LIKE_TYPES): + if six.PY3: + if isinstance(obj, bytes): + return bytes(obj) + else: + return str(obj) + else: + # pylint: disable=incompatible-py3-code + if isinstance(obj, unicode): + return unicode(obj) + else: + return str(obj) + else: + return obj + elif isinstance(obj, collections.Sequence): + if not (issubclass(obj_type, list) or issubclass(obj_type, tuple)): + return [ + ThreadLocalProxy.unproxy_recursive(elem) for elem in obj + ] + else: + # We prefer using the original object. This way we can avoid + # duplicating data structures in memory. However, if we have to + # convert one of the elements, we also need a new instance for the + # object that contained the converted element so that it can + # reference the converted element. + elems = [] + conversion_happened = False + for elem in obj: + converted_elem = ThreadLocalProxy.unproxy_recursive(elem) + if elem is not converted_elem: + conversion_happened = True + elems.append(converted_elem) + if conversion_happened: + return elems + else: + return obj + else: + return obj + + def __init__(self, initial_reference, fallback_to_shared=False): + ''' + Create a proxy object that references the specified object. + + initial_reference: + object this proxy should initially reference (for the current + thread of execution). The :func:`set_reference` function is called + for the newly created proxy, passing this object. + + fallback_to_shared: + flag indicating what should happen when the proxy is used in a + thread where the reference has not been set explicitly. If + ``True``, the thread's reference is silently initialized to use the + reference last set by any thread. If ``False`` (the default), an + exception is raised when the proxy is used in a thread without + first initializing the reference in this thread. + ''' + object.__setattr__(self, '_thread_local', threading.local()) + object.__setattr__(self, '_fallback_to_shared', fallback_to_shared) + ThreadLocalProxy.set_reference(self, initial_reference) + + def __repr__(self): + reference = ThreadLocalProxy.get_reference(self) + return repr(reference) + + def __str__(self): + reference = ThreadLocalProxy.get_reference(self) + return str(reference) + + def __lt__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference < other + + def __le__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference <= other + + def __eq__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference == other + + def __ne__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference != other + + def __gt__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference > other + + def __ge__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference >= other + + def __hash__(self): + reference = ThreadLocalProxy.get_reference(self) + return hash(reference) + + def __nonzero__(self): + reference = ThreadLocalProxy.get_reference(self) + return bool(reference) + + def __getattr__(self, name): + reference = ThreadLocalProxy.get_reference(self) + # Old-style classes might not have a __getattr__ method, but using + # getattr(...) will still work. + try: + original_method = reference.__getattr__ + except AttributeError: + return getattr(reference, name) + return reference.__getattr__(name) + + def __setattr__(self, name, value): + reference = ThreadLocalProxy.get_reference(self) + reference.__setattr__(name, value) + + def __delattr__(self, name): + reference = ThreadLocalProxy.get_reference(self) + reference.__delattr__(name) + + def __getattribute__(self, name): + reference = ThreadLocalProxy.get_reference(self) + return reference.__getattribute__(name) + + def __call__(self, *args, **kwargs): + reference = ThreadLocalProxy.get_reference(self) + return reference(*args, **kwargs) + + def __len__(self): + reference = ThreadLocalProxy.get_reference(self) + return len(reference) + + def __getitem__(self, key): + reference = ThreadLocalProxy.get_reference(self) + return reference[key] + + def __setitem__(self, key, value): + reference = ThreadLocalProxy.get_reference(self) + reference[key] = value + + def __delitem__(self, key): + reference = ThreadLocalProxy.get_reference(self) + del reference[key] + + def __iter__(self): + reference = ThreadLocalProxy.get_reference(self) + return reference.__iter__() + + def __reversed__(self): + reference = ThreadLocalProxy.get_reference(self) + return reversed(reference) + + def __contains__(self, item): + reference = ThreadLocalProxy.get_reference(self) + return item in reference + + def __add__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference + other + + def __sub__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference - other + + def __mul__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference * other + + def __floordiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference // other + + def __mod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference % other + + def __divmod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return divmod(reference, other) + + def __pow__(self, other, modulo=None): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + modulo = ThreadLocalProxy.unproxy(modulo) + if modulo is None: + return pow(reference, other) + else: + return pow(reference, other, modulo) + + def __lshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference << other + + def __rshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference >> other + + def __and__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference & other + + def __xor__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference ^ other + + def __or__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return reference | other + + def __div__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__div__ + except AttributeError: + return NotImplemented + return func(other) + + def __truediv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__truediv__ + except AttributeError: + return NotImplemented + return func(other) + + def __radd__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other + reference + + def __rsub__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other - reference + + def __rmul__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other * reference + + def __rdiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__rdiv__ + except AttributeError: + return NotImplemented + return func(other) + + def __rtruediv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__rtruediv__ + except AttributeError: + return NotImplemented + return func(other) + + def __rfloordiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other // reference + + def __rmod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other % reference + + def __rdivmod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return divmod(other, reference) + + def __rpow__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other ** reference + + def __rlshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other << reference + + def __rrshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other >> reference + + def __rand__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other & reference + + def __rxor__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other ^ reference + + def __ror__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return other | reference + + def __iadd__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference += other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __isub__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference -= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __imul__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference *= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __idiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__idiv__ + except AttributeError: + return NotImplemented + reference = func(other) + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __itruediv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + try: + func = reference.__itruediv__ + except AttributeError: + return NotImplemented + reference = func(other) + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ifloordiv__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference //= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __imod__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference %= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ipow__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference **= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ilshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference <<= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __irshift__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference >>= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __iand__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference &= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ixor__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference ^= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __ior__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + reference |= other + ThreadLocalProxy.set_reference(self, reference) + return reference + + def __neg__(self): + reference = ThreadLocalProxy.get_reference(self) + return - reference + + def __pos__(self): + reference = ThreadLocalProxy.get_reference(self) + return + reference + + def __abs__(self): + reference = ThreadLocalProxy.get_reference(self) + return abs(reference) + + def __invert__(self): + reference = ThreadLocalProxy.get_reference(self) + return ~ reference + + def __complex__(self): + reference = ThreadLocalProxy.get_reference(self) + return complex(reference) + + def __int__(self): + reference = ThreadLocalProxy.get_reference(self) + return int(reference) + + def __float__(self): + reference = ThreadLocalProxy.get_reference(self) + return float(reference) + + def __oct__(self): + reference = ThreadLocalProxy.get_reference(self) + return oct(reference) + + def __hex__(self): + reference = ThreadLocalProxy.get_reference(self) + return hex(reference) + + def __index__(self): + reference = ThreadLocalProxy.get_reference(self) + try: + func = reference.__index__ + except AttributeError: + return NotImplemented + return func() + + def __coerce__(self, other): + reference = ThreadLocalProxy.get_reference(self) + other = ThreadLocalProxy.unproxy(other) + return coerce(reference, other) + + if six.PY2: + # pylint: disable=incompatible-py3-code + def __unicode__(self): + reference = ThreadLocalProxy.get_reference(self) + return unicode(reference) + + def __long__(self): + reference = ThreadLocalProxy.get_reference(self) + return long(reference) From 97ce1e5aaa61456e39192ccf7ff2318414e678a1 Mon Sep 17 00:00:00 2001 From: Sebastian Marsching Date: Wed, 17 Jan 2018 18:40:13 +0100 Subject: [PATCH 2/3] Unproxy objects before serializing them. The salt.utils.json module now takes care of unwrapping objects that are proxied using the ThreadLocalProxy. The salt.utils.msgpack module has been added and basically provides the same functions as the salt.utils.json module, but for msgpack. Like the json module, it takes care of unwrapping proxies. --- salt/cloud/clouds/ec2.py | 7 +- salt/cloud/clouds/gce.py | 6 +- salt/engines/stalekey.py | 6 +- salt/key.py | 8 +- salt/log/handlers/fluent_mod.py | 6 +- salt/modules/saltcheck.py | 2 +- salt/modules/state.py | 10 +-- salt/modules/win_repo.py | 5 +- salt/payload.py | 60 +++++++++---- salt/renderers/msgpack.py | 6 +- salt/returners/local_cache.py | 6 +- salt/runners/winrepo.py | 7 +- salt/sdb/sqlite3.py | 10 +-- salt/serializers/msgpack.py | 15 +++- salt/state.py | 9 +- salt/states/netsnmp.py | 3 +- salt/states/netusers.py | 2 +- salt/states/pkg.py | 4 - salt/states/probes.py | 2 +- salt/states/zabbix_host.py | 2 +- salt/states/zabbix_user.py | 2 +- salt/transport/frame.py | 8 +- salt/utils/cache.py | 8 +- salt/utils/cloud.py | 25 +++--- salt/utils/http.py | 10 +-- salt/utils/json.py | 2 + salt/utils/msgpack.py | 85 +++++++++++++++++++ .../log_handlers/runtests_log_handler.py | 7 +- tests/packdump.py | 6 +- 29 files changed, 223 insertions(+), 106 deletions(-) create mode 100644 salt/utils/msgpack.py diff --git a/salt/cloud/clouds/ec2.py b/salt/cloud/clouds/ec2.py index 5b47cf63642a..acb935936dd1 100644 --- a/salt/cloud/clouds/ec2.py +++ b/salt/cloud/clouds/ec2.py @@ -82,7 +82,6 @@ import binascii import datetime import base64 -import msgpack import re import decimal @@ -91,6 +90,7 @@ import salt.utils.files import salt.utils.hashutils import salt.utils.json +import salt.utils.msgpack import salt.utils.stringutils import salt.utils.yaml from salt._compat import ElementTree as ET @@ -4828,7 +4828,7 @@ def _parse_pricing(url, name): __opts__['cachedir'], 'ec2-pricing-{0}.p'.format(name) ) with salt.utils.files.fopen(outfile, 'w') as fho: - msgpack.dump(regions, fho) + salt.utils.msgpack.dump(regions, fho) return True @@ -4896,7 +4896,8 @@ def show_pricing(kwargs=None, call=None): update_pricing({'type': name}, 'function') with salt.utils.files.fopen(pricefile, 'r') as fhi: - ec2_price = salt.utils.stringutils.to_unicode(msgpack.load(fhi)) + ec2_price = salt.utils.stringutils.to_unicode( + salt.utils.msgpack.load(fhi)) region = get_location(profile) size = profile.get('size', None) diff --git a/salt/cloud/clouds/gce.py b/salt/cloud/clouds/gce.py index 29f059c7a54c..6a1038902ca3 100644 --- a/salt/cloud/clouds/gce.py +++ b/salt/cloud/clouds/gce.py @@ -53,7 +53,6 @@ import re import pprint import logging -import msgpack from ast import literal_eval from salt.utils.versions import LooseVersion as _LooseVersion @@ -90,6 +89,7 @@ import salt.utils.cloud import salt.utils.files import salt.utils.http +import salt.utils.msgpack import salt.config as config from salt.cloud.libcloudfuncs import * # pylint: disable=redefined-builtin,wildcard-import,unused-wildcard-import from salt.exceptions import ( @@ -2618,7 +2618,7 @@ def update_pricing(kwargs=None, call=None): __opts__['cachedir'], 'gce-pricing.p' ) with salt.utils.files.fopen(outfile, 'w') as fho: - msgpack.dump(price_json['dict'], fho) + salt.utils.msgpack.dump(price_json['dict'], fho) return True @@ -2657,7 +2657,7 @@ def show_pricing(kwargs=None, call=None): update_pricing() with salt.utils.files.fopen(pricefile, 'r') as fho: - sizes = msgpack.load(fho) + sizes = salt.utils.msgpack.load(fho) per_hour = float(sizes['gcp_price_list'][size][region]) diff --git a/salt/engines/stalekey.py b/salt/engines/stalekey.py index f927a25049bf..7b27d1733de1 100644 --- a/salt/engines/stalekey.py +++ b/salt/engines/stalekey.py @@ -28,11 +28,11 @@ import salt.key import salt.utils.files import salt.utils.minions +import salt.utils.msgpack import salt.wheel # Import 3rd-party libs from salt.ext import six -import msgpack log = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def start(interval=3600, expire=604800): if os.path.exists(presence_file): try: with salt.utils.files.fopen(presence_file, 'r') as f: - minions = msgpack.load(f) + minions = salt.utils.msgpack.load(f) except IOError as e: log.error('Could not open presence file %s: %s', presence_file, e) time.sleep(interval) @@ -95,7 +95,7 @@ def start(interval=3600, expire=604800): try: with salt.utils.files.fopen(presence_file, 'w') as f: - msgpack.dump(minions, f) + salt.utils.msgpack.dump(minions, f) except IOError as e: log.error('Could not write to presence file %s: %s', presence_file, e) time.sleep(interval) diff --git a/salt/key.py b/salt/key.py index 4163a1274e43..22e7ddd3c2e6 100644 --- a/salt/key.py +++ b/salt/key.py @@ -38,9 +38,10 @@ from salt.ext.six.moves import input # pylint: enable=import-error,no-name-in-module,redefined-builtin -# Import third party libs +# We do not always need msgpack, so we do not want to fail here if msgpack is +# not available. try: - import msgpack + import salt.utils.msgpack except ImportError: pass @@ -1035,7 +1036,8 @@ def check_minion_cache(self, preserve_minions=False): if ext == '.json': data = salt.utils.json.load(fp_) elif ext == '.msgpack': - data = msgpack.load(fp_) + data = salt.utils.msgpack.load(fp_, + _msgpack_module=msgpack) role = salt.utils.stringutils.to_unicode(data['role']) if role not in minions: os.remove(path) diff --git a/salt/log/handlers/fluent_mod.py b/salt/log/handlers/fluent_mod.py index 20724f4ed0e7..1ede2c403894 100644 --- a/salt/log/handlers/fluent_mod.py +++ b/salt/log/handlers/fluent_mod.py @@ -96,15 +96,17 @@ try: # Attempt to import msgpack - import msgpack + import salt.utils.msgpack # There is a serialization issue on ARM and potentially other platforms # for some msgpack bindings, check for it if msgpack.loads(msgpack.dumps([1, 2, 3]), use_list=True) is None: raise ImportError + import salt.utils.msgpack except ImportError: # Fall back to msgpack_pure try: import msgpack_pure as msgpack + import salt.utils.msgpack except ImportError: # TODO: Come up with a sane way to get a configured logfile # and write to the logfile when this error is hit also @@ -456,7 +458,7 @@ def _make_packet(self, label, timestamp, data): packet = (tag, timestamp, data) if self.verbose: print(packet) - return msgpack.packb(packet) + return salt.utils.msgpack.packb(packet, _msgpack_module=msgpack) def _send(self, bytes_): self.lock.acquire() diff --git a/salt/modules/saltcheck.py b/salt/modules/saltcheck.py index 7c3d580c0229..781cd15a6fa7 100644 --- a/salt/modules/saltcheck.py +++ b/salt/modules/saltcheck.py @@ -51,7 +51,7 @@ import logging import os import time -from json import loads, dumps +from salt.utils.json import loads, dumps try: import salt.utils.files import salt.utils.path diff --git a/salt/modules/state.py b/salt/modules/state.py index 4ac4860b2c90..efac79abe077 100644 --- a/salt/modules/state.py +++ b/salt/modules/state.py @@ -34,6 +34,7 @@ import salt.utils.hashutils import salt.utils.jid import salt.utils.json +import salt.utils.msgpack import salt.utils.platform import salt.utils.state import salt.utils.stringutils @@ -45,7 +46,6 @@ # Import 3rd-party libs from salt.ext import six -import msgpack __proxyenabled__ = ['*'] @@ -185,7 +185,7 @@ def _get_pause(jid, state_id=None): data[state_id] = {} if os.path.exists(pause_path): with salt.utils.files.fopen(pause_path, 'rb') as fp_: - data = msgpack.loads(fp_.read()) + data = salt.utils.msgpack.loads(fp_.read()) return data, pause_path @@ -256,7 +256,7 @@ def soft_kill(jid, state_id=None): data, pause_path = _get_pause(jid, state_id) data[state_id]['kill'] = True with salt.utils.files.fopen(pause_path, 'wb') as fp_: - fp_.write(msgpack.dumps(data)) + fp_.write(salt.utils.msgpack.dumps(data)) def pause(jid, state_id=None, duration=None): @@ -291,7 +291,7 @@ def pause(jid, state_id=None, duration=None): if duration: data[state_id]['duration'] = int(duration) with salt.utils.files.fopen(pause_path, 'wb') as fp_: - fp_.write(msgpack.dumps(data)) + fp_.write(salt.utils.msgpack.dumps(data)) def resume(jid, state_id=None): @@ -325,7 +325,7 @@ def resume(jid, state_id=None): if state_id == '__all__': data = {} with salt.utils.files.fopen(pause_path, 'wb') as fp_: - fp_.write(msgpack.dumps(data)) + fp_.write(salt.utils.msgpack.dumps(data)) def orchestrate(mods, diff --git a/salt/modules/win_repo.py b/salt/modules/win_repo.py index a087504e096d..d7610d4fce4e 100644 --- a/salt/modules/win_repo.py +++ b/salt/modules/win_repo.py @@ -31,11 +31,8 @@ PER_REMOTE_ONLY ) from salt.ext import six -try: - import msgpack -except ImportError: - import msgpack_pure as msgpack # pylint: disable=import-error import salt.utils.gitfs +import salt.utils.msgpack # pylint: enable=unused-import log = logging.getLogger(__name__) diff --git a/salt/payload.py b/salt/payload.py index 6cade01e141c..fbef758c9dc6 100644 --- a/salt/payload.py +++ b/salt/payload.py @@ -54,6 +54,10 @@ #sys.exit(salt.defaults.exitcodes.EX_GENERIC) +if HAS_MSGPACK: + import salt.utils.msgpack + + if HAS_MSGPACK and not hasattr(msgpack, 'exceptions'): class PackValueError(Exception): ''' @@ -74,14 +78,15 @@ def package(payload): This method for now just wraps msgpack.dumps, but it is here so that we can make the serialization a custom option in the future with ease. ''' - return msgpack.dumps(payload) + return salt.utils.msgpack.dumps(payload, _msgpack_module=msgpack) def unpackage(package_): ''' Unpackages a payload ''' - return msgpack.loads(package_, use_list=True) + return salt.utils.msgpack.loads(package_, use_list=True, + _msgpack_module=msgpack) def format_payload(enc, **kwargs): @@ -134,9 +139,12 @@ def loads(self, msg, encoding=None, raw=False): # Due to this, if we don't need it, don't pass it at all so # that under Python 2 we can still work with older versions # of msgpack. - ret = msgpack.loads(msg, use_list=True, encoding=encoding) + ret = salt.utils.msgpack.loads(msg, use_list=True, + encoding=encoding, + _msgpack_module=msgpack) else: - ret = msgpack.loads(msg, use_list=True) + ret = salt.utils.msgpack.loads(msg, use_list=True, + _msgpack_module=msgpack) if six.PY3 and encoding is None and not raw: ret = salt.transport.frame.decode_embedded_strs(ret) except Exception as exc: @@ -181,9 +189,10 @@ def dumps(self, msg, use_bin_type=False): # Due to this, if we don't need it, don't pass it at all so # that under Python 2 we can still work with older versions # of msgpack. - return msgpack.dumps(msg, use_bin_type=use_bin_type) + return salt.utils.msgpack.dumps(msg, use_bin_type=use_bin_type, + _msgpack_module=msgpack) else: - return msgpack.dumps(msg) + return salt.utils.msgpack.dumps(msg, _msgpack_module=msgpack) except (OverflowError, msgpack.exceptions.PackValueError): # msgpack can't handle the very long Python longs for jids # Convert any very long longs to strings @@ -207,9 +216,12 @@ def verylong_encoder(obj): else: return obj if msgpack.version >= (0, 4, 0): - return msgpack.dumps(verylong_encoder(msg), use_bin_type=use_bin_type) + return salt.utils.msgpack.dumps(verylong_encoder(msg), + use_bin_type=use_bin_type, + _msgpack_module=msgpack) else: - return msgpack.dumps(verylong_encoder(msg)) + return salt.utils.msgpack.dumps(verylong_encoder(msg), + _msgpack_module=msgpack) except TypeError as e: # msgpack doesn't support datetime.datetime datatype # So here we have converted datetime.datetime to custom datatype @@ -220,9 +232,14 @@ def default(obj): def dt_encode(obj): datetime_str = obj.strftime("%Y%m%dT%H:%M:%S.%f") if msgpack.version >= (0, 4, 0): - return msgpack.packb(datetime_str, default=default, use_bin_type=use_bin_type) + return salt.utils.msgpack.packb(datetime_str, + default=default, + use_bin_type=use_bin_type, + _msgpack_module=msgpack) else: - return msgpack.packb(datetime_str, default=default) + return salt.utils.msgpack.packb(datetime_str, + default=default, + _msgpack_module=msgpack) def datetime_encoder(obj): if isinstance(obj, dict): @@ -254,14 +271,22 @@ def immutable_encoder(obj): if "datetime.datetime" in six.text_type(e): if msgpack.version >= (0, 4, 0): - return msgpack.dumps(datetime_encoder(msg), use_bin_type=use_bin_type) + return salt.utils.msgpack.dumps(datetime_encoder(msg), + use_bin_type=use_bin_type, + _msgpack_module=msgpack) else: - return msgpack.dumps(datetime_encoder(msg)) + return salt.utils.msgpack.dumps(datetime_encoder(msg), + _msgpack_module=msgpack) elif "Immutable" in six.text_type(e): if msgpack.version >= (0, 4, 0): - return msgpack.dumps(msg, default=immutable_encoder, use_bin_type=use_bin_type) + return salt.utils.msgpack.dumps(msg, + default=immutable_encoder, + use_bin_type=use_bin_type, + _msgpack_module=msgpack) else: - return msgpack.dumps(msg, default=immutable_encoder) + return salt.utils.msgpack.dumps(msg, + default=immutable_encoder, + _msgpack_module=msgpack) if msgpack.version >= (0, 2, 0): # Should support OrderedDict serialization, so, let's @@ -286,9 +311,12 @@ def odict_encoder(obj): return obj return obj if msgpack.version >= (0, 4, 0): - return msgpack.dumps(odict_encoder(msg), use_bin_type=use_bin_type) + return salt.utils.msgpack.dumps(odict_encoder(msg), + use_bin_type=use_bin_type, + _msgpack_module=msgpack) else: - return msgpack.dumps(odict_encoder(msg)) + return salt.utils.msgpack.dumps(odict_encoder(msg), + _msgpack_module=msgpack) except (SystemError, TypeError) as exc: # pylint: disable=W0705 log.critical( 'Unable to serialize message! Consider upgrading msgpack. ' diff --git a/salt/renderers/msgpack.py b/salt/renderers/msgpack.py index f58d11b85b8d..eceac4f53bb5 100644 --- a/salt/renderers/msgpack.py +++ b/salt/renderers/msgpack.py @@ -1,10 +1,8 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import, print_function, unicode_literals -# Import third party libs -import msgpack - # Import salt libs +import salt.utils.msgpack from salt.ext import six @@ -28,4 +26,4 @@ def render(msgpack_data, saltenv='base', sls='', **kws): msgpack_data = msgpack_data[(msgpack_data.find('\n') + 1):] if not msgpack_data.strip(): return {} - return msgpack.loads(msgpack_data) + return salt.utils.msgpack.loads(msgpack_data) diff --git a/salt/returners/local_cache.py b/salt/returners/local_cache.py index 8df04aa1f2b0..021dd8e6a871 100644 --- a/salt/returners/local_cache.py +++ b/salt/returners/local_cache.py @@ -20,11 +20,11 @@ import salt.utils.files import salt.utils.jid import salt.utils.minions +import salt.utils.msgpack import salt.utils.stringutils import salt.exceptions # Import 3rd-party libs -import msgpack from salt.ext import six @@ -503,7 +503,7 @@ def save_reg(data): raise try: with salt.utils.files.fopen(regfile, 'a') as fh_: - msgpack.dump(data, fh_) + salt.utils.msgpack.dump(data, fh_) except: log.error('Could not write to msgpack file %s', __opts__['outdir']) raise @@ -517,7 +517,7 @@ def load_reg(): regfile = os.path.join(reg_dir, 'register') try: with salt.utils.files.fopen(regfile, 'r') as fh_: - return msgpack.load(fh_) + return salt.utils.msgpack.load(fh_) except: log.error('Could not write to msgpack file %s', __opts__['outdir']) raise diff --git a/salt/runners/winrepo.py b/salt/runners/winrepo.py index 716ba30e9b65..6440cf6dfc7d 100644 --- a/salt/runners/winrepo.py +++ b/salt/runners/winrepo.py @@ -12,15 +12,12 @@ # Import third party libs from salt.ext import six -try: - import msgpack -except ImportError: - import msgpack_pure as msgpack # pylint: disable=import-error # Import salt libs from salt.exceptions import CommandExecutionError, SaltRenderError import salt.utils.files import salt.utils.gitfs +import salt.utils.msgpack import salt.utils.path import logging import salt.minion @@ -123,7 +120,7 @@ def genrepo(opts=None, fire_event=True): ret.setdefault('name_map', {}).update(revmap) with salt.utils.files.fopen( os.path.join(winrepo_dir, winrepo_cachefile), 'w+b') as repo: - repo.write(msgpack.dumps(ret)) + repo.write(salt.utils.msgpack.dumps(ret)) return ret diff --git a/salt/sdb/sqlite3.py b/salt/sdb/sqlite3.py index 540a289d56ae..006d574c283b 100644 --- a/salt/sdb/sqlite3.py +++ b/salt/sdb/sqlite3.py @@ -54,11 +54,9 @@ HAS_SQLITE3 = False # Import salt libs +import salt.utils.msgpack from salt.ext import six -# Import third party libs -import msgpack - DEFAULT_TABLE = 'sdb' @@ -126,9 +124,9 @@ def set_(key, value, profile=None): return False conn, cur, table = _connect(profile) if six.PY2: - value = buffer(msgpack.packb(value)) + value = buffer(salt.utils.msgpack.packb(value)) else: - value = memoryview(msgpack.packb(value)) + value = memoryview(salt.utils.msgpack.packb(value)) q = profile.get('set_query', ('INSERT OR REPLACE INTO {0} VALUES ' '(:key, :value)').format(table)) conn.execute(q, {'key': key, 'value': value}) @@ -149,4 +147,4 @@ def get(key, profile=None): res = res.fetchone() if not res: return None - return msgpack.unpackb(res[0]) + return salt.utils.msgpack.unpackb(res[0]) diff --git a/salt/serializers/msgpack.py b/salt/serializers/msgpack.py index f55fa878b669..90df192ee3c6 100644 --- a/salt/serializers/msgpack.py +++ b/salt/serializers/msgpack.py @@ -24,6 +24,7 @@ try: # Attempt to import msgpack import msgpack + import salt.utils.msgpack # There is a serialization issue on ARM and potentially other platforms # for some msgpack bindings, check for it if msgpack.loads(msgpack.dumps([1, 2, 3]), use_list=True) is None: @@ -33,6 +34,7 @@ # Fall back to msgpack_pure try: import msgpack_pure as msgpack # pylint: disable=import-error + import salt.utils.msgpack except ImportError: # TODO: Come up with a sane way to get a configured logfile # and write to the logfile when this error is hit also @@ -60,7 +62,8 @@ def _deserialize(stream_or_string, **options): def _serialize(obj, **options): try: - return msgpack.dumps(obj, **options) + return salt.utils.msgpack.dumps(obj, _msgpack_module=msgpack, + **options) except Exception as error: raise SerializationError(error) @@ -68,7 +71,9 @@ def _deserialize(stream_or_string, **options): try: options.setdefault('use_list', True) options.setdefault('encoding', 'utf-8') - return msgpack.loads(stream_or_string, **options) + return salt.utils.msgpack.loads(stream_or_string, + _msgpack_module=msgpack, + **options) except Exception as error: raise DeserializationError(error) @@ -95,14 +100,16 @@ def _decoder(obj): def _serialize(obj, **options): try: obj = _encoder(obj) - return msgpack.dumps(obj, **options) + return salt.utils.msgpack.dumps(obj, _msgpack_module=msgpack, + **options) except Exception as error: raise SerializationError(error) def _deserialize(stream_or_string, **options): options.setdefault('use_list', True) try: - obj = msgpack.loads(stream_or_string) + obj = salt.utils.msgpack.loads(stream_or_string, + _msgpack_module=msgpack) return _decoder(obj) except Exception as error: raise DeserializationError(error) diff --git a/salt/state.py b/salt/state.py index e9d814a390ab..f2ba8a21197c 100644 --- a/salt/state.py +++ b/salt/state.py @@ -37,6 +37,7 @@ import salt.utils.event import salt.utils.files import salt.utils.immutabletypes as immutabletypes +import salt.utils.msgpack import salt.utils.platform import salt.utils.process import salt.utils.url @@ -1818,7 +1819,7 @@ def _call_parallel_target(self, cdata, low): # and the attempt, we are safe to pass pass with salt.utils.files.fopen(tfile, 'wb+') as fp_: - fp_.write(msgpack.dumps(ret)) + fp_.write(salt.utils.msgpack.dumps(ret, _msgpack_module=msgpack)) def call_parallel(self, cdata, low): ''' @@ -2219,7 +2220,8 @@ def check_pause(self, low): tries = 0 with salt.utils.files.fopen(pause_path, 'rb') as fp_: try: - pdat = msgpack.loads(fp_.read()) + pdat = salt.utils.msgpack.loads( + fp_.read(),_msgpack_module=msgpack) except msgpack.UnpackValueError: # Reading race condition if tries > 10: @@ -2266,7 +2268,8 @@ def reconcile_procs(self, running): 'changes': {}} try: with salt.utils.files.fopen(ret_cache, 'rb') as fp_: - ret = msgpack.loads(fp_.read()) + ret = salt.utils.msgpack.loads( + fp_.read(), _msgpack_module=msgpack) except (OSError, IOError): ret = {'result': False, 'comment': 'Parallel cache failure', diff --git a/salt/states/netsnmp.py b/salt/states/netsnmp.py index 34ca8b894d12..fd3d365da014 100644 --- a/salt/states/netsnmp.py +++ b/salt/states/netsnmp.py @@ -23,9 +23,8 @@ import logging log = logging.getLogger(__name__) -from json import loads, dumps - # salt lib +from salt.utils.json import loads, dumps from salt.ext import six # import NAPALM utils import salt.utils.napalm diff --git a/salt/states/netusers.py b/salt/states/netusers.py index af1f26458f72..04536bd9ca36 100644 --- a/salt/states/netusers.py +++ b/salt/states/netusers.py @@ -25,9 +25,9 @@ # Python std lib from copy import deepcopy -from json import loads, dumps # salt lib +from salt.utils.json import loads, dumps from salt.ext import six # import NAPALM utils import salt.utils.napalm diff --git a/salt/states/pkg.py b/salt/states/pkg.py index 7385708736f1..89a99255b939 100644 --- a/salt/states/pkg.py +++ b/salt/states/pkg.py @@ -135,10 +135,6 @@ # The following imports are used by the namespaced win_pkg funcs # and need to be included in their globals. # pylint: disable=import-error,unused-import - try: - import msgpack - except ImportError: - import msgpack_pure as msgpack from salt.utils.versions import LooseVersion # pylint: enable=import-error,unused-import # pylint: enable=invalid-name diff --git a/salt/states/probes.py b/salt/states/probes.py index 18ba05b90da7..eac1ce1e10a9 100644 --- a/salt/states/probes.py +++ b/salt/states/probes.py @@ -25,9 +25,9 @@ log = logging.getLogger(__name__) from copy import deepcopy -from json import loads, dumps # salt modules +from salt.utils.json import loads, dumps from salt.ext import six # import NAPALM utils import salt.utils.napalm diff --git a/salt/states/zabbix_host.py b/salt/states/zabbix_host.py index 7220297db239..3a5574a14b18 100644 --- a/salt/states/zabbix_host.py +++ b/salt/states/zabbix_host.py @@ -7,7 +7,7 @@ ''' from __future__ import absolute_import -from json import loads, dumps +from salt.utils.json import loads, dumps from copy import deepcopy from salt.ext import six diff --git a/salt/states/zabbix_user.py b/salt/states/zabbix_user.py index d60cf2cffa30..61f8492b6040 100644 --- a/salt/states/zabbix_user.py +++ b/salt/states/zabbix_user.py @@ -7,7 +7,7 @@ ''' from __future__ import absolute_import -from json import loads, dumps +from salt.utils.json import loads, dumps from copy import deepcopy diff --git a/salt/transport/frame.py b/salt/transport/frame.py index 33d0c0d91703..88b595184ec7 100644 --- a/salt/transport/frame.py +++ b/salt/transport/frame.py @@ -4,7 +4,7 @@ ''' # Import python libs from __future__ import absolute_import, print_function, unicode_literals -import msgpack +import salt.utils.msgpack from salt.ext import six @@ -18,7 +18,7 @@ def frame_msg(body, header=None, raw_body=False): # pylint: disable=unused-argu framed_msg['head'] = header framed_msg['body'] = body - return msgpack.dumps(framed_msg) + return salt.utils.msgpack.dumps(framed_msg) def frame_msg_ipc(body, header=None, raw_body=False): # pylint: disable=unused-argument @@ -35,9 +35,9 @@ def frame_msg_ipc(body, header=None, raw_body=False): # pylint: disable=unused- framed_msg['head'] = header framed_msg['body'] = body if six.PY2: - return msgpack.dumps(framed_msg) + return salt.utils.msgpack.dumps(framed_msg) else: - return msgpack.dumps(framed_msg, use_bin_type=True) + return salt.utils.msgpack.dumps(framed_msg, use_bin_type=True) def _decode_embedded_list(src): diff --git a/salt/utils/cache.py b/salt/utils/cache.py index 6350b37fe2d7..f1b2211f37de 100644 --- a/salt/utils/cache.py +++ b/salt/utils/cache.py @@ -9,7 +9,7 @@ import time import logging try: - import msgpack + import salt.utils.msgpack HAS_MSGPACK = True except ImportError: HAS_MSGPACK = False @@ -143,7 +143,9 @@ def _read(self): if not HAS_MSGPACK or not os.path.exists(self._path): return with salt.utils.files.fopen(self._path, 'rb') as fp_: - cache = salt.utils.data.decode(msgpack.load(fp_, encoding=__salt_system_encoding__)) + cache = salt.utils.data.decode( + salt.utils.msgpack.load( + fp_, encoding=__salt_system_encoding__)) if "CacheDisk_cachetime" in cache: # new format self._dict = cache["CacheDisk_data"] self._key_cache_time = cache["CacheDisk_cachetime"] @@ -168,7 +170,7 @@ def _write(self): "CacheDisk_data": self._dict, "CacheDisk_cachetime": self._key_cache_time } - msgpack.dump(cache, fp_, use_bin_type=True) + salt.utils.msgpack.dump(cache, fp_, use_bin_type=True) class CacheCli(object): diff --git a/salt/utils/cloud.py b/salt/utils/cloud.py index 88daf3d8c0a6..e4144f0edddb 100644 --- a/salt/utils/cloud.py +++ b/salt/utils/cloud.py @@ -18,7 +18,6 @@ import multiprocessing import logging import pipes -import msgpack import traceback import copy import re @@ -51,6 +50,7 @@ import salt.utils.data import salt.utils.event import salt.utils.files +import salt.utils.msgpack import salt.utils.platform import salt.utils.stringutils import salt.utils.versions @@ -2506,7 +2506,7 @@ def cachedir_index_add(minion_id, profile, driver, provider, base=None): if os.path.exists(index_file): mode = 'rb' if six.PY3 else 'r' with salt.utils.files.fopen(index_file, mode) as fh_: - index = salt.utils.data.decode(msgpack.load(fh_)) + index = salt.utils.data.decode(salt.utils.msgpack.load(fh_)) else: index = {} @@ -2523,7 +2523,7 @@ def cachedir_index_add(minion_id, profile, driver, provider, base=None): mode = 'wb' if six.PY3 else 'w' with salt.utils.files.fopen(index_file, mode) as fh_: - msgpack.dump(index, fh_) + salt.utils.msgpack.dump(index, fh_) unlock_file(index_file) @@ -2540,7 +2540,7 @@ def cachedir_index_del(minion_id, base=None): if os.path.exists(index_file): mode = 'rb' if six.PY3 else 'r' with salt.utils.files.fopen(index_file, mode) as fh_: - index = salt.utils.data.decode(msgpack.load(fh_)) + index = salt.utils.data.decode(salt.utils.msgpack.load(fh_)) else: return @@ -2549,7 +2549,7 @@ def cachedir_index_del(minion_id, base=None): mode = 'wb' if six.PY3 else 'w' with salt.utils.files.fopen(index_file, mode) as fh_: - msgpack.dump(index, fh_) + salt.utils.msgpack.dump(index, fh_) unlock_file(index_file) @@ -2606,7 +2606,7 @@ def request_minion_cachedir( fname = '{0}.p'.format(minion_id) path = os.path.join(base, 'requested', fname) with salt.utils.files.fopen(path, 'w') as fh_: - msgpack.dump(data, fh_) + salt.utils.msgpack.dump(data, fh_) def change_minion_cachedir( @@ -2638,12 +2638,12 @@ def change_minion_cachedir( path = os.path.join(base, cachedir, fname) with salt.utils.files.fopen(path, 'r') as fh_: - cache_data = salt.utils.data.decode(msgpack.load(fh_)) + cache_data = salt.utils.data.decode(salt.utils.msgpack.load(fh_)) cache_data.update(data) with salt.utils.files.fopen(path, 'w') as fh_: - msgpack.dump(cache_data, fh_) + salt.utils.msgpack.dump(cache_data, fh_) def activate_minion_cachedir(minion_id, base=None): @@ -2716,7 +2716,8 @@ def list_cache_nodes_full(opts=None, provider=None, base=None): fpath = os.path.join(min_dir, fname) minion_id = fname[:-2] # strip '.p' from end of msgpack filename with salt.utils.files.fopen(fpath, 'r') as fh_: - minions[driver][prov][minion_id] = salt.utils.data.decode(msgpack.load(fh_)) + minions[driver][prov][minion_id] = salt.utils.data.decode( + salt.utils.msgpack.load(fh_)) return minions @@ -2892,7 +2893,7 @@ def cache_node_list(nodes, provider, opts): diff_node_cache(prov_dir, node, nodes[node], opts) path = os.path.join(prov_dir, '{0}.p'.format(node)) with salt.utils.files.fopen(path, 'w') as fh_: - msgpack.dump(nodes[node], fh_) + salt.utils.msgpack.dump(nodes[node], fh_) def cache_node(node, provider, opts): @@ -2917,7 +2918,7 @@ def cache_node(node, provider, opts): os.makedirs(prov_dir) path = os.path.join(prov_dir, '{0}.p'.format(node['name'])) with salt.utils.files.fopen(path, 'w') as fh_: - msgpack.dump(node, fh_) + salt.utils.msgpack.dump(node, fh_) def missing_node_cache(prov_dir, node_list, provider, opts): @@ -2992,7 +2993,7 @@ def diff_node_cache(prov_dir, node, new_data, opts): with salt.utils.files.fopen(path, 'r') as fh_: try: - cache_data = salt.utils.data.decode(msgpack.load(fh_)) + cache_data = salt.utils.data.decode(salt.utils.msgpack.load(fh_)) except ValueError: log.warning('Cache for %s was corrupt: Deleting', node) cache_data = {} diff --git a/salt/utils/http.py b/salt/utils/http.py index a70dbbc29194..b68da7ecff0d 100644 --- a/salt/utils/http.py +++ b/salt/utils/http.py @@ -82,7 +82,7 @@ HAS_REQUESTS = False try: - import msgpack + import salt.utils.msgpack HAS_MSGPACK = True except ImportError: HAS_MSGPACK = False @@ -273,12 +273,12 @@ def query(url, # contain expirations, they can't be stored in a proper cookie jar. if os.path.isfile(session_cookie_jar): with salt.utils.files.fopen(session_cookie_jar, 'rb') as fh_: - session_cookies = msgpack.load(fh_) + session_cookies = salt.utils.msgpack.load(fh_) if isinstance(session_cookies, dict): header_dict.update(session_cookies) else: with salt.utils.files.fopen(session_cookie_jar, 'wb') as fh_: - msgpack.dump('', fh_) + salt.utils.msgpack.dump('', fh_) for header in header_list: comps = header.split(':') @@ -614,9 +614,9 @@ def query(url, with salt.utils.files.fopen(session_cookie_jar, 'wb') as fh_: session_cookies = result_headers.get('set-cookie', None) if session_cookies is not None: - msgpack.dump({'Cookie': session_cookies}, fh_) + salt.utils.msgpack.dump({'Cookie': session_cookies}, fh_) else: - msgpack.dump('', fh_) + salt.utils.msgpack.dump('', fh_) if status is True: ret['status'] = result_status_code diff --git a/salt/utils/json.py b/salt/utils/json.py index 50fe09f314bf..7d10d4ad39c4 100644 --- a/salt/utils/json.py +++ b/salt/utils/json.py @@ -100,6 +100,7 @@ def dump(obj, fp, **kwargs): json_module = kwargs.pop('_json_module', json) if 'ensure_ascii' not in kwargs: kwargs['ensure_ascii'] = False + obj = salt.utils.thread_local_proxy.ThreadLocalProxy.unproxy_recursive(obj) obj = salt.utils.data.encode(obj) return json.dump(obj, fp, **kwargs) # future lint: blacklisted-function @@ -121,5 +122,6 @@ def dumps(obj, **kwargs): json_module = kwargs.pop('_json_module', json) if 'ensure_ascii' not in kwargs: kwargs['ensure_ascii'] = False + obj = salt.utils.thread_local_proxy.ThreadLocalProxy.unproxy_recursive(obj) obj = salt.utils.data.encode(obj) return json_module.dumps(obj, **kwargs) # future lint: blacklisted-function diff --git a/salt/utils/msgpack.py b/salt/utils/msgpack.py new file mode 100644 index 000000000000..5bf3b4949c28 --- /dev/null +++ b/salt/utils/msgpack.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +''' +Functions to work with MessagePack +''' + +from __future__ import absolute_import + +# Import Python libs +try: + # Attempt to import msgpack + import msgpack +except ImportError: + # Fall back to msgpack_pure + import msgpack_pure as msgpack # pylint: disable=import-error + +# Import Salt libs +import salt.utils.thread_local_proxy + + +def pack(o, stream, **kwargs): + ''' + .. versionadded:: Oxygen + + Wraps msgpack.pack and ensures that the passed object is unwrapped if it is + a proxy. + + By default, this function uses the msgpack module and falls back to + msgpack_pure, if the msgpack is not available. You can pass an alternate + msgpack module using the _msgpack_module argument. + ''' + msgpack_module = kwargs.pop('_msgpack_module', msgpack) + o = salt.utils.thread_local_proxy.ThreadLocalProxy.unproxy_recursive(o) + return msgpack_module.pack(o, stream, **kwargs) + + +def packb(o, **kwargs): + ''' + .. versionadded:: Oxygen + + Wraps msgpack.packb and ensures that the passed object is unwrapped if it + is a proxy. + + By default, this function uses the msgpack module and falls back to + msgpack_pure, if the msgpack is not available. You can pass an alternate + msgpack module using the _msgpack_module argument. + ''' + msgpack_module = kwargs.pop('_msgpack_module', msgpack) + o = salt.utils.thread_local_proxy.ThreadLocalProxy.unproxy_recursive(o) + return msgpack_module.packb(o, **kwargs) + + +def unpack(stream, **kwargs): + ''' + .. versionadded:: Oxygen + + Wraps msgpack.unpack. + + By default, this function uses the msgpack module and falls back to + msgpack_pure, if the msgpack is not available. You can pass an alternate + msgpack module using the _msgpack_module argument. + ''' + msgpack_module = kwargs.pop('_msgpack_module', msgpack) + return msgpack_module.unpack(stream, **kwargs) + + +def unpackb(packed, **kwargs): + ''' + .. versionadded:: Oxygen + + Wraps msgpack.unpack. + + By default, this function uses the msgpack module and falls back to + msgpack_pure, if the msgpack is not available. You can pass an alternate + msgpack module using the _msgpack_module argument. + ''' + msgpack_module = kwargs.pop('_msgpack_module', msgpack) + return msgpack_module.unpackb(packed, **kwargs) + + +# alias for compatibility to simplejson/marshal/pickle. +load = unpack +loads = unpackb + +dump = pack +dumps = packb diff --git a/tests/integration/files/log_handlers/runtests_log_handler.py b/tests/integration/files/log_handlers/runtests_log_handler.py index 82e1509f72f3..6a21ef54ceb1 100644 --- a/tests/integration/files/log_handlers/runtests_log_handler.py +++ b/tests/integration/files/log_handlers/runtests_log_handler.py @@ -19,10 +19,8 @@ import threading from multiprocessing import Queue -# Import 3rd-party libs -import msgpack - # Import Salt libs +import salt.utils.msgpack from salt.ext import six import salt.log.setup @@ -85,7 +83,8 @@ def process_queue(port, queue): break # Just log everything, filtering will happen on the main process # logging handlers - sock.sendall(msgpack.dumps(record.__dict__, encoding='utf-8')) + sock.sendall(salt.utils.msgpack.dumps(record.__dict__, + encoding='utf-8')) except (IOError, EOFError, KeyboardInterrupt, SystemExit): sock.shutdown(socket.SHUT_RDWR) sock.close() diff --git a/tests/packdump.py b/tests/packdump.py index 92ed79de29bc..5a230eed946f 100644 --- a/tests/packdump.py +++ b/tests/packdump.py @@ -9,8 +9,8 @@ import sys import pprint -# Import third party libs -import msgpack +# Import Salt libs +import salt.utils.msgpack def dump(path): @@ -21,7 +21,7 @@ def dump(path): print('Not a file') return with open(path, 'rb') as fp_: - data = msgpack.loads(fp_.read()) + data = salt.utils.msgpack.loads(fp_.read()) pprint.pprint(data) From e42b489d3e70127455d9048704ce2d2171d66e30 Mon Sep 17 00:00:00 2001 From: Sebastian Marsching Date: Wed, 17 Jan 2018 20:40:09 +0100 Subject: [PATCH 3/3] Fix lint problems. --- salt/key.py | 3 +-- salt/log/handlers/fluent_mod.py | 1 + salt/state.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/salt/key.py b/salt/key.py index 22e7ddd3c2e6..0404c2c6dcf1 100644 --- a/salt/key.py +++ b/salt/key.py @@ -1036,8 +1036,7 @@ def check_minion_cache(self, preserve_minions=False): if ext == '.json': data = salt.utils.json.load(fp_) elif ext == '.msgpack': - data = salt.utils.msgpack.load(fp_, - _msgpack_module=msgpack) + data = salt.utils.msgpack.load(fp_) role = salt.utils.stringutils.to_unicode(data['role']) if role not in minions: os.remove(path) diff --git a/salt/log/handlers/fluent_mod.py b/salt/log/handlers/fluent_mod.py index 1ede2c403894..651db6ae299e 100644 --- a/salt/log/handlers/fluent_mod.py +++ b/salt/log/handlers/fluent_mod.py @@ -96,6 +96,7 @@ try: # Attempt to import msgpack + import msgpack import salt.utils.msgpack # There is a serialization issue on ARM and potentially other platforms # for some msgpack bindings, check for it diff --git a/salt/state.py b/salt/state.py index f2ba8a21197c..e40b7ba75449 100644 --- a/salt/state.py +++ b/salt/state.py @@ -2221,7 +2221,7 @@ def check_pause(self, low): with salt.utils.files.fopen(pause_path, 'rb') as fp_: try: pdat = salt.utils.msgpack.loads( - fp_.read(),_msgpack_module=msgpack) + fp_.read(), _msgpack_module=msgpack) except msgpack.UnpackValueError: # Reading race condition if tries > 10: