Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

switch to using a thread for the live server #112

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 15 additions & 31 deletions flask_testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import gc
import multiprocessing
import socket
import threading
import time

try:
Expand All @@ -32,7 +33,7 @@
# Python 2 urlparse fallback
from urlparse import urlparse, urljoin

from werkzeug import cached_property
from werkzeug import cached_property, serving

# Use Flask's preferred JSON module so that our runtime behavior matches.
from flask import json_available, templating, template_rendered
Expand Down Expand Up @@ -436,7 +437,6 @@ def __call__(self, result=None):
self.app = self.create_app()

self._configured_port = self.app.config.get('LIVESERVER_PORT', 5000)
self._port_value = multiprocessing.Value('i', self._configured_port)

# We need to create a context in order for extensions to catch up
self._ctx = self.app.test_request_context()
Expand All @@ -453,37 +453,16 @@ def get_server_url(self):
"""
Return the url of the test server
"""
return 'http://localhost:%s' % self._port_value.value
return 'http://localhost:%s' % self._port

def _spawn_live_server(self):
self._process = None
port_value = self._port_value

def worker(app, port):
# Based on solution: http://stackoverflow.com/a/27598916
# Monkey-patch the server_bind so we can determine the port bound by Flask.
# This handles the case where the port specified is `0`, which means that
# the OS chooses the port. This is the only known way (currently) of getting
# the port out of Flask once we call `run`.
original_socket_bind = socketserver.TCPServer.server_bind
def socket_bind_wrapper(self):
ret = original_socket_bind(self)

# Get the port and save it into the port_value, so the parent process
# can read it.
(_, port) = self.socket.getsockname()
port_value.value = port
socketserver.TCPServer.server_bind = original_socket_bind
return ret

socketserver.TCPServer.server_bind = socket_bind_wrapper
app.run(port=port, use_reloader=False)

self._process = multiprocessing.Process(
target=worker, args=(self.app, self._configured_port)
self._server = serving.make_server(
'localhost', self._configured_port, self.app,
)
(_, self._port) = self._server.socket.getsockname()

self._process.start()
self._thread = threading.Thread(target=self._server.serve_forever, args=())
self._thread.start()

# We must wait for the server to start listening, but give up
# after a specified maximum timeout
Expand Down Expand Up @@ -548,5 +527,10 @@ def _post_teardown(self):
del self._ctx

def _terminate_live_server(self):
if self._process:
self._process.terminate()
if self._server:
self._server.shutdown()
self._server = None

if self._thread:
self._thread.join()
self._thread = None
12 changes: 6 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,14 @@ def test_assert_no_flashed_messages_fail(self):

class BaseTestLiveServer(LiveServerTestCase):

def test_server_process_is_spawned(self):
process = self._process
def test_server_thread_is_spawned(self):
thread = self._thread

# Check the process is spawned
self.assertNotEqual(process, None)
# Check the thread is spawned
self.assertNotEqual(thread, None)

# Check the process is alive
self.assertTrue(process.is_alive())
# Check the thread is alive
self.assertTrue(thread.is_alive())

def test_server_listening(self):
response = urlopen(self.get_server_url())
Expand Down