diff options
author | 2021-06-11 17:05:15 -0700 | |
---|---|---|
committer | 2021-06-11 17:05:15 -0700 | |
commit | 223410eab3c18be989e520a7aeac470bf941dd6b (patch) | |
tree | 7ccdc2fb77c82b7aba3db84f1916637b043fcf4b | |
parent | dc3309d130c171f75d4e416f3569faa9c985d123 (diff) | |
download | bun-223410eab3c18be989e520a7aeac470bf941dd6b.tar.gz bun-223410eab3c18be989e520a7aeac470bf941dd6b.tar.zst bun-223410eab3c18be989e520a7aeac470bf941dd6b.zip |
lil websocket server
Former-commit-id: 19f4c59b4281bd24130a6898e3b43d9d6d7e6603
-rw-r--r-- | src/api/schema.peechy | 28 | ||||
-rw-r--r-- | src/http.zig | 220 | ||||
-rw-r--r-- | src/http/websocket.zig | 310 |
3 files changed, 533 insertions, 25 deletions
diff --git a/src/api/schema.peechy b/src/api/schema.peechy index 87f44dbaf..863ed6aa8 100644 --- a/src/api/schema.peechy +++ b/src/api/schema.peechy @@ -74,21 +74,27 @@ struct JavascriptBundle { // These are sorted alphabetically so you can do binary search JavascriptBundledModule[] modules; JavascriptBundledPackage[] packages; - + + // This is ASCII-encoded so you can send it directly over HTTP byte[] etag; - // this will be a u64, but that data type doesn't exist in this schema format + uint32 generated_at; + // generated by hashing all ${name}@${version} in sorted order byte[] app_package_json_dependencies_hash; + + byte[] import_from_name; - + // This is what StringPointer refers to byte[] manifest_string; } message JavascriptBundleContainer { uint32 bundle_format_version = 1; + JavascriptBundle bundle = 2; + // Don't technically need to store this, but it may be helpful as a sanity check uint32 code_length = 3; } @@ -226,4 +232,20 @@ struct Log { uint32 warnings; uint32 errors; Message[] msgs; +} + + +// The WebSocket protocol +// Server: "hey, this file changed. Do you want it?" +// Client: *checks hash table* "uhh yeah, ok. rebuild that for me" +// Server: "here u go" +// This makes the client responsible for tracking which files it needs to listen for. +smol WebsocketMessageType { + visibility_status_change = 1, + build_status_update = 2, + +} + +message WebsocketMessageContainer { + }
\ No newline at end of file diff --git a/src/http.zig b/src/http.zig index 2d54a45c0..7ec448f20 100644 --- a/src/http.zig +++ b/src/http.zig @@ -20,7 +20,7 @@ const Response = picohttp.Response; const Headers = picohttp.Headers; const MimeType = @import("http/mime_type.zig"); const Bundler = bundler.ServeBundler; - +const Websocket = @import("./http/websocket.zig"); const js_printer = @import("js_printer.zig"); const SOCKET_FLAGS = os.SOCK_CLOEXEC; @@ -34,7 +34,7 @@ pub fn println(comptime fmt: string, args: anytype) void { // } } -const HTTPStatusCode = u9; +const HTTPStatusCode = u10; pub const URLPath = struct { extname: string = "", @@ -160,6 +160,7 @@ pub const RequestContext = struct { url: URLPath, conn: *tcp.Connection, allocator: *std.mem.Allocator, + arena: std.heap.ArenaAllocator, log: logger.Log, bundler: *Bundler, keep_alive: bool = true, @@ -167,15 +168,24 @@ pub const RequestContext = struct { has_written_last_header: bool = false, has_called_done: bool = false, mime_type: MimeType = MimeType.other, + controlled: bool = false, res_headers_count: usize = 0, pub const bundle_prefix = "__speedy"; pub fn header(ctx: *RequestContext, comptime name: anytype) ?Header { - for (ctx.request.headers) |head| { - if (strings.eqlComptime(head.name, name)) { - return head; + if (name.len < 17) { + for (ctx.request.headers) |head| { + if (strings.eqlComptime(head.name, name)) { + return head; + } + } + } else { + for (ctx.request.headers) |head| { + if (strings.eql(head.name, name)) { + return head; + } } } @@ -184,6 +194,7 @@ pub const RequestContext = struct { pub fn printStatusLine(comptime code: HTTPStatusCode) []const u8 { const status_text = switch (code) { + 101 => "ACTIVATING WEBSOCKET", 200...299 => "OK", 300...399 => "=>", 400...499 => "DID YOU KNOW YOU CAN MAKE THIS SAY WHATEVER YOU WANT", @@ -258,16 +269,19 @@ pub const RequestContext = struct { ctx.status = code; } - pub fn init(req: Request, allocator: *std.mem.Allocator, conn: *tcp.Connection, bundler_: *Bundler) !RequestContext { - return RequestContext{ + pub fn init(req: Request, arena: std.heap.ArenaAllocator, conn: *tcp.Connection, bundler_: *Bundler) !RequestContext { + var ctx = RequestContext{ .request = req, - .allocator = allocator, + .arena = arena, .bundler = bundler_, .url = URLPath.parse(req.path), - .log = logger.Log.init(allocator), + .log = undefined, .conn = conn, + .allocator = undefined, .method = Method.which(req.method) orelse return error.InvalidMethod, }; + + return ctx; } pub fn sendNotFound(req: *RequestContext) !void { @@ -362,10 +376,161 @@ pub const RequestContext = struct { ); } + pub const WebsocketHandler = struct { + accept_key: [28]u8 = undefined, + ctx: RequestContext, + + pub fn handle(self: WebsocketHandler) void { + var this = self; + _handle(&this) catch {}; + } + + fn _handle(handler: *WebsocketHandler) !void { + var ctx = &handler.ctx; + defer ctx.arena.deinit(); + defer ctx.conn.deinit(); + defer Output.flush(); + + handler.checkUpgradeHeaders() catch |err| { + switch (err) { + error.BadRequest => { + try ctx.sendBadRequest(); + ctx.done(); + }, + else => { + return err; + }, + } + }; + + switch (try handler.getWebsocketVersion()) { + 7, 8, 13 => {}, + else => { + // Unsupported version + // Set header to indicate to the client which versions are supported + ctx.appendHeader("Sec-WebSocket-Version", "7,8,13"); + try ctx.writeStatus(426); + try ctx.flushHeaders(); + ctx.done(); + return; + }, + } + + const key = try handler.getWebsocketAcceptKey(); + + ctx.appendHeader("Connection", "Upgrade"); + ctx.appendHeader("Upgrade", "websocket"); + ctx.appendHeader("Sec-WebSocket-Accept", key); + try ctx.writeStatus(101); + try ctx.flushHeaders(); + Output.println("101 - Websocket connected.", .{}); + Output.flush(); + + var websocket = Websocket.Websocket.create(ctx, SOCKET_FLAGS); + _ = try websocket.writeText("Hello!"); + + while (true) { + defer Output.flush(); + var frame = websocket.read() catch |err| { + switch (err) { + error.ConnectionClosed => { + Output.prettyln("Websocket closed.", .{}); + return; + }, + else => { + Output.prettyErrorln("<r><red>ERR:<r> <b>{s}<r>", .{err}); + }, + } + return; + }; + switch (frame.header.opcode) { + .Close => { + Output.prettyln("Websocket closed.", .{}); + return; + }, + .Text => { + Output.print("Data: {s}", .{frame.data}); + _ = try websocket.writeText(frame.data); + }, + .Ping => { + var pong = frame; + pong.header.opcode = .Pong; + _ = try websocket.writeDataFrame(pong); + }, + else => { + Output.prettyErrorln("Websocket unknown opcode: {s}", .{@tagName(frame.header.opcode)}); + }, + } + } + } + + fn checkUpgradeHeaders( + self: *WebsocketHandler, + ) !void { + var request: *RequestContext = &self.ctx; + const upgrade_header = request.header("Upgrade") orelse return error.BadRequest; + + if (!std.ascii.eqlIgnoreCase(upgrade_header.value, "websocket")) { + return error.BadRequest; // Can only upgrade to websocket + } + + // Some proxies/load balancers will mess with the connection header + // and browsers also send multiple values here + const connection_header = request.header("Connection") orelse return error.BadRequest; + var it = std.mem.split(connection_header.value, ","); + while (it.next()) |part| { + const conn = std.mem.trim(u8, part, " "); + if (std.ascii.eqlIgnoreCase(conn, "upgrade")) { + return; + } + } + return error.BadRequest; // Connection must be upgrade + } + + fn getWebsocketVersion( + self: *WebsocketHandler, + ) !u8 { + var request: *RequestContext = &self.ctx; + const v = request.header("Sec-WebSocket-Version") orelse return error.BadRequest; + return std.fmt.parseInt(u8, v.value, 10) catch error.BadRequest; + } + + fn getWebsocketAcceptKey( + self: *WebsocketHandler, + ) ![]const u8 { + var request: *RequestContext = &self.ctx; + const key = (request.header("Sec-WebSocket-Key") orelse return error.BadRequest).value; + if (key.len < 8) { + return error.BadRequest; + } + + var hash = std.crypto.hash.Sha1.init(.{}); + var out: [20]u8 = undefined; + hash.update(key); + hash.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11"); + hash.final(&out); + + // Encode it + return std.base64.standard_encoder.encode(&self.accept_key, &out); + } + }; + + pub fn handleWebsocket(ctx: *RequestContext) anyerror!void { + ctx.controlled = true; + var handler = WebsocketHandler{ .ctx = ctx.* }; + _ = try std.Thread.spawn(WebsocketHandler.handle, handler); + } + pub fn handleGet(ctx: *RequestContext) !void { if (strings.eqlComptime(ctx.url.extname, "jsb") and ctx.bundler.options.node_modules_bundle != null) { return try ctx.sendJSB(); } + + if (strings.eqlComptime(ctx.url.path, "_api")) { + try ctx.handleWebsocket(); + return; + } + const result = try ctx.bundler.buildFile( &ctx.log, ctx.allocator, @@ -787,7 +952,8 @@ pub const Server = struct { try listener.listen(1280); const addr = try listener.getLocalAddress(); - Output.println("Started Speedy at http://{s}", .{addr}); + Output.prettyln("<r>Started Speedy at <b><cyan>http://{s}<r>", .{addr}); + Output.flush(); // var listener_handle = try std.os.kqueue(); // var change_list = std.mem.zeroes([2]os.Kevent); @@ -796,6 +962,7 @@ pub const Server = struct { // var eventlist: [128]os.Kevent = undefined; while (true) { + defer Output.flush(); var conn = listener.accept(.{ .close_on_exec = true }) catch |err| { continue; }; @@ -831,14 +998,21 @@ pub const Server = struct { }; var request_arena = std.heap.ArenaAllocator.init(server.allocator); - defer request_arena.deinit(); - - var req_ctx = RequestContext.init(req, &request_arena.allocator, conn, &server.bundler) catch |err| { - Output.printErrorln("FAIL [{s}] - {s}: {s}", .{ @errorName(err), req.method, req.path }); + var req_ctx: RequestContext = undefined; + defer { + if (!req_ctx.controlled) { + req_ctx.arena.deinit(); + } + } + req_ctx = RequestContext.init(req, request_arena, conn, &server.bundler) catch |err| { + Output.printErrorln("<r>[<red>{s}<r>] - <b>{s}<r>: {s}", .{ @errorName(err), req.method, req.path }); conn.client.deinit(); return; }; + req_ctx.allocator = &req_ctx.arena.allocator; + req_ctx.log = logger.Log.init(req_ctx.allocator); + if (FeatureFlags.keep_alive) { if (req_ctx.header("Connection")) |connection| { req_ctx.keep_alive = strings.eqlInsensitive(connection.value, "keep-alive"); @@ -861,16 +1035,18 @@ pub const Server = struct { } }; - const status = req_ctx.status orelse @intCast(HTTPStatusCode, 500); + if (!req_ctx.controlled) { + const status = req_ctx.status orelse @intCast(HTTPStatusCode, 500); - if (req_ctx.log.msgs.items.len == 0) { - println("{d} – {s} {s} as {s}", .{ status, @tagName(req_ctx.method), req.path, req_ctx.mime_type.value }); - } else { - println("{s} {s}", .{ @tagName(req_ctx.method), req.path }); - for (req_ctx.log.msgs.items) |msg| { - msg.writeFormat(Output.errorWriter()) catch continue; + if (req_ctx.log.msgs.items.len == 0) { + println("{d} – {s} {s} as {s}", .{ status, @tagName(req_ctx.method), req.path, req_ctx.mime_type.value }); + } else { + println("{s} {s}", .{ @tagName(req_ctx.method), req.path }); + for (req_ctx.log.msgs.items) |msg| { + msg.writeFormat(Output.errorWriter()) catch continue; + } + req_ctx.log.deinit(); } - req_ctx.log.deinit(); } } diff --git a/src/http/websocket.zig b/src/http/websocket.zig new file mode 100644 index 000000000..9fb38a92f --- /dev/null +++ b/src/http/websocket.zig @@ -0,0 +1,310 @@ +// This code is based on https://github.com/frmdstryr/zhp/blob/a4b5700c289c3619647206144e10fb414113a888/src/websocket.zig +// Thank you @frmdstryr. +const std = @import("std"); +const native_endian = std.Target.current.cpu.arch.endian(); + +const tcp = std.x.net.tcp; +const ip = std.x.net.ip; + +const IPv4 = std.x.os.IPv4; +const IPv6 = std.x.os.IPv6; +const Socket = std.x.os.Socket; +const os = std.os; +usingnamespace @import("../http.zig"); +usingnamespace @import("../global.zig"); + +pub const Opcode = enum(u4) { + Continue = 0x0, + Text = 0x1, + Binary = 0x2, + Res3 = 0x3, + Res4 = 0x4, + Res5 = 0x5, + Res6 = 0x6, + Res7 = 0x7, + Close = 0x8, + Ping = 0x9, + Pong = 0xA, + ResB = 0xB, + ResC = 0xC, + ResD = 0xD, + ResE = 0xE, + ResF = 0xF, + + pub fn isControl(opcode: Opcode) bool { + return @enumToInt(opcode) & 0x8 != 0; + } +}; + +pub const WebsocketHeader = packed struct { + len: u7, + mask: bool, + opcode: Opcode, + rsv3: u1 = 0, + rsv2: u1 = 0, + compressed: bool = false, // rsv1 + final: bool = true, + + pub fn packLength(length: usize) u7 { + return switch (length) { + 0...126 => @truncate(u7, length), + 127...0xFFFF => 126, + else => 127, + }; + } +}; + +pub const WebsocketDataFrame = struct { + header: WebsocketHeader, + mask: [4]u8 = undefined, + data: []const u8, + + pub fn isValid(dataframe: WebsocketDataFrame) bool { + // Validate control frame + if (dataframe.header.opcode.isControl()) { + if (!dataframe.header.final) { + return false; // Control frames cannot be fragmented + } + if (dataframe.data.len > 125) { + return false; // Control frame payloads cannot exceed 125 bytes + } + } + + // Validate header len field + const expected = switch (dataframe.data.len) { + 0...126 => dataframe.data.len, + 127...0xFFFF => 126, + else => 127, + }; + return dataframe.header.len == expected; + } +}; + +// Create a buffered writer +// TODO: This will still split packets +pub fn Writer(comptime size: usize, comptime opcode: Opcode) type { + const WriterType = switch (opcode) { + .Text => Websocket.TextFrameWriter, + .Binary => Websocket.BinaryFrameWriter, + else => @compileError("Unsupported writer opcode"), + }; + return std.io.BufferedWriter(size, WriterType); +} + +const ReadStream = std.io.FixedBufferStream([]u8); + +pub const Websocket = struct { + pub const WriteError = error{ + InvalidMessage, + MessageTooLarge, + EndOfStream, + } || std.fs.File.WriteError; + + request: *RequestContext, + + err: ?anyerror = null, + buf: [4096]u8 = undefined, + read_stream: ReadStream, + reader: ReadStream.Reader, + flags: u32 = 0, + pub fn create( + ctx: *RequestContext, + comptime flags: u32, + ) Websocket { + var stream = ReadStream{ + .buffer = &[_]u8{}, + .pos = 0, + }; + var socket = Websocket{ + .read_stream = undefined, + .reader = undefined, + .request = ctx, + .flags = flags, + }; + + socket.read_stream = stream; + socket.reader = socket.read_stream.reader(); + return socket; + } + + // ------------------------------------------------------------------------ + // Stream API + // ------------------------------------------------------------------------ + pub const TextFrameWriter = std.io.Writer(*Websocket, WriteError, Websocket.writeText); + pub const BinaryFrameWriter = std.io.Writer(*Websocket, WriteError, Websocket.writeBinary); + + // A buffered writer that will buffer up to size bytes before writing out + pub fn newWriter(self: *Websocket, comptime size: usize, comptime opcode: Opcode) Writer(size, opcode) { + const BufferedWriter = Writer(size, opcode); + const frame_writer = switch (opcode) { + .Text => TextFrameWriter{ .context = self }, + .Binary => BinaryFrameWriter{ .context = self }, + else => @compileError("Unsupported writer type"), + }; + return BufferedWriter{ .unbuffered_writer = frame_writer }; + } + + // Close and send the status + pub fn close(self: *Websocket, code: u16) !void { + const c = if (native_endian == .Big) code else @byteSwap(u16, code); + const data = @bitCast([2]u8, c); + _ = try self.writeMessage(.Close, &data); + } + + // ------------------------------------------------------------------------ + // Low level API + // ------------------------------------------------------------------------ + + // Flush any buffered data out the underlying stream + pub fn flush(self: *Websocket) !void { + try self.io.flush(); + } + + pub fn writeText(self: *Websocket, data: []const u8) !usize { + return self.writeMessage(.Text, data); + } + + pub fn writeBinary(self: *Websocket, data: []const u8) !usize { + return self.writeMessage(.Binary, data); + } + + // Write a final message packet with the given opcode + pub fn writeMessage(self: *Websocket, opcode: Opcode, message: []const u8) !usize { + return self.writeSplitMessage(opcode, true, message); + } + + // Write a message packet with the given opcode and final flag + pub fn writeSplitMessage(self: *Websocket, opcode: Opcode, final: bool, message: []const u8) !usize { + return self.writeDataFrame(WebsocketDataFrame{ + .header = WebsocketHeader{ + .final = final, + .opcode = opcode, + .mask = false, // Server to client is not masked + .len = WebsocketHeader.packLength(message.len), + }, + .data = message, + }); + } + + // Write a raw data frame + pub fn writeDataFrame(self: *Websocket, dataframe: WebsocketDataFrame) !usize { + var stream = self.request.conn.client.writer(self.flags); + + if (!dataframe.isValid()) return error.InvalidMessage; + + try stream.writeIntBig(u16, @bitCast(u16, dataframe.header)); + + // Write extended length if needed + const n = dataframe.data.len; + switch (n) { + 0...126 => {}, // Included in header + 127...0xFFFF => try stream.writeIntBig(u16, @truncate(u16, n)), + else => try stream.writeIntBig(u64, n), + } + + // TODO: Handle compression + if (dataframe.header.compressed) return error.InvalidMessage; + + if (dataframe.header.mask) { + const mask = &dataframe.mask; + try stream.writeAll(mask); + + // Encode + for (dataframe.data) |c, i| { + try stream.writeByte(c ^ mask[i % 4]); + } + } else { + try stream.writeAll(dataframe.data); + } + + // try self.io.flush(); + + return dataframe.data.len; + } + + pub fn read(self: *Websocket) !WebsocketDataFrame { + @memset(&self.buf, 0, self.buf.len); + + // Read and retry if we hit the end of the stream buffer + var start = try self.request.conn.client.read(&self.buf, self.flags); + if (start == 0) { + return error.ConnectionClosed; + } + + self.read_stream.pos = start; + return try self.readDataFrameInBuffer(); + } + + pub fn eatAt(self: *Websocket, offset: usize, _len: usize) []u8 { + const len = std.math.min(self.read_stream.buffer.len, _len); + self.read_stream.pos = len; + return self.read_stream.buffer[offset..len]; + } + + // Read assuming everything can fit before the stream hits the end of + // it's buffer + pub fn readDataFrameInBuffer( + self: *Websocket, + ) !WebsocketDataFrame { + var buf: []u8 = self.buf[0..]; + + const header_bytes = buf[0..2]; + var header = std.mem.zeroes(WebsocketHeader); + header.final = header_bytes[0] & 0x80 == 0x80; + // header.rsv1 = header_bytes[0] & 0x40 == 0x40; + // header.rsv2 = header_bytes[0] & 0x20; + // header.rsv3 = header_bytes[0] & 0x10; + header.opcode = @intToEnum(Opcode, @truncate(u4, header_bytes[0])); + header.mask = header_bytes[1] & 0x80 == 0x80; + header.len = @truncate(u7, header_bytes[1]); + + // Decode length + var length: u64 = header.len; + + switch (header.len) { + 126 => { + length = std.mem.readIntBig(u16, buf[2..4]); + buf = buf[4..]; + }, + 127 => { + length = std.mem.readIntBig(u64, buf[2..10]); + // Most significant bit must be 0 + if (length >> 63 == 1) { + return error.InvalidMessage; + } + buf = buf[10..]; + }, + else => { + buf = buf[2..]; + }, + } + + const start: usize = if (header.mask) 4 else 0; + + const end = start + length; + + if (end > self.read_stream.pos) { + var extend_length = try self.request.conn.client.read(self.buf[self.read_stream.pos..], self.flags); + if (self.read_stream.pos + extend_length > self.buf.len) { + return error.MessageTooLarge; + } + self.read_stream.pos += extend_length; + } + + var data = buf[start..end]; + + if (header.mask) { + const mask = buf[0..4]; + // Decode data in place + for (data) |c, i| { + data[i] ^= mask[i % 4]; + } + } + + return WebsocketDataFrame{ + .header = header, + .mask = if (header.mask) buf[0..4].* else undefined, + .data = data, + }; + } +}; |