28
28
#include <sys/un.h>
29
29
#endif
30
30
31
+ #include <openssl/ssl.h>
32
+
31
33
#include "char_buffer.h"
32
34
#include "socket_manager.h"
33
35
#include "hash_table.h"
@@ -50,6 +52,8 @@ struct sm_private {
50
52
fd_set * server_fds ; // can on_accept, i.e. "is_server"
51
53
fd_set * send_fds ; // blocked sends, same as fd_to_sendq.keys
52
54
fd_set * recv_fds ; // can recv, same as all_fds - sendq.recv_fd's
55
+ // fd to ssl_session
56
+ ht_t fd_to_ssl ;
53
57
// fd to on_* callback
54
58
ht_t fd_to_value ;
55
59
// fd to blocked sm_sendq_t, often empty
@@ -301,7 +305,8 @@ sm_status sm_on_debug(sm_t self, const char *format, ...) {
301
305
return SM_SUCCESS ;
302
306
}
303
307
304
- sm_status sm_add_fd (sm_t self , int fd , void * value , bool is_server ) {
308
+ sm_status sm_add_fd (sm_t self , int fd , void * ssl_session , void * value ,
309
+ bool is_server ) {
305
310
sm_private_t my = self -> private_state ;
306
311
if (FD_ISSET (fd , my -> all_fds )) {
307
312
return SM_ERROR ;
@@ -310,6 +315,9 @@ sm_status sm_add_fd(sm_t self, int fd, void *value, bool is_server) {
310
315
// The above FD_ISSET(..master..) should prevent this
311
316
return SM_ERROR ;
312
317
}
318
+ if (ssl_session != NULL && ht_put (my -> fd_to_ssl , HT_KEY (fd ), ssl_session )) {
319
+ return SM_ERROR ;
320
+ }
313
321
// is_server == getsockopt(..., SO_ACCEPTCONN, ...)?
314
322
sm_on_debug (self , "ss.add%s_fd(%d)" , (is_server ? "_server" : "" ), fd );
315
323
FD_SET (fd , my -> all_fds );
@@ -332,6 +340,7 @@ sm_status sm_remove_fd(sm_t self, int fd) {
332
340
if (!FD_ISSET (fd , my -> all_fds )) {
333
341
return SM_ERROR ;
334
342
}
343
+ ht_put (my -> fd_to_ssl , HT_KEY (fd ), NULL );
335
344
void * value = ht_put (my -> fd_to_value , HT_KEY (fd ), NULL );
336
345
bool is_server = FD_ISSET (fd , my -> server_fds );
337
346
sm_on_debug (self , "ss.remove%s_fd(%d)" , (is_server ? "_server" : "" ), fd );
@@ -380,20 +389,35 @@ sm_status sm_send(sm_t self, int fd, const char *data, size_t length,
380
389
const char * head = data ;
381
390
const char * tail = data + length ;
382
391
if (!sendq ) {
392
+ void * ssl_session = ht_get_value (my -> fd_to_ssl , HT_KEY (fd ));
383
393
// send as much as we can without blocking
384
394
while (1 ) {
385
- ssize_t sent_bytes = send (fd , (void * )head , (tail - head ), 0 );
386
- if (sent_bytes <= 0 ) {
395
+ ssize_t sent_bytes ;
396
+ if (ssl_session == NULL ) {
397
+ sent_bytes = send (fd , (void * )head , (tail - head ), 0 );
398
+ if (sent_bytes <= 0 ) {
387
399
#ifdef WIN32
388
- if (sent_bytes && WSAGetLastError () != WSAEWOULDBLOCK ) {
400
+ if (sent_bytes && WSAGetLastError () != WSAEWOULDBLOCK ) {
389
401
#else
390
- if (sent_bytes && errno != EWOULDBLOCK ) {
402
+ if (sent_bytes && errno != EWOULDBLOCK ) {
391
403
#endif
392
- sm_on_debug (self , "ss.failed fd=%d" , fd );
393
- perror ("send failed" );
394
- return SM_ERROR ;
404
+ sm_on_debug (self , "ss.failed fd=%d" , fd );
405
+ perror ("send failed" );
406
+ return SM_ERROR ;
407
+ }
408
+ break ;
409
+ }
410
+ } else {
411
+ sent_bytes = SSL_write ((SSL * )ssl_session , (void * )head , tail - head );
412
+ if (sent_bytes <= 0 ) {
413
+ if (SSL_get_error (ssl_session , sent_bytes ) != SSL_ERROR_WANT_READ &&
414
+ SSL_get_error (ssl_session , sent_bytes ) != SSL_ERROR_WANT_WRITE ) {
415
+ sm_on_debug (self , "ss.failed fd=%d" , fd );
416
+ perror ("ssl send failed" );
417
+ return SM_ERROR ;
418
+ }
419
+ break ;
395
420
}
396
- break ;
397
421
}
398
422
head += sent_bytes ;
399
423
if (head >= tail ) {
@@ -454,7 +478,7 @@ void sm_accept(sm_t self, int fd) {
454
478
#else
455
479
close (new_fd );
456
480
#endif
457
- } else if (self -> add_fd (self , new_fd , new_value , false)) {
481
+ } else if (self -> add_fd (self , new_fd , NULL , new_value , false)) {
458
482
self -> on_close (self , new_fd , new_value , false);
459
483
#ifdef WIN32
460
484
closesocket (new_fd );
@@ -468,27 +492,42 @@ void sm_accept(sm_t self, int fd) {
468
492
void sm_resend (sm_t self , int fd ) {
469
493
sm_private_t my = self -> private_state ;
470
494
sm_sendq_t sendq = ht_get_value (my -> fd_to_sendq , HT_KEY (fd ));
495
+ void * ssl_session = ht_get_value (my -> fd_to_ssl , HT_KEY (fd ));
471
496
while (sendq ) {
472
497
char * head = sendq -> head ;
473
498
char * tail = sendq -> tail ;
474
499
// send as much as we can without blocking
475
500
sm_on_debug (self , "ss.sendq<%p> resume send to fd=%d len=%zd" , sendq , fd ,
476
501
(tail - head ));
477
502
while (head < tail ) {
478
- ssize_t sent_bytes = send (fd , (void * )head , (tail - head ), 0 );
479
- if (sent_bytes <= 0 ) {
503
+ ssize_t sent_bytes ;
504
+ if (ssl_session == NULL ) {
505
+ sent_bytes = send (fd , (void * )head , (tail - head ), 0 );
506
+ if (sent_bytes <= 0 ) {
480
507
#ifdef WIN32
481
- if (sent_bytes && WSAGetLastError () != WSAEWOULDBLOCK ) {
482
- fprintf (stderr , "sendq retry failed with error: %d\n" ,
483
- WSAGetLastError ());
508
+ if (sent_bytes && WSAGetLastError () != WSAEWOULDBLOCK ) {
509
+ fprintf (stderr , "sendq retry failed with error: %d\n" ,
510
+ WSAGetLastError ());
484
511
#else
485
- if (sent_bytes && errno != EWOULDBLOCK ) {
486
- perror ("sendq retry failed" );
512
+ if (sent_bytes && errno != EWOULDBLOCK ) {
513
+ perror ("sendq retry failed" );
487
514
#endif
488
- self -> remove_fd (self , fd );
489
- return ;
515
+ self -> remove_fd (self , fd );
516
+ return ;
517
+ }
518
+ break ;
519
+ }
520
+ } else {
521
+ sent_bytes = SSL_write ((SSL * )ssl_session , (void * )head , tail - head );
522
+ if (sent_bytes <= 0 ) {
523
+ if (SSL_get_error (ssl_session , sent_bytes ) != SSL_ERROR_WANT_READ &&
524
+ SSL_get_error (ssl_session , sent_bytes ) != SSL_ERROR_WANT_WRITE ) {
525
+ perror ("ssl sendq retry failed" );
526
+ self -> remove_fd (self , fd );
527
+ return ;
528
+ }
529
+ break ;
490
530
}
491
- break ;
492
531
}
493
532
head += sent_bytes ;
494
533
}
@@ -535,19 +574,33 @@ void sm_resend(sm_t self, int fd) {
535
574
void sm_recv (sm_t self , int fd ) {
536
575
sm_private_t my = self -> private_state ;
537
576
my -> curr_recv_fd = fd ;
577
+ void * ssl_session = ht_get_value (my -> fd_to_ssl , HT_KEY (fd ));
538
578
while (1 ) {
539
- ssize_t read_bytes = recv (fd , my -> tmp_buf , my -> tmp_buf_length , RECV_FLAGS );
540
- if (read_bytes < 0 ) {
579
+ ssize_t read_bytes ;
580
+ if (ssl_session == NULL ) {
581
+ read_bytes = recv (fd , my -> tmp_buf , my -> tmp_buf_length , RECV_FLAGS );
582
+ if (read_bytes < 0 ) {
541
583
#ifdef WIN32
542
- if (WSAGetLastError () != WSAEWOULDBLOCK ) {
543
- fprintf (stderr , "recv failed with error %d\n" , WSAGetLastError ());
584
+ if (WSAGetLastError () != WSAEWOULDBLOCK ) {
585
+ fprintf (stderr , "recv failed with error %d\n" , WSAGetLastError ());
544
586
#else
545
- if (errno != EWOULDBLOCK ) {
546
- perror ("recv failed" );
587
+ if (errno != EWOULDBLOCK ) {
588
+ perror ("recv failed" );
547
589
#endif
548
- self -> remove_fd (self , fd );
590
+ self -> remove_fd (self , fd );
591
+ }
592
+ break ;
593
+ }
594
+ } else {
595
+ read_bytes = SSL_read ((SSL * )ssl_session , my -> tmp_buf , my -> tmp_buf_length );
596
+ if (read_bytes <= 0 ) {
597
+ if (SSL_get_error (ssl_session , read_bytes ) != SSL_ERROR_WANT_READ &&
598
+ SSL_get_error (ssl_session , read_bytes ) != SSL_ERROR_WANT_WRITE ) {
599
+ perror ("ssl recv failed" );
600
+ self -> remove_fd (self , fd );
601
+ }
602
+ break ;
549
603
}
550
- break ;
551
604
}
552
605
sm_on_debug (self , "ss.recv fd=%d len=%zd" , fd , read_bytes );
553
606
void * value = ht_get_value (my -> fd_to_value , HT_KEY (fd ));
@@ -648,6 +701,7 @@ void sm_private_free(sm_private_t my) {
648
701
free (my -> tmp_send_fds );
649
702
free (my -> tmp_recv_fds );
650
703
free (my -> tmp_fail_fds );
704
+ ht_free (my -> fd_to_ssl );
651
705
ht_free (my -> fd_to_value );
652
706
ht_free (my -> fd_to_sendq );
653
707
free (my -> tmp_buf );
@@ -669,13 +723,14 @@ sm_private_t sm_private_new(size_t buf_length) {
669
723
my -> tmp_send_fds = (fd_set * )malloc (SIZEOF_FD_SET );
670
724
my -> tmp_recv_fds = (fd_set * )malloc (SIZEOF_FD_SET );
671
725
my -> tmp_fail_fds = (fd_set * )malloc (SIZEOF_FD_SET );
726
+ my -> fd_to_ssl = ht_new (HT_INT_KEYS );
672
727
my -> fd_to_value = ht_new (HT_INT_KEYS );
673
728
my -> fd_to_sendq = ht_new (HT_INT_KEYS );
674
729
my -> tmp_buf = (char * )calloc (buf_length , sizeof (char * ));
675
730
if (!my -> tmp_buf || !my -> all_fds || !my -> server_fds ||
676
731
!my -> send_fds || !my -> recv_fds ||
677
732
!my -> tmp_send_fds || !my -> tmp_recv_fds || !my -> tmp_fail_fds ||
678
- !my -> fd_to_value || !my -> fd_to_sendq ) {
733
+ !my -> fd_to_ssl || ! my -> fd_to_value || !my -> fd_to_sendq ) {
679
734
sm_private_free (my );
680
735
return NULL ;
681
736
}
0 commit comments