Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/openssl/evp.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
#include <openssl/digest.h>
#include <openssl/nid.h>
#include <openssl/objects.h>
#include <openssl/hmac.h> // Needed by Apache mod_ssl

#if defined(__cplusplus)
extern "C" {
Expand Down
37 changes: 37 additions & 0 deletions include/openssl/ssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,34 @@ OPENSSL_EXPORT size_t SSL_get0_peer_delegation_algorithms(
// - SSL_CLIENT_HELLO_RETRY (not supported) is handled like SSL_CLIENT_HELLO_ERROR
typedef int (*SSL_client_hello_cb_fn)(SSL *s, int *al, void *arg);

// SSL_client_hello_get1_extensions_present iterates over the extensions in the
// ClientHello. If any are found, it allocates an array of int and sets |*out|
// to point to this array and |*outlen| to the number of extensions. The ints
// in the array correspond to the type of each extension. The caller is
// responsible for releasing the array with OPENSSL_free. If no extensions are
// found, it sets |*out| to NULL and |*outlen| to 0. The function returns 1 on
// success and returns 0 on error.
//
// This function can only be called from within a client hello callback (see
// |SSL_CTX_set_client_hello_cb|) or during server certificate selection (see
// |SSL_CTX_set_select_certificate_cb|).
OPENSSL_EXPORT int SSL_client_hello_get1_extensions_present(SSL *s, int **out, size_t *outlen);

// SSL_client_hello_get_extension_order iterates over the extensions in the
// ClientHello. If |exts| is not null, the type for each extension will be
// stored in |exts| and |*num_exts| should be the size of storage
// allocated for |exts|; the function will return an error if |*num_exts| is
// too small. On success, the function will return 1 and will set |*num_exts| to
// the number of extensions. The caller may pass |exts| as null to obtain the
// number of extensions. If no ClientHello extensions are found, the
// function returns 1 and sets |*num_exts| to 0. The functions returns 0 on
// error.
//
// This function can only be called from within a client hello callback (see
// |SSL_CTX_set_client_hello_cb|) or during server certificate selection (see
// |SSL_CTX_set_select_certificate_cb|).
OPENSSL_EXPORT int SSL_client_hello_get_extension_order(SSL *s, uint16_t *exts, size_t *num_exts);

// SSL_CTX_set_client_hello_cb configures a callback that is called when a
// ClientHello message is received. This can be used to select certificates,
// adjust settings, or otherwise make decisions about the connection before
Expand All @@ -1144,6 +1172,15 @@ OPENSSL_EXPORT void SSL_CTX_set_client_hello_cb(SSL_CTX *c, SSL_client_hello_cb_
// SSL_client_hello_isv2 always returns zero as SSLv2 is not supported.
OPENSSL_EXPORT int SSL_client_hello_isv2(SSL *s);


// SSL_client_hello_get0_legacy_version provides the value of the
// "legacy_version" field in the client hello.
//
// This function can only be called from within a client hello callback (see
// |SSL_CTX_set_client_hello_cb|) or during server certificate selection (see
// |SSL_CTX_set_select_certificate_cb|).
OPENSSL_EXPORT unsigned int SSL_client_hello_get0_legacy_version(SSL *s);

// SSL_client_hello_get0_ext searches the extensions in the ClientHello for an
// extension of the given type. If found, it sets |*out| to point to the
// extension contents (not including the type and length bytes) and |*outlen|
Expand Down
179 changes: 179 additions & 0 deletions ssl/ssl_client_hello_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@

#include <openssl/ssl.h>
#include <openssl/tls1.h>
#include <openssl/mem.h>

#include "ssl_common_test.h"

#include <memory>

BSSL_NAMESPACE_BEGIN

namespace {
Expand Down Expand Up @@ -345,6 +348,182 @@ TEST(SSLClientHelloTest, ClientHelloKnownExtensions) {
EXPECT_GT(results.supported_groups_len, 0u);
}

struct ExtensionsPresentTestArgs {
bool *called;
bool expect_session_ticket;
};

int callback_SSL_client_hello_get1_extensions_present_impl(
SSL *ssl, int *al, void *arg) {
auto *args = static_cast<ExtensionsPresentTestArgs *>(arg);
*(args->called) = true;

int *extensions = nullptr;
size_t extensions_len = 0;
if (!SSL_client_hello_get1_extensions_present(ssl, &extensions,
&extensions_len)) {
ADD_FAILURE() << "SSL_client_hello_get1_extensions_present failed";
return SSL_CLIENT_HELLO_ERROR;
}

EXPECT_GT(extensions_len, 0u);
EXPECT_TRUE(extensions);

unsigned legacy_version = SSL_client_hello_get0_legacy_version(ssl);
EXPECT_EQ(legacy_version, (unsigned)TLS1_2_VERSION);

// Verify a few common extensions are present
bool found_supported_groups = false;
bool found_session_ticket = false;
for (size_t i = 0; i < extensions_len; i++) {
if (extensions[i] == TLSEXT_TYPE_supported_groups) {
found_supported_groups = true;
}
if (extensions[i] == TLSEXT_TYPE_session_ticket) {
found_session_ticket = true;
}
}
EXPECT_TRUE(found_supported_groups);
EXPECT_EQ(found_session_ticket, args->expect_session_ticket);

OPENSSL_free(extensions);

return SSL_CLIENT_HELLO_SUCCESS;
}

// Test SSL_client_hello_get1_extensions_present with a client hello that has
// extensions.
TEST(SSLClientHelloTest, ExtensionsPresent) {
UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method()));
UniquePtr<SSL_CTX> server_ctx =
CreateContextWithTestCertificate(TLS_method());
ASSERT_TRUE(client_ctx);
ASSERT_TRUE(server_ctx);

SSL_CTX_set_info_callback(
client_ctx.get(), [](const SSL *ssl, int type, int val) {
if (type == SSL_CB_HANDSHAKE_START) {
ASSERT_TRUE(
SSL_set_tlsext_host_name(const_cast<SSL *>(ssl), "example.com"));
}
});

bool callback_called = false;
ExtensionsPresentTestArgs args = {&callback_called,
true /* expect_session_ticket */};
SSL_CTX_set_client_hello_cb(
server_ctx.get(), callback_SSL_client_hello_get1_extensions_present_impl,
&args);

UniquePtr<SSL> client, server;
ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
server_ctx.get()));
EXPECT_TRUE(callback_called);
}

// Test SSL_client_hello_get1_extensions_present with a client hello that has
// no session ticket extension.
TEST(SSLClientHelloTest, NoTicketExtensionPresent) {
UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method()));
UniquePtr<SSL_CTX> server_ctx =
CreateContextWithTestCertificate(TLS_method());
ASSERT_TRUE(client_ctx);
ASSERT_TRUE(server_ctx);

// Disable all extensions on the client to simulate a "no extensions" scenario
// Note: This is a bit artificial as the library might add some extensions
// by default. We rely on the callback to check the result.
SSL_CTX_set_options(client_ctx.get(), SSL_OP_NO_TICKET);

bool callback_called = false;
ExtensionsPresentTestArgs args = {&callback_called,
false /* expect_session_ticket */};
SSL_CTX_set_client_hello_cb(
server_ctx.get(), callback_SSL_client_hello_get1_extensions_present_impl,
&args);

UniquePtr<SSL> client, server;
ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
server_ctx.get()));
EXPECT_TRUE(callback_called);
}

// Test SSL_client_hello_get_extension_order to verify its behavior with
// different buffer sizes and to ensure it correctly reports the number of
// extensions.
TEST(SSLClientHelloTest, GetExtensionOrder) {
UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method()));
UniquePtr<SSL_CTX> server_ctx =
CreateContextWithTestCertificate(TLS_method());
ASSERT_TRUE(client_ctx);
ASSERT_TRUE(server_ctx);

bool callback_called = false;
SSL_CTX_set_client_hello_cb(
server_ctx.get(),
[](SSL *ssl, int *al, void *arg) -> int {
bool *called = static_cast<bool *>(arg);
*called = true;

size_t num_extensions = 0;
// First, call with a null buffer to get the count of extensions.
if (SSL_client_hello_get_extension_order(ssl, nullptr,
&num_extensions) != 1) {
ADD_FAILURE()
<< "Failed initial call to SSL_client_hello_get_extension_order";
return SSL_CLIENT_HELLO_ERROR;
}
EXPECT_GT(num_extensions, 0u);

// Allocate a buffer of the correct size and get the extensions.
uint16_t *exts = static_cast<uint16_t *>(
OPENSSL_zalloc(sizeof(uint16_t) * num_extensions));
if (exts == nullptr) {
ADD_FAILURE() << "Failed to allocate extensions";
return SSL_CLIENT_HELLO_ERROR;
}
if (SSL_client_hello_get_extension_order(ssl, exts, &num_extensions) !=
1) {
ADD_FAILURE()
<< "Failed call to SSL_client_hello_get_extension_order";
OPENSSL_free(exts);
return SSL_CLIENT_HELLO_ERROR;
}

unsigned legacy_version = SSL_client_hello_get0_legacy_version(ssl);
EXPECT_EQ(legacy_version, static_cast<unsigned>(TLS1_2_VERSION));

// Call with a buffer that is too small and confirm it fails.
size_t too_small_num_extensions = num_extensions - 1;
uint16_t *too_small_exts = static_cast<uint16_t *>(
OPENSSL_zalloc(sizeof(uint16_t) * too_small_num_extensions));
if (!too_small_exts) {
OPENSSL_free(exts);
ADD_FAILURE() << "Failed to allocate too small buffer";
return SSL_CLIENT_HELLO_ERROR;
}
// Expect failure
if (SSL_client_hello_get_extension_order(
ssl, too_small_exts, &too_small_num_extensions) != 0) {
OPENSSL_free(exts);
OPENSSL_free(too_small_exts);
ADD_FAILURE()
<< "Failed call to SSL_client_hello_get_extension_order";
return SSL_CLIENT_HELLO_ERROR;
}
OPENSSL_free(exts);
OPENSSL_free(too_small_exts);

return SSL_CLIENT_HELLO_SUCCESS;
},
&callback_called);

UniquePtr<SSL> client, server;
ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
server_ctx.get()));
EXPECT_TRUE(callback_called);
}

} // namespace

BSSL_NAMESPACE_END
105 changes: 104 additions & 1 deletion ssl/ssl_lib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3081,6 +3081,110 @@ int SSL_client_hello_get0_ext(SSL *s, unsigned int type, const unsigned char **o
return 1; // Success
}

int SSL_client_hello_get1_extensions_present(SSL *s, int **out,
size_t *outlen) {
GUARD_PTR(s);
GUARD_PTR(out);
GUARD_PTR(outlen);
size_t num_extensions = 0;

// Count the number of extensions so we can allocate
if (1 != SSL_client_hello_get_extension_order(s, nullptr, &num_extensions)) {
return 0;
}

if (num_extensions == 0) {
*out = nullptr;
*outlen = 0;
return 1;
}

// Allocate a uint16_t for each extension
uint16_t *exts =
static_cast<uint16_t *>(OPENSSL_zalloc(sizeof(uint16_t) * num_extensions));
if (exts == nullptr) {
return 0;
}

// Collect the type for each extension
if (1 != SSL_client_hello_get_extension_order(s, exts, &num_extensions)) {
OPENSSL_free(exts);
return 0;
}

// Allocate the int array needed by caller.
int *ext_types =
static_cast<int *>(OPENSSL_zalloc(sizeof(int) * num_extensions));
if (ext_types == nullptr) {
OPENSSL_free(exts);
return 0;
}

// Cast each uint16_t type to an int
for (size_t i = 0; i < num_extensions; i++) {
ext_types[i] = exts[i];
}
OPENSSL_free(exts);

*out = ext_types;
*outlen = num_extensions;

return 1;
}

int SSL_client_hello_get_extension_order(SSL *s, uint16_t *exts, size_t *num_exts) {
GUARD_PTR(s);
GUARD_PTR(s->s3);
SSL_HANDSHAKE *hs = s->s3->hs.get();
GUARD_PTR(hs);

SSLMessage msg_unused;
SSL_CLIENT_HELLO client_hello;
if (!hs->GetClientHello(&msg_unused, &client_hello)) {
return 0;
}

CBS extensions;
CBS_init(&extensions, client_hello.extensions, client_hello.extensions_len);

size_t num_extensions = 0;
while (CBS_len(&extensions) > 0) {
uint16_t type = 0;
CBS body;
if (!CBS_get_u16(&extensions, &type) ||
!CBS_get_u16_length_prefixed(&extensions, &body)) {
OPENSSL_PUT_ERROR(SSL, SSL_R_DECODE_ERROR);
return 0;
}
if (exts != nullptr) {
// num_exts is an in/out param. Return error if insufficient size.
if (num_extensions >= *num_exts) {
return 0;
}
// Store the type for each extension
exts[num_extensions] = type;
}
num_extensions++;
}
*num_exts = num_extensions;

return 1;
}

unsigned int SSL_client_hello_get0_legacy_version(SSL *s) {
GUARD_PTR(s);
GUARD_PTR(s->s3);
SSL_HANDSHAKE *hs = s->s3->hs.get();
GUARD_PTR(hs);

SSLMessage msg_unused;
SSL_CLIENT_HELLO client_hello;
if (!hs->GetClientHello(&msg_unused, &client_hello)) {
return 0;
}
return client_hello.version;
}

void SSL_CTX_set_keylog_callback(SSL_CTX *ctx,
void (*cb)(const SSL *ssl, const char *line)) {
ctx->keylog_callback = cb;
Expand Down Expand Up @@ -3655,4 +3759,3 @@ OPENSSL_EXPORT int SSL_get_write_traffic_secret(
int SSL_verify_client_post_handshake(SSL *ssl) {
return 0;
}

Loading