Skip to content

Commit

Permalink
Merge pull request #30 from tomplus/feat/watch-forever
Browse files Browse the repository at this point in the history
feat: watch work forever if timeout is not specified
  • Loading branch information
tomplus authored Jul 28, 2018
2 parents 12e3374 + ef5152a commit e375f02
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 23 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pip install -r test-requirements.txt
You can run the style checks and tests with

```bash
flake8 && isort -c
flake8 kubernetes_asyncio/
isort --diff --recursive kubernetes_asyncio/
nosetests
```
Binary file not shown.
67 changes: 47 additions & 20 deletions kubernetes_asyncio/watch/watch.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import json
import pydoc
from functools import partial
Expand Down Expand Up @@ -48,6 +49,7 @@ def __init__(self, return_type=None):
self._raw_return_type = return_type
self._stop = False
self._api_client = client.ApiClient()
self.resource_version = 0

def stop(self):
self._stop = True
Expand Down Expand Up @@ -86,6 +88,19 @@ def unmarshal_event(self, data: str, response_type):
response=SimpleNamespace(data=json.dumps(js['raw_object'])),
response_type=response_type
)

# decode and save resource_version to continue watching
if hasattr(js['object'], 'metadata'):
self.resource_version = js['object'].metadata.resource_version

# For custom objects that we don't have model defined, json
# deserialization results in dictionary
elif (isinstance(js['object'], dict) and
'metadata' in js['object'] and
'resourceVersion' in js['object']['metadata']):

self.resource_version = js['object']['metadata']['resourceVersion']

return js

def __aiter__(self):
Expand All @@ -95,26 +110,38 @@ async def __anext__(self):
return await self.next()

async def next(self):
# Set the response object to the user supplied function (eg
# `list_namespaced_pods`) if this is the first iteration.
if self.resp is None:
self.resp = await self.func()

# Abort at the current iteration if the user has called `stop` on this
# stream instance.
if self._stop:
raise StopAsyncIteration

# Fetch the next K8s response.
line = await self.resp.content.readline()
line = line.decode('utf8')

# Stop the iterator if K8s sends an empty response. This happens when
# eg the supplied timeout has expired.
if line == '':
raise StopAsyncIteration

return self.unmarshal_event(line, self.return_type)

while 1:

# Set the response object to the user supplied function (eg
# `list_namespaced_pods`) if this is the first iteration.
if self.resp is None:
self.resp = await self.func()

# Abort at the current iteration if the user has called `stop` on this
# stream instance.
if self._stop:
raise StopAsyncIteration

# Fetch the next K8s response.
try:
line = await self.resp.content.readline()
except asyncio.TimeoutError:
if 'timeout_seconds' not in self.func.keywords:
self.resp = None
self.func.keywords['resource_version'] = self.resource_version
continue
else:
raise

line = line.decode('utf8')

# Stop the iterator if K8s sends an empty response. This happens when
# eg the supplied timeout has expired.
if line == '':
raise StopAsyncIteration

return self.unmarshal_event(line, self.return_type)

def stream(self, func, *args, **kwargs):
"""Watch an API resource and stream the result back via a generator.
Expand Down
48 changes: 46 additions & 2 deletions kubernetes_asyncio/watch/watch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import json

from asynctest import CoroutineMock, Mock, TestCase
from asynctest import CoroutineMock, Mock, TestCase, call

import kubernetes_asyncio
from kubernetes_asyncio.watch import Watch
Expand All @@ -29,7 +30,8 @@ async def test_watch_with_decode(self):
{
"type": "ADDED",
"object": {
"metadata": {"name": "test{}".format(uid)},
"metadata": {"name": "test{}".format(uid),
"resourceVersion": str(uid)},
"spec": {}, "status": {}
}
}
Expand All @@ -49,6 +51,9 @@ async def test_watch_with_decode(self):
self.assertEqual("ADDED", e['type'])
# make sure decoder worked and we got a model with the right name
self.assertEqual("test%d" % count, e['object'].metadata.name)
# make sure decoder worked and updated Watch.resource_version
self.assertEqual(e['object'].metadata.resource_version, str(count))
self.assertEqual(watch.resource_version, str(count))

# Stop the watch. This must not return the next event which would
# be an AssertionError exception.
Expand Down Expand Up @@ -127,6 +132,19 @@ async def test_unmarshall_k8s_error_response(self):
self.assertEqual(ret['object'], k8s_err['object'])
self.assertEqual(ret['object'], k8s_err['object'])

def test_unmarshal_with_custom_object(self):
w = Watch()
event = w.unmarshal_event('{"type": "ADDED", "object": {"apiVersion":'
'"test.com/v1beta1","kind":"foo","metadata":'
'{"name": "bar", "resourceVersion": "1"}}}',
'object')
self.assertEqual("ADDED", event['type'])
# make sure decoder deserialized json into dictionary and updated
# Watch.resource_version
self.assertTrue(isinstance(event['object'], dict))
self.assertEqual("1", event['object']['metadata']['resourceVersion'])
self.assertEqual("1", w.resource_version)

async def test_watch_with_exception(self):
fake_resp = CoroutineMock()
fake_resp.content.readline = CoroutineMock()
Expand All @@ -140,6 +158,32 @@ async def test_watch_with_exception(self):
async for e in watch.stream(fake_api.get_namespaces, timeout_seconds=10): # noqa
pass

async def test_watch_timeout(self):
fake_resp = CoroutineMock()
fake_resp.content.readline = CoroutineMock()

mock_event = {"type": "ADDED",
"object": {"metadata": {"name": "test1555",
"resourceVersion": "1555"},
"spec": {},
"status": {}}}

fake_resp.content.readline.side_effect = [json.dumps(mock_event).encode('utf8'),
asyncio.TimeoutError(),
b""]

fake_api = Mock()
fake_api.get_namespaces = CoroutineMock(return_value=fake_resp)
fake_api.get_namespaces.__doc__ = ':return: V1NamespaceList'

watch = kubernetes_asyncio.watch.Watch()
async for e in watch.stream(fake_api.get_namespaces): # noqa
pass

fake_api.get_namespaces.assert_has_calls(
[call(_preload_content=False, watch=True),
call(_preload_content=False, watch=True, resource_version='1555')])


if __name__ == '__main__':
import asynctest
Expand Down

0 comments on commit e375f02

Please sign in to comment.