diff --git a/tests/client-server.py b/tests/client-server.py index 8d3beb1c..4f0d10e4 100755 --- a/tests/client-server.py +++ b/tests/client-server.py @@ -35,12 +35,16 @@ def wait_tcp_port(host, port): print("Connected to {}:{}".format(host, port)) -def run_with_maybe_valgrind(args, env, valgrind): +def run_with_maybe_valgrind(args, env, valgrind, expect_error=False): if valgrind is not None: args = [valgrind] + args process_env = os.environ.copy() process_env.update(env) - subprocess.check_call(args, env=process_env, stdout=subprocess.DEVNULL) + try: + subprocess.check_call(args, env=process_env, stdout=subprocess.DEVNULL) + except subprocess.CalledProcessError as e: + if not expect_error: + raise e def run_client_tests(client, valgrind): @@ -81,6 +85,50 @@ def run_client_tests(client, valgrind): }, valgrind ) + run_with_maybe_valgrind( + [ + client, + HOST, + str(PORT), + "/" + ], + { + "CA_FILE": "testdata/minica.pem", + "AUTH_CERT": "testdata/localhost/cert.pem", + "AUTH_KEY": "testdata/localhost/key.pem", + }, + valgrind + ) + + +def run_mtls_client_tests(client, valgrind): + run_with_maybe_valgrind( + [ + client, + HOST, + str(PORT), + "/" + ], + { + "CA_FILE": "testdata/minica.pem", + }, + valgrind, + expect_error=True # Client connecting w/o AUTH_CERT/AUTH_KEY should err. + ) + run_with_maybe_valgrind( + [ + client, + HOST, + str(PORT), + "/" + ], + { + "CA_FILE": "testdata/minica.pem", + "AUTH_CERT": "testdata/localhost/cert.pem", + "AUTH_KEY": "testdata/localhost/key.pem", + }, + valgrind + ) def run_server(server, valgrind, env): @@ -116,17 +164,27 @@ def main(): .format(PORT)) sys.exit(1) + # Standard client/server tests. server_popen = run_server(server, valgrind, {}) wait_tcp_port(HOST, PORT) run_client_tests(client, valgrind) server_popen.kill() server_popen.wait() - run_server(server, valgrind, { + # Client/server tests w/ vectored I/O. + server_popen = run_server(server, valgrind, { "VECTORED_IO": "" }) wait_tcp_port(HOST, PORT) run_client_tests(client, valgrind) + server_popen.kill() + server_popen.wait() + + # Client/server tests w/ mandatory client authentication. + run_server(server, valgrind, { + "AUTH_CERT": "testdata/minica.pem", + }) + run_mtls_client_tests(client, valgrind) if __name__ == "__main__": diff --git a/tests/client.c b/tests/client.c index 84b78bf4..562ccc63 100644 --- a/tests/client.c +++ b/tests/client.c @@ -409,6 +409,7 @@ main(int argc, const char **argv) rustls_client_config_builder_new(); const struct rustls_client_config *client_config = NULL; struct rustls_slice_bytes alpn_http11; + const struct rustls_certified_key *certified_key = NULL; alpn_http11.data = (unsigned char*)"http/1.1"; alpn_http11.len = 8; @@ -434,6 +435,19 @@ main(int argc, const char **argv) goto cleanup; } + char* auth_cert = getenv("AUTH_CERT"); + char* auth_key = getenv("AUTH_KEY"); + if((auth_cert && !auth_key) || (!auth_cert && auth_key)) { + fprintf(stderr, "client: must set both AUTH_CERT and AUTH_KEY env vars, or neither\n"); + goto cleanup; + } else if (auth_cert && auth_key) { + certified_key = load_cert_and_key(argv[0], auth_cert, auth_key); + if(certified_key == NULL) { + goto cleanup; + } + rustls_client_config_builder_set_certified_key(config_builder, &certified_key, 1); + } + rustls_client_config_builder_set_alpn_protocols(config_builder, &alpn_http11, 1); client_config = rustls_client_config_builder_build(config_builder); @@ -450,6 +464,7 @@ main(int argc, const char **argv) ret = 0; cleanup: + rustls_certified_key_free(certified_key); rustls_client_config_free(client_config); #ifdef _WIN32 diff --git a/tests/common.c b/tests/common.c index d2797bae..ab412422 100644 --- a/tests/common.c +++ b/tests/common.c @@ -16,7 +16,6 @@ #include <string.h> #include <stdlib.h> #include <errno.h> -#include <limits.h> #include "rustls.h" #include "common.h" @@ -327,3 +326,52 @@ log_cb(void *userdata, const struct rustls_log_params *params) fprintf(stderr, "%s[fd %d][%.*s]: %.*s\n", conn->program_name, conn->fd, (int)level_str.len, level_str.data, (int)params->message.len, params->message.data); } + +enum demo_result +read_file(const char *progname, const char *filename, char *buf, size_t buflen, size_t *n) +{ + FILE *f = fopen(filename, "r"); + if(f == NULL) { + fprintf(stderr, "%s: opening %s: %s\n", progname, filename, strerror(errno)); + return DEMO_ERROR; + } + *n = fread(buf, 1, buflen, f); + if(!feof(f)) { + fprintf(stderr, "%s: reading %s: %s\n", progname, filename, strerror(errno)); + fclose(f); + return DEMO_ERROR; + } + fclose(f); + return DEMO_OK; +} + +const struct rustls_certified_key * +load_cert_and_key(const char *progname, const char *certfile, const char *keyfile) +{ + char certbuf[10000]; + size_t certbuf_len; + char keybuf[10000]; + size_t keybuf_len; + + unsigned int result = read_file(progname, certfile, certbuf, sizeof(certbuf), &certbuf_len); + if(result != DEMO_OK) { + return NULL; + } + + result = read_file(progname, keyfile, keybuf, sizeof(keybuf), &keybuf_len); + if(result != DEMO_OK) { + return NULL; + } + + const struct rustls_certified_key *certified_key; + result = rustls_certified_key_build((uint8_t *)certbuf, + certbuf_len, + (uint8_t *)keybuf, + keybuf_len, + &certified_key); + if(result != RUSTLS_RESULT_OK) { + print_error(progname, "parsing certificate and key", result); + return NULL; + } + return certified_key; +} diff --git a/tests/common.h b/tests/common.h index 6efdffc3..37a4a567 100644 --- a/tests/common.h +++ b/tests/common.h @@ -125,4 +125,10 @@ get_first_header_value(const char *headers, size_t headers_len, void log_cb(void *userdata, const struct rustls_log_params *params); +enum demo_result +read_file(const char *progname, const char *filename, char *buf, size_t buflen, size_t *n); + +const struct rustls_certified_key * +load_cert_and_key(const char *progname, const char *certfile, const char *keyfile); + #endif /* COMMON_H */ diff --git a/tests/server.c b/tests/server.c index a394cb34..43c31eeb 100644 --- a/tests/server.c +++ b/tests/server.c @@ -23,24 +23,6 @@ #include "rustls.h" #include "common.h" -enum demo_result -read_file(const char *filename, char *buf, size_t buflen, size_t *n) -{ - FILE *f = fopen(filename, "r"); - if(f == NULL) { - fprintf(stderr, "server: opening %s: %s\n", filename, strerror(errno)); - return DEMO_ERROR; - } - *n = fread(buf, 1, buflen, f); - if(!feof(f)) { - fprintf(stderr, "server: reading %s: %s\n", filename, strerror(errno)); - fclose(f); - return DEMO_ERROR; - } - fclose(f); - return DEMO_OK; -} - typedef enum exchange_state { READING_REQUEST, @@ -242,37 +224,6 @@ handle_conn(struct conndata *conn) free(conn); } -const struct rustls_certified_key * -load_cert_and_key(const char *certfile, const char *keyfile) -{ - char certbuf[10000]; - size_t certbuf_len; - char keybuf[10000]; - size_t keybuf_len; - - unsigned int result = read_file(certfile, certbuf, sizeof(certbuf), &certbuf_len); - if(result != DEMO_OK) { - return NULL; - } - - result = read_file(keyfile, keybuf, sizeof(keybuf), &keybuf_len); - if(result != DEMO_OK) { - return NULL; - } - - const struct rustls_certified_key *certified_key; - result = rustls_certified_key_build((uint8_t *)certbuf, - certbuf_len, - (uint8_t *)keybuf, - keybuf_len, - &certified_key); - if(result != RUSTLS_RESULT_OK) { - print_error("server", "parsing certificate and key", result); - return NULL; - } - return certified_key; -} - bool shutting_down = false; void handle_signal(int signo) { @@ -294,6 +245,8 @@ main(int argc, const char **argv) struct rustls_connection *rconn = NULL; const struct rustls_certified_key *certified_key = NULL; struct rustls_slice_bytes alpn_http11; + const struct rustls_client_cert_verifier *client_cert_verifier = NULL; + struct rustls_root_cert_store *client_cert_root_store = NULL; alpn_http11.data = (unsigned char*)"http/1.1"; alpn_http11.len = 8; @@ -315,7 +268,7 @@ main(int argc, const char **argv) goto cleanup; } - certified_key = load_cert_and_key(argv[1], argv[2]); + certified_key = load_cert_and_key(argv[0], argv[1], argv[2]); if(certified_key == NULL) { goto cleanup; } @@ -324,6 +277,22 @@ main(int argc, const char **argv) config_builder, &certified_key, 1); rustls_server_config_builder_set_alpn_protocols(config_builder, &alpn_http11, 1); + char* auth_cert = getenv("AUTH_CERT"); + if(auth_cert) { + char certbuf[10000]; + size_t certbuf_len; + int result = read_file(argv[0], auth_cert, certbuf, sizeof(certbuf), &certbuf_len); + if(result != DEMO_OK) { + goto cleanup; + } + + client_cert_root_store = rustls_root_cert_store_new(); + rustls_root_cert_store_add_pem(client_cert_root_store, (uint8_t *)certbuf, certbuf_len, true); + + client_cert_verifier = rustls_client_cert_verifier_new(client_cert_root_store); + rustls_server_config_builder_set_client_verifier(config_builder, client_cert_verifier); + } + server_config = rustls_server_config_builder_build(config_builder); #ifdef _WIN32 @@ -399,6 +368,8 @@ main(int argc, const char **argv) cleanup: rustls_certified_key_free(certified_key); + rustls_root_cert_store_free(client_cert_root_store); + rustls_client_cert_verifier_free(client_cert_verifier); rustls_server_config_free(server_config); rustls_connection_free(rconn); if(sockfd>0) {