diff --git a/kubernetes/base/stream/stream.py b/kubernetes/base/stream/stream.py index 115a899b..e34dedfc 100644 --- a/kubernetes/base/stream/stream.py +++ b/kubernetes/base/stream/stream.py @@ -30,9 +30,18 @@ def _websocket_request(websocket_request, force_kwargs, api_method, *args, **kwa except AttributeError: configuration = api_client.config prev_request = api_client.request + binary = kwargs.pop('binary', False) try: - api_client.request = functools.partial(websocket_request, configuration) - return api_method(*args, **kwargs) + api_client.request = functools.partial(websocket_request, configuration, binary=binary) + out = api_method(*args, **kwargs) + # The api_client insists on converting this to a string using its representation, so we have + # to do this dance to strip it of the b' prefix and ' suffix, encode it byte-per-byte (latin1), + # escape all of the unicode \x*'s, then encode it back byte-by-byte + # However, if _preload_content=False is passed, then the entire WSClient is returned instead + # of a response, and we want to leave it alone + if binary and kwargs.get('_preload_content', True): + out = out[2:-1].encode('latin1').decode('unicode_escape').encode('latin1') + return out finally: api_client.request = prev_request diff --git a/kubernetes/base/stream/ws_client.py b/kubernetes/base/stream/ws_client.py index 5ec8e7d4..3c854ea7 100644 --- a/kubernetes/base/stream/ws_client.py +++ b/kubernetes/base/stream/ws_client.py @@ -26,8 +26,9 @@ import six import yaml + from six.moves.urllib.parse import urlencode, urlparse, urlunparse -from six import StringIO +from six import StringIO, BytesIO from websocket import WebSocket, ABNF, enableTrace from base64 import urlsafe_b64decode @@ -48,7 +49,7 @@ def getvalue(self): class WSClient: - def __init__(self, configuration, url, headers, capture_all): + def __init__(self, configuration, url, headers, capture_all, binary=False): """A websocket client with support for channels. Exec command uses different channels for different streams. for @@ -58,8 +59,10 @@ def __init__(self, configuration, url, headers, capture_all): """ self._connected = False self._channels = {} + self.binary = binary + self.newline = '\n' if not self.binary else b'\n' if capture_all: - self._all = StringIO() + self._all = StringIO() if not self.binary else BytesIO() else: self._all = _IgnoredIO() self.sock = create_websocket(configuration, url, headers) @@ -92,8 +95,8 @@ def readline_channel(self, channel, timeout=None): while self.is_open() and time.time() - start < timeout: if channel in self._channels: data = self._channels[channel] - if "\n" in data: - index = data.find("\n") + if self.newline in data: + index = data.find(self.newline) ret = data[:index] data = data[index+1:] if data: @@ -197,10 +200,12 @@ def update(self, timeout=0): return elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT: data = frame.data - if six.PY3: + if six.PY3 and not self.binary: data = data.decode("utf-8", "replace") if len(data) > 1: - channel = ord(data[0]) + channel = data[0] + if six.PY3 and not self.binary: + channel = ord(channel) data = data[1:] if data: if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]: @@ -518,13 +523,17 @@ def websocket_call(configuration, _method, url, **kwargs): _request_timeout = kwargs.get("_request_timeout", 60) _preload_content = kwargs.get("_preload_content", True) capture_all = kwargs.get("capture_all", True) - + binary = kwargs.get('binary', False) try: - client = WSClient(configuration, url, headers, capture_all) + client = WSClient(configuration, url, headers, capture_all, binary=binary) if not _preload_content: return client client.run_forever(timeout=_request_timeout) - return WSResponse('%s' % ''.join(client.read_all())) + all = client.read_all() + if binary: + return WSResponse(all) + else: + return WSResponse('%s' % ''.join(all)) except (Exception, KeyboardInterrupt, SystemExit) as e: raise ApiException(status=0, reason=str(e)) diff --git a/kubernetes/e2e_test/test_client.py b/kubernetes/e2e_test/test_client.py index 2c28dedc..15689291 100644 --- a/kubernetes/e2e_test/test_client.py +++ b/kubernetes/e2e_test/test_client.py @@ -20,6 +20,8 @@ import unittest import uuid import six +import io +import gzip from kubernetes.client import api_client from kubernetes.client.api import core_v1_api @@ -118,15 +120,28 @@ def test_pod_apis(self): command=exec_command, stderr=False, stdin=False, stdout=True, tty=False) - print('EXEC response : %s' % resp) + print('EXEC response : %s (%s)' % (repr(resp), type(resp))) + self.assertIsInstance(resp, str) self.assertEqual(3, len(resp.splitlines())) + exec_command = ['/bin/sh', + '-c', + 'echo -n "This is a test string" | gzip'] + resp = stream(api.connect_get_namespaced_pod_exec, name, 'default', + command=exec_command, + stderr=False, stdin=False, + stdout=True, tty=False, + binary=True) + print('EXEC response : %s (%s)' % (repr(resp), type(resp))) + self.assertIsInstance(resp, bytes) + self.assertEqual("This is a test string", gzip.decompress(resp).decode('utf-8')) + exec_command = 'uptime' resp = stream(api.connect_post_namespaced_pod_exec, name, 'default', command=exec_command, stderr=False, stdin=False, stdout=True, tty=False) - print('EXEC response : %s' % resp) + print('EXEC response : %s' % repr(resp)) self.assertEqual(1, len(resp.splitlines())) resp = stream(api.connect_post_namespaced_pod_exec, name, 'default', @@ -154,6 +169,32 @@ def test_pod_apis(self): resp.update(timeout=5) self.assertFalse(resp.is_open()) + resp = stream(api.connect_post_namespaced_pod_exec, name, 'default', + command='/bin/sh', + stderr=True, stdin=True, + stdout=True, tty=False, + binary=True, + _preload_content=False) + resp.write_stdin(b"echo test string 1\n") + line = resp.readline_stdout(timeout=5) + self.assertFalse(resp.peek_stderr()) + self.assertEqual(b"test string 1", line) + resp.write_stdin(b"echo test string 2 >&2\n") + line = resp.readline_stderr(timeout=5) + self.assertFalse(resp.peek_stdout()) + self.assertEqual(b"test string 2", line) + resp.write_stdin(b"exit\n") + resp.update(timeout=5) + while True: + line = resp.read_channel(ERROR_CHANNEL) + if len(line) != 0: + break + time.sleep(1) + status = json.loads(line) + self.assertEqual(status['status'], 'Success') + resp.update(timeout=5) + self.assertFalse(resp.is_open()) + number_of_pods = len(api.list_pod_for_all_namespaces().items) self.assertTrue(number_of_pods > 0)