Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test client/server mTLS support. #321

Merged
merged 4 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 61 additions & 3 deletions tests/client-server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@ def wait_tcp_port(host, port):
print("Connected to {}:{}".format(host, port))


def run_with_maybe_valgrind(args, env, valgrind):
def run_with_maybe_valgrind(args, env, valgrind, expect_error=False):
if valgrind is not None:
args = [valgrind] + args
process_env = os.environ.copy()
process_env.update(env)
subprocess.check_call(args, env=process_env, stdout=subprocess.DEVNULL)
try:
subprocess.check_call(args, env=process_env, stdout=subprocess.DEVNULL)
except subprocess.CalledProcessError as e:
if not expect_error:
raise e


def run_client_tests(client, valgrind):
Expand Down Expand Up @@ -81,6 +85,50 @@ def run_client_tests(client, valgrind):
},
valgrind
)
run_with_maybe_valgrind(
[
client,
HOST,
str(PORT),
"/"
],
{
"CA_FILE": "testdata/minica.pem",
"AUTH_CERT": "testdata/localhost/cert.pem",
"AUTH_KEY": "testdata/localhost/key.pem",
},
valgrind
)


def run_mtls_client_tests(client, valgrind):
run_with_maybe_valgrind(
[
client,
HOST,
str(PORT),
"/"
],
{
"CA_FILE": "testdata/minica.pem",
},
valgrind,
expect_error=True # Client connecting w/o AUTH_CERT/AUTH_KEY should err.
)
run_with_maybe_valgrind(
[
client,
HOST,
str(PORT),
"/"
],
{
"CA_FILE": "testdata/minica.pem",
"AUTH_CERT": "testdata/localhost/cert.pem",
"AUTH_KEY": "testdata/localhost/key.pem",
},
valgrind
)


def run_server(server, valgrind, env):
Expand Down Expand Up @@ -116,17 +164,27 @@ def main():
.format(PORT))
sys.exit(1)

# Standard client/server tests.
server_popen = run_server(server, valgrind, {})
wait_tcp_port(HOST, PORT)
run_client_tests(client, valgrind)
server_popen.kill()
server_popen.wait()

run_server(server, valgrind, {
# Client/server tests w/ vectored I/O.
server_popen = run_server(server, valgrind, {
"VECTORED_IO": ""
})
wait_tcp_port(HOST, PORT)
run_client_tests(client, valgrind)
server_popen.kill()
server_popen.wait()

# Client/server tests w/ mandatory client authentication.
run_server(server, valgrind, {
"AUTH_CERT": "testdata/minica.pem",
})
run_mtls_client_tests(client, valgrind)


if __name__ == "__main__":
Expand Down
15 changes: 15 additions & 0 deletions tests/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ main(int argc, const char **argv)
rustls_client_config_builder_new();
const struct rustls_client_config *client_config = NULL;
struct rustls_slice_bytes alpn_http11;
const struct rustls_certified_key *certified_key = NULL;

alpn_http11.data = (unsigned char*)"http/1.1";
alpn_http11.len = 8;
Expand All @@ -434,6 +435,19 @@ main(int argc, const char **argv)
goto cleanup;
}

char* auth_cert = getenv("AUTH_CERT");
char* auth_key = getenv("AUTH_KEY");
if((auth_cert && !auth_key) || (!auth_cert && auth_key)) {
fprintf(stderr, "client: must set both AUTH_CERT and AUTH_KEY env vars, or neither\n");
goto cleanup;
} else if (auth_cert && auth_key) {
certified_key = load_cert_and_key(argv[0], auth_cert, auth_key);
if(certified_key == NULL) {
goto cleanup;
}
rustls_client_config_builder_set_certified_key(config_builder, &certified_key, 1);
}

rustls_client_config_builder_set_alpn_protocols(config_builder, &alpn_http11, 1);

client_config = rustls_client_config_builder_build(config_builder);
Expand All @@ -450,6 +464,7 @@ main(int argc, const char **argv)
ret = 0;

cleanup:
rustls_certified_key_free(certified_key);
rustls_client_config_free(client_config);

#ifdef _WIN32
Expand Down
50 changes: 49 additions & 1 deletion tests/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include <string.h>
#include <stdlib.h>
#include <errno.h>
#include <limits.h>

#include "rustls.h"
#include "common.h"
Expand Down Expand Up @@ -327,3 +326,52 @@ log_cb(void *userdata, const struct rustls_log_params *params)
fprintf(stderr, "%s[fd %d][%.*s]: %.*s\n", conn->program_name, conn->fd,
(int)level_str.len, level_str.data, (int)params->message.len, params->message.data);
}

enum demo_result
read_file(const char *progname, const char *filename, char *buf, size_t buflen, size_t *n)
{
FILE *f = fopen(filename, "r");
if(f == NULL) {
fprintf(stderr, "%s: opening %s: %s\n", progname, filename, strerror(errno));
return DEMO_ERROR;
}
*n = fread(buf, 1, buflen, f);
if(!feof(f)) {
fprintf(stderr, "%s: reading %s: %s\n", progname, filename, strerror(errno));
fclose(f);
return DEMO_ERROR;
}
fclose(f);
return DEMO_OK;
}

const struct rustls_certified_key *
load_cert_and_key(const char *progname, const char *certfile, const char *keyfile)
{
char certbuf[10000];
size_t certbuf_len;
char keybuf[10000];
size_t keybuf_len;

unsigned int result = read_file(progname, certfile, certbuf, sizeof(certbuf), &certbuf_len);
if(result != DEMO_OK) {
return NULL;
}

result = read_file(progname, keyfile, keybuf, sizeof(keybuf), &keybuf_len);
if(result != DEMO_OK) {
return NULL;
}

const struct rustls_certified_key *certified_key;
result = rustls_certified_key_build((uint8_t *)certbuf,
certbuf_len,
(uint8_t *)keybuf,
keybuf_len,
&certified_key);
if(result != RUSTLS_RESULT_OK) {
print_error(progname, "parsing certificate and key", result);
return NULL;
}
return certified_key;
}
6 changes: 6 additions & 0 deletions tests/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,10 @@ get_first_header_value(const char *headers, size_t headers_len,
void
log_cb(void *userdata, const struct rustls_log_params *params);

enum demo_result
read_file(const char *progname, const char *filename, char *buf, size_t buflen, size_t *n);

const struct rustls_certified_key *
load_cert_and_key(const char *progname, const char *certfile, const char *keyfile);

#endif /* COMMON_H */
71 changes: 21 additions & 50 deletions tests/server.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,6 @@
#include "rustls.h"
#include "common.h"

enum demo_result
read_file(const char *filename, char *buf, size_t buflen, size_t *n)
{
FILE *f = fopen(filename, "r");
if(f == NULL) {
fprintf(stderr, "server: opening %s: %s\n", filename, strerror(errno));
return DEMO_ERROR;
}
*n = fread(buf, 1, buflen, f);
if(!feof(f)) {
fprintf(stderr, "server: reading %s: %s\n", filename, strerror(errno));
fclose(f);
return DEMO_ERROR;
}
fclose(f);
return DEMO_OK;
}

typedef enum exchange_state
{
READING_REQUEST,
Expand Down Expand Up @@ -242,37 +224,6 @@ handle_conn(struct conndata *conn)
free(conn);
}

const struct rustls_certified_key *
load_cert_and_key(const char *certfile, const char *keyfile)
{
char certbuf[10000];
size_t certbuf_len;
char keybuf[10000];
size_t keybuf_len;

unsigned int result = read_file(certfile, certbuf, sizeof(certbuf), &certbuf_len);
if(result != DEMO_OK) {
return NULL;
}

result = read_file(keyfile, keybuf, sizeof(keybuf), &keybuf_len);
if(result != DEMO_OK) {
return NULL;
}

const struct rustls_certified_key *certified_key;
result = rustls_certified_key_build((uint8_t *)certbuf,
certbuf_len,
(uint8_t *)keybuf,
keybuf_len,
&certified_key);
if(result != RUSTLS_RESULT_OK) {
print_error("server", "parsing certificate and key", result);
return NULL;
}
return certified_key;
}

bool shutting_down = false;

void handle_signal(int signo) {
Expand All @@ -294,6 +245,8 @@ main(int argc, const char **argv)
struct rustls_connection *rconn = NULL;
const struct rustls_certified_key *certified_key = NULL;
struct rustls_slice_bytes alpn_http11;
const struct rustls_client_cert_verifier *client_cert_verifier = NULL;
struct rustls_root_cert_store *client_cert_root_store = NULL;

alpn_http11.data = (unsigned char*)"http/1.1";
alpn_http11.len = 8;
Expand All @@ -315,7 +268,7 @@ main(int argc, const char **argv)
goto cleanup;
}

certified_key = load_cert_and_key(argv[1], argv[2]);
certified_key = load_cert_and_key(argv[0], argv[1], argv[2]);
if(certified_key == NULL) {
goto cleanup;
}
Expand All @@ -324,6 +277,22 @@ main(int argc, const char **argv)
config_builder, &certified_key, 1);
rustls_server_config_builder_set_alpn_protocols(config_builder, &alpn_http11, 1);

char* auth_cert = getenv("AUTH_CERT");
if(auth_cert) {
char certbuf[10000];
size_t certbuf_len;
int result = read_file(argv[0], auth_cert, certbuf, sizeof(certbuf), &certbuf_len);
if(result != DEMO_OK) {
goto cleanup;
}

client_cert_root_store = rustls_root_cert_store_new();
rustls_root_cert_store_add_pem(client_cert_root_store, (uint8_t *)certbuf, certbuf_len, true);

client_cert_verifier = rustls_client_cert_verifier_new(client_cert_root_store);
rustls_server_config_builder_set_client_verifier(config_builder, client_cert_verifier);
}

server_config = rustls_server_config_builder_build(config_builder);

#ifdef _WIN32
Expand Down Expand Up @@ -399,6 +368,8 @@ main(int argc, const char **argv)

cleanup:
rustls_certified_key_free(certified_key);
rustls_root_cert_store_free(client_cert_root_store);
rustls_client_cert_verifier_free(client_cert_verifier);
rustls_server_config_free(server_config);
rustls_connection_free(rconn);
if(sockfd>0) {
Expand Down