aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Jarred Sumner <jarred@jarredsumner.com> 2021-06-11 17:05:15 -0700
committerGravatar Jarred Sumner <jarred@jarredsumner.com> 2021-06-11 17:05:15 -0700
commit223410eab3c18be989e520a7aeac470bf941dd6b (patch)
tree7ccdc2fb77c82b7aba3db84f1916637b043fcf4b
parentdc3309d130c171f75d4e416f3569faa9c985d123 (diff)
downloadbun-223410eab3c18be989e520a7aeac470bf941dd6b.tar.gz
bun-223410eab3c18be989e520a7aeac470bf941dd6b.tar.zst
bun-223410eab3c18be989e520a7aeac470bf941dd6b.zip
lil websocket server
Former-commit-id: 19f4c59b4281bd24130a6898e3b43d9d6d7e6603
-rw-r--r--src/api/schema.peechy28
-rw-r--r--src/http.zig220
-rw-r--r--src/http/websocket.zig310
3 files changed, 533 insertions, 25 deletions
diff --git a/src/api/schema.peechy b/src/api/schema.peechy
index 87f44dbaf..863ed6aa8 100644
--- a/src/api/schema.peechy
+++ b/src/api/schema.peechy
@@ -74,21 +74,27 @@ struct JavascriptBundle {
// These are sorted alphabetically so you can do binary search
JavascriptBundledModule[] modules;
JavascriptBundledPackage[] packages;
-
+
+ // This is ASCII-encoded so you can send it directly over HTTP
byte[] etag;
- // this will be a u64, but that data type doesn't exist in this schema format
+
uint32 generated_at;
+
// generated by hashing all ${name}@${version} in sorted order
byte[] app_package_json_dependencies_hash;
+
+
byte[] import_from_name;
-
+
// This is what StringPointer refers to
byte[] manifest_string;
}
message JavascriptBundleContainer {
uint32 bundle_format_version = 1;
+
JavascriptBundle bundle = 2;
+
// Don't technically need to store this, but it may be helpful as a sanity check
uint32 code_length = 3;
}
@@ -226,4 +232,20 @@ struct Log {
uint32 warnings;
uint32 errors;
Message[] msgs;
+}
+
+
+// The WebSocket protocol
+// Server: "hey, this file changed. Do you want it?"
+// Client: *checks hash table* "uhh yeah, ok. rebuild that for me"
+// Server: "here u go"
+// This makes the client responsible for tracking which files it needs to listen for.
+smol WebsocketMessageType {
+ visibility_status_change = 1,
+ build_status_update = 2,
+
+}
+
+message WebsocketMessageContainer {
+
} \ No newline at end of file
diff --git a/src/http.zig b/src/http.zig
index 2d54a45c0..7ec448f20 100644
--- a/src/http.zig
+++ b/src/http.zig
@@ -20,7 +20,7 @@ const Response = picohttp.Response;
const Headers = picohttp.Headers;
const MimeType = @import("http/mime_type.zig");
const Bundler = bundler.ServeBundler;
-
+const Websocket = @import("./http/websocket.zig");
const js_printer = @import("js_printer.zig");
const SOCKET_FLAGS = os.SOCK_CLOEXEC;
@@ -34,7 +34,7 @@ pub fn println(comptime fmt: string, args: anytype) void {
// }
}
-const HTTPStatusCode = u9;
+const HTTPStatusCode = u10;
pub const URLPath = struct {
extname: string = "",
@@ -160,6 +160,7 @@ pub const RequestContext = struct {
url: URLPath,
conn: *tcp.Connection,
allocator: *std.mem.Allocator,
+ arena: std.heap.ArenaAllocator,
log: logger.Log,
bundler: *Bundler,
keep_alive: bool = true,
@@ -167,15 +168,24 @@ pub const RequestContext = struct {
has_written_last_header: bool = false,
has_called_done: bool = false,
mime_type: MimeType = MimeType.other,
+ controlled: bool = false,
res_headers_count: usize = 0,
pub const bundle_prefix = "__speedy";
pub fn header(ctx: *RequestContext, comptime name: anytype) ?Header {
- for (ctx.request.headers) |head| {
- if (strings.eqlComptime(head.name, name)) {
- return head;
+ if (name.len < 17) {
+ for (ctx.request.headers) |head| {
+ if (strings.eqlComptime(head.name, name)) {
+ return head;
+ }
+ }
+ } else {
+ for (ctx.request.headers) |head| {
+ if (strings.eql(head.name, name)) {
+ return head;
+ }
}
}
@@ -184,6 +194,7 @@ pub const RequestContext = struct {
pub fn printStatusLine(comptime code: HTTPStatusCode) []const u8 {
const status_text = switch (code) {
+ 101 => "ACTIVATING WEBSOCKET",
200...299 => "OK",
300...399 => "=>",
400...499 => "DID YOU KNOW YOU CAN MAKE THIS SAY WHATEVER YOU WANT",
@@ -258,16 +269,19 @@ pub const RequestContext = struct {
ctx.status = code;
}
- pub fn init(req: Request, allocator: *std.mem.Allocator, conn: *tcp.Connection, bundler_: *Bundler) !RequestContext {
- return RequestContext{
+ pub fn init(req: Request, arena: std.heap.ArenaAllocator, conn: *tcp.Connection, bundler_: *Bundler) !RequestContext {
+ var ctx = RequestContext{
.request = req,
- .allocator = allocator,
+ .arena = arena,
.bundler = bundler_,
.url = URLPath.parse(req.path),
- .log = logger.Log.init(allocator),
+ .log = undefined,
.conn = conn,
+ .allocator = undefined,
.method = Method.which(req.method) orelse return error.InvalidMethod,
};
+
+ return ctx;
}
pub fn sendNotFound(req: *RequestContext) !void {
@@ -362,10 +376,161 @@ pub const RequestContext = struct {
);
}
+ pub const WebsocketHandler = struct {
+ accept_key: [28]u8 = undefined,
+ ctx: RequestContext,
+
+ pub fn handle(self: WebsocketHandler) void {
+ var this = self;
+ _handle(&this) catch {};
+ }
+
+ fn _handle(handler: *WebsocketHandler) !void {
+ var ctx = &handler.ctx;
+ defer ctx.arena.deinit();
+ defer ctx.conn.deinit();
+ defer Output.flush();
+
+ handler.checkUpgradeHeaders() catch |err| {
+ switch (err) {
+ error.BadRequest => {
+ try ctx.sendBadRequest();
+ ctx.done();
+ },
+ else => {
+ return err;
+ },
+ }
+ };
+
+ switch (try handler.getWebsocketVersion()) {
+ 7, 8, 13 => {},
+ else => {
+ // Unsupported version
+ // Set header to indicate to the client which versions are supported
+ ctx.appendHeader("Sec-WebSocket-Version", "7,8,13");
+ try ctx.writeStatus(426);
+ try ctx.flushHeaders();
+ ctx.done();
+ return;
+ },
+ }
+
+ const key = try handler.getWebsocketAcceptKey();
+
+ ctx.appendHeader("Connection", "Upgrade");
+ ctx.appendHeader("Upgrade", "websocket");
+ ctx.appendHeader("Sec-WebSocket-Accept", key);
+ try ctx.writeStatus(101);
+ try ctx.flushHeaders();
+ Output.println("101 - Websocket connected.", .{});
+ Output.flush();
+
+ var websocket = Websocket.Websocket.create(ctx, SOCKET_FLAGS);
+ _ = try websocket.writeText("Hello!");
+
+ while (true) {
+ defer Output.flush();
+ var frame = websocket.read() catch |err| {
+ switch (err) {
+ error.ConnectionClosed => {
+ Output.prettyln("Websocket closed.", .{});
+ return;
+ },
+ else => {
+ Output.prettyErrorln("<r><red>ERR:<r> <b>{s}<r>", .{err});
+ },
+ }
+ return;
+ };
+ switch (frame.header.opcode) {
+ .Close => {
+ Output.prettyln("Websocket closed.", .{});
+ return;
+ },
+ .Text => {
+ Output.print("Data: {s}", .{frame.data});
+ _ = try websocket.writeText(frame.data);
+ },
+ .Ping => {
+ var pong = frame;
+ pong.header.opcode = .Pong;
+ _ = try websocket.writeDataFrame(pong);
+ },
+ else => {
+ Output.prettyErrorln("Websocket unknown opcode: {s}", .{@tagName(frame.header.opcode)});
+ },
+ }
+ }
+ }
+
+ fn checkUpgradeHeaders(
+ self: *WebsocketHandler,
+ ) !void {
+ var request: *RequestContext = &self.ctx;
+ const upgrade_header = request.header("Upgrade") orelse return error.BadRequest;
+
+ if (!std.ascii.eqlIgnoreCase(upgrade_header.value, "websocket")) {
+ return error.BadRequest; // Can only upgrade to websocket
+ }
+
+ // Some proxies/load balancers will mess with the connection header
+ // and browsers also send multiple values here
+ const connection_header = request.header("Connection") orelse return error.BadRequest;
+ var it = std.mem.split(connection_header.value, ",");
+ while (it.next()) |part| {
+ const conn = std.mem.trim(u8, part, " ");
+ if (std.ascii.eqlIgnoreCase(conn, "upgrade")) {
+ return;
+ }
+ }
+ return error.BadRequest; // Connection must be upgrade
+ }
+
+ fn getWebsocketVersion(
+ self: *WebsocketHandler,
+ ) !u8 {
+ var request: *RequestContext = &self.ctx;
+ const v = request.header("Sec-WebSocket-Version") orelse return error.BadRequest;
+ return std.fmt.parseInt(u8, v.value, 10) catch error.BadRequest;
+ }
+
+ fn getWebsocketAcceptKey(
+ self: *WebsocketHandler,
+ ) ![]const u8 {
+ var request: *RequestContext = &self.ctx;
+ const key = (request.header("Sec-WebSocket-Key") orelse return error.BadRequest).value;
+ if (key.len < 8) {
+ return error.BadRequest;
+ }
+
+ var hash = std.crypto.hash.Sha1.init(.{});
+ var out: [20]u8 = undefined;
+ hash.update(key);
+ hash.update("258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
+ hash.final(&out);
+
+ // Encode it
+ return std.base64.standard_encoder.encode(&self.accept_key, &out);
+ }
+ };
+
+ pub fn handleWebsocket(ctx: *RequestContext) anyerror!void {
+ ctx.controlled = true;
+ var handler = WebsocketHandler{ .ctx = ctx.* };
+ _ = try std.Thread.spawn(WebsocketHandler.handle, handler);
+ }
+
pub fn handleGet(ctx: *RequestContext) !void {
if (strings.eqlComptime(ctx.url.extname, "jsb") and ctx.bundler.options.node_modules_bundle != null) {
return try ctx.sendJSB();
}
+
+ if (strings.eqlComptime(ctx.url.path, "_api")) {
+ try ctx.handleWebsocket();
+ return;
+ }
+
const result = try ctx.bundler.buildFile(
&ctx.log,
ctx.allocator,
@@ -787,7 +952,8 @@ pub const Server = struct {
try listener.listen(1280);
const addr = try listener.getLocalAddress();
- Output.println("Started Speedy at http://{s}", .{addr});
+ Output.prettyln("<r>Started Speedy at <b><cyan>http://{s}<r>", .{addr});
+ Output.flush();
// var listener_handle = try std.os.kqueue();
// var change_list = std.mem.zeroes([2]os.Kevent);
@@ -796,6 +962,7 @@ pub const Server = struct {
// var eventlist: [128]os.Kevent = undefined;
while (true) {
+ defer Output.flush();
var conn = listener.accept(.{ .close_on_exec = true }) catch |err| {
continue;
};
@@ -831,14 +998,21 @@ pub const Server = struct {
};
var request_arena = std.heap.ArenaAllocator.init(server.allocator);
- defer request_arena.deinit();
-
- var req_ctx = RequestContext.init(req, &request_arena.allocator, conn, &server.bundler) catch |err| {
- Output.printErrorln("FAIL [{s}] - {s}: {s}", .{ @errorName(err), req.method, req.path });
+ var req_ctx: RequestContext = undefined;
+ defer {
+ if (!req_ctx.controlled) {
+ req_ctx.arena.deinit();
+ }
+ }
+ req_ctx = RequestContext.init(req, request_arena, conn, &server.bundler) catch |err| {
+ Output.printErrorln("<r>[<red>{s}<r>] - <b>{s}<r>: {s}", .{ @errorName(err), req.method, req.path });
conn.client.deinit();
return;
};
+ req_ctx.allocator = &req_ctx.arena.allocator;
+ req_ctx.log = logger.Log.init(req_ctx.allocator);
+
if (FeatureFlags.keep_alive) {
if (req_ctx.header("Connection")) |connection| {
req_ctx.keep_alive = strings.eqlInsensitive(connection.value, "keep-alive");
@@ -861,16 +1035,18 @@ pub const Server = struct {
}
};
- const status = req_ctx.status orelse @intCast(HTTPStatusCode, 500);
+ if (!req_ctx.controlled) {
+ const status = req_ctx.status orelse @intCast(HTTPStatusCode, 500);
- if (req_ctx.log.msgs.items.len == 0) {
- println("{d} – {s} {s} as {s}", .{ status, @tagName(req_ctx.method), req.path, req_ctx.mime_type.value });
- } else {
- println("{s} {s}", .{ @tagName(req_ctx.method), req.path });
- for (req_ctx.log.msgs.items) |msg| {
- msg.writeFormat(Output.errorWriter()) catch continue;
+ if (req_ctx.log.msgs.items.len == 0) {
+ println("{d} – {s} {s} as {s}", .{ status, @tagName(req_ctx.method), req.path, req_ctx.mime_type.value });
+ } else {
+ println("{s} {s}", .{ @tagName(req_ctx.method), req.path });
+ for (req_ctx.log.msgs.items) |msg| {
+ msg.writeFormat(Output.errorWriter()) catch continue;
+ }
+ req_ctx.log.deinit();
}
- req_ctx.log.deinit();
}
}
diff --git a/src/http/websocket.zig b/src/http/websocket.zig
new file mode 100644
index 000000000..9fb38a92f
--- /dev/null
+++ b/src/http/websocket.zig
@@ -0,0 +1,310 @@
+// This code is based on https://github.com/frmdstryr/zhp/blob/a4b5700c289c3619647206144e10fb414113a888/src/websocket.zig
+// Thank you @frmdstryr.
+const std = @import("std");
+const native_endian = std.Target.current.cpu.arch.endian();
+
+const tcp = std.x.net.tcp;
+const ip = std.x.net.ip;
+
+const IPv4 = std.x.os.IPv4;
+const IPv6 = std.x.os.IPv6;
+const Socket = std.x.os.Socket;
+const os = std.os;
+usingnamespace @import("../http.zig");
+usingnamespace @import("../global.zig");
+
+pub const Opcode = enum(u4) {
+ Continue = 0x0,
+ Text = 0x1,
+ Binary = 0x2,
+ Res3 = 0x3,
+ Res4 = 0x4,
+ Res5 = 0x5,
+ Res6 = 0x6,
+ Res7 = 0x7,
+ Close = 0x8,
+ Ping = 0x9,
+ Pong = 0xA,
+ ResB = 0xB,
+ ResC = 0xC,
+ ResD = 0xD,
+ ResE = 0xE,
+ ResF = 0xF,
+
+ pub fn isControl(opcode: Opcode) bool {
+ return @enumToInt(opcode) & 0x8 != 0;
+ }
+};
+
+pub const WebsocketHeader = packed struct {
+ len: u7,
+ mask: bool,
+ opcode: Opcode,
+ rsv3: u1 = 0,
+ rsv2: u1 = 0,
+ compressed: bool = false, // rsv1
+ final: bool = true,
+
+ pub fn packLength(length: usize) u7 {
+ return switch (length) {
+ 0...126 => @truncate(u7, length),
+ 127...0xFFFF => 126,
+ else => 127,
+ };
+ }
+};
+
+pub const WebsocketDataFrame = struct {
+ header: WebsocketHeader,
+ mask: [4]u8 = undefined,
+ data: []const u8,
+
+ pub fn isValid(dataframe: WebsocketDataFrame) bool {
+ // Validate control frame
+ if (dataframe.header.opcode.isControl()) {
+ if (!dataframe.header.final) {
+ return false; // Control frames cannot be fragmented
+ }
+ if (dataframe.data.len > 125) {
+ return false; // Control frame payloads cannot exceed 125 bytes
+ }
+ }
+
+ // Validate header len field
+ const expected = switch (dataframe.data.len) {
+ 0...126 => dataframe.data.len,
+ 127...0xFFFF => 126,
+ else => 127,
+ };
+ return dataframe.header.len == expected;
+ }
+};
+
+// Create a buffered writer
+// TODO: This will still split packets
+pub fn Writer(comptime size: usize, comptime opcode: Opcode) type {
+ const WriterType = switch (opcode) {
+ .Text => Websocket.TextFrameWriter,
+ .Binary => Websocket.BinaryFrameWriter,
+ else => @compileError("Unsupported writer opcode"),
+ };
+ return std.io.BufferedWriter(size, WriterType);
+}
+
+const ReadStream = std.io.FixedBufferStream([]u8);
+
+pub const Websocket = struct {
+ pub const WriteError = error{
+ InvalidMessage,
+ MessageTooLarge,
+ EndOfStream,
+ } || std.fs.File.WriteError;
+
+ request: *RequestContext,
+
+ err: ?anyerror = null,
+ buf: [4096]u8 = undefined,
+ read_stream: ReadStream,
+ reader: ReadStream.Reader,
+ flags: u32 = 0,
+ pub fn create(
+ ctx: *RequestContext,
+ comptime flags: u32,
+ ) Websocket {
+ var stream = ReadStream{
+ .buffer = &[_]u8{},
+ .pos = 0,
+ };
+ var socket = Websocket{
+ .read_stream = undefined,
+ .reader = undefined,
+ .request = ctx,
+ .flags = flags,
+ };
+
+ socket.read_stream = stream;
+ socket.reader = socket.read_stream.reader();
+ return socket;
+ }
+
+ // ------------------------------------------------------------------------
+ // Stream API
+ // ------------------------------------------------------------------------
+ pub const TextFrameWriter = std.io.Writer(*Websocket, WriteError, Websocket.writeText);
+ pub const BinaryFrameWriter = std.io.Writer(*Websocket, WriteError, Websocket.writeBinary);
+
+ // A buffered writer that will buffer up to size bytes before writing out
+ pub fn newWriter(self: *Websocket, comptime size: usize, comptime opcode: Opcode) Writer(size, opcode) {
+ const BufferedWriter = Writer(size, opcode);
+ const frame_writer = switch (opcode) {
+ .Text => TextFrameWriter{ .context = self },
+ .Binary => BinaryFrameWriter{ .context = self },
+ else => @compileError("Unsupported writer type"),
+ };
+ return BufferedWriter{ .unbuffered_writer = frame_writer };
+ }
+
+ // Close and send the status
+ pub fn close(self: *Websocket, code: u16) !void {
+ const c = if (native_endian == .Big) code else @byteSwap(u16, code);
+ const data = @bitCast([2]u8, c);
+ _ = try self.writeMessage(.Close, &data);
+ }
+
+ // ------------------------------------------------------------------------
+ // Low level API
+ // ------------------------------------------------------------------------
+
+ // Flush any buffered data out the underlying stream
+ pub fn flush(self: *Websocket) !void {
+ try self.io.flush();
+ }
+
+ pub fn writeText(self: *Websocket, data: []const u8) !usize {
+ return self.writeMessage(.Text, data);
+ }
+
+ pub fn writeBinary(self: *Websocket, data: []const u8) !usize {
+ return self.writeMessage(.Binary, data);
+ }
+
+ // Write a final message packet with the given opcode
+ pub fn writeMessage(self: *Websocket, opcode: Opcode, message: []const u8) !usize {
+ return self.writeSplitMessage(opcode, true, message);
+ }
+
+ // Write a message packet with the given opcode and final flag
+ pub fn writeSplitMessage(self: *Websocket, opcode: Opcode, final: bool, message: []const u8) !usize {
+ return self.writeDataFrame(WebsocketDataFrame{
+ .header = WebsocketHeader{
+ .final = final,
+ .opcode = opcode,
+ .mask = false, // Server to client is not masked
+ .len = WebsocketHeader.packLength(message.len),
+ },
+ .data = message,
+ });
+ }
+
+ // Write a raw data frame
+ pub fn writeDataFrame(self: *Websocket, dataframe: WebsocketDataFrame) !usize {
+ var stream = self.request.conn.client.writer(self.flags);
+
+ if (!dataframe.isValid()) return error.InvalidMessage;
+
+ try stream.writeIntBig(u16, @bitCast(u16, dataframe.header));
+
+ // Write extended length if needed
+ const n = dataframe.data.len;
+ switch (n) {
+ 0...126 => {}, // Included in header
+ 127...0xFFFF => try stream.writeIntBig(u16, @truncate(u16, n)),
+ else => try stream.writeIntBig(u64, n),
+ }
+
+ // TODO: Handle compression
+ if (dataframe.header.compressed) return error.InvalidMessage;
+
+ if (dataframe.header.mask) {
+ const mask = &dataframe.mask;
+ try stream.writeAll(mask);
+
+ // Encode
+ for (dataframe.data) |c, i| {
+ try stream.writeByte(c ^ mask[i % 4]);
+ }
+ } else {
+ try stream.writeAll(dataframe.data);
+ }
+
+ // try self.io.flush();
+
+ return dataframe.data.len;
+ }
+
+ pub fn read(self: *Websocket) !WebsocketDataFrame {
+ @memset(&self.buf, 0, self.buf.len);
+
+ // Read and retry if we hit the end of the stream buffer
+ var start = try self.request.conn.client.read(&self.buf, self.flags);
+ if (start == 0) {
+ return error.ConnectionClosed;
+ }
+
+ self.read_stream.pos = start;
+ return try self.readDataFrameInBuffer();
+ }
+
+ pub fn eatAt(self: *Websocket, offset: usize, _len: usize) []u8 {
+ const len = std.math.min(self.read_stream.buffer.len, _len);
+ self.read_stream.pos = len;
+ return self.read_stream.buffer[offset..len];
+ }
+
+ // Read assuming everything can fit before the stream hits the end of
+ // it's buffer
+ pub fn readDataFrameInBuffer(
+ self: *Websocket,
+ ) !WebsocketDataFrame {
+ var buf: []u8 = self.buf[0..];
+
+ const header_bytes = buf[0..2];
+ var header = std.mem.zeroes(WebsocketHeader);
+ header.final = header_bytes[0] & 0x80 == 0x80;
+ // header.rsv1 = header_bytes[0] & 0x40 == 0x40;
+ // header.rsv2 = header_bytes[0] & 0x20;
+ // header.rsv3 = header_bytes[0] & 0x10;
+ header.opcode = @intToEnum(Opcode, @truncate(u4, header_bytes[0]));
+ header.mask = header_bytes[1] & 0x80 == 0x80;
+ header.len = @truncate(u7, header_bytes[1]);
+
+ // Decode length
+ var length: u64 = header.len;
+
+ switch (header.len) {
+ 126 => {
+ length = std.mem.readIntBig(u16, buf[2..4]);
+ buf = buf[4..];
+ },
+ 127 => {
+ length = std.mem.readIntBig(u64, buf[2..10]);
+ // Most significant bit must be 0
+ if (length >> 63 == 1) {
+ return error.InvalidMessage;
+ }
+ buf = buf[10..];
+ },
+ else => {
+ buf = buf[2..];
+ },
+ }
+
+ const start: usize = if (header.mask) 4 else 0;
+
+ const end = start + length;
+
+ if (end > self.read_stream.pos) {
+ var extend_length = try self.request.conn.client.read(self.buf[self.read_stream.pos..], self.flags);
+ if (self.read_stream.pos + extend_length > self.buf.len) {
+ return error.MessageTooLarge;
+ }
+ self.read_stream.pos += extend_length;
+ }
+
+ var data = buf[start..end];
+
+ if (header.mask) {
+ const mask = buf[0..4];
+ // Decode data in place
+ for (data) |c, i| {
+ data[i] ^= mask[i % 4];
+ }
+ }
+
+ return WebsocketDataFrame{
+ .header = header,
+ .mask = if (header.mask) buf[0..4].* else undefined,
+ .data = data,
+ };
+ }
+};