Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Add rate-limiting for account recovery and registration #3093

Merged
merged 4 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 86 additions & 8 deletions crates/config/src/sections/rate_limiting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,28 @@ use crate::ConfigurationSection;
/// Configuration related to sending emails
#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct RateLimitingConfig {
/// Account Recovery-specific rate limits
#[serde(default)]
pub account_recovery: AccountRecoveryRateLimitingConfig,
/// Login-specific rate limits
#[serde(default)]
pub login: LoginRateLimitingConfig,
/// Controls how many registrations attempts are permitted
/// based on source address.
#[serde(default = "default_registration")]
pub registration: RateLimiterConfiguration,
}

#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct LoginRateLimitingConfig {
/// Controls how many login attempts are permitted
/// based on source address.
/// based on source IP address.
/// This can protect against brute force login attempts.
///
/// Note: this limit also applies to password checks when a user attempts to
/// change their own password.
#[serde(default = "default_login_per_address")]
pub per_address: RateLimiterConfiguration,
#[serde(default = "default_login_per_ip")]
pub per_ip: RateLimiterConfiguration,
/// Controls how many login attempts are permitted
/// based on the account that is being attempted to be logged into.
/// This can protect against a distributed brute force attack
Expand All @@ -50,6 +57,24 @@ pub struct LoginRateLimitingConfig {
pub per_account: RateLimiterConfiguration,
}

#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct AccountRecoveryRateLimitingConfig {
/// Controls how many account recovery attempts are permitted
/// based on source IP address.
/// This can protect against causing e-mail spam to many targets.
///
/// Note: this limit also applies to re-sends.
#[serde(default = "default_account_recovery_per_ip")]
pub per_ip: RateLimiterConfiguration,
/// Controls how many account recovery attempts are permitted
/// based on the e-mail address entered into the recovery form.
/// This can protect against causing e-mail spam to one target.
///
/// Note: this limit also applies to re-sends.
#[serde(default = "default_account_recovery_per_address")]
pub per_address: RateLimiterConfiguration,
}

#[derive(Copy, Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
pub struct RateLimiterConfiguration {
/// A one-off burst of actions that the user can perform
Expand All @@ -66,6 +91,13 @@ impl ConfigurationSection for RateLimitingConfig {
fn validate(&self, figment: &figment::Figment) -> Result<(), figment::Error> {
let metadata = figment.find_metadata(Self::PATH.unwrap());

let error_on_field = |mut error: figment::error::Error, field: &'static str| {
error.metadata = metadata.cloned();
error.profile = Some(figment::Profile::Default);
error.path = vec![Self::PATH.unwrap().to_owned(), field.to_owned()];
error
};

let error_on_nested_field =
|mut error: figment::error::Error, container: &'static str, field: &'static str| {
error.metadata = metadata.cloned();
Expand All @@ -92,8 +124,23 @@ impl ConfigurationSection for RateLimitingConfig {
None
};

if let Some(error) = error_on_limiter(&self.login.per_address) {
return Err(error_on_nested_field(error, "login", "per_address"));
if let Some(error) = error_on_limiter(&self.account_recovery.per_ip) {
return Err(error_on_nested_field(error, "account_recovery", "per_ip"));
}
if let Some(error) = error_on_limiter(&self.account_recovery.per_address) {
return Err(error_on_nested_field(
error,
"account_recovery",
"per_address",
));
}

if let Some(error) = error_on_limiter(&self.registration) {
return Err(error_on_field(error, "registration"));
}

if let Some(error) = error_on_limiter(&self.login.per_ip) {
return Err(error_on_nested_field(error, "login", "per_ip"));
}
if let Some(error) = error_on_limiter(&self.login.per_account) {
return Err(error_on_nested_field(error, "login", "per_account"));
Expand All @@ -119,7 +166,7 @@ impl RateLimiterConfiguration {
}
}

fn default_login_per_address() -> RateLimiterConfiguration {
fn default_login_per_ip() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 60.0,
Expand All @@ -133,20 +180,51 @@ fn default_login_per_account() -> RateLimiterConfiguration {
}
}

#[allow(clippy::derivable_impls)] // when we add some top-level ratelimiters this will not be derivable anymore
fn default_registration() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 3600.0,
}
}

fn default_account_recovery_per_ip() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 3.0 / 3600.0,
}
}

fn default_account_recovery_per_address() -> RateLimiterConfiguration {
RateLimiterConfiguration {
burst: NonZeroU32::new(3).unwrap(),
per_second: 1.0 / 3600.0,
}
}

impl Default for RateLimitingConfig {
fn default() -> Self {
RateLimitingConfig {
login: LoginRateLimitingConfig::default(),
registration: default_registration(),
account_recovery: AccountRecoveryRateLimitingConfig::default(),
}
}
}

impl Default for LoginRateLimitingConfig {
fn default() -> Self {
LoginRateLimitingConfig {
per_address: default_login_per_address(),
per_ip: default_login_per_ip(),
per_account: default_login_per_account(),
}
}
}

impl Default for AccountRecoveryRateLimitingConfig {
fn default() -> Self {
AccountRecoveryRateLimitingConfig {
per_ip: default_account_recovery_per_ip(),
per_address: default_account_recovery_per_address(),
}
}
}
74 changes: 73 additions & 1 deletion crates/handlers/src/rate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ use mas_config::RateLimitingConfig;
use mas_data_model::User;
use ulid::Ulid;

#[derive(Debug, Clone, thiserror::Error)]
pub enum AccountRecoveryLimitedError {
#[error("Too many account recovery requests for requester {0}")]
Requester(RequesterFingerprint),

#[error("Too many account recovery requests for e-mail {0}")]
Email(String),
}

#[derive(Debug, Clone, Copy, thiserror::Error)]
pub enum PasswordCheckLimitedError {
#[error("Too many password checks for requester {0}")]
Expand All @@ -28,6 +37,12 @@ pub enum PasswordCheckLimitedError {
User(Ulid),
}

#[derive(Debug, Clone, thiserror::Error)]
pub enum RegistrationLimitedError {
#[error("Too many account registration requests for requester {0}")]
Requester(RequesterFingerprint),
}

/// Key used to rate limit requests per requester
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RequesterFingerprint {
Expand Down Expand Up @@ -66,15 +81,25 @@ type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;

#[derive(Debug)]
struct LimiterInner {
account_recovery_per_requester: KeyedRateLimiter<RequesterFingerprint>,
account_recovery_per_email: KeyedRateLimiter<String>,
password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
password_check_for_user: KeyedRateLimiter<Ulid>,
registration_per_requester: KeyedRateLimiter<RequesterFingerprint>,
}

impl LimiterInner {
fn new(config: &RateLimitingConfig) -> Option<Self> {
Some(Self {
password_check_for_requester: RateLimiter::keyed(config.login.per_address.to_quota()?),
account_recovery_per_requester: RateLimiter::keyed(
config.account_recovery.per_ip.to_quota()?,
),
account_recovery_per_email: RateLimiter::keyed(
config.account_recovery.per_address.to_quota()?,
),
password_check_for_requester: RateLimiter::keyed(config.login.per_ip.to_quota()?),
password_check_for_user: RateLimiter::keyed(config.login.per_account.to_quota()?),
registration_per_requester: RateLimiter::keyed(config.registration.to_quota()?),
})
}
}
Expand Down Expand Up @@ -105,14 +130,44 @@ impl Limiter {

loop {
// Call the retain_recent method on each rate limiter
this.inner.account_recovery_per_email.retain_recent();
this.inner.account_recovery_per_requester.retain_recent();
this.inner.password_check_for_requester.retain_recent();
this.inner.password_check_for_user.retain_recent();
this.inner.registration_per_requester.retain_recent();

interval.tick().await;
}
});
}

/// Check if an account recovery can be performed
///
/// # Errors
///
/// Returns an error if the operation is rate limited.
pub fn check_account_recovery(
&self,
requester: RequesterFingerprint,
email_address: &str,
) -> Result<(), AccountRecoveryLimitedError> {
self.inner
.account_recovery_per_requester
.check_key(&requester)
.map_err(|_| AccountRecoveryLimitedError::Requester(requester))?;

// Convert to lowercase to prevent bypassing the limit by enumerating different
// case variations.
// A case-folding transformation may be more proper.
let canonical_email = email_address.to_lowercase();
self.inner
.account_recovery_per_email
.check_key(&canonical_email)
.map_err(|_| AccountRecoveryLimitedError::Email(canonical_email))?;

Ok(())
}

/// Check if a password check can be performed
///
/// # Errors
Expand All @@ -135,6 +190,23 @@ impl Limiter {

Ok(())
}

/// Check if an account registration can be performed
///
/// # Errors
///
/// Returns an error if the operation is rate limited.
pub fn check_registration(
&self,
requester: RequesterFingerprint,
) -> Result<(), RegistrationLimitedError> {
self.inner
.registration_per_requester
.check_key(&requester)
.map_err(|_| RegistrationLimitedError::Requester(requester))?;

Ok(())
}
}

#[cfg(test)]
Expand Down
19 changes: 16 additions & 3 deletions crates/handlers/src/views/recovery/progress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use axum::{
response::{Html, IntoResponse, Response},
Form,
};
use hyper::StatusCode;
use mas_axum_utils::{
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
Expand All @@ -31,7 +32,7 @@ use mas_storage::{
use mas_templates::{EmptyContext, RecoveryProgressContext, TemplateContext, Templates};
use ulid::Ulid;

use crate::PreferredLanguage;
use crate::{Limiter, PreferredLanguage, RequesterFingerprint};

pub(crate) async fn get(
mut rng: BoxRng,
Expand Down Expand Up @@ -74,7 +75,7 @@ pub(crate) async fn get(
return Ok((cookie_jar, Html(rendered)).into_response());
}

let context = RecoveryProgressContext::new(recovery_session)
let context = RecoveryProgressContext::new(recovery_session, false)
.with_csrf(csrf_token.form_value())
.with_language(locale);

Expand All @@ -92,6 +93,7 @@ pub(crate) async fn post(
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
(State(limiter), requester): (State<Limiter>, RequesterFingerprint),
PreferredLanguage(locale): PreferredLanguage,
cookie_jar: CookieJar,
Path(id): Path<Ulid>,
Expand Down Expand Up @@ -130,14 +132,25 @@ pub(crate) async fn post(
// Verify the CSRF token
let () = cookie_jar.verify_form(&clock, form)?;

// Check the rate limit if we are about to process the form
if let Err(e) = limiter.check_account_recovery(requester, &recovery_session.email) {
tracing::warn!(error = &e as &dyn std::error::Error);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really wish tracing had a sigil (like ? or %) to report as &dyn Error :(

Copy link
Contributor Author

@reivilibre reivilibre Aug 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're not alone: tokio-rs/tracing#1308

let context = RecoveryProgressContext::new(recovery_session, true)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let rendered = templates.render_recovery_progress(&context)?;

return Ok((StatusCode::TOO_MANY_REQUESTS, (cookie_jar, Html(rendered))).into_response());
}

// Schedule a new batch of emails
repo.job()
.schedule_job(SendAccountRecoveryEmailsJob::new(&recovery_session))
.await?;

repo.save().await?;

let context = RecoveryProgressContext::new(recovery_session)
let context = RecoveryProgressContext::new(recovery_session, false)
.with_csrf(csrf_token.form_value())
.with_language(locale);

Expand Down
13 changes: 11 additions & 2 deletions crates/handlers/src/views/recovery/start.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ use mas_storage::{
BoxClock, BoxRepository, BoxRng,
};
use mas_templates::{
EmptyContext, FieldError, FormState, RecoveryStartContext, RecoveryStartFormField,
EmptyContext, FieldError, FormError, FormState, RecoveryStartContext, RecoveryStartFormField,
TemplateContext, Templates,
};
use serde::{Deserialize, Serialize};

use crate::{BoundActivityTracker, PreferredLanguage};
use crate::{BoundActivityTracker, Limiter, PreferredLanguage, RequesterFingerprint};

#[derive(Deserialize, Serialize)]
pub(crate) struct StartRecoveryForm {
Expand Down Expand Up @@ -90,6 +90,7 @@ pub(crate) async fn post(
State(site_config): State<SiteConfig>,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
(State(limiter), requester): (State<Limiter>, RequesterFingerprint),
PreferredLanguage(locale): PreferredLanguage,
cookie_jar: CookieJar,
Form(form): Form<ProtectedForm<StartRecoveryForm>>,
Expand Down Expand Up @@ -120,6 +121,14 @@ pub(crate) async fn post(
form_state.with_error_on_field(RecoveryStartFormField::Email, FieldError::Invalid);
}

if form_state.is_valid() {
// Check the rate limit if we are about to process the form
if let Err(e) = limiter.check_account_recovery(requester, &form.email) {
tracing::warn!(error = &e as &dyn std::error::Error);
form_state.add_error_on_form(FormError::RateLimitExceeded);
}
}

if !form_state.is_valid() {
repo.save().await?;
let context = RecoveryStartContext::new()
Expand Down
Loading
Loading