Skip to content

Commit

Permalink
Merge pull request #90 from EdinburghGenomics/rest_communication_mult…
Browse files Browse the repository at this point in the history
…iprocess

Rest communication secure multiple process
  • Loading branch information
Timothee Cezard authored Nov 22, 2018
2 parents 8f4fefb + 7c52ac9 commit 5c675a3
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 9 deletions.
29 changes: 20 additions & 9 deletions egcg_core/rest_communication.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import mimetypes
import os
from urllib.parse import urljoin
import requests
from multiprocessing import Lock
Expand All @@ -18,7 +19,7 @@ def __init__(self, auth=None, baseurl=None, retries=5):
self._baseurl = baseurl
self._auth = auth
self.retries = retries
self._session = None
self._sessions = {}
self.lock = Lock()

def begin_session(self):
Expand All @@ -36,9 +37,11 @@ def begin_session(self):

@property
def session(self):
if self._session is None:
self._session = self.begin_session()
return self._session
"""Create and return a session per PID so each sub-processes will use their own session"""
pid = os.getpid()
if pid not in self._sessions:
self._sessions[pid] = self.begin_session()
return self._sessions[pid]

@staticmethod
def serialise(queries):
Expand Down Expand Up @@ -137,23 +140,22 @@ def _req(self, method, url, quiet=False, **kwargs):
kwargs['data'] = kwargs.pop('json')

self.lock.acquire()
r = self.session.request(method, url, **kwargs)
try:
r = self.session.request(method, url, **kwargs)
finally:
self.lock.release()

kwargs.pop('files', None)
# e.g: 'POST <url> ({"some": "args"}) -> {"some": "content"}. Status code 201. Reason: CREATED
report = '%s %s (%s) -> %s. Status code %s. Reason: %s' % (
r.request.method, r.request.path_url, kwargs, r.content.decode('utf-8'), r.status_code, r.reason
)

if r.status_code in self.successful_statuses:
if not quiet:
self.debug(report)

self.lock.release()
return r
else:
self.error(report)
self.lock.release()
raise RestCommunicationError('Encountered a %s status code: %s' % (r.status_code, r.reason))

def get_content(self, endpoint, paginate=True, quiet=False, **query_args):
Expand Down Expand Up @@ -258,6 +260,15 @@ def post_or_patch(self, endpoint, input_json, id_field=None, update_lists=None):
else:
self.post_entry(endpoint, _payload)

def close(self):
for s in self._sessions.values():
s.close()

def __del__(self):
try:
self.close()
except ReferenceError:
pass

default = Communicator()
api_url = default.api_url
Expand Down
48 changes: 48 additions & 0 deletions tests/test_rest_communication.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import multiprocessing
import os
import json

import pytest
from requests import Session
from unittest.mock import Mock, patch, call

from requests.exceptions import SSLError

from tests import FakeRestResponse, TestEGCG
from egcg_core import rest_communication
from egcg_core.util import check_if_nested
Expand Down Expand Up @@ -30,6 +36,7 @@ def fake_request(method, url, **kwargs):


patched_request = patch.object(Session, 'request', side_effect=fake_request)
patched_failed_request = patch.object(Session, 'request', side_effect=SSLError('SSL error'))
auth = ('a_user', 'a_password')


Expand Down Expand Up @@ -97,6 +104,47 @@ def test_req(self, mocked_request):
assert json.loads(response.content.decode('utf-8')) == response.json() == test_nested_request_content
mocked_request.assert_called_with('METHOD', rest_url(test_endpoint), json=json_content)

@patched_failed_request
def test_failed_req(self, mocked_request):
json_content = ['some', {'test': 'json'}]
self.comm.lock = Mock()
self.comm.lock.acquire.assert_not_called()
self.comm.lock.release.assert_not_called()

with pytest.raises(SSLError):
_ = self.comm._req('METHOD', rest_url(test_endpoint), json=json_content)

self.comm.lock.acquire.assert_called_once()
self.comm.lock.release.assert_called_once() # exception raised, but lock still released

@patched_request
def test_multi_session(self, mocked_request):
json_content = ['some', {'test': 'json'}]
with patch('os.getpid', return_value=1):
_ = self.comm._req('METHOD', rest_url(test_endpoint), json=json_content)
with patch('os.getpid', return_value=2):
_ = self.comm._req('METHOD', rest_url(test_endpoint), json=json_content)
assert len(self.comm._sessions) == 2

@patched_request
def test_with_multiprocessing(self, mocked_request):
json_content = ['some', {'test': 'json'}]

def assert_request():
_ = self.comm._req('METHOD', rest_url(test_endpoint), json=json_content)
assert mocked_request.call_count == 2
assert len(self.comm._sessions) == 2

# initiate in the Session in the main thread
self.comm._req('METHOD', rest_url(test_endpoint), json=json_content)
procs = []
for i in range(10):
procs.append(multiprocessing.Process(target=assert_request))
for p in procs:
p.start()
for p in procs:
p.join()

@patch.object(Session, '__exit__')
@patch.object(Session, '__enter__')
@patched_request
Expand Down

0 comments on commit 5c675a3

Please sign in to comment.