Skip to content

Commit

Permalink
feat: Use autoendpoint-rs in integration tests (#205)
Browse files Browse the repository at this point in the history
* Fix importing SkipTest from a private module

* Copy the integration tests and remove unneeded tests

These integration tests will be modified to test the Rust autoendpoint
code. Python autopush code will not be tested in this repo anymore.

Unneeded tests include:
- Legacy message ID/sort key tests
- Table rotation
- The "cross" tests which connected to both Rust and Python connection
  servers in the same test.

Some more changes have been made, mostly around removing dependencies on
the Python autopush code and removing table rotation mechanisms. For
example, we only make one message table instead of one for this month
and one for last month.

* Simplify the integration test setup code

* Replace the Python endpoint server with Rust in the new integration test

* Avoid extra newline between log entries

* Respond with 201 on direct deliveries instead of 200

* Remove some unnecessary global variables

* Replace "endpoint_url" setting with "scheme" (generate the URL instead)

Also removed the unused `debug` setting.

* Use a timestamp in milliseconds for the message ID and sort key

* Apply a fix from Python autopush tests after a Stored HTTP code change

* Fix bottle routes getting added to the default bottle instance

This would cause the wrong routes to be used for each test, eventually
causing an assertion error when checking the megaphone token.

* Set cache-control for HTTP 410 errors

This is expected by an integration test

* Show debug logs in integration tests

* Fix repad_base64 function and add tests

Sometimes it would add three padding characters, while base64 has at
most two padding characters.

* Fix JWT validation by decoding the base64 public key before using it

* Don't pass on encryption headers if there is no data

* Don't set notification encryption headers at all if there's no data

* Fix flake8 formatting issues

Closes #168
  • Loading branch information
AzureMarker authored Aug 5, 2020
1 parent b4bb163 commit 31d2d19
Show file tree
Hide file tree
Showing 12 changed files with 1,224 additions and 45 deletions.
8 changes: 7 additions & 1 deletion autoendpoint/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,13 @@ impl ResponseError for ApiError {
}

fn error_response(&self) -> HttpResponse {
HttpResponse::build(self.kind.status()).json(self)
let mut builder = HttpResponse::build(self.kind.status());

if self.status_code() == 410 {
builder.set_header("Cache-Control", "max-age=86400");
}

builder.json(self)
}
}

Expand Down
20 changes: 16 additions & 4 deletions autoendpoint/src/extractors/notification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::server::ServerState;
use actix_web::dev::{Payload, PayloadStream};
use actix_web::web::Data;
use actix_web::{web, FromRequest, HttpRequest};
use autopush_common::util::sec_since_epoch;
use autopush_common::util::{ms_since_epoch, sec_since_epoch};
use cadence::Counted;
use fernet::MultiFernet;
use futures::{future, FutureExt};
Expand All @@ -19,7 +19,10 @@ pub struct Notification {
pub message_id: String,
pub subscription: Subscription,
pub headers: NotificationHeaders,
/// UNIX timestamp in seconds
pub timestamp: u64,
/// UNIX timestamp in milliseconds
pub sort_key_timestamp: u64,
pub data: Option<String>,
}

Expand Down Expand Up @@ -52,12 +55,13 @@ impl FromRequest for Notification {

let headers = NotificationHeaders::from_request(&req, data.is_some())?;
let timestamp = sec_since_epoch();
let sort_key_timestamp = ms_since_epoch();
let message_id = Self::generate_message_id(
&state.fernet,
subscription.user.uaid,
subscription.channel_id,
headers.topic.as_deref(),
timestamp,
sort_key_timestamp,
);

// Record the encoding if we have an encrypted payload
Expand All @@ -75,6 +79,7 @@ impl FromRequest for Notification {
subscription,
headers,
timestamp,
sort_key_timestamp,
data,
})
}
Expand All @@ -91,8 +96,15 @@ impl From<Notification> for autopush_common::notification::Notification {
topic: notification.headers.topic.clone(),
timestamp: notification.timestamp,
data: notification.data,
sortkey_timestamp: Some(notification.timestamp),
headers: Some(notification.headers.into()),
sortkey_timestamp: Some(notification.sort_key_timestamp),
headers: {
let headers: HashMap<String, String> = notification.headers.into();
if headers.is_empty() {
None
} else {
Some(headers)
}
},
}
}
}
Expand Down
39 changes: 21 additions & 18 deletions autoendpoint/src/extractors/notification_headers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ lazy_static! {
Regex::new(r"(?P<head>[0-9A-Za-z\-_]+)=+(?P<tail>[,;]|$)").unwrap();
}

/// 60 days
const MAX_TTL: i64 = 60 * 60 * 24 * 60;

/// Extractor and validator for notification headers
Expand Down Expand Up @@ -51,8 +52,6 @@ impl From<NotificationHeaders> for HashMap<String, String> {
fn from(headers: NotificationHeaders) -> Self {
let mut map = HashMap::new();

map.insert("ttl".to_string(), headers.ttl.to_string());
map.insert_opt("topic", headers.topic);
map.insert_opt("encoding", headers.encoding);
map.insert_opt("encryption", headers.encryption);
map.insert_opt("encryption_key", headers.encryption_key);
Expand All @@ -75,22 +74,26 @@ impl NotificationHeaders {
.map(|ttl| min(ttl, MAX_TTL))
.ok_or(ApiErrorKind::NoTTL)?;
let topic = get_owned_header(req, "topic");
let encoding = get_owned_header(req, "content-encoding");
let encryption = get_owned_header(req, "encryption");
let encryption_key = get_owned_header(req, "encryption-key");
let crypto_key = get_owned_header(req, "crypto-key");

// Strip quotes and padding from some headers
let encryption = encryption.map(Self::strip_header);
let crypto_key = crypto_key.map(Self::strip_header);

let headers = NotificationHeaders {
ttl,
topic,
encoding,
encryption,
encryption_key,
crypto_key,

let headers = if has_data {
NotificationHeaders {
ttl,
topic,
encoding: get_owned_header(req, "content-encoding"),
encryption: get_owned_header(req, "encryption").map(Self::strip_header),
encryption_key: get_owned_header(req, "encryption-key"),
crypto_key: get_owned_header(req, "crypto-key").map(Self::strip_header),
}
} else {
// Messages without a body shouldn't pass along unnecessary headers
NotificationHeaders {
ttl,
topic,
encoding: None,
encryption: None,
encryption_key: None,
crypto_key: None,
}
};

// Validate encryption if there is a message body
Expand Down
2 changes: 1 addition & 1 deletion autoendpoint/src/extractors/routers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ impl FromRequest for Routers {
ddb: state.ddb.clone(),
metrics: state.metrics.clone(),
http: state.http.clone(),
endpoint_url: state.settings.endpoint_url.clone(),
endpoint_url: state.settings.endpoint_url(),
},
fcm: state.fcm_router.clone(),
apns: state.apns_router.clone(),
Expand Down
25 changes: 21 additions & 4 deletions autoendpoint/src/extractors/subscription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ impl FromRequest for Subscription {

/// Add back padding to a base64 string
fn repad_base64(data: &str) -> Cow<'_, str> {
let remaining_padding = data.len() % 4;
let trailing_chars = data.len() % 4;

if remaining_padding != 0 {
if trailing_chars != 0 {
let mut data = data.to_string();

for _ in 0..remaining_padding {
for _ in trailing_chars..4 {
data.push('=');
}

Expand Down Expand Up @@ -201,9 +201,11 @@ fn validate_vapid_jwt(vapid: &VapidHeaderWithKey) -> ApiResult<()> {
}

// Check the signature and make sure the expiration is in the future
let public_key = base64::decode_config(public_key, base64::URL_SAFE_NO_PAD)
.map_err(|_| VapidError::InvalidKey)?;
let token_data = jsonwebtoken::decode::<Claims>(
&vapid.token,
&DecodingKey::from_ec_der(public_key.as_bytes()),
&DecodingKey::from_ec_der(&public_key),
&Validation::new(Algorithm::ES256),
)?;

Expand All @@ -218,3 +220,18 @@ fn validate_vapid_jwt(vapid: &VapidHeaderWithKey) -> ApiResult<()> {

Ok(())
}

#[cfg(test)]
mod tests {
use crate::extractors::subscription::repad_base64;

#[test]
fn repad_base64_1_padding() {
assert_eq!(repad_base64("Zm9vYmE"), "Zm9vYmE=")
}

#[test]
fn repad_base64_2_padding() {
assert_eq!(repad_base64("Zm9vYg"), "Zm9vYg==")
}
}
1 change: 1 addition & 0 deletions autoendpoint/src/routers/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub mod tests {
crypto_key: Some("test-crypto-key".to_string()),
},
timestamp: 0,
sort_key_timestamp: 0,
data,
}
}
Expand Down
2 changes: 1 addition & 1 deletion autoendpoint/src/routers/webpush.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl WebPushRouter {
/// Update metrics and create a response for when a notification has been directly forwarded to
/// an autopush server.
fn make_delivered_response(&self, notification: &Notification) -> RouterResponse {
self.make_response(notification, "Direct", StatusCode::OK)
self.make_response(notification, "Direct", StatusCode::CREATED)
}

/// Update metrics and create a response for when a notification has been stored in the database
Expand Down
4 changes: 2 additions & 2 deletions autoendpoint/src/routes/registration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub async fn register_uaid_route(
&user.uaid,
&channel_id,
router_data_input.key.as_deref(),
state.settings.endpoint_url.as_str(),
state.settings.endpoint_url().as_str(),
&state.fernet,
)
.map_err(ApiErrorKind::EndpointUrl)?;
Expand Down Expand Up @@ -140,7 +140,7 @@ pub async fn new_channel_route(
&path_args.uaid,
&channel_id,
channel_data.key.as_deref(),
state.settings.endpoint_url.as_str(),
state.settings.endpoint_url().as_str(),
&state.fernet,
)
.map_err(ApiErrorKind::EndpointUrl)?;
Expand Down
5 changes: 3 additions & 2 deletions autoendpoint/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ impl Server {
let metrics = metrics::metrics_from_opts(&settings)?;
let bind_address = format!("{}:{}", settings.host, settings.port);
let fernet = Arc::new(settings.make_fernet());
let endpoint_url = settings.endpoint_url();
let ddb = Box::new(DbClientImpl::new(
metrics.clone(),
settings.router_table_name.clone(),
Expand All @@ -50,7 +51,7 @@ impl Server {
let fcm_router = Arc::new(
FcmRouter::new(
settings.fcm.clone(),
settings.endpoint_url.clone(),
endpoint_url.clone(),
http.clone(),
metrics.clone(),
ddb.clone(),
Expand All @@ -60,7 +61,7 @@ impl Server {
let apns_router = Arc::new(
ApnsRouter::new(
settings.apns.clone(),
settings.endpoint_url.clone(),
endpoint_url,
metrics.clone(),
ddb.clone(),
)
Expand Down
17 changes: 10 additions & 7 deletions autoendpoint/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@ use fernet::{Fernet, MultiFernet};
use serde::Deserialize;
use url::Url;

const DEFAULT_PORT: u16 = 8000;
const ENV_PREFIX: &str = "autoend";

#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
#[serde(deny_unknown_fields)]
pub struct Settings {
pub debug: bool,
pub port: u16,
pub scheme: String,
pub host: String,
pub endpoint_url: Url,
pub port: u16,

pub router_table_name: String,
pub message_table_name: String,
Expand All @@ -38,10 +36,9 @@ pub struct Settings {
impl Default for Settings {
fn default() -> Settings {
Settings {
debug: false,
port: DEFAULT_PORT,
scheme: "http".to_string(),
host: "127.0.0.1".to_string(),
endpoint_url: Url::parse("http://127.0.0.1:8000/").unwrap(),
port: 8000,
router_table_name: "router".to_string(),
message_table_name: "message".to_string(),
max_data_bytes: 4096,
Expand Down Expand Up @@ -116,4 +113,10 @@ impl Settings {
pub fn auth_keys(&self) -> Vec<&str> {
Self::read_list_from_str(&self.auth_keys, "Invalid AUTOEND_AUTH_KEYS").collect()
}

/// Get the URL for this endpoint server
pub fn endpoint_url(&self) -> Url {
Url::parse(&format!("{}://{}:{}", self.scheme, self.host, self.port))
.expect("Invalid endpoint URL")
}
}
11 changes: 6 additions & 5 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from contextlib import contextmanager
from functools import wraps
from threading import Event, Thread
from unittest.case import SkipTest
from unittest import SkipTest

import autopush.tests
import autopush.tests as ap_tests
Expand Down Expand Up @@ -47,6 +47,7 @@
from twisted.logger import globalLogPublisher
from twisted.trial import unittest

app = bottle.Bottle()
log = logging.getLogger(__name__)

here_dir = os.path.abspath(os.path.dirname(__file__))
Expand Down Expand Up @@ -148,14 +149,14 @@ def wrapper(self, *args, **kwargs):
return max_logs_decorator


@bottle.get("/v1/broadcasts")
@app.get("/v1/broadcasts")
def broadcast_handler():
assert bottle.request.headers["Authorization"] == MOCK_MP_TOKEN
MOCK_MP_POLLED.set()
return dict(broadcasts=MOCK_MP_SERVICES)


@bottle.post("/api/1/store/")
@app.post("/api/1/store/")
def sentry_handler():
content = bottle.request.json
MOCK_SENTRY_QUEUE.put(content)
Expand Down Expand Up @@ -234,7 +235,7 @@ def setup_module():
CN_QUEUES.extend([out_q, err_q])

MOCK_SERVER_THREAD = Thread(
target=bottle.run,
target=app.run,
kwargs=dict(
port=MOCK_SERVER_PORT, debug=True
))
Expand Down Expand Up @@ -765,7 +766,7 @@ def test_ttl_0_not_connected(self):
data = str(uuid.uuid4())
client = yield self.quick_register()
yield client.disconnect()
yield client.send_notification(data=data, ttl=0)
yield client.send_notification(data=data, ttl=0, status=201)
yield client.connect()
yield client.hello()
result = yield client.get_notification(timeout=0.5)
Expand Down
Loading

0 comments on commit 31d2d19

Please sign in to comment.