Skip to content

Commit

Permalink
Merge pull request #1313 from anuj-ssharma/master
Browse files Browse the repository at this point in the history
Add basic auth for webui
  • Loading branch information
heyman authored Apr 8, 2020
2 parents aa16edc + d744e75 commit eecfb12
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 11 deletions.
8 changes: 7 additions & 1 deletion locust/argument_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ def setup_parser_arguments(parser):
action='store_true',
help="Disable the web interface, and instead start the load test immediately. Requires -c and -t to be specified."
)
web_ui_group.add_argument(
'--web-auth',
type=str,
dest='web_auth',
default=None,
help='Turn on Basic Auth for the web interface. Should be supplied in the following format: username:password'
)

master_group = parser.add_argument_group(
"Master options",
Expand Down Expand Up @@ -358,7 +365,6 @@ def setup_parser_arguments(parser):
help="Optionally specify which Locust classes that should be used (available Locust classes can be listed with -l or --list)",
)


def get_parser(default_config_files=DEFAULT_CONFIG_FILES):
# get a parser that is only able to parse the -f argument
parser = get_empty_argument_parser(add_help=True, default_config_files=default_config_files)
Expand Down
9 changes: 8 additions & 1 deletion locust/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,11 @@ class RPCError(Exception):
Exception that shows bad or broken network.
When raised from zmqrpc, RPC should be reestablished.
"""
"""

class AuthCredentialsError(ValueError):
"""
Exception when the auth credentials provided
are not in the correct format
"""
pass
10 changes: 8 additions & 2 deletions locust/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
stats_printer, stats_writer, write_csv_files)
from .util.timespan import parse_timespan
from .web import WebUI
from .exception import AuthCredentialsError

_internals = [Locust, HttpLocust]
version = locust.__version__
Expand Down Expand Up @@ -227,8 +228,13 @@ def timelimit_stop():
if not options.headless and not options.worker:
# spawn web greenlet
logger.info("Starting web monitor at http://%s:%s" % (options.web_host or "*", options.web_port))
web_ui = WebUI(environment=environment)
main_greenlet = gevent.spawn(web_ui.start, host=options.web_host, port=options.web_port)
try:
web_ui = WebUI(environment=environment, auth_credentials=options.web_auth)
except AuthCredentialsError:
logger.error("Credentials supplied with --web-auth should have the format: username:password")
sys.exit(1)
else:
main_greenlet = gevent.spawn(web_ui.start, host=options.web_host, port=options.web_port)
else:
web_ui = None

Expand Down
10 changes: 9 additions & 1 deletion locust/test/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def test_parameter_parsing(self):
self.assertEqual(options.locustfile, 'locustfile_from_env')
self.assertEqual(options.host, 'host_from_args') # overridden

def test_web_auth(self):
args = [
"--web-auth", "hello:bye"
]
opts = self.parser.parse_args(args)
self.assertEqual(opts.web_auth, "hello:bye")



class TestArgumentParser(LocustTestCase):
def test_parse_options(self):
Expand Down Expand Up @@ -140,4 +148,4 @@ def _(parser, **kw):
out.seek(0)
stdout = out.read()
self.assertIn("Custom boolean flag", stdout)
self.assertIn("Custom string arg", stdout)
self.assertIn("Custom string arg", stdout)
37 changes: 34 additions & 3 deletions locust/test/test_web.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import gevent
import requests
from flask_basicauth import BasicAuth

from locust import constant
from locust.argument_parser import get_parser
Expand All @@ -21,15 +22,15 @@
class TestWebUI(LocustTestCase):
def setUp(self):
super(TestWebUI, self).setUp()

parser = get_parser(default_config_files=[])
self.environment.options = parser.parse_args([])
self.runner = LocustRunner(self.environment, [])
self.stats = self.runner.stats

self.web_ui = WebUI(self.environment)
self.web_ui.app.view_functions["request_stats"].clear_cache()

self.web_ui.app.config["BASIC_AUTH_ENABLED"] = False
gevent.spawn(lambda: self.web_ui.start("127.0.0.1", 0))
gevent.sleep(0.01)
self.web_port = self.web_ui.server.server_port
Expand Down Expand Up @@ -252,3 +253,33 @@ def my_task(self):
)
self.assertEqual(200, response.status_code)
self.assertIn("Step Load Mode", response.text)


class TestWebUIAuth(LocustTestCase):
def setUp(self):
super(TestWebUIAuth, self).setUp()

parser = get_parser(default_config_files=[])
self.environment.options = parser.parse_args(["--web-auth", "john:doe"])
self.runner = LocustRunner(self.environment, [])
self.stats = self.runner.stats
self.web_ui = WebUI(self.environment, self.environment.options.web_auth)
self.web_ui.app.view_functions["request_stats"].clear_cache()
gevent.spawn(lambda: self.web_ui.start("127.0.0.1", 0))
gevent.sleep(0.01)
self.web_port = self.web_ui.server.server_port

def tearDown(self):
super(TestWebUIAuth, self).tearDown()
self.web_ui.stop()
self.runner.quit()

def test_index_with_basic_auth_enabled_correct_credentials(self):
self.assertEqual(200, requests.get("http://127.0.0.1:%i/?ele=phino" % self.web_port, auth=('john', 'doe')).status_code)

def test_index_with_basic_auth_enabled_incorrect_credentials(self):
self.assertEqual(401, requests.get("http://127.0.0.1:%i/?ele=phino" % self.web_port,
auth=('john', 'invalid')).status_code)

def test_index_with_basic_auth_enabled_blank_credentials(self):
self.assertEqual(401, requests.get("http://127.0.0.1:%i/?ele=phino" % self.web_port).status_code)
48 changes: 45 additions & 3 deletions locust/web.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import logging
import os.path
from collections import defaultdict
from functools import wraps
from itertools import chain
from time import time
from flask_basicauth import BasicAuth
from .exception import AuthCredentialsError

try:
# >= Py3.2
Expand Down Expand Up @@ -36,17 +39,35 @@

class WebUI:
server = None
"""Refernce to pyqsgi.WSGIServer once it's started"""
"""Reference to pyqsgi.WSGIServer once it's started"""

def __init__(self, environment):
def __init__(self, environment, auth_credentials=None):
"""
If auth_credentials is provided, it will enable basic auth with all the routes protected by default.
Should be supplied in the format: "user:pass".
"""
environment.web_ui = self
self.environment = environment
app = Flask(__name__)
self.app = app
app.debug = True
app.root_path = os.path.dirname(os.path.abspath(__file__))

self.app.config["BASIC_AUTH_ENABLED"] = False
self.auth = None

if auth_credentials is not None:
credentials = auth_credentials.split(':')
if len(credentials) == 2:
self.app.config["BASIC_AUTH_USERNAME"] = credentials[0]
self.app.config["BASIC_AUTH_PASSWORD"] = credentials[1]
self.app.config["BASIC_AUTH_ENABLED"] = True
self.auth = BasicAuth()
self.auth.init_app(self.app)
else:
raise AuthCredentialsError("Invalid auth_credentials. It should be a string in the following format: 'user.pass'")

@app.route('/')
@self.auth_required_if_enabled
def index():
if not environment.runner:
return make_response("Error: Locust Environment does not have any runner", 500)
Expand Down Expand Up @@ -84,6 +105,7 @@ def index():
)

@app.route('/swarm', methods=["POST"])
@self.auth_required_if_enabled
def swarm():
assert request.method == "POST"
locust_count = int(request.form["locust_count"])
Expand All @@ -101,17 +123,20 @@ def swarm():
return jsonify({'success': True, 'message': 'Swarming started', 'host': environment.host})

@app.route('/stop')
@self.auth_required_if_enabled
def stop():
environment.runner.stop()
return jsonify({'success':True, 'message': 'Test stopped'})

@app.route("/stats/reset")
@self.auth_required_if_enabled
def reset_stats():
environment.runner.stats.reset_all()
environment.runner.exceptions = {}
return "ok"

@app.route("/stats/requests/csv")
@self.auth_required_if_enabled
def request_stats_csv():
response = make_response(requests_csv(self.environment.runner.stats))
file_name = "requests_{0}.csv".format(time())
Expand All @@ -121,6 +146,7 @@ def request_stats_csv():
return response

@app.route("/stats/failures/csv")
@self.auth_required_if_enabled
def failures_stats_csv():
response = make_response(failures_csv(self.environment.runner.stats))
file_name = "failures_{0}.csv".format(time())
Expand All @@ -130,6 +156,7 @@ def failures_stats_csv():
return response

@app.route('/stats/requests')
@self.auth_required_if_enabled
@memoize(timeout=DEFAULT_CACHE_TIME, dynamic_timeout=True)
def request_stats():
stats = []
Expand Down Expand Up @@ -184,6 +211,7 @@ def request_stats():
return jsonify(report)

@app.route("/exceptions")
@self.auth_required_if_enabled
def exceptions():
return jsonify({
'exceptions': [
Expand All @@ -197,6 +225,7 @@ def exceptions():
})

@app.route("/exceptions/csv")
@self.auth_required_if_enabled
def exceptions_csv():
data = StringIO()
writer = csv.writer(data)
Expand All @@ -219,3 +248,16 @@ def start(self, host, port):

def stop(self):
self.server.stop()

def auth_required_if_enabled(self, view_func):
@wraps(view_func)
def wrapper(*args, **kwargs):
if self.app.config["BASIC_AUTH_ENABLED"]:
if self.auth.authenticate():
return view_func(*args, **kwargs)
else:
return self.auth.challenge()
else:
return view_func(*args, **kwargs)

return wrapper
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
"geventhttpclient-wheels==1.3.1.dev3",
"ConfigArgParse>=1.0",
"psutil>=5.6.7",
"Flask-BasicAuth==0.2.0"
],
test_suite="locust.test",
tests_require=['mock'],
Expand Down

0 comments on commit eecfb12

Please sign in to comment.