-
Notifications
You must be signed in to change notification settings - Fork 13.3k
Fix mem leak in SSL server, allow for concurrent client and server connections w/o interference #4305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix mem leak in SSL server, allow for concurrent client and server connections w/o interference #4305
Changes from 4 commits
8c6ea61
e9102b4
5e438aa
88d5af1
a90a884
5275b5e
2c654bf
03be43a
81e8438
56124a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,37 +74,47 @@ typedef std::list<BufferItem> BufferList; | |
class SSLContext | ||
{ | ||
public: | ||
SSLContext() | ||
SSLContext(bool isServer = false) | ||
{ | ||
if (_ssl_ctx_refcnt == 0) { | ||
_ssl_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0); | ||
_isServer = isServer; | ||
if (!_isServer) { | ||
if (_ssl_client_ctx_refcnt == 0) { | ||
_ssl_client_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0); | ||
} | ||
++_ssl_client_ctx_refcnt; | ||
} else { | ||
if (_ssl_svr_ctx_refcnt == 0) { | ||
_ssl_svr_ctx = ssl_ctx_new(SSL_SERVER_VERIFY_LATER | SSL_DEBUG_OPTS | SSL_CONNECT_IN_PARTS | SSL_READ_BLOCKING | SSL_NO_DEFAULT_KEY, 0); | ||
} | ||
++_ssl_svr_ctx_refcnt; | ||
} | ||
++_ssl_ctx_refcnt; | ||
} | ||
|
||
~SSLContext() | ||
{ | ||
if (_ssl) { | ||
ssl_free(_ssl); | ||
_ssl = nullptr; | ||
if (io_ctx) { | ||
io_ctx->unref(); | ||
io_ctx = nullptr; | ||
} | ||
|
||
--_ssl_ctx_refcnt; | ||
if (_ssl_ctx_refcnt == 0) { | ||
ssl_ctx_free(_ssl_ctx); | ||
_ssl = nullptr; | ||
if (!_isServer) { | ||
--_ssl_client_ctx_refcnt; | ||
if (_ssl_client_ctx_refcnt == 0) { | ||
ssl_ctx_free(_ssl_client_ctx); | ||
_ssl_client_ctx = nullptr; | ||
} | ||
} else { | ||
--_ssl_svr_ctx_refcnt; | ||
if (_ssl_svr_ctx_refcnt == 0) { | ||
ssl_ctx_free(_ssl_svr_ctx); | ||
_ssl_svr_ctx = nullptr; | ||
} | ||
} | ||
} | ||
|
||
void ref() | ||
static void _delete_shared_SSL(SSL *_to_del) | ||
{ | ||
++_refcnt; | ||
} | ||
|
||
void unref() | ||
{ | ||
if (--_refcnt == 0) { | ||
delete this; | ||
} | ||
ssl_free(_to_del); | ||
} | ||
|
||
void connect(ClientContext* ctx, const char* hostName, uint32_t timeout_ms) | ||
|
@@ -116,50 +126,64 @@ class SSLContext | |
ssl_free will want to send a close notify alert, but the old TCP connection | ||
is already gone at this point, so reset io_ctx. */ | ||
io_ctx = nullptr; | ||
ssl_free(_ssl); | ||
_ssl = nullptr; | ||
_available = 0; | ||
_read_ptr = nullptr; | ||
} | ||
io_ctx = ctx; | ||
_ssl = ssl_client_new(_ssl_ctx, reinterpret_cast<int>(this), nullptr, 0, ext); | ||
ctx->ref(); | ||
|
||
// Wrap the new SSL with a smart pointer, custom deleter to call ssl_free | ||
SSL *_new_ssl = ssl_client_new(_ssl_client_ctx, reinterpret_cast<int>(this), nullptr, 0, ext); | ||
std::shared_ptr<SSL> _new_ssl_shared(_new_ssl, _delete_shared_SSL); | ||
_ssl = _new_ssl_shared; | ||
|
||
uint32_t t = millis(); | ||
|
||
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) { | ||
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) { | ||
uint8_t* data; | ||
int rc = ssl_read(_ssl, &data); | ||
int rc = ssl_read(_ssl.get(), &data); | ||
if (rc < SSL_OK) { | ||
ssl_display_error(rc); | ||
break; | ||
} | ||
} | ||
} | ||
|
||
void connectServer(ClientContext *ctx) { | ||
void connectServer(ClientContext *ctx, uint32_t timeout_ms) | ||
{ | ||
io_ctx = ctx; | ||
_ssl = ssl_server_new(_ssl_ctx, reinterpret_cast<int>(this)); | ||
_isServer = true; | ||
ctx->ref(); | ||
|
||
// Wrap the new SSL with a smart pointer, custom deleter to call ssl_free | ||
SSL *_new_ssl = ssl_server_new(_ssl_svr_ctx, reinterpret_cast<int>(this)); | ||
std::shared_ptr<SSL> _new_ssl_shared(_new_ssl, _delete_shared_SSL); | ||
_ssl = _new_ssl_shared; | ||
|
||
uint32_t timeout_ms = 5000; | ||
uint32_t t = millis(); | ||
|
||
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl) != SSL_OK) { | ||
while (millis() - t < timeout_ms && ssl_handshake_status(_ssl.get()) != SSL_OK) { | ||
uint8_t* data; | ||
int rc = ssl_read(_ssl, &data); | ||
int rc = ssl_read(_ssl.get(), &data); | ||
if (rc < SSL_OK) { | ||
ssl_display_error(rc); | ||
break; | ||
} | ||
} | ||
} | ||
|
||
void stop() | ||
{ | ||
if (io_ctx) { | ||
io_ctx->unref(); | ||
} | ||
io_ctx = nullptr; | ||
} | ||
|
||
bool connected() | ||
{ | ||
if (_isServer) return _ssl != nullptr; | ||
else return _ssl != nullptr && ssl_handshake_status(_ssl) == SSL_OK; | ||
else return _ssl != nullptr && ssl_handshake_status(_ssl.get()) == SSL_OK; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent return line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done |
||
} | ||
|
||
int read(uint8_t* dst, size_t size) | ||
|
@@ -292,7 +316,7 @@ class SSLContext | |
|
||
bool loadObject(int type, const uint8_t* data, size_t size) | ||
{ | ||
int rc = ssl_obj_memory_load(_ssl_ctx, type, data, static_cast<int>(size), nullptr); | ||
int rc = ssl_obj_memory_load(_isServer?_ssl_svr_ctx:_ssl_client_ctx, type, data, static_cast<int>(size), nullptr); | ||
if (rc != SSL_OK) { | ||
DEBUGV("loadObject: ssl_obj_memory_load returned %d\n", rc); | ||
return false; | ||
|
@@ -302,7 +326,7 @@ class SSLContext | |
|
||
bool verifyCert() | ||
{ | ||
int rc = ssl_verify_cert(_ssl); | ||
int rc = ssl_verify_cert(_ssl.get()); | ||
if (_allowSelfSignedCerts && rc == SSL_X509_ERROR(X509_VFY_ERROR_SELF_SIGNED)) { | ||
DEBUGV("Allowing self-signed certificate\n"); | ||
return true; | ||
|
@@ -321,12 +345,14 @@ class SSLContext | |
|
||
operator SSL*() | ||
{ | ||
return _ssl; | ||
return _ssl.get(); | ||
} | ||
|
||
static ClientContext* getIOContext(int fd) | ||
{ | ||
return reinterpret_cast<SSLContext*>(fd)->io_ctx; | ||
if (!fd) return nullptr; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. indent return to new line There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. I did a 'grep if | grep ";"' and didn't see any other spots. You've got good eyes to have found them! |
||
SSLContext *thisSSL = reinterpret_cast<SSLContext*>(fd); | ||
return thisSSL->io_ctx; | ||
} | ||
|
||
protected: | ||
|
@@ -339,10 +365,9 @@ class SSLContext | |
optimistic_yield(100); | ||
|
||
uint8_t* data; | ||
int rc = ssl_read(_ssl, &data); | ||
int rc = ssl_read(_ssl.get(), &data); | ||
if (rc <= 0) { | ||
if (rc < SSL_OK && rc != SSL_CLOSE_NOTIFY && rc != SSL_ERROR_CONN_LOST) { | ||
ssl_free(_ssl); | ||
_ssl = nullptr; | ||
} | ||
return 0; | ||
|
@@ -359,7 +384,7 @@ class SSLContext | |
return 0; | ||
} | ||
|
||
int rc = ssl_write(_ssl, src, size); | ||
int rc = ssl_write(_ssl.get(), src, size); | ||
if (rc >= 0) { | ||
return rc; | ||
} | ||
|
@@ -404,19 +429,22 @@ class SSLContext | |
} | ||
|
||
bool _isServer = false; | ||
static SSL_CTX* _ssl_ctx; | ||
static int _ssl_ctx_refcnt; | ||
SSL* _ssl = nullptr; | ||
int _refcnt = 0; | ||
static SSL_CTX* _ssl_client_ctx; | ||
static int _ssl_client_ctx_refcnt; | ||
static SSL_CTX* _ssl_svr_ctx; | ||
static int _ssl_svr_ctx_refcnt; | ||
std::shared_ptr<SSL> _ssl = nullptr; | ||
const uint8_t* _read_ptr = nullptr; | ||
size_t _available = 0; | ||
BufferList _writeBuffers; | ||
bool _allowSelfSignedCerts = false; | ||
ClientContext* io_ctx = nullptr; | ||
}; | ||
|
||
SSL_CTX* SSLContext::_ssl_ctx = nullptr; | ||
int SSLContext::_ssl_ctx_refcnt = 0; | ||
SSL_CTX* SSLContext::_ssl_client_ctx = nullptr; | ||
int SSLContext::_ssl_client_ctx_refcnt = 0; | ||
SSL_CTX* SSLContext::_ssl_svr_ctx = nullptr; | ||
int SSLContext::_ssl_svr_ctx_refcnt = 0; | ||
|
||
WiFiClientSecure::WiFiClientSecure() | ||
{ | ||
|
@@ -426,41 +454,25 @@ WiFiClientSecure::WiFiClientSecure() | |
|
||
WiFiClientSecure::~WiFiClientSecure() | ||
{ | ||
if (_ssl) { | ||
_ssl->unref(); | ||
} | ||
} | ||
|
||
WiFiClientSecure::WiFiClientSecure(const WiFiClientSecure& other) | ||
: WiFiClient(static_cast<const WiFiClient&>(other)) | ||
{ | ||
_ssl = other._ssl; | ||
if (_ssl) { | ||
_ssl->ref(); | ||
} | ||
} | ||
|
||
WiFiClientSecure& WiFiClientSecure::operator=(const WiFiClientSecure& rhs) | ||
{ | ||
(WiFiClient&) *this = rhs; | ||
_ssl = rhs._ssl; | ||
if (_ssl) { | ||
_ssl->ref(); | ||
} | ||
return *this; | ||
_ssl = nullptr; | ||
} | ||
|
||
// Only called by the WifiServerSecure, need to get the keys/certs loaded before beginning | ||
WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const uint8_t *rsakey, int rsakeyLen, const uint8_t *cert, int certLen) | ||
WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, | ||
const uint8_t *rsakey, int rsakeyLen, | ||
const uint8_t *cert, int certLen) | ||
{ | ||
// TLS handshake may take more than the 5 second default timeout | ||
_timeout = 15000; | ||
|
||
// We've been given the client context from the available() call | ||
_client = client; | ||
if (_ssl) { | ||
_ssl->unref(); | ||
_ssl = nullptr; | ||
} | ||
_client->ref(); | ||
|
||
_ssl = new SSLContext; | ||
_ssl->ref(); | ||
// Make the "_ssl" SSLContext, in the constructor there should be none yet | ||
SSLContext *_new_ssl = new SSLContext(true); | ||
std::shared_ptr<SSLContext> _new_ssl_shared(_new_ssl); | ||
_ssl = _new_ssl_shared; | ||
|
||
if (usePMEM) { | ||
if (rsakey && rsakeyLen) { | ||
|
@@ -477,8 +489,7 @@ WiFiClientSecure::WiFiClientSecure(ClientContext* client, bool usePMEM, const ui | |
_ssl->loadObject(SSL_OBJ_X509_CERT, cert, certLen); | ||
} | ||
} | ||
_client->ref(); | ||
_ssl->connectServer(client); | ||
_ssl->connectServer(client, _timeout); | ||
} | ||
|
||
int WiFiClientSecure::connect(IPAddress ip, uint16_t port) | ||
|
@@ -510,14 +521,12 @@ int WiFiClientSecure::connect(const String host, uint16_t port) | |
int WiFiClientSecure::_connectSSL(const char* hostName) | ||
{ | ||
if (!_ssl) { | ||
_ssl = new SSLContext; | ||
_ssl->ref(); | ||
_ssl = std::make_shared<SSLContext>(); | ||
} | ||
_ssl->connect(_client, hostName, _timeout); | ||
|
||
auto status = ssl_handshake_status(*_ssl); | ||
if (status != SSL_OK) { | ||
_ssl->unref(); | ||
_ssl = nullptr; | ||
return 0; | ||
} | ||
|
@@ -537,7 +546,6 @@ size_t WiFiClientSecure::write(const uint8_t *buf, size_t size) | |
} | ||
|
||
if (rc != SSL_CLOSE_NOTIFY) { | ||
_ssl->unref(); | ||
_ssl = nullptr; | ||
} | ||
|
||
|
@@ -640,8 +648,6 @@ void WiFiClientSecure::stop() | |
{ | ||
if (_ssl) { | ||
_ssl->stop(); | ||
_ssl->unref(); | ||
_ssl = nullptr; | ||
} | ||
WiFiClient::stop(); | ||
} | ||
|
@@ -723,9 +729,9 @@ bool WiFiClientSecure::_verifyDN(const char* domain_name) | |
String domain_name_str(domain_name); | ||
domain_name_str.toLowerCase(); | ||
|
||
const char* san = NULL; | ||
const char* san = nullptr; | ||
int i = 0; | ||
while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != NULL) { | ||
while ((san = ssl_get_cert_subject_alt_dnsname(*_ssl, i)) != nullptr) { | ||
String san_str(san); | ||
san_str.toLowerCase(); | ||
if (matchName(san_str, domain_name_str)) { | ||
|
@@ -759,8 +765,7 @@ bool WiFiClientSecure::verifyCertChain(const char* domain_name) | |
void WiFiClientSecure::_initSSLContext() | ||
{ | ||
if (!_ssl) { | ||
_ssl = new SSLContext; | ||
_ssl->ref(); | ||
_ssl = std::make_shared<SSLContext>(); | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These if(server) else client statements hint at separating SSLContext into SSLServerContext and SSLClientContext. However, I suspect that doing that would require even more changes, and right now we need to improve stability. So let's handle that later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be straightforward to make the SSLServerContext from the unmodified SSLContext, because there's only 3 or spots where you'll see this check or a different codepath dependent on this. That said, I'm not sure it helps understanding by breaking it into another subclass...frankly most of my time was spent going through the inherited classes to see WTH was going on instead of just being able to peek at one file/class and see the inner workings. We can revisit if axtls sticks around. BearSSL doesn't need this kind of abstraction at all, so the code there is simpler.
As for stability, this code spent 5 hours being beaten by a while(true) mosquiit_pub sending it data, it mqtt publishing to a SSL mosquitto server using a Client cert every 5 seconds, and a while(true) wget https://esp8266 loop to make it serve web pages. Not even a hiccup noted...
Wish I could package this setup as a test, but it needs a mosquitto server, fixed IPs (for the certs), and a couple monitors to make sure the mqtt mesages and the wget continue working. I wouldn't know where to start...