diff options
-rw-r--r-- | src/bun.js/api/bun/socket.zig | 142 | ||||
-rw-r--r-- | test/bun.js/socket/echo.js | 70 | ||||
-rw-r--r-- | test/bun.js/socket/socket.test.ts | 34 |
3 files changed, 179 insertions, 67 deletions
diff --git a/src/bun.js/api/bun/socket.zig b/src/bun.js/api/bun/socket.zig index 1b0df92f5..319d1de37 100644 --- a/src/bun.js/api/bun/socket.zig +++ b/src/bun.js/api/bun/socket.zig @@ -543,8 +543,10 @@ pub const Listener = struct { var listener = this.listener orelse return JSValue.jsUndefined(); this.listener = null; listener.close(this.ssl); - if (this.handlers.active_connections == 0) { + if (this.poll_ref.isActive()) { this.poll_ref.unref(this.handlers.vm); + } + if (this.handlers.active_connections == 0) { this.handlers.unprotect(); this.socket_context.?.deinit(this.ssl); this.socket_context = null; @@ -785,6 +787,13 @@ fn NewSocket(comptime ssl: bool) type { const This = @This(); const log = Output.scoped(.Socket, false); + const WriteResult = union(enum) { + fail: void, + success: struct { + wrote: i32 = 0, + total: usize = 0, + }, + }; pub usingnamespace JSSocketType(ssl); @@ -852,12 +861,12 @@ fn NewSocket(comptime ssl: bool) type { JSC.markBinding(@src()); if (this.detached) return; this.detached = true; + defer this.markInactive(); var handlers = this.handlers; this.poll_ref.unref(handlers.vm); var globalObject = handlers.globalObject; const callback = handlers.onTimeout; - this.markInactive(); if (callback == .zero) { return; } @@ -873,10 +882,11 @@ fn NewSocket(comptime ssl: bool) type { } } } - pub fn onConnectError(this: *This, socket: Socket, errno: c_int) void { + pub fn onConnectError(this: *This, _: Socket, errno: c_int) void { JSC.markBinding(@src()); log("onConnectError({d}", .{errno}); this.detached = true; + defer this.markInactive(); var handlers = this.handlers; this.poll_ref.unref(handlers.vm); var err = JSC.SystemError{ @@ -885,10 +895,6 @@ fn NewSocket(comptime ssl: bool) type { .syscall = ZigString.init("connect"), }; _ = handlers.rejectPromise(err.toErrorInstance(handlers.globalObject)); - if (this.reffer.has) { - this.reffer.unref(handlers.vm); - handlers.markInactive(ssl, socket.context()); - } } pub fn markActive(this: *This) void { @@ -943,10 +949,10 @@ fn NewSocket(comptime ssl: bool) type { }); if (!result.isEmptyOrUndefinedOrNull() and result.isAnyError(handlers.globalObject)) { + this.detached = true; + defer this.markInactive(); if (!this.socket.isClosed()) { log("Closing due to error", .{}); - this.detached = true; - this.socket.close(0, null); } else { log("Already closed", .{}); } @@ -978,6 +984,7 @@ fn NewSocket(comptime ssl: bool) type { JSC.markBinding(@src()); log("onEnd", .{}); this.detached = true; + defer this.markInactive(); var handlers = this.handlers; const callback = handlers.onEnd; @@ -1001,6 +1008,7 @@ fn NewSocket(comptime ssl: bool) type { JSC.markBinding(@src()); log("onClose", .{}); this.detached = true; + defer this.markInactive(); var handlers = this.handlers; this.poll_ref.unref(handlers.vm); @@ -1008,7 +1016,6 @@ fn NewSocket(comptime ssl: bool) type { var globalObject = handlers.globalObject; if (callback == .zero) { - this.markInactive(); return; } @@ -1144,7 +1151,10 @@ fn NewSocket(comptime ssl: bool) type { return .zero; } - return this.writeOrEnd(globalObject, args.ptr[0..args.len], false); + return switch (this.writeOrEnd(globalObject, args.ptr[0..args.len], false)) { + .fail => .zero, + .success => |result| JSValue.jsNumber(result.wrote), + }; } pub fn getLocalPort( @@ -1179,19 +1189,15 @@ fn NewSocket(comptime ssl: bool) type { } fn writeMaybeCorked(this: *This, buffer: []const u8, is_end: bool) i32 { + if (this.socket.isShutdown() or this.socket.isClosed()) { + return -1; + } // we don't cork yet but we might later return this.socket.write(buffer, is_end); } - fn writeOrEnd(this: *This, globalObject: *JSC.JSGlobalObject, args: []const JSC.JSValue, is_end: bool) JSValue { - if (args.ptr[0].isEmptyOrUndefinedOrNull()) { - globalObject.throw("Expected an ArrayBufferView, a string, or a Blob", .{}); - return .zero; - } - - if (this.socket.isShutdown() or this.socket.isClosed()) { - return JSValue.jsNumber(@as(i32, -1)); - } + fn writeOrEnd(this: *This, globalObject: *JSC.JSGlobalObject, args: []const JSC.JSValue, is_end: bool) WriteResult { + if (args.len == 0) return .{ .success = .{} }; if (args.ptr[0].asArrayBuffer(globalObject)) |array_buffer| { var slice = array_buffer.slice(); @@ -1199,7 +1205,7 @@ fn NewSocket(comptime ssl: bool) type { if (args.len > 1) { if (!args.ptr[1].isAnyInt()) { globalObject.throw("Expected offset integer, got {any}", .{args.ptr[1].getZigString(globalObject)}); - return .zero; + return .{ .fail = {} }; } const offset = @min(args.ptr[1].toUInt64NoTruncate(), slice.len); @@ -1208,7 +1214,7 @@ fn NewSocket(comptime ssl: bool) type { if (args.len > 2) { if (!args.ptr[2].isAnyInt()) { globalObject.throw("Expected length integer, got {any}", .{args.ptr[2].getZigString(globalObject)}); - return .zero; + return .{ .fail = {} }; } const length = @min(args.ptr[2].toUInt64NoTruncate(), slice.len); @@ -1216,11 +1222,14 @@ fn NewSocket(comptime ssl: bool) type { } } - if (slice.len == 0) { - return JSValue.jsNumber(@as(i32, 0)); - } + if (slice.len == 0) return .{ .success = .{} }; - return JSValue.jsNumber(this.writeMaybeCorked(slice, is_end)); + return .{ + .success = .{ + .wrote = this.writeMaybeCorked(slice, is_end), + .total = slice.len, + }, + }; } else if (args.ptr[0].jsType() == .DOMWrapper) { const blob: JSC.WebCore.AnyBlob = getter: { if (args.ptr[0].as(JSC.WebCore.Blob)) |blob| { @@ -1233,7 +1242,7 @@ fn NewSocket(comptime ssl: bool) type { } globalObject.throw("Only Blob/buffered bodies are supported for now", .{}); - return .zero; + return .{ .fail = {} }; } else if (args.ptr[0].as(JSC.WebCore.Request)) |request| { request.body.toBlobIfPossible(); if (request.body.tryUseAsAnyBlob()) |blob| { @@ -1241,11 +1250,11 @@ fn NewSocket(comptime ssl: bool) type { } globalObject.throw("Only Blob/buffered bodies are supported for now", .{}); - return .zero; + return .{ .fail = {} }; } globalObject.throw("Expected Blob, Request or Response", .{}); - return .zero; + return .{ .fail = {} }; }; if (!blob.needsToReadFile()) { @@ -1254,7 +1263,7 @@ fn NewSocket(comptime ssl: bool) type { if (args.len > 1) { if (!args.ptr[1].isAnyInt()) { globalObject.throw("Expected offset integer, got {any}", .{args.ptr[1].getZigString(globalObject)}); - return .zero; + return .{ .fail = {} }; } const offset = @min(args.ptr[1].toUInt64NoTruncate(), slice.len); @@ -1263,7 +1272,7 @@ fn NewSocket(comptime ssl: bool) type { if (args.len > 2) { if (!args.ptr[2].isAnyInt()) { globalObject.throw("Expected length integer, got {any}", .{args.ptr[2].getZigString(globalObject)}); - return .zero; + return .{ .fail = {} }; } const length = @min(args.ptr[2].toUInt64NoTruncate(), slice.len); @@ -1271,15 +1280,18 @@ fn NewSocket(comptime ssl: bool) type { } } - if (slice.len == 0) { - return JSValue.jsNumber(@as(i32, 0)); - } + if (slice.len == 0) return .{ .success = .{} }; - return JSValue.jsNumber(this.writeMaybeCorked(slice, is_end)); + return .{ + .success = .{ + .wrote = this.writeMaybeCorked(slice, is_end), + .total = slice.len, + }, + }; } globalObject.throw("sendfile() not implemented yet", .{}); - return .zero; + return .{ .fail = {} }; } else if (args.ptr[0].toStringOrNull(globalObject)) |jsstring| { var zig_str = jsstring.toSlice(globalObject, globalObject.bunVM().allocator); defer zig_str.deinit(); @@ -1289,7 +1301,7 @@ fn NewSocket(comptime ssl: bool) type { if (args.len > 1) { if (!args.ptr[1].isAnyInt()) { globalObject.throw("Expected offset integer, got {any}", .{args.ptr[1].getZigString(globalObject)}); - return .zero; + return .{ .fail = {} }; } const offset = @min(args.ptr[1].toUInt64NoTruncate(), slice.len); @@ -1298,7 +1310,7 @@ fn NewSocket(comptime ssl: bool) type { if (args.len > 2) { if (!args.ptr[2].isAnyInt()) { globalObject.throw("Expected length integer, got {any}", .{args.ptr[2].getZigString(globalObject)}); - return .zero; + return .{ .fail = {} }; } const length = @min(args.ptr[2].toUInt64NoTruncate(), slice.len); @@ -1306,10 +1318,17 @@ fn NewSocket(comptime ssl: bool) type { } } - return JSValue.jsNumber(this.writeMaybeCorked(slice, is_end)); + if (slice.len == 0) return .{ .success = .{} }; + + return .{ + .success = .{ + .wrote = this.writeMaybeCorked(slice, is_end), + .total = slice.len, + }, + }; } else { globalObject.throw("Expected ArrayBufferView, a string, or a Blob", .{}); - return .zero; + return .{ .fail = {} }; } } @@ -1352,37 +1371,26 @@ fn NewSocket(comptime ssl: bool) type { const args = callframe.arguments(4); - if (args.len == 0) { - log("end()", .{}); - if (!this.detached) { - if (!this.socket.isClosed()) this.socket.flush(); - this.detached = true; - - if (!this.socket.isClosed()) - this.socket.close(0, null); - this.markInactive(); - } - - return JSValue.jsUndefined(); - } - log("end({d} args)", .{args.len}); if (this.detached) { return JSValue.jsNumber(@as(i32, -1)); } - const result = this.writeOrEnd(globalObject, args.ptr[0..args.len], true); - if (result != .zero and result.toInt32() > 0) { - this.socket.flush(); - this.detached = true; - - if (!this.socket.isClosed()) - this.socket.close(0, null); - this.markInactive(); - } - - return result; + return switch (this.writeOrEnd(globalObject, args.ptr[0..args.len], true)) { + .fail => .zero, + .success => |result| brk: { + if (result.wrote == result.total) { + this.socket.flush(); + this.detached = true; + if (!this.socket.isClosed()) { + this.socket.close(0, null); + } + this.markInactive(); + } + break :brk JSValue.jsNumber(result.wrote); + }, + }; } pub fn ref(this: *This, globalObject: *JSC.JSGlobalObject, _: *JSC.CallFrame) callconv(.C) JSValue { @@ -1400,8 +1408,8 @@ fn NewSocket(comptime ssl: bool) type { pub fn finalize(this: *This) callconv(.C) void { log("finalize()", .{}); - if (!this.detached and !this.socket.isClosed()) { - this.detached = true; + this.detached = true; + if (!this.socket.isClosed()) { this.socket.close(0, null); } this.markInactive(); diff --git a/test/bun.js/socket/echo.js b/test/bun.js/socket/echo.js new file mode 100644 index 000000000..e864f1b2b --- /dev/null +++ b/test/bun.js/socket/echo.js @@ -0,0 +1,70 @@ +function createOptions(type, message, closeOnDone) { + let buffers = []; + let report = function() { + report = function() {}; + const data = new Uint8Array(buffers.reduce(function(sum, buffer) { + return sum + buffer.length; + }, 0)); + buffers.reduce(function(offset, buffer) { + data.set(buffer, offset); + return offset + buffer.length; + }, 0); + console.log(type, "GOT", new TextDecoder().decode(data)); + } + + let done = closeOnDone ? function(socket, sent) { + socket.data[sent ? "sent" : "received"] = true; + if (socket.data.sent && socket.data.received) { + done = function() {}; + closeOnDone(socket); + } + } : function() {}; + + function drain(socket) { + const message = socket.data.message; + const written = socket.write(message); + if (written < message.length) { + socket.data.message = message.slice(written); + } else { + done(socket, true); + } + } + + return { + hostname: "localhost", + port: 12345, + socket: { + close() { + report(); + console.log(type, "CLOSED"); + }, + data(socket, buffer) { + buffers.push(buffer); + done(socket); + }, + drain: drain, + end() { + report(); + console.log(type, "ENDED"); + }, + error(socket, err) { + console.log(type, "ERRED", err); + }, + open(socket) { + console.log(type, "OPENED"); + drain(socket); + }, + }, + data: { + sent: false, + received: false, + message: message, + }, + }; +} + +const server = Bun.listen(createOptions("[Server]", "response", socket => { + server.stop(); + socket.end(); +})); +Bun.connect(createOptions("[Client]", "request")); diff --git a/test/bun.js/socket/socket.test.ts b/test/bun.js/socket/socket.test.ts new file mode 100644 index 000000000..aff001c75 --- /dev/null +++ b/test/bun.js/socket/socket.test.ts @@ -0,0 +1,34 @@ +import { expect, it } from "bun:test"; +import { bunExe } from "../bunExe"; +import { spawn } from "bun"; + +it("should keep process alive only when active", async () => { + const { exited, stdout, stderr } = spawn({ + cmd: [ bunExe(), "echo.js" ], + cwd: import.meta.dir, + stdout: "pipe", + stdin: null, + stderr: "pipe", + env: { + BUN_DEBUG_QUIET_LOGS: 1, + }, + }); + expect(await exited).toBe(0); + expect(await new Response(stderr).text()).toBe(""); + var lines = (await new Response(stdout).text()).split(/\r?\n/); + expect(lines.filter(function(line) { + return line.startsWith("[Server]"); + })).toEqual([ + "[Server] OPENED", + "[Server] GOT request", + "[Server] CLOSED", + ]); + expect(lines.filter(function(line) { + return line.startsWith("[Client]"); + })).toEqual([ + "[Client] OPENED", + "[Client] GOT response", + "[Client] ENDED", + "[Client] CLOSED", + ]); +}); |