From cd2300484037155cedb2b2dabfa8edb0fd77baed Mon Sep 17 00:00:00 2001 From: Raymond Penners Date: Tue, 20 Aug 2024 20:50:14 +0200 Subject: [PATCH] refactor(account): Login by code is now a stage --- allauth/account/adapter.py | 1 + .../account/internal/flows/login_by_code.py | 44 ++++++++++++------- allauth/account/models.py | 19 ++++---- allauth/account/stages.py | 14 ++++++ allauth/account/tests/test_login_by_code.py | 12 ++--- allauth/account/views.py | 15 +++++-- .../account/tests/test_login_by_code.py | 31 +++++++++++++ allauth/headless/account/views.py | 15 +++++-- allauth/headless/base/response.py | 15 ++++--- allauth/headless/constants.py | 4 +- allauth/headless/internal/restkit/views.py | 5 ++- .../openapi-specification/openapi.yaml | 9 +++- .../frontend/src/account/ConfirmLoginCode.js | 2 +- 13 files changed, 136 insertions(+), 50 deletions(-) diff --git a/allauth/account/adapter.py b/allauth/account/adapter.py index 67214fb639..0277e1bddc 100644 --- a/allauth/account/adapter.py +++ b/allauth/account/adapter.py @@ -710,6 +710,7 @@ def generate_emailconfirmation_key(self, email): def get_login_stages(self): ret = [] + ret.append("allauth.account.stages.LoginByCodeStage") ret.append("allauth.account.stages.EmailVerificationStage") if allauth_app_settings.MFA_ENABLED: ret.append("allauth.mfa.stages.AuthenticateStage") diff --git a/allauth/account/internal/flows/login_by_code.py b/allauth/account/internal/flows/login_by_code.py index e71344ad12..a8d7e09797 100644 --- a/allauth/account/internal/flows/login_by_code.py +++ b/allauth/account/internal/flows/login_by_code.py @@ -14,11 +14,11 @@ from allauth.account.models import Login -LOGIN_CODE_SESSION_KEY = "account_login_code" +LOGIN_CODE_STATE_KEY = "login_code" def request_login_code(request: HttpRequest, email: str) -> None: - from allauth.account.utils import filter_users_by_email + from allauth.account.utils import filter_users_by_email, stash_login adapter = get_adapter() users = filter_users_by_email(email, is_active=True, prefer_verified=True) @@ -28,6 +28,7 @@ def request_login_code(request: HttpRequest, email: str) -> None: "failed_attempts": 0, } if not users: + user = None send_unknown_account_mail(request, email) else: user = users[0] @@ -40,27 +41,29 @@ def request_login_code(request: HttpRequest, email: str) -> None: pending_login.update( {"code": code, "user_id": user._meta.pk.value_to_string(user)} ) - - request.session[LOGIN_CODE_SESSION_KEY] = pending_login + login = Login(user=user, email=email) + login.state[LOGIN_CODE_STATE_KEY] = pending_login + login.state["stages"] = {"current": "login_by_code"} adapter.add_message( request, messages.SUCCESS, "account/messages/login_code_sent.txt", {"email": email}, ) + stash_login(request, login) def get_pending_login( - request: HttpRequest, peek: bool = False + login: Login, peek: bool = False ) -> Tuple[Optional[AbstractBaseUser], Optional[Dict[str, Any]]]: if peek: - data = request.session.get(LOGIN_CODE_SESSION_KEY) + data = login.state.get(LOGIN_CODE_STATE_KEY) else: - data = request.session.pop(LOGIN_CODE_SESSION_KEY, None) + data = login.state.pop(LOGIN_CODE_STATE_KEY, None) if not data: return None, None if time.time() - data["at"] >= app_settings.LOGIN_BY_CODE_TIMEOUT: - request.session.pop(LOGIN_CODE_SESSION_KEY, None) + login.state.pop(LOGIN_CODE_STATE_KEY, None) return None, None user_id_str = data.get("user_id") user = None @@ -70,30 +73,37 @@ def get_pending_login( return user, data -def record_invalid_attempt(request: HttpRequest, pending_login: Dict[str, Any]) -> bool: +def record_invalid_attempt(request, login: Login) -> bool: + from allauth.account.utils import stash_login, unstash_login + + pending_login = login.state[LOGIN_CODE_STATE_KEY] n = pending_login["failed_attempts"] n += 1 pending_login["failed_attempts"] = n if n >= app_settings.LOGIN_BY_CODE_MAX_ATTEMPTS: - request.session.pop(LOGIN_CODE_SESSION_KEY, None) + unstash_login(request) return False else: - request.session[LOGIN_CODE_SESSION_KEY] = pending_login + login.state[LOGIN_CODE_STATE_KEY] = pending_login + stash_login(request, login) return True def perform_login_by_code( request: HttpRequest, - user: AbstractBaseUser, + stage, redirect_url: Optional[str], - pending_login: Dict[str, Any], ): - request.session.pop(LOGIN_CODE_SESSION_KEY, None) - record_authentication(request, method="code", email=pending_login["email"]) + state = stage.login.state.pop(LOGIN_CODE_STATE_KEY) + email = state["email"] + record_authentication(request, method="code", email=email) + # Just requesting a login code does is not considered to be a real login, + # yet, is needed in order to make the stage machinery work. Now that we've + # completed the code, let's start a real login. login = Login( - user=user, + user=stage.login.user, redirect_url=redirect_url, - email=pending_login["email"], + email=email, ) return perform_login(request, login) diff --git a/allauth/account/models.py b/allauth/account/models.py index 7951a90ef5..2d93d95f00 100644 --- a/allauth/account/models.py +++ b/allauth/account/models.py @@ -4,6 +4,7 @@ from django.conf import settings from django.contrib.auth import get_user_model +from django.contrib.auth.models import AbstractBaseUser from django.core import signing from django.db import models from django.db.models import Q @@ -230,6 +231,9 @@ class Login: case email verification is optional and we are only logging in). """ + # Optional, because we might be prentending logins to prevent user + # enumeration. + user: Optional[AbstractBaseUser] email_verification: app_settings.EmailVerificationMethod signal_kwargs: Optional[Dict] signup: bool @@ -271,7 +275,7 @@ def serialize(self): signal_kwargs["sociallogin"] = sociallogin.serialize() data = { - "user_pk": user_pk_to_url_str(self.user), + "user_pk": user_pk_to_url_str(self.user) if self.user else None, "email_verification": self.email_verification, "signup": self.signup, "redirect_url": self.redirect_url, @@ -286,13 +290,12 @@ def serialize(self): def deserialize(cls, data): from allauth.account.utils import url_str_to_user_pk - user = ( - get_user_model() - .objects.filter(pk=url_str_to_user_pk(data["user_pk"])) - .first() - ) - if user is None: - raise ValueError() + user = None + user_pk = data["user_pk"] + if user_pk is not None: + user = ( + get_user_model().objects.filter(pk=url_str_to_user_pk(user_pk)).first() + ) try: # :-( Knowledge of the `socialaccount` is entering the `account` app. signal_kwargs = data["signal_kwargs"] diff --git a/allauth/account/stages.py b/allauth/account/stages.py index d5ec518bd3..e5c7a1300c 100644 --- a/allauth/account/stages.py +++ b/allauth/account/stages.py @@ -132,3 +132,17 @@ def handle(self): self.request, login.user ) return response, cont + + +class LoginByCodeStage(LoginStage): + key = "login_by_code" + + def handle(self): + from allauth.account.internal.flows import login_by_code + + user, data = login_by_code.get_pending_login(self.login, peek=True) + if data is None: + # No pending login, just continue. + return None, True + response = HttpResponseRedirect(reverse("account_confirm_login_code")) + return response, True diff --git a/allauth/account/tests/test_login_by_code.py b/allauth/account/tests/test_login_by_code.py index b2af89e4b0..d4be1c1db7 100644 --- a/allauth/account/tests/test_login_by_code.py +++ b/allauth/account/tests/test_login_by_code.py @@ -1,12 +1,12 @@ from unittest.mock import ANY -from django.contrib.auth import SESSION_KEY from django.urls import reverse import pytest from allauth.account.authentication import AUTHENTICATION_METHODS_SESSION_KEY -from allauth.account.internal.flows.login_by_code import LOGIN_CODE_SESSION_KEY +from allauth.account.internal.flows.login import LOGIN_SESSION_KEY +from allauth.account.internal.flows.login_by_code import LOGIN_CODE_STATE_KEY @pytest.fixture @@ -23,7 +23,7 @@ def f(client, email): resp["location"] == reverse("account_confirm_login_code") + "?next=%2Ffoo" ) assert len(mailoutbox) == 1 - code = client.session[LOGIN_CODE_SESSION_KEY]["code"] + code = client.session[LOGIN_SESSION_KEY]["state"][LOGIN_CODE_STATE_KEY]["code"] assert len(code) == 6 assert code in mailoutbox[0].body return code @@ -39,7 +39,7 @@ def test_login_by_code(client, user, request_login_by_code): data={"code": code_with_ws, "next": "/foo"}, ) assert resp.status_code == 302 - assert client.session[SESSION_KEY] == str(user.pk) + assert LOGIN_SESSION_KEY not in client.session assert resp["location"] == "/foo" assert client.session[AUTHENTICATION_METHODS_SESSION_KEY][-1] == { "method": "code", @@ -58,10 +58,10 @@ def test_login_by_code_max_attempts(client, user, request_login_by_code, setting if i >= 1: assert resp.status_code == 302 assert resp["location"] == reverse("account_request_login_code") - assert LOGIN_CODE_SESSION_KEY not in client.session + assert LOGIN_SESSION_KEY not in client.session else: assert resp.status_code == 200 - assert LOGIN_CODE_SESSION_KEY in client.session + assert LOGIN_SESSION_KEY in client.session assert resp.context["form"].errors == {"code": ["Incorrect code."]} diff --git a/allauth/account/views.py b/allauth/account/views.py index 8f44ff4f1b..cdbb115c5c 100644 --- a/allauth/account/views.py +++ b/allauth/account/views.py @@ -43,7 +43,11 @@ EmailConfirmation, get_emailconfirmation_model, ) -from allauth.account.stages import EmailVerificationStage, LoginStageController +from allauth.account.stages import ( + EmailVerificationStage, + LoginByCodeStage, + LoginStageController, +) from allauth.account.utils import ( complete_signup, perform_login, @@ -957,8 +961,11 @@ class ConfirmLoginCodeView(RedirectAuthenticatedUserMixin, NextRedirectMixin, Fo @method_decorator(never_cache) def dispatch(self, request, *args, **kwargs): + self.stage = LoginStageController.enter(request, LoginByCodeStage.key) + if not self.stage: + return HttpResponseRedirect(reverse("account_request_login_code")) self.user, self.pending_login = flows.login_by_code.get_pending_login( - request, peek=True + self.stage.login, peek=True ) if not self.pending_login: return HttpResponseRedirect(reverse("account_request_login_code")) @@ -975,12 +982,12 @@ def get_form_kwargs(self): def form_valid(self, form): redirect_url = self.get_next_url() return flows.login_by_code.perform_login_by_code( - self.request, self.user, redirect_url, self.pending_login + self.request, self.stage, redirect_url ) def form_invalid(self, form): attempts_left = flows.login_by_code.record_invalid_attempt( - self.request, self.pending_login + self.request, self.stage.login ) if attempts_left: return super().form_invalid(form) diff --git a/allauth/headless/account/tests/test_login_by_code.py b/allauth/headless/account/tests/test_login_by_code.py index 5a64eeeaed..d2f8e6d7e7 100644 --- a/allauth/headless/account/tests/test_login_by_code.py +++ b/allauth/headless/account/tests/test_login_by_code.py @@ -46,3 +46,34 @@ def test_login_by_code_rate_limit( "param": "email", }, ] + + +def test_login_by_code_max_attemps(headless_reverse, user, client, settings): + settings.ACCOUNT_LOGIN_BY_CODE_MAX_ATTEMPTS = 2 + resp = client.post( + headless_reverse("headless:account:request_login_code"), + data={"email": user.email}, + content_type="application/json", + ) + assert resp.status_code == 401 + for i in range(3): + resp = client.post( + headless_reverse("headless:account:confirm_login_code"), + data={"code": "wrong"}, + content_type="application/json", + ) + session_resp = client.get( + headless_reverse("headless:account:current_session"), + data={"code": "wrong"}, + content_type="application/json", + ) + assert session_resp.status_code == 401 + pending_flows = [ + f for f in session_resp.json()["data"]["flows"] if f.get("is_pending") + ] + if i >= 1: + assert resp.status_code == 409 if i >= 2 else 400 + assert len(pending_flows) == 0 + else: + assert resp.status_code == 400 + assert len(pending_flows) == 1 diff --git a/allauth/headless/account/views.py b/allauth/headless/account/views.py index 9ead680981..f37d19f5ae 100644 --- a/allauth/headless/account/views.py +++ b/allauth/headless/account/views.py @@ -33,6 +33,7 @@ ForbiddenResponse, ) from allauth.headless.base.views import APIView, AuthenticatedAPIView +from allauth.headless.internal import authkit from allauth.headless.internal.restkit.response import ErrorResponse @@ -50,15 +51,17 @@ class ConfirmLoginCodeView(APIView): input_class = ConfirmLoginCodeInput def dispatch(self, request, *args, **kwargs): + auth_status = authkit.AuthenticationStatus(request) + self.stage = auth_status.get_pending_stage() + if not self.stage: + return ConflictResponse(request) self.user, self.pending_login = flows.login_by_code.get_pending_login( - request, peek=True + self.stage.login, peek=True ) return super().dispatch(request, *args, **kwargs) def post(self, request, *args, **kwargs): - flows.login_by_code.perform_login_by_code( - self.request, self.user, None, self.pending_login - ) + flows.login_by_code.perform_login_by_code(self.request, self.stage, None) return AuthenticationResponse(request) def get_input_kwargs(self): @@ -68,6 +71,10 @@ def get_input_kwargs(self): ) return kwargs + def handle_invalid_input(self, input): + flows.login_by_code.record_invalid_attempt(self.request, self.stage.login) + return super().handle_invalid_input(input) + @method_decorator(rate_limit(action="login"), name="handle") class LoginView(APIView): diff --git a/allauth/headless/base/response.py b/allauth/headless/base/response.py index 27cef3e671..88baf0650b 100644 --- a/allauth/headless/base/response.py +++ b/allauth/headless/base/response.py @@ -41,11 +41,7 @@ def _get_flows(self, request, user): if not allauth_settings.SOCIALACCOUNT_ONLY: ret.append({"id": Flow.LOGIN}) if account_settings.LOGIN_BY_CODE_ENABLED: - code_flow = {"id": Flow.LOGIN_BY_CODE} - _, data = flows.login_by_code.get_pending_login(request, peek=True) - if data: - code_flow["is_pending"] = True - ret.append(code_flow) + ret.append({"id": Flow.LOGIN_BY_CODE}) if ( get_account_adapter().is_open_for_signup(request) and not allauth_settings.SOCIALACCOUNT_ONLY @@ -72,9 +68,16 @@ def _get_flows(self, request, user): pending_flow = {"id": stage_key, "is_pending": True} if stage and stage_key == Flow.MFA_AUTHENTICATE: self._enrich_mfa_flow(stage, pending_flow) - ret.append(pending_flow) + self._upsert_pending_flow(ret, pending_flow) return ret + def _upsert_pending_flow(self, flows, pending_flow): + flow = next((flow for flow in flows if flow["id"] == pending_flow["id"]), None) + if flow: + flow.update(pending_flow) + else: + flows.append(pending_flow) + def _enrich_mfa_flow(self, stage, flow: dict) -> None: from allauth.mfa.adapter import get_adapter as get_mfa_adapter from allauth.mfa.models import Authenticator diff --git a/allauth/headless/constants.py b/allauth/headless/constants.py index dca9f28cc1..aa04dc1b0e 100644 --- a/allauth/headless/constants.py +++ b/allauth/headless/constants.py @@ -1,6 +1,6 @@ from enum import Enum -from allauth.account.stages import EmailVerificationStage +from allauth.account.stages import EmailVerificationStage, LoginByCodeStage class Client(str, Enum): @@ -11,7 +11,7 @@ class Client(str, Enum): class Flow(str, Enum): VERIFY_EMAIL = EmailVerificationStage.key LOGIN = "login" - LOGIN_BY_CODE = "login_by_code" + LOGIN_BY_CODE = LoginByCodeStage.key SIGNUP = "signup" PROVIDER_REDIRECT = "provider_redirect" PROVIDER_SIGNUP = "provider_signup" diff --git a/allauth/headless/internal/restkit/views.py b/allauth/headless/internal/restkit/views.py index 6a8e6ab7f8..177f09c629 100644 --- a/allauth/headless/internal/restkit/views.py +++ b/allauth/headless/internal/restkit/views.py @@ -39,7 +39,10 @@ def handle_input(self, data): data = {} self.input = input_class(data=data, **input_kwargs) if not self.input.is_valid(): - return ErrorResponse(self.request, input=self.input) + return self.handle_invalid_input(self.input) + + def handle_invalid_input(self, input): + return ErrorResponse(self.request, input=input) def _parse_json(self, request): if request.method == "GET" or not request.body: diff --git a/docs/headless/openapi-specification/openapi.yaml b/docs/headless/openapi-specification/openapi.yaml index 33a4a516f0..4cfdf28003 100644 --- a/docs/headless/openapi-specification/openapi.yaml +++ b/docs/headless/openapi-specification/openapi.yaml @@ -677,7 +677,7 @@ paths: - "Authentication: Login By Code" summary: Confirm login code description: | - Use this endpoint to input along the received "special" login code. + Use this endpoint to pass along the received "special" login code. parameters: - $ref: "#/components/parameters/Client" requestBody: @@ -704,6 +704,13 @@ paths: examples: unauthenticated_pending_2fa: $ref: "#/components/examples/UnauthenticatedPending2FA" + "409": + description: | + Conflict. The "login by code" flow is not pending. + content: + application/json: + schema: + $ref: "#/components/schemas/ConflictResponse" ###################################################################### # Account: Providers diff --git a/examples/react-spa/frontend/src/account/ConfirmLoginCode.js b/examples/react-spa/frontend/src/account/ConfirmLoginCode.js index 07190f1dc1..51acaba1ff 100644 --- a/examples/react-spa/frontend/src/account/ConfirmLoginCode.js +++ b/examples/react-spa/frontend/src/account/ConfirmLoginCode.js @@ -22,7 +22,7 @@ export default function ConfirmLoginCode () { }) } - if (authInfo.pendingFlow?.id !== Flows.LOGIN_BY_CODE) { + if (response.content?.status === 409 || authInfo.pendingFlow?.id !== Flows.LOGIN_BY_CODE) { return } return (