aboutsummaryrefslogtreecommitdiff
path: root/packages/bun-usockets/src/crypto/openssl.c
diff options
context:
space:
mode:
Diffstat (limited to 'packages/bun-usockets/src/crypto/openssl.c')
-rw-r--r--packages/bun-usockets/src/crypto/openssl.c49
1 files changed, 19 insertions, 30 deletions
diff --git a/packages/bun-usockets/src/crypto/openssl.c b/packages/bun-usockets/src/crypto/openssl.c
index b6466bcf9..0b55ca866 100644
--- a/packages/bun-usockets/src/crypto/openssl.c
+++ b/packages/bun-usockets/src/crypto/openssl.c
@@ -86,7 +86,6 @@ struct us_internal_ssl_socket_context_t {
/* Pointer to sni tree, created when the context is created and freed likewise when freed */
void *sni;
- int pending_handshake;
void (*on_handshake)(struct us_internal_ssl_socket_t *, int success, struct us_bun_verify_error_t verify_error, void* custom_data);
void* handshake_data;
};
@@ -97,6 +96,7 @@ struct us_internal_ssl_socket_t {
SSL *ssl;
int ssl_write_wants_read; // we use this for now
int ssl_read_wants_write;
+ int pending_handshake;
};
int passphrase_cb(char *buf, int size, int rwflag, void *u) {
@@ -164,6 +164,7 @@ int BIO_s_custom_read(BIO *bio, char *dst, int length) {
struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s, int is_client, char *ip, int ip_length) {
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
+
struct us_loop_t *loop = us_socket_context_loop(0, &context->sc);
struct loop_ssl_data *loop_ssl_data = (struct loop_ssl_data *) loop->data.ssl_data;
@@ -186,7 +187,9 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s,
struct us_internal_ssl_socket_t * result = (struct us_internal_ssl_socket_t *) context->on_open(s, is_client, ip, ip_length);
// Hello Message!
- if(context->pending_handshake) {
+ // always handshake after open if on_handshake is set
+ if(context->on_handshake || s->pending_handshake) {
+ s->pending_handshake = 1;
us_internal_ssl_handshake(s, context->on_handshake, context->handshake_data);
}
@@ -195,7 +198,6 @@ struct us_internal_ssl_socket_t *ssl_on_open(struct us_internal_ssl_socket_t *s,
void us_internal_on_ssl_handshake(struct us_internal_ssl_socket_context_t * context, void (*on_handshake)(struct us_internal_ssl_socket_t *, int success, struct us_bun_verify_error_t verify_error, void* custom_data), void* custom_data) {
- context->pending_handshake = 1;
context->on_handshake = on_handshake;
context->handshake_data = custom_data;
}
@@ -206,7 +208,7 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han
// will start on_open, on_writable or on_data
if(!s->ssl) {
- context->pending_handshake = 1;
+ s->pending_handshake = 1;
context->on_handshake = on_handshake;
context->handshake_data = custom_data;
return;
@@ -218,9 +220,7 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han
loop_ssl_data->ssl_socket = &s->s;
if (us_socket_is_closed(0, &s->s) || us_internal_ssl_socket_is_shut_down(s)) {
- context->pending_handshake = 0;
- context->on_handshake = NULL;
- context->handshake_data = NULL;
+ s->pending_handshake = 0;
struct us_bun_verify_error_t verify_error = (struct us_bun_verify_error_t) { .error = 0, .code = NULL, .reason = NULL };
if(on_handshake != NULL) {
@@ -236,9 +236,7 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han
int err = SSL_get_error(s->ssl, result);
// as far as I know these are the only errors we want to handle
if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) {
- context->pending_handshake = 0;
- context->on_handshake = NULL;
- context->handshake_data = NULL;
+ s->pending_handshake = 0;
struct us_bun_verify_error_t verify_error = us_internal_verify_error(s);
// clear per thread error queue if it may contain something
@@ -252,7 +250,7 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han
}
return;
} else {
- context->pending_handshake = 1;
+ s->pending_handshake = 1;
context->on_handshake = on_handshake;
context->handshake_data = custom_data;
// Ensure that we'll cycle through internal openssl's state
@@ -262,10 +260,7 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han
}
} else {
- context->pending_handshake = 0;
- context->on_handshake = NULL;
- context->handshake_data = NULL;
-
+ s->pending_handshake = 0;
struct us_bun_verify_error_t verify_error = us_internal_verify_error(s);
// success
@@ -283,16 +278,16 @@ void us_internal_ssl_handshake(struct us_internal_ssl_socket_t *s, void (*on_han
struct us_internal_ssl_socket_t *us_internal_ssl_socket_close(struct us_internal_ssl_socket_t *s, int code, void *reason) {
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
- if (context->pending_handshake) {
- context->pending_handshake = 0;
+ if (s->pending_handshake) {
+ s->pending_handshake = 0;
}
return (struct us_internal_ssl_socket_t *) us_socket_close(0, (struct us_socket_t *) s, code, reason);
}
struct us_internal_ssl_socket_t *ssl_on_close(struct us_internal_ssl_socket_t *s, int code, void *reason) {
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
- if (context->pending_handshake) {
- context->pending_handshake = 0;
+ if (s->pending_handshake) {
+ s->pending_handshake = 0;
}
SSL_free(s->ssl);
@@ -300,11 +295,8 @@ struct us_internal_ssl_socket_t *ssl_on_close(struct us_internal_ssl_socket_t *s
}
struct us_internal_ssl_socket_t *ssl_on_end(struct us_internal_ssl_socket_t *s) {
- if(&s->s) {
- struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
- if (context && context->pending_handshake) {
- context->pending_handshake = 0;
- }
+ if(&s->s && s->pending_handshake) {
+ s->pending_handshake = 0;
}
// whatever state we are in, a TCP FIN is always an answered shutdown
@@ -322,7 +314,7 @@ struct us_internal_ssl_socket_t *ssl_on_data(struct us_internal_ssl_socket_t *s,
struct us_loop_t *loop = us_socket_context_loop(0, &context->sc);
struct loop_ssl_data *loop_ssl_data = (struct loop_ssl_data *) loop->data.ssl_data;
- if(context->pending_handshake) {
+ if(s->pending_handshake) {
us_internal_ssl_handshake(s, context->on_handshake, context->handshake_data);
}
@@ -477,7 +469,7 @@ struct us_internal_ssl_socket_t *ssl_on_writable(struct us_internal_ssl_socket_t
struct us_internal_ssl_socket_context_t *context = (struct us_internal_ssl_socket_context_t *) us_socket_context(0, &s->s);
- if(context->pending_handshake) {
+ if(s->pending_handshake) {
us_internal_ssl_handshake(s, context->on_handshake, context->handshake_data);
}
@@ -1302,10 +1294,8 @@ struct us_internal_ssl_socket_context_t *us_internal_create_ssl_socket_context(s
context->ssl_context = ssl_context;//create_ssl_context_from_options(options);
context->is_parent = 1;
- context->pending_handshake = 0;
context->on_handshake = NULL;
context->handshake_data = NULL;
-
/* We, as parent context, may ignore data */
context->sc.is_low_prio = (int (*)(struct us_socket_t *)) ssl_is_low_prio;
@@ -1340,10 +1330,9 @@ struct us_internal_ssl_socket_context_t *us_internal_bun_create_ssl_socket_conte
/* Then we extend its SSL parts */
context->ssl_context = ssl_context;//create_ssl_context_from_options(options);
context->is_parent = 1;
- context->pending_handshake = 0;
+
context->on_handshake = NULL;
context->handshake_data = NULL;
-
/* We, as parent context, may ignore data */
context->sc.is_low_prio = (int (*)(struct us_socket_t *)) ssl_is_low_prio;