Skip to content

Commit

Permalink
Merge pull request #2194 from meln5674/feature/binary-wsclient
Browse files Browse the repository at this point in the history
Enable binary support for WSClient
  • Loading branch information
vadym1226 committed Feb 29, 2024
2 parents 356c6d2 + 0f7aa72 commit 9baf270
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 14 deletions.
13 changes: 11 additions & 2 deletions kubernetes/base/stream/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,18 @@ def _websocket_request(websocket_request, force_kwargs, api_method, *args, **kwa
except AttributeError:
configuration = api_client.config
prev_request = api_client.request
binary = kwargs.pop('binary', False)
try:
api_client.request = functools.partial(websocket_request, configuration)
return api_method(*args, **kwargs)
api_client.request = functools.partial(websocket_request, configuration, binary=binary)
out = api_method(*args, **kwargs)
# The api_client insists on converting this to a string using its representation, so we have
# to do this dance to strip it of the b' prefix and ' suffix, encode it byte-per-byte (latin1),
# escape all of the unicode \x*'s, then encode it back byte-by-byte
# However, if _preload_content=False is passed, then the entire WSClient is returned instead
# of a response, and we want to leave it alone
if binary and kwargs.get('_preload_content', True):
out = out[2:-1].encode('latin1').decode('unicode_escape').encode('latin1')
return out
finally:
api_client.request = prev_request

Expand Down
29 changes: 19 additions & 10 deletions kubernetes/base/stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
import six
import yaml


from six.moves.urllib.parse import urlencode, urlparse, urlunparse
from six import StringIO
from six import StringIO, BytesIO

from websocket import WebSocket, ABNF, enableTrace
from base64 import urlsafe_b64decode
Expand All @@ -48,7 +49,7 @@ def getvalue(self):


class WSClient:
def __init__(self, configuration, url, headers, capture_all):
def __init__(self, configuration, url, headers, capture_all, binary=False):
"""A websocket client with support for channels.
Exec command uses different channels for different streams. for
Expand All @@ -58,8 +59,10 @@ def __init__(self, configuration, url, headers, capture_all):
"""
self._connected = False
self._channels = {}
self.binary = binary
self.newline = '\n' if not self.binary else b'\n'
if capture_all:
self._all = StringIO()
self._all = StringIO() if not self.binary else BytesIO()
else:
self._all = _IgnoredIO()
self.sock = create_websocket(configuration, url, headers)
Expand Down Expand Up @@ -92,8 +95,8 @@ def readline_channel(self, channel, timeout=None):
while self.is_open() and time.time() - start < timeout:
if channel in self._channels:
data = self._channels[channel]
if "\n" in data:
index = data.find("\n")
if self.newline in data:
index = data.find(self.newline)
ret = data[:index]
data = data[index+1:]
if data:
Expand Down Expand Up @@ -197,10 +200,12 @@ def update(self, timeout=0):
return
elif op_code == ABNF.OPCODE_BINARY or op_code == ABNF.OPCODE_TEXT:
data = frame.data
if six.PY3:
if six.PY3 and not self.binary:
data = data.decode("utf-8", "replace")
if len(data) > 1:
channel = ord(data[0])
channel = data[0]
if six.PY3 and not self.binary:
channel = ord(channel)
data = data[1:]
if data:
if channel in [STDOUT_CHANNEL, STDERR_CHANNEL]:
Expand Down Expand Up @@ -518,13 +523,17 @@ def websocket_call(configuration, _method, url, **kwargs):
_request_timeout = kwargs.get("_request_timeout", 60)
_preload_content = kwargs.get("_preload_content", True)
capture_all = kwargs.get("capture_all", True)

binary = kwargs.get('binary', False)
try:
client = WSClient(configuration, url, headers, capture_all)
client = WSClient(configuration, url, headers, capture_all, binary=binary)
if not _preload_content:
return client
client.run_forever(timeout=_request_timeout)
return WSResponse('%s' % ''.join(client.read_all()))
all = client.read_all()
if binary:
return WSResponse(all)
else:
return WSResponse('%s' % ''.join(all))
except (Exception, KeyboardInterrupt, SystemExit) as e:
raise ApiException(status=0, reason=str(e))

Expand Down
45 changes: 43 additions & 2 deletions kubernetes/e2e_test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import unittest
import uuid
import six
import io
import gzip

from kubernetes.client import api_client
from kubernetes.client.api import core_v1_api
Expand Down Expand Up @@ -118,15 +120,28 @@ def test_pod_apis(self):
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
print('EXEC response : %s (%s)' % (repr(resp), type(resp)))
self.assertIsInstance(resp, str)
self.assertEqual(3, len(resp.splitlines()))

exec_command = ['/bin/sh',
'-c',
'echo -n "This is a test string" | gzip']
resp = stream(api.connect_get_namespaced_pod_exec, name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False,
binary=True)
print('EXEC response : %s (%s)' % (repr(resp), type(resp)))
self.assertIsInstance(resp, bytes)
self.assertEqual("This is a test string", gzip.decompress(resp).decode('utf-8'))

exec_command = 'uptime'
resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
command=exec_command,
stderr=False, stdin=False,
stdout=True, tty=False)
print('EXEC response : %s' % resp)
print('EXEC response : %s' % repr(resp))
self.assertEqual(1, len(resp.splitlines()))

resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
Expand Down Expand Up @@ -154,6 +169,32 @@ def test_pod_apis(self):
resp.update(timeout=5)
self.assertFalse(resp.is_open())

resp = stream(api.connect_post_namespaced_pod_exec, name, 'default',
command='/bin/sh',
stderr=True, stdin=True,
stdout=True, tty=False,
binary=True,
_preload_content=False)
resp.write_stdin(b"echo test string 1\n")
line = resp.readline_stdout(timeout=5)
self.assertFalse(resp.peek_stderr())
self.assertEqual(b"test string 1", line)
resp.write_stdin(b"echo test string 2 >&2\n")
line = resp.readline_stderr(timeout=5)
self.assertFalse(resp.peek_stdout())
self.assertEqual(b"test string 2", line)
resp.write_stdin(b"exit\n")
resp.update(timeout=5)
while True:
line = resp.read_channel(ERROR_CHANNEL)
if len(line) != 0:
break
time.sleep(1)
status = json.loads(line)
self.assertEqual(status['status'], 'Success')
resp.update(timeout=5)
self.assertFalse(resp.is_open())

number_of_pods = len(api.list_pod_for_all_namespaces().items)
self.assertTrue(number_of_pods > 0)

Expand Down

0 comments on commit 9baf270

Please sign in to comment.