diff options
Diffstat (limited to 'packages/bun-usockets/src/crypto/openssl.c')
-rw-r--r-- | packages/bun-usockets/src/crypto/openssl.c | 49 |
1 files changed, 19 insertions, 30 deletions
diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c index b6466bcf9..0b55ca866 100644 --- a/packages/bun-usockets/src/crypto/openssl.c +++ b/packages/bun-usockets/src/crypto/openssl.c @@ -86,7 +86,6 @@ struct us_internal_ssl_socket_context_t { /* Pointer to sni tree, created when the context is created and freed likewise when freed */ void *sni; - int pending_handshake; void (*on_handshake)(struct us_internal_ssl_socket_t *, int success, struct us_bun_verify_error_t verify_error, void* custom_data); void* handshake_data; }; @@ -97,6 +96,7 @@ struct us_internal_ssl_socket_t { SSL *ssl; int ssl_write_wants_read; // we use this for now int ssl_read_wants_write; + int pending_handshake; }; int passphrase_cb(char *buf, int size, int rwflag, void *u) { @@ -164,6 +164,7 @@ int BIO_s_custom_read(BIO *bio, char *dst, int length) { struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length) { struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); + struct us_loop_t *loop = us_socket_context_loop(0, &context->sc); struct loop_ssl_data *loop_ssl_data = (struct loop_ssl_data *) loop->data.ssl_data; @@ -186,7 +187,9 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, struct us_internal_ssl_socket_t * result = (struct us_internal_ssl_socket_t *) context->on_open(s, is_client, ip, ip_length); // Hello Message! - if(context->pending_handshake) { + // always handshake after open if on_handshake is set + if(context->on_handshake || s->pending_handshake) { + s->pending_handshake = 1; us_internal_ssl_handshake(s, context->on_handshake, context->handshake_data); } @@ -195,7 +198,6 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, void us_internal_on_ssl_handshake(struct us_internal_ssl_socket_context_t * context, void (*on_handshake)(struct us_internal_ssl_socket_t *, int success, struct us_bun_verify_error_t verify_error, void* custom_data), void* custom_data) { - context->pending_handshake = 1; context->on_handshake = on_handshake; context->handshake_data = custom_data; } @@ -206,7 +208,7 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han // will start on_open, on_writable or on_data if(!s->ssl) { - context->pending_handshake = 1; + s->pending_handshake = 1; context->on_handshake = on_handshake; context->handshake_data = custom_data; return; @@ -218,9 +220,7 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han loop_ssl_data->ssl_socket = &s->s; if (us_socket_is_closed(0, &s->s) || us_internal_ssl_socket_is_shut_down(s)) { - context->pending_handshake = 0; - context->on_handshake = NULL; - context->handshake_data = NULL; + s->pending_handshake = 0; struct us_bun_verify_error_t verify_error = (struct us_bun_verify_error_t) { .error = 0, .code = NULL, .reason = NULL }; if(on_handshake != NULL) { @@ -236,9 +236,7 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han int err = SSL_get_error(s->ssl, result); // as far as I know these are the only errors we want to handle if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) { - context->pending_handshake = 0; - context->on_handshake = NULL; - context->handshake_data = NULL; + s->pending_handshake = 0; struct us_bun_verify_error_t verify_error = us_internal_verify_error(s); // clear per thread error queue if it may contain something @@ -252,7 +250,7 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han } return; } else { - context->pending_handshake = 1; + s->pending_handshake = 1; context->on_handshake = on_handshake; context->handshake_data = custom_data; // Ensure that we'll cycle through internal openssl's state @@ -262,10 +260,7 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han } } else { - context->pending_handshake = 0; - context->on_handshake = NULL; - context->handshake_data = NULL; - + s->pending_handshake = 0; struct us_bun_verify_error_t verify_error = us_internal_verify_error(s); // success @@ -283,16 +278,16 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s, int code, void *reason) { struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); - if (context->pending_handshake) { - context->pending_handshake = 0; + if (s->pending_handshake) { + s->pending_handshake = 0; } return (struct us_internal_ssl_socket_t *) us_socket_close(0, (struct us_socket_t *) s, code, reason); } struct us_internal_ssl_socket_t *ssl_on_close(struct us_internal_ssl_socket_t *s, int code, void *reason) { struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); - if (context->pending_handshake) { - context->pending_handshake = 0; + if (s->pending_handshake) { + s->pending_handshake = 0; } SSL_free(s->ssl); @@ -300,11 +295,8 @@ struct us_internal_ssl_socket_t *ssl_on_close(struct us_internal_ssl_socket_t *s } struct us_internal_ssl_socket_t *ssl_on_end(struct us_internal_ssl_socket_t *s) { - if(&s->s) { - struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); - if (context && context->pending_handshake) { - context->pending_handshake = 0; - } + if(&s->s && s->pending_handshake) { + s->pending_handshake = 0; } // whatever state we are in, a TCP FIN is always an answered shutdown @@ -322,7 +314,7 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s, struct us_loop_t *loop = us_socket_context_loop(0, &context->sc); struct loop_ssl_data *loop_ssl_data = (struct loop_ssl_data *) loop->data.ssl_data; - if(context->pending_handshake) { + if(s->pending_handshake) { us_internal_ssl_handshake(s, context->on_handshake, context->handshake_data); } @@ -477,7 +469,7 @@ struct us_internal_ssl_socket_t *ssl_on_writable(struct us_internal_ssl_socket_t struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s); - if(context->pending_handshake) { + if(s->pending_handshake) { us_internal_ssl_handshake(s, context->on_handshake, context->handshake_data); } @@ -1302,10 +1294,8 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s context->ssl_context = ssl_context;//create_ssl_context_from_options(options); context->is_parent = 1; - context->pending_handshake = 0; context->on_handshake = NULL; context->handshake_data = NULL; - /* We, as parent context, may ignore data */ context->sc.is_low_prio = (int (*)(struct us_socket_t *)) ssl_is_low_prio; @@ -1340,10 +1330,9 @@ struct us_internal_ssl_socket_context_t *us_internal_bun_create_ssl_socket_conte /* Then we extend its SSL parts */ context->ssl_context = ssl_context;//create_ssl_context_from_options(options); context->is_parent = 1; - context->pending_handshake = 0; + context->on_handshake = NULL; context->handshake_data = NULL; - /* We, as parent context, may ignore data */ context->sc.is_low_prio = (int (*)(struct us_socket_t *)) ssl_is_low_prio; |