diff --git a/source/secure_tunneling.c b/source/secure_tunneling.c index c14bb0c2..a70132bd 100644 --- a/source/secure_tunneling.c +++ b/source/secure_tunneling.c @@ -892,6 +892,7 @@ static void s_secure_tunneling_websocket_on_send_data_complete_callback( } aws_secure_tunnel_data_tunnel_pair_destroy(pair); secure_tunnel->pending_write_completion = false; + s_reevaluate_service_task(secure_tunnel); } static bool secure_tunneling_websocket_stream_outgoing_payload( diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index cbac1041..1e5d7547 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -32,6 +32,7 @@ add_net_test_case(secure_tunneling_session_reset_test) add_net_test_case(secure_tunneling_serializer_data_message_test) add_net_test_case(secure_tunneling_max_payload_test) add_net_test_case(secure_tunneling_max_payload_exceed_test) +add_net_test_case(secure_tunneling_subsequent_writes) add_net_test_case(secure_tunneling_receive_connection_start_test) add_net_test_case(secure_tunneling_ignore_inactive_stream_message_test) add_net_test_case(secure_tunneling_ignore_inactive_connection_id_message_test) diff --git a/tests/secure_tunnel_tests.c b/tests/secure_tunnel_tests.c index 0bc2cecc..2a0b2832 100644 --- a/tests/secure_tunnel_tests.c +++ b/tests/secure_tunnel_tests.c @@ -20,7 +20,6 @@ #include #include #include -#include #define PAYLOAD_BYTE_LENGTH_PREFIX 2 AWS_STATIC_STRING_FROM_LITERAL(s_access_token, "IAmAnAccessToken"); @@ -85,6 +84,10 @@ typedef int(aws_secure_tunnel_mock_test_fixture_header_check_fn)( const struct aws_http_headers *request_headers, void *user_data); +typedef void(aws_secure_tunnel_mock_test_fixture_on_message_received_fn)( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view); + struct aws_secure_tunnel_mock_test_fixture { struct aws_allocator *allocator; @@ -101,7 +104,7 @@ struct aws_secure_tunnel_mock_test_fixture { struct aws_secure_tunnel_vtable secure_tunnel_vtable; aws_secure_tunnel_mock_test_fixture_header_check_fn *header_check; - + aws_secure_tunnel_mock_test_fixture_on_message_received_fn *on_server_message_received; struct aws_mutex lock; struct aws_condition_variable signal; bool listener_destroyed; @@ -120,6 +123,7 @@ struct aws_secure_tunnel_mock_test_fixture { struct aws_byte_buf last_message_payload_buf; + /* The following fields are intended to validate things from the mocked secure tunnel perspective. */ int secure_tunnel_message_received_count; int secure_tunnel_message_sent_count; int secure_tunnel_stream_started_count; @@ -130,8 +134,11 @@ struct aws_secure_tunnel_mock_test_fixture { int secure_tunnel_message_sent_count_target; int secure_tunnel_message_sent_connection_reset_count; int secure_tunnel_message_sent_data_count; + int secure_tunnel_message_previous_data_value; + bool secure_tunnel_messages_received_in_order; bool on_send_message_complete_fired; + int on_send_message_complete_fired_cnt; struct { enum aws_secure_tunnel_message_type type; int error_code; @@ -251,6 +258,7 @@ static void s_on_test_secure_tunnel_send_message_complete( aws_mutex_lock(&test_fixture->lock); test_fixture->on_send_message_complete_fired = true; + test_fixture->on_send_message_complete_fired_cnt++; test_fixture->on_send_message_complete_result.type = type; test_fixture->on_send_message_complete_result.error_code = error_code; aws_condition_variable_notify_all(&test_fixture->signal); @@ -575,7 +583,7 @@ void aws_secure_tunnel_send_mock_message( &receive_task->task, s_secure_tunneling_mock_websocket_receive_frame_payload_task_fn, (void *)receive_task, - "MockWebsocketSendMessage"); + "MockWebSocketSendMessageFromServer"); receive_task->test_fixture = test_fixture; @@ -635,6 +643,7 @@ int aws_websocket_client_connect_mock_fn(const struct aws_websocket_client_conne return AWS_OP_SUCCESS; } +/* Mock for a server-side code receiving WebSocket frames. */ void aws_secure_tunnel_test_on_message_received( struct aws_secure_tunnel *secure_tunnel, struct aws_secure_tunnel_message_view *message_view) { @@ -657,6 +666,70 @@ void aws_secure_tunnel_test_on_message_received( aws_mutex_unlock(&test_fixture->lock); } +void aws_secure_tunnel_test_on_message_received_with_order_validation( + struct aws_secure_tunnel *secure_tunnel, + struct aws_secure_tunnel_message_view *message_view) { + (void)message_view; + struct aws_secure_tunnel_mock_test_fixture *test_fixture = secure_tunnel->config->user_data; + + aws_mutex_lock(&test_fixture->lock); + test_fixture->secure_tunnel_message_sent_count++; + int data_value; + switch (message_view->type) { + case AWS_SECURE_TUNNEL_MT_DATA: + test_fixture->secure_tunnel_message_sent_data_count++; + data_value = (int)strtol((const char *)message_view->payload->ptr, NULL, 10); + if (test_fixture->secure_tunnel_message_previous_data_value > 0 && + data_value != test_fixture->secure_tunnel_message_previous_data_value + 1) { + /* We cannot assert in this callback, log error and set corresponding fail flag instead. */ + fprintf( + stderr, + "ERROR: secure tunnel expected %d, received %d\n", + test_fixture->secure_tunnel_message_previous_data_value + 1, + data_value); + test_fixture->secure_tunnel_messages_received_in_order = false; + } + test_fixture->secure_tunnel_message_previous_data_value = data_value; + break; + case AWS_SECURE_TUNNEL_MT_CONNECTION_RESET: + test_fixture->secure_tunnel_message_sent_connection_reset_count++; + break; + default: + break; + } + aws_condition_variable_notify_all(&test_fixture->signal); + aws_mutex_unlock(&test_fixture->lock); +} + +struct aws_secure_tunnel_mock_websocket_send_frame_task { + struct aws_task task; + struct aws_secure_tunnel_mock_test_fixture *test_fixture; + struct data_tunnel_pair *pair; + aws_websocket_outgoing_frame_complete_fn *on_complete; +}; + +static void s_secure_tunneling_mock_websocket_send_frame_task_fn( + struct aws_task *task, + void *arg, + enum aws_task_status status) { + + (void)task; + + struct aws_secure_tunnel_mock_websocket_send_frame_task *send_task = arg; + if (status != AWS_TASK_STATUS_RUN_READY) { + return; + } + + struct aws_secure_tunnel_mock_test_fixture *test_fixture = send_task->test_fixture; + + aws_secure_tunnel_deserialize_message_from_cursor( + test_fixture->secure_tunnel, &send_task->pair->cur, test_fixture->on_server_message_received); + + send_task->on_complete((struct aws_websocket *)test_fixture, AWS_OP_SUCCESS, send_task->pair); + + aws_mem_release(test_fixture->allocator, send_task); +} + int aws_websocket_send_frame_mock_fn( struct aws_websocket *websocket, const struct aws_websocket_send_frame_options *options) { @@ -668,11 +741,21 @@ int aws_websocket_send_frame_mock_fn( void *pointer = websocket; struct aws_secure_tunnel_mock_test_fixture *test_fixture = pointer; - struct data_tunnel_pair *pair = options->user_data; - aws_secure_tunnel_deserialize_message_from_cursor( - test_fixture->secure_tunnel, &pair->cur, &aws_secure_tunnel_test_on_message_received); + struct aws_secure_tunnel_mock_websocket_send_frame_task *send_task = aws_mem_calloc( + test_fixture->secure_tunnel->allocator, 1, sizeof(struct aws_secure_tunnel_mock_websocket_send_frame_task)); + + aws_task_init( + &send_task->task, + s_secure_tunneling_mock_websocket_send_frame_task_fn, + (void *)send_task, + "MockWebSocketSendMessageFromClient"); + + send_task->test_fixture = test_fixture; + send_task->pair = options->user_data; + send_task->on_complete = options->on_complete; - options->on_complete(websocket, AWS_OP_SUCCESS, options->user_data); + /* TODO Schedule in 10 ms. */ + aws_event_loop_schedule_task_now(test_fixture->secure_tunnel->loop, &send_task->task); return AWS_OP_SUCCESS; } @@ -772,6 +855,9 @@ int aws_secure_tunnel_mock_test_fixture_init( test_fixture->secure_tunnel_vtable.aws_websocket_close_fn = aws_websocket_close_mock_fn; test_fixture->secure_tunnel_vtable.vtable_user_data = test_fixture; + test_fixture->on_server_message_received = aws_secure_tunnel_test_on_message_received; + test_fixture->secure_tunnel_messages_received_in_order = true; + aws_secure_tunnel_set_vtable(test_fixture->secure_tunnel, &test_fixture->secure_tunnel_vtable); return AWS_OP_SUCCESS; @@ -1374,6 +1460,69 @@ static int s_secure_tunneling_max_payload_exceed_test_fn(struct aws_allocator *a AWS_TEST_CASE(secure_tunneling_max_payload_exceed_test, s_secure_tunneling_max_payload_exceed_test_fn) +/* Test that messages sent by a user one after another without delay are actually being sent to server. */ +static int s_secure_tunneling_subsequent_writes_test_fn(struct aws_allocator *allocator, void *ctx) { + (void)ctx; + struct secure_tunnel_test_options test_options; + struct aws_secure_tunnel_mock_test_fixture test_fixture; + aws_secure_tunnel_mock_test_init(allocator, &test_options, &test_fixture, AWS_SECURE_TUNNELING_DESTINATION_MODE); + struct aws_secure_tunnel *secure_tunnel = test_fixture.secure_tunnel; + + test_fixture.on_server_message_received = aws_secure_tunnel_test_on_message_received_with_order_validation; + + ASSERT_SUCCESS(aws_secure_tunnel_start(secure_tunnel)); + s_wait_for_connected_successfully(&test_fixture); + + /* Create and send a stream start message from the server to the destination client */ + struct aws_byte_cursor service_1 = aws_byte_cursor_from_string(s_service_id_1); + struct aws_secure_tunnel_message_view stream_start_message_view = { + .type = AWS_SECURE_TUNNEL_MT_STREAM_START, + .service_id = &service_1, + .stream_id = 1, + }; + aws_secure_tunnel_send_mock_message(&test_fixture, &stream_start_message_view); + + /* Wait and confirm that a stream has been started */ + s_wait_for_stream_started(&test_fixture); + ASSERT_TRUE(s_secure_tunnel_check_active_stream_id(secure_tunnel, &service_1, 1)); + + int total_messages = 100; + for (int i = 0; i < total_messages; ++i) { + uint8_t buf[16]; + struct aws_byte_cursor s_payload_buf = { + .ptr = buf, + .len = 16, + }; + + snprintf((char *)buf, sizeof(buf), "%d", i); + + struct aws_secure_tunnel_message_view data_message_view = { + .type = AWS_SECURE_TUNNEL_MT_DATA, + .stream_id = 0, + .service_id = &service_1, + .payload = &s_payload_buf, + }; + + int result = aws_secure_tunnel_send_message(secure_tunnel, &data_message_view); + ASSERT_INT_EQUALS(result, AWS_OP_SUCCESS); + } + + /* 1 second must be enough to send few messages. */ + aws_thread_current_sleep(aws_timestamp_convert(1, AWS_TIMESTAMP_SECS, AWS_TIMESTAMP_NANOS, NULL)); + + ASSERT_INT_EQUALS(test_fixture.on_send_message_complete_fired_cnt, total_messages); + ASSERT_TRUE(test_fixture.secure_tunnel_messages_received_in_order); + + ASSERT_SUCCESS(aws_secure_tunnel_stop(secure_tunnel)); + s_wait_for_connection_shutdown(&test_fixture); + + aws_secure_tunnel_mock_test_clean_up(&test_fixture); + + return AWS_OP_SUCCESS; +} + +AWS_TEST_CASE(secure_tunneling_subsequent_writes, s_secure_tunneling_subsequent_writes_test_fn) + static int s_secure_tunneling_receive_connection_start_test_fn(struct aws_allocator *allocator, void *ctx) { (void)ctx; struct secure_tunnel_test_options test_options;