diff --git a/locust/argument_parser.py b/locust/argument_parser.py index 231c0c8815..001444f704 100644 --- a/locust/argument_parser.py +++ b/locust/argument_parser.py @@ -173,6 +173,13 @@ def setup_parser_arguments(parser): action='store_true', help="Disable the web interface, and instead start the load test immediately. Requires -c and -t to be specified." ) + web_ui_group.add_argument( + '--web-auth', + type=str, + dest='web_auth', + default=None, + help='Turn on Basic Auth for the web interface. Should be supplied in the following format: username:password' + ) master_group = parser.add_argument_group( "Master options", @@ -358,7 +365,6 @@ def setup_parser_arguments(parser): help="Optionally specify which Locust classes that should be used (available Locust classes can be listed with -l or --list)", ) - def get_parser(default_config_files=DEFAULT_CONFIG_FILES): # get a parser that is only able to parse the -f argument parser = get_empty_argument_parser(add_help=True, default_config_files=default_config_files) diff --git a/locust/exception.py b/locust/exception.py index a6925a941c..4004eac70f 100644 --- a/locust/exception.py +++ b/locust/exception.py @@ -43,4 +43,11 @@ class RPCError(Exception): Exception that shows bad or broken network. When raised from zmqrpc, RPC should be reestablished. - """ \ No newline at end of file + """ + +class AuthCredentialsError(ValueError): + """ + Exception when the auth credentials provided + are not in the correct format + """ + pass \ No newline at end of file diff --git a/locust/main.py b/locust/main.py index 5a53b2aa78..d6cd2c7fa0 100644 --- a/locust/main.py +++ b/locust/main.py @@ -22,6 +22,7 @@ stats_printer, stats_writer, write_csv_files) from .util.timespan import parse_timespan from .web import WebUI +from .exception import AuthCredentialsError _internals = [Locust, HttpLocust] version = locust.__version__ @@ -227,8 +228,13 @@ def timelimit_stop(): if not options.headless and not options.worker: # spawn web greenlet logger.info("Starting web monitor at http://%s:%s" % (options.web_host or "*", options.web_port)) - web_ui = WebUI(environment=environment) - main_greenlet = gevent.spawn(web_ui.start, host=options.web_host, port=options.web_port) + try: + web_ui = WebUI(environment=environment, auth_credentials=options.web_auth) + except AuthCredentialsError: + logger.error("Credentials supplied with --web-auth should have the format: username:password") + sys.exit(1) + else: + main_greenlet = gevent.spawn(web_ui.start, host=options.web_host, port=options.web_port) else: web_ui = None diff --git a/locust/test/test_parser.py b/locust/test/test_parser.py index 639044aca4..3e1ffacfdf 100644 --- a/locust/test/test_parser.py +++ b/locust/test/test_parser.py @@ -45,6 +45,14 @@ def test_parameter_parsing(self): self.assertEqual(options.locustfile, 'locustfile_from_env') self.assertEqual(options.host, 'host_from_args') # overridden + def test_web_auth(self): + args = [ + "--web-auth", "hello:bye" + ] + opts = self.parser.parse_args(args) + self.assertEqual(opts.web_auth, "hello:bye") + + class TestArgumentParser(LocustTestCase): def test_parse_options(self): @@ -140,4 +148,4 @@ def _(parser, **kw): out.seek(0) stdout = out.read() self.assertIn("Custom boolean flag", stdout) - self.assertIn("Custom string arg", stdout) + self.assertIn("Custom string arg", stdout) \ No newline at end of file diff --git a/locust/test/test_web.py b/locust/test/test_web.py index 2bdfb9d095..b64f834c7a 100644 --- a/locust/test/test_web.py +++ b/locust/test/test_web.py @@ -7,6 +7,7 @@ import gevent import requests +from flask_basicauth import BasicAuth from locust import constant from locust.argument_parser import get_parser @@ -21,15 +22,15 @@ class TestWebUI(LocustTestCase): def setUp(self): super(TestWebUI, self).setUp() - + parser = get_parser(default_config_files=[]) self.environment.options = parser.parse_args([]) self.runner = LocustRunner(self.environment, []) self.stats = self.runner.stats - + self.web_ui = WebUI(self.environment) self.web_ui.app.view_functions["request_stats"].clear_cache() - + self.web_ui.app.config["BASIC_AUTH_ENABLED"] = False gevent.spawn(lambda: self.web_ui.start("127.0.0.1", 0)) gevent.sleep(0.01) self.web_port = self.web_ui.server.server_port @@ -252,3 +253,33 @@ def my_task(self): ) self.assertEqual(200, response.status_code) self.assertIn("Step Load Mode", response.text) + + +class TestWebUIAuth(LocustTestCase): + def setUp(self): + super(TestWebUIAuth, self).setUp() + + parser = get_parser(default_config_files=[]) + self.environment.options = parser.parse_args(["--web-auth", "john:doe"]) + self.runner = LocustRunner(self.environment, []) + self.stats = self.runner.stats + self.web_ui = WebUI(self.environment, self.environment.options.web_auth) + self.web_ui.app.view_functions["request_stats"].clear_cache() + gevent.spawn(lambda: self.web_ui.start("127.0.0.1", 0)) + gevent.sleep(0.01) + self.web_port = self.web_ui.server.server_port + + def tearDown(self): + super(TestWebUIAuth, self).tearDown() + self.web_ui.stop() + self.runner.quit() + + def test_index_with_basic_auth_enabled_correct_credentials(self): + self.assertEqual(200, requests.get("http://127.0.0.1:%i/?ele=phino" % self.web_port, auth=('john', 'doe')).status_code) + + def test_index_with_basic_auth_enabled_incorrect_credentials(self): + self.assertEqual(401, requests.get("http://127.0.0.1:%i/?ele=phino" % self.web_port, + auth=('john', 'invalid')).status_code) + + def test_index_with_basic_auth_enabled_blank_credentials(self): + self.assertEqual(401, requests.get("http://127.0.0.1:%i/?ele=phino" % self.web_port).status_code) \ No newline at end of file diff --git a/locust/web.py b/locust/web.py index 3c468a8a36..190675138c 100644 --- a/locust/web.py +++ b/locust/web.py @@ -5,8 +5,11 @@ import logging import os.path from collections import defaultdict +from functools import wraps from itertools import chain from time import time +from flask_basicauth import BasicAuth +from .exception import AuthCredentialsError try: # >= Py3.2 @@ -36,17 +39,35 @@ class WebUI: server = None - """Refernce to pyqsgi.WSGIServer once it's started""" + """Reference to pyqsgi.WSGIServer once it's started""" - def __init__(self, environment): + def __init__(self, environment, auth_credentials=None): + """ + If auth_credentials is provided, it will enable basic auth with all the routes protected by default. + Should be supplied in the format: "user:pass". + """ environment.web_ui = self self.environment = environment app = Flask(__name__) self.app = app app.debug = True app.root_path = os.path.dirname(os.path.abspath(__file__)) - + self.app.config["BASIC_AUTH_ENABLED"] = False + self.auth = None + + if auth_credentials is not None: + credentials = auth_credentials.split(':') + if len(credentials) == 2: + self.app.config["BASIC_AUTH_USERNAME"] = credentials[0] + self.app.config["BASIC_AUTH_PASSWORD"] = credentials[1] + self.app.config["BASIC_AUTH_ENABLED"] = True + self.auth = BasicAuth() + self.auth.init_app(self.app) + else: + raise AuthCredentialsError("Invalid auth_credentials. It should be a string in the following format: 'user.pass'") + @app.route('/') + @self.auth_required_if_enabled def index(): if not environment.runner: return make_response("Error: Locust Environment does not have any runner", 500) @@ -84,6 +105,7 @@ def index(): ) @app.route('/swarm', methods=["POST"]) + @self.auth_required_if_enabled def swarm(): assert request.method == "POST" locust_count = int(request.form["locust_count"]) @@ -101,17 +123,20 @@ def swarm(): return jsonify({'success': True, 'message': 'Swarming started', 'host': environment.host}) @app.route('/stop') + @self.auth_required_if_enabled def stop(): environment.runner.stop() return jsonify({'success':True, 'message': 'Test stopped'}) @app.route("/stats/reset") + @self.auth_required_if_enabled def reset_stats(): environment.runner.stats.reset_all() environment.runner.exceptions = {} return "ok" @app.route("/stats/requests/csv") + @self.auth_required_if_enabled def request_stats_csv(): response = make_response(requests_csv(self.environment.runner.stats)) file_name = "requests_{0}.csv".format(time()) @@ -121,6 +146,7 @@ def request_stats_csv(): return response @app.route("/stats/failures/csv") + @self.auth_required_if_enabled def failures_stats_csv(): response = make_response(failures_csv(self.environment.runner.stats)) file_name = "failures_{0}.csv".format(time()) @@ -130,6 +156,7 @@ def failures_stats_csv(): return response @app.route('/stats/requests') + @self.auth_required_if_enabled @memoize(timeout=DEFAULT_CACHE_TIME, dynamic_timeout=True) def request_stats(): stats = [] @@ -184,6 +211,7 @@ def request_stats(): return jsonify(report) @app.route("/exceptions") + @self.auth_required_if_enabled def exceptions(): return jsonify({ 'exceptions': [ @@ -197,6 +225,7 @@ def exceptions(): }) @app.route("/exceptions/csv") + @self.auth_required_if_enabled def exceptions_csv(): data = StringIO() writer = csv.writer(data) @@ -219,3 +248,16 @@ def start(self, host, port): def stop(self): self.server.stop() + + def auth_required_if_enabled(self, view_func): + @wraps(view_func) + def wrapper(*args, **kwargs): + if self.app.config["BASIC_AUTH_ENABLED"]: + if self.auth.authenticate(): + return view_func(*args, **kwargs) + else: + return self.auth.challenge() + else: + return view_func(*args, **kwargs) + + return wrapper diff --git a/setup.py b/setup.py index 679e28d387..388a627356 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ "geventhttpclient-wheels==1.3.1.dev3", "ConfigArgParse>=1.0", "psutil>=5.6.7", + "Flask-BasicAuth==0.2.0" ], test_suite="locust.test", tests_require=['mock'],