aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jarred Sumner <jarred@jarredsumner.com> 2022-02-05 00:29:41 -0800
committerGravatar Jarred Sumner <jarred@jarredsumner.com> 2022-02-05 00:30:28 -0800
commit2b45c8dffec45dd838355b71e30f04a2dc117325 (patch)
tree11b33362fc89638728aee295ad15258f8394e896
parent860d7e93c0ca5b6b113dda487a3d5701f0f39572 (diff)
downloadbun-2b45c8dffec45dd838355b71e30f04a2dc117325.tar.gz
bun-2b45c8dffec45dd838355b71e30f04a2dc117325.tar.zst
bun-2b45c8dffec45dd838355b71e30f04a2dc117325.zip
Implement keep-alive but disable it
-rw-r--r--src/http_client_async.zig133
-rw-r--r--src/thread_pool.zig10
2 files changed, 126 insertions, 17 deletions
diff --git a/src/http_client_async.zig b/src/http_client_async.zig
index b1d50c36c..d8a25a424 100644
--- a/src/http_client_async.zig
+++ b/src/http_client_async.zig
@@ -67,7 +67,7 @@ else
pub const OPEN_SOCKET_FLAGS = SOCK.CLOEXEC;
-pub const extremely_verbose = false;
+pub const extremely_verbose = Environment.isDebug;
fn writeRequest(
comptime Writer: type,
@@ -113,6 +113,7 @@ socket: AsyncSocket.SSL = undefined,
socket_loaded: bool = false,
gzip_elapsed: u64 = 0,
stage: Stage = Stage.pending,
+received_keep_alive: bool = false,
/// Some HTTP servers (such as npm) report Last-Modified times but ignore If-Modified-Since.
/// This is a workaround for that.
@@ -135,7 +136,9 @@ pub fn init(
.url = url,
.header_entries = header_entries,
.header_buf = header_buf,
- .socket = undefined,
+ .socket = AsyncSocket.SSL{
+ .socket = undefined,
+ },
};
}
@@ -246,6 +249,62 @@ pub const HTTPChannelContext = struct {
}
};
+// This causes segfaults when resume connect()
+pub const KeepAlive = struct {
+ const limit = 2;
+ pub const disabled = true;
+ fds: [limit]u32 = undefined,
+ hosts: [limit]u64 = undefined,
+ ports: [limit]u16 = undefined,
+ used: u8 = 0,
+
+ pub var instance = KeepAlive{};
+
+ pub fn append(this: *KeepAlive, host: []const u8, port: u16, fd: os.socket_t) bool {
+ if (disabled) return false;
+ if (this.used >= limit or fd > std.math.maxInt(u32)) return false;
+
+ const i = this.used;
+ const hash = std.hash.Wyhash.hash(0, host);
+
+ this.fds[i] = @truncate(u32, @intCast(u64, fd));
+ this.hosts[i] = hash;
+ this.ports[i] = port;
+ this.used += 1;
+ return true;
+ }
+ pub fn find(this: *KeepAlive, host: []const u8, port: u16) ?os.socket_t {
+ if (disabled) return null;
+
+ if (this.used == 0) {
+ return null;
+ }
+
+ const hash = std.hash.Wyhash.hash(0, host);
+ const list = this.hosts[0..this.used];
+ for (list) |host_hash, i| {
+ if (host_hash == hash and this.ports[i] == port) {
+ const fd = this.fds[i];
+ const last = this.used - 1;
+
+ if (i > last) {
+ const end_host = this.hosts[last];
+ const end_fd = this.fds[last];
+ const end_port = this.ports[last];
+ this.hosts[i] = end_host;
+ this.fds[i] = end_fd;
+ this.ports[i] = end_port;
+ }
+ this.used -= 1;
+
+ return @intCast(os.socket_t, fd);
+ }
+ }
+
+ return null;
+ }
+};
+
pub const AsyncHTTP = struct {
request: ?picohttp.Request = null,
response: ?picohttp.Response = null,
@@ -319,6 +378,13 @@ pub const AsyncHTTP = struct {
return this;
}
+ fn reset(this: *AsyncHTTP) !void {
+ const timeout = this.timeout;
+ this.client = try HTTPClient.init(this.allocator, this.method, this.client.url, this.client.header_entries, this.client.header_buf);
+ this.client.timeout = timeout;
+ this.timeout = timeout;
+ }
+
pub fn schedule(this: *AsyncHTTP, _: std.mem.Allocator, batch: *ThreadPool.Batch) void {
std.debug.assert(NetworkThread.global_loaded.load(.Monotonic) == 1);
this.state.store(.scheduled, .Monotonic);
@@ -381,6 +447,10 @@ pub const AsyncHTTP = struct {
};
pub fn do(sender: *HTTPSender, this: *AsyncHTTP) void {
+ defer {
+ NetworkThread.global.pool.schedule(.{ .head = &sender.finisher, .tail = &sender.finisher, .len = 1 });
+ }
+
outer: {
this.err = null;
this.state.store(.sending, .Monotonic);
@@ -394,6 +464,7 @@ pub const AsyncHTTP = struct {
if (this.max_retry_count > this.retries_count) {
this.retries_count += 1;
this.response_buffer.reset();
+
NetworkThread.global.pool.schedule(ThreadPool.Batch.from(&this.task));
return;
}
@@ -408,7 +479,6 @@ pub const AsyncHTTP = struct {
if (this.callback) |callback| {
callback(this);
}
- NetworkThread.global.pool.schedule(.{ .head = &sender.finisher, .tail = &sender.finisher, .len = 1 });
}
};
@@ -534,20 +604,20 @@ pub fn sendAsync(this: *HTTPClient, body: []const u8, body_out_str: *MutableStri
return async this.send(body, body_out_str);
}
-pub fn send(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) !picohttp.Response {
- defer {
- if (this.socket_loaded) {
- this.socket_loaded = false;
- this.socket.deinit();
- }
+fn maybeClearSocket(this: *HTTPClient) void {
+ if (this.socket_loaded) {
+ this.socket_loaded = false;
+
+ this.socket.deinit();
}
+}
+
+pub fn send(this: *HTTPClient, body: []const u8, body_out_str: *MutableString) !picohttp.Response {
+ defer this.maybeClearSocket();
// this prevents stack overflow
redirect: while (this.remaining_redirect_count >= -1) {
- if (this.socket_loaded) {
- this.socket_loaded = false;
- this.socket.deinit();
- }
+ this.maybeClearSocket();
_ = AsyncHTTP.active_requests_count.fetchAdd(1, .Monotonic);
defer {
@@ -596,7 +666,7 @@ pub fn sendHTTP(this: *HTTPClient, body: []const u8, body_out_str: *MutableStrin
var socket = &this.socket.socket;
try this.connect(*AsyncSocket, socket);
this.stage = Stage.request;
- defer this.socket.close();
+ defer this.closeSocket();
var request = buildRequest(this, body.len);
if (this.verbose) {
@@ -673,6 +743,10 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
var location: string = "";
var pretend_its_304 = false;
+ var maybe_keepalive = false;
+ errdefer {
+ maybe_keepalive = false;
+ }
for (response.headers) |header| {
switch (hashHeaderName(header.name)) {
@@ -707,6 +781,13 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
location_header_hash => {
location = header.value;
},
+ hashHeaderName("Connection") => {
+ if (response.status_code >= 200 and response.status_code <= 299 and !KeepAlive.disabled) {
+ if (strings.eqlComptime(header.value, "keep-alive")) {
+ maybe_keepalive = true;
+ }
+ }
+ },
hashHeaderName("Last-Modified") => {
if (this.force_last_modified and response.status_code > 199 and response.status_code < 300 and this.if_modified_since.len > 0) {
if (strings.eql(this.if_modified_since, header.value)) {
@@ -774,6 +855,7 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
if (response.status_code == 304) break :body_getter;
if (transfer_encoding == Encoding.chunked) {
+ maybe_keepalive = false;
var decoder = std.mem.zeroes(picohttp.phr_chunked_decoder);
var buffer_: *MutableString = body_out_str;
@@ -1020,9 +1102,30 @@ pub fn processResponse(this: *HTTPClient, comptime report_progress: bool, compti
this.progress_node.?.context.maybeRefresh();
}
+ if (maybe_keepalive and response.status_code >= 200 and response.status_code < 300) {
+ this.received_keep_alive = true;
+ }
+
return response;
}
+pub fn closeSocket(this: *HTTPClient) void {
+ if (this.received_keep_alive) {
+ this.received_keep_alive = false;
+ if (this.url.hostname.len > 0 and this.socket.socket.socket > 0) {
+ if (!this.socket.connect_frame.wait and
+ (!this.socket.ssl_bio_loaded or
+ (this.socket.ssl_bio.pending_sends == 0 and this.socket.ssl_bio.pending_reads == 0)))
+ {
+ if (KeepAlive.instance.append(this.url.hostname, this.url.getPortAuto(), this.socket.socket.socket)) {
+ this.socket.socket.socket = 0;
+ }
+ }
+ }
+ }
+ this.socket.close();
+}
+
pub fn sendHTTPS(this: *HTTPClient, body_str: []const u8, body_out_str: *MutableString) !picohttp.Response {
this.socket = try AsyncSocket.SSL.init(default_allocator, &AsyncIO.global);
this.socket_loaded = true;
@@ -1031,7 +1134,7 @@ pub fn sendHTTPS(this: *HTTPClient, body_str: []const u8, body_out_str: *Mutable
this.stage = Stage.connect;
try this.connect(*AsyncSocket.SSL, socket);
this.stage = Stage.request;
- defer this.socket.close();
+ defer this.closeSocket();
var request = buildRequest(this, body_str.len);
if (this.verbose) {
diff --git a/src/thread_pool.zig b/src/thread_pool.zig
index c398151ca..dc3ff65db 100644
--- a/src/thread_pool.zig
+++ b/src/thread_pool.zig
@@ -269,8 +269,14 @@ fn _wait(self: *ThreadPool, _is_waking: bool, comptime sleep_on_idle: bool) erro
const end_count = HTTP.AsyncHTTP.active_requests_count.loadUnchecked();
if (end_count > 0) {
- while (HTTP.AsyncHTTP.active_requests_count.loadUnchecked() > HTTP.AsyncHTTP.max_simultaneous_requests) {
- io.run_for_ns(std.time.ns_per_ms) catch {};
+ if (comptime sleep_on_idle) {
+ idle_network_ticks = 0;
+ }
+
+ var remaining_ticks: i32 = 5;
+
+ while (remaining_ticks > 0 and HTTP.AsyncHTTP.active_requests_count.loadUnchecked() > HTTP.AsyncHTTP.max_simultaneous_requests) : (remaining_ticks -= 1) {
+ io.run_for_ns(std.time.ns_per_ms * 2) catch {};
io.tick() catch {};
}
}