aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGravatar Ciro Spaciari <ciro.spaciari@gmail.com> 2023-05-11 03:55:39 -0300
committerGravatar GitHub <noreply@github.com> 2023-05-10 23:55:39 -0700
commitf9831e1f6f9cbc29559baa15ae87a229cb91d798 (patch)
tree51883211427624ef63531b888ab2a889c9a7d6b2
parent8f4a5903abb20c5c874365d6ac552497ea89cd04 (diff)
downloadbun-f9831e1f6f9cbc29559baa15ae87a229cb91d798.tar.gz
bun-f9831e1f6f9cbc29559baa15ae87a229cb91d798.tar.zst
bun-f9831e1f6f9cbc29559baa15ae87a229cb91d798.zip
feat(fetch) add redirect: 'error' support (#2845)
* add redirect: 'error' support * fix typo * fix typo * refactor FetchRedirect enum * fix FetchRedirect * updated
-rw-r--r--misctools/fetch.zig12
-rw-r--r--src/bun.js/webcore/response.zig27
-rw-r--r--src/cli/create_command.zig8
-rw-r--r--src/cli/upgrade_command.zig4
-rw-r--r--src/http_client_async.zig211
-rw-r--r--src/install/install.zig2
-rw-r--r--test/js/web/fetch/fetch.test.ts21
7 files changed, 158 insertions, 127 deletions
diff --git a/misctools/fetch.zig b/misctools/fetch.zig
index 991491011..38a948f52 100644
--- a/misctools/fetch.zig
+++ b/misctools/fetch.zig
@@ -187,17 +187,7 @@ pub fn main() anyerror!void {
var ctx = try default_allocator.create(HTTP.HTTPChannelContext);
ctx.* = .{
.channel = channel,
- .http = try HTTP.AsyncHTTP.init(
- default_allocator,
- args.method,
- args.url,
- args.headers,
- args.headers_buf,
- response_body_string,
- args.body,
-
- 0,
- ),
+ .http = try HTTP.AsyncHTTP.init(default_allocator, args.method, args.url, args.headers, args.headers_buf, response_body_string, args.body, 0, HTTP.FetchRedirect.follow),
};
ctx.http.callback = HTTP.HTTPChannelContext.callback;
var batch = HTTPThread.Batch{};
diff --git a/src/bun.js/webcore/response.zig b/src/bun.js/webcore/response.zig
index ed658b12b..d283df5a0 100644
--- a/src/bun.js/webcore/response.zig
+++ b/src/bun.js/webcore/response.zig
@@ -5,6 +5,7 @@ const RequestContext = @import("../../http.zig").RequestContext;
const MimeType = @import("../../http.zig").MimeType;
const ZigURL = @import("../../url.zig").URL;
const HTTPClient = @import("root").bun.HTTP;
+const FetchRedirect = HTTPClient.FetchRedirect;
const NetworkThread = HTTPClient.NetworkThread;
const AsyncIO = NetworkThread.AsyncIO;
const JSC = @import("root").bun.JSC;
@@ -847,9 +848,9 @@ pub const Fetch = struct {
FetchTasklet.callback,
).init(
fetch_tasklet,
- ), proxy, if (fetch_tasklet.signal != null) &fetch_tasklet.aborted else null, fetch_options.hostname);
+ ), proxy, if (fetch_tasklet.signal != null) &fetch_tasklet.aborted else null, fetch_options.hostname, fetch_options.redirect_type);
- if (!fetch_options.follow_redirects) {
+ if (fetch_options.redirect_type != FetchRedirect.follow) {
fetch_tasklet.http.?.client.remaining_redirect_count = 0;
}
@@ -884,7 +885,7 @@ pub const Fetch = struct {
disable_keepalive: bool,
url: ZigURL,
verbose: bool = false,
- follow_redirects: bool = true,
+ redirect_type: FetchRedirect = FetchRedirect.follow,
proxy: ?ZigURL = null,
url_proxy_buffer: []const u8 = "",
signal: ?*JSC.WebCore.AbortSignal = null,
@@ -960,7 +961,7 @@ pub const Fetch = struct {
var disable_keepalive = false;
var verbose = script_ctx.log.level.atLeast(.debug);
var proxy: ?ZigURL = null;
- var follow_redirects = true;
+ var redirect_type: FetchRedirect = FetchRedirect.follow;
var signal: ?*JSC.WebCore.AbortSignal = null;
// Custom Hostname
var hostname: ?[]u8 = null;
@@ -1028,10 +1029,10 @@ pub const Fetch = struct {
}
}
- if (options.get(ctx, "redirect")) |redirect_value| {
- if (redirect_value.getZigString(globalThis).eqlComptime("manual")) {
- follow_redirects = false;
- }
+ if (options.getOptionalEnum(ctx, "redirect", FetchRedirect) catch {
+ return .zero;
+ }) |redirect_value| {
+ redirect_type = redirect_value;
}
if (options.get(ctx, "keepalive")) |keepalive_value| {
@@ -1158,10 +1159,10 @@ pub const Fetch = struct {
}
}
- if (options.get(ctx, "redirect")) |redirect_value| {
- if (redirect_value.getZigString(globalThis).eqlComptime("manual")) {
- follow_redirects = false;
- }
+ if (options.getOptionalEnum(ctx, "redirect", FetchRedirect) catch {
+ return .zero;
+ }) |redirect_value| {
+ redirect_type = redirect_value;
}
if (options.get(ctx, "keepalive")) |keepalive_value| {
@@ -1331,7 +1332,7 @@ pub const Fetch = struct {
.timeout = std.time.ns_per_hour,
.disable_keepalive = disable_keepalive,
.disable_timeout = disable_timeout,
- .follow_redirects = follow_redirects,
+ .redirect_type = redirect_type,
.verbose = verbose,
.proxy = proxy,
.url_proxy_buffer = url_proxy_buffer,
diff --git a/src/cli/create_command.zig b/src/cli/create_command.zig
index 60261a446..fcbee9dbb 100644
--- a/src/cli/create_command.zig
+++ b/src/cli/create_command.zig
@@ -1852,7 +1852,7 @@ pub const Example = struct {
// ensure very stable memory address
var async_http: *HTTP.AsyncHTTP = ctx.allocator.create(HTTP.AsyncHTTP) catch unreachable;
- async_http.* = HTTP.AsyncHTTP.initSync(ctx.allocator, .GET, api_url, header_entries, headers_buf, mutable, "", 60 * std.time.ns_per_min, http_proxy, null);
+ async_http.* = HTTP.AsyncHTTP.initSync(ctx.allocator, .GET, api_url, header_entries, headers_buf, mutable, "", 60 * std.time.ns_per_min, http_proxy, null, HTTP.FetchRedirect.follow);
async_http.client.progress_node = progress;
const response = try async_http.sendSync(true);
@@ -1916,7 +1916,7 @@ pub const Example = struct {
// ensure very stable memory address
var async_http: *HTTP.AsyncHTTP = ctx.allocator.create(HTTP.AsyncHTTP) catch unreachable;
- async_http.* = HTTP.AsyncHTTP.initSync(ctx.allocator, .GET, url, .{}, "", mutable, "", 60 * std.time.ns_per_min, http_proxy, null);
+ async_http.* = HTTP.AsyncHTTP.initSync(ctx.allocator, .GET, url, .{}, "", mutable, "", 60 * std.time.ns_per_min, http_proxy, null, HTTP.FetchRedirect.follow);
async_http.client.progress_node = progress;
var response = try async_http.sendSync(true);
@@ -1992,7 +1992,7 @@ pub const Example = struct {
http_proxy = env_loader.getHttpProxy(parsed_tarball_url);
- async_http.* = HTTP.AsyncHTTP.initSync(ctx.allocator, .GET, parsed_tarball_url, .{}, "", mutable, "", 60 * std.time.ns_per_min, http_proxy, null);
+ async_http.* = HTTP.AsyncHTTP.initSync(ctx.allocator, .GET, parsed_tarball_url, .{}, "", mutable, "", 60 * std.time.ns_per_min, http_proxy, null, HTTP.FetchRedirect.follow);
async_http.client.progress_node = progress;
refresher.maybeRefresh();
@@ -2022,7 +2022,7 @@ pub const Example = struct {
var mutable = try ctx.allocator.create(MutableString);
mutable.* = try MutableString.init(ctx.allocator, 2048);
- async_http.* = HTTP.AsyncHTTP.initSync(ctx.allocator, .GET, url, .{}, "", mutable, "", 60 * std.time.ns_per_min, http_proxy, null);
+ async_http.* = HTTP.AsyncHTTP.initSync(ctx.allocator, .GET, url, .{}, "", mutable, "", 60 * std.time.ns_per_min, http_proxy, null, HTTP.FetchRedirect.follow);
if (Output.enable_ansi_colors) {
async_http.client.progress_node = progress_node;
diff --git a/src/cli/upgrade_command.zig b/src/cli/upgrade_command.zig
index 665508833..2f06fe674 100644
--- a/src/cli/upgrade_command.zig
+++ b/src/cli/upgrade_command.zig
@@ -223,7 +223,7 @@ pub const UpgradeCommand = struct {
// ensure very stable memory address
var async_http: *HTTP.AsyncHTTP = allocator.create(HTTP.AsyncHTTP) catch unreachable;
- async_http.* = HTTP.AsyncHTTP.initSync(allocator, .GET, api_url, header_entries, headers_buf, &metadata_body, "", 60 * std.time.ns_per_min, http_proxy, null);
+ async_http.* = HTTP.AsyncHTTP.initSync(allocator, .GET, api_url, header_entries, headers_buf, &metadata_body, "", 60 * std.time.ns_per_min, http_proxy, null, HTTP.FetchRedirect.follow);
if (!silent) async_http.client.progress_node = progress;
const response = try async_http.sendSync(true);
@@ -454,7 +454,7 @@ pub const UpgradeCommand = struct {
var zip_file_buffer = try ctx.allocator.create(MutableString);
zip_file_buffer.* = try MutableString.init(ctx.allocator, @max(version.size, 1024));
- async_http.* = HTTP.AsyncHTTP.initSync(ctx.allocator, .GET, zip_url, .{}, "", zip_file_buffer, "", timeout, http_proxy, null);
+ async_http.* = HTTP.AsyncHTTP.initSync(ctx.allocator, .GET, zip_url, .{}, "", zip_file_buffer, "", timeout, http_proxy, null, HTTP.FetchRedirect.follow);
async_http.client.timeout = timeout;
async_http.client.progress_node = progress;
const response = try async_http.sendSync(true);
diff --git a/src/http_client_async.zig b/src/http_client_async.zig
index c2998f6c6..3f8b6998f 100644
--- a/src/http_client_async.zig
+++ b/src/http_client_async.zig
@@ -59,6 +59,18 @@ var shared_response_headers_buf: [256]picohttp.Header = undefined;
const end_of_chunked_http1_1_encoding_response_body = "0\r\n\r\n";
+pub const FetchRedirect = enum(u8) {
+ follow,
+ manual,
+ @"error",
+
+ pub const Map = bun.ComptimeStringMap(FetchRedirect, .{
+ .{ "follow", .follow },
+ .{ "manual", .manual },
+ .{ "error", .@"error" },
+ });
+};
+
const ProxySSLData = struct {
buffer: std.ArrayList(u8),
partial: bool,
@@ -1011,7 +1023,7 @@ allocator: std.mem.Allocator,
verbose: bool = Environment.isTest,
remaining_redirect_count: i8 = default_redirect_count,
allow_retry: bool = false,
-follow_redirects: bool = true,
+redirect_type: FetchRedirect = FetchRedirect.follow,
redirect: ?*URLBufferPool.Node = null,
timeout: usize = 0,
progress_node: ?*std.Progress.Node = null,
@@ -1264,6 +1276,7 @@ pub const AsyncHTTP = struct {
http_proxy: ?URL,
signal: ?*std.atomic.Atomic(bool),
hostname: ?[]u8,
+ redirect_type: FetchRedirect,
) AsyncHTTP {
var this = AsyncHTTP{ .allocator = allocator, .url = url, .method = method, .request_headers = headers, .request_header_buf = headers_buf, .request_body = request_body, .response_buffer = response_buffer, .completion_callback = callback, .http_proxy = http_proxy, .async_http_id = if (signal != null) async_http_id.fetchAdd(1, .Monotonic) else 0 };
@@ -1271,6 +1284,7 @@ pub const AsyncHTTP = struct {
this.client.async_http_id = this.async_http_id;
this.client.timeout = timeout;
this.client.http_proxy = this.http_proxy;
+ this.client.redirect_type = redirect_type;
this.timeout = timeout;
if (http_proxy) |proxy| {
@@ -1337,8 +1351,8 @@ pub const AsyncHTTP = struct {
return this;
}
- pub fn initSync(allocator: std.mem.Allocator, method: Method, url: URL, headers: Headers.Entries, headers_buf: string, response_buffer: *MutableString, request_body: []const u8, timeout: usize, http_proxy: ?URL, hostname: ?[]u8) AsyncHTTP {
- return @This().init(allocator, method, url, headers, headers_buf, response_buffer, request_body, timeout, undefined, http_proxy, null, hostname);
+ pub fn initSync(allocator: std.mem.Allocator, method: Method, url: URL, headers: Headers.Entries, headers_buf: string, response_buffer: *MutableString, request_body: []const u8, timeout: usize, http_proxy: ?URL, hostname: ?[]u8, redirect_type: FetchRedirect) AsyncHTTP {
+ return @This().init(allocator, method, url, headers, headers_buf, response_buffer, request_body, timeout, undefined, http_proxy, null, hostname, redirect_type);
}
fn reset(this: *AsyncHTTP) !void {
@@ -1617,7 +1631,7 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {
pub fn doRedirect(this: *HTTPClient) void {
var body_out_str = this.state.body_out_str.?;
this.remaining_redirect_count -|= 1;
- std.debug.assert(this.follow_redirects);
+ std.debug.assert(this.redirect_type == FetchRedirect.follow);
if (this.remaining_redirect_count == 0) {
this.fail(error.TooManyRedirects);
@@ -2713,119 +2727,122 @@ pub fn handleResponseMetadata(
}
const is_redirect = this.state.pending_response.status_code >= 300 and this.state.pending_response.status_code <= 399;
+ if (is_redirect) {
+ if (this.redirect_type == FetchRedirect.follow and location.len > 0 and this.remaining_redirect_count > 0) {
+ switch (this.state.pending_response.status_code) {
+ 302, 301, 307, 308, 303 => {
+ if (strings.indexOf(location, "://")) |i| {
+ var url_buf = URLBufferPool.get(default_allocator);
+
+ const is_protocol_relative = i == 0;
+ const protocol_name = if (is_protocol_relative) this.url.displayProtocol() else location[0..i];
+ const is_http = strings.eqlComptime(protocol_name, "http");
+ if (is_http or strings.eqlComptime(protocol_name, "https")) {} else {
+ return error.UnsupportedRedirectProtocol;
+ }
- if (is_redirect and this.follow_redirects and location.len > 0 and this.remaining_redirect_count > 0) {
- switch (this.state.pending_response.status_code) {
- 302, 301, 307, 308, 303 => {
- if (strings.indexOf(location, "://")) |i| {
- var url_buf = URLBufferPool.get(default_allocator);
-
- const is_protocol_relative = i == 0;
- const protocol_name = if (is_protocol_relative) this.url.displayProtocol() else location[0..i];
- const is_http = strings.eqlComptime(protocol_name, "http");
- if (is_http or strings.eqlComptime(protocol_name, "https")) {} else {
- return error.UnsupportedRedirectProtocol;
- }
-
- if ((protocol_name.len * @as(usize, @boolToInt(is_protocol_relative))) + location.len > url_buf.data.len) {
- return error.RedirectURLTooLong;
- }
+ if ((protocol_name.len * @as(usize, @boolToInt(is_protocol_relative))) + location.len > url_buf.data.len) {
+ return error.RedirectURLTooLong;
+ }
- deferred_redirect.* = this.redirect;
- var url_buf_len = location.len;
- if (is_protocol_relative) {
- if (is_http) {
- url_buf.data[0.."http".len].* = "http".*;
- bun.copy(u8, url_buf.data["http".len..], location);
- url_buf_len += "http".len;
+ deferred_redirect.* = this.redirect;
+ var url_buf_len = location.len;
+ if (is_protocol_relative) {
+ if (is_http) {
+ url_buf.data[0.."http".len].* = "http".*;
+ bun.copy(u8, url_buf.data["http".len..], location);
+ url_buf_len += "http".len;
+ } else {
+ url_buf.data[0.."https".len].* = "https".*;
+ bun.copy(u8, url_buf.data["https".len..], location);
+ url_buf_len += "https".len;
+ }
} else {
- url_buf.data[0.."https".len].* = "https".*;
- bun.copy(u8, url_buf.data["https".len..], location);
- url_buf_len += "https".len;
+ bun.copy(u8, &url_buf.data, location);
}
- } else {
- bun.copy(u8, &url_buf.data, location);
- }
- this.url = URL.parse(url_buf.data[0..url_buf_len]);
- this.redirect = url_buf;
- } else if (strings.hasPrefixComptime(location, "//")) {
- var url_buf = URLBufferPool.get(default_allocator);
+ this.url = URL.parse(url_buf.data[0..url_buf_len]);
+ this.redirect = url_buf;
+ } else if (strings.hasPrefixComptime(location, "//")) {
+ var url_buf = URLBufferPool.get(default_allocator);
- const protocol_name = this.url.displayProtocol();
+ const protocol_name = this.url.displayProtocol();
- if (protocol_name.len + 1 + location.len > url_buf.data.len) {
- return error.RedirectURLTooLong;
- }
+ if (protocol_name.len + 1 + location.len > url_buf.data.len) {
+ return error.RedirectURLTooLong;
+ }
- deferred_redirect.* = this.redirect;
- var url_buf_len = location.len;
+ deferred_redirect.* = this.redirect;
+ var url_buf_len = location.len;
- if (strings.eqlComptime(protocol_name, "http")) {
- url_buf.data[0.."http:".len].* = "http:".*;
- bun.copy(u8, url_buf.data["http:".len..], location);
- url_buf_len += "http:".len;
- } else {
- url_buf.data[0.."https:".len].* = "https:".*;
- bun.copy(u8, url_buf.data["https:".len..], location);
- url_buf_len += "https:".len;
- }
+ if (strings.eqlComptime(protocol_name, "http")) {
+ url_buf.data[0.."http:".len].* = "http:".*;
+ bun.copy(u8, url_buf.data["http:".len..], location);
+ url_buf_len += "http:".len;
+ } else {
+ url_buf.data[0.."https:".len].* = "https:".*;
+ bun.copy(u8, url_buf.data["https:".len..], location);
+ url_buf_len += "https:".len;
+ }
- this.url = URL.parse(url_buf.data[0..url_buf_len]);
- this.redirect = url_buf;
- } else {
- var url_buf = URLBufferPool.get(default_allocator);
- const original_url = this.url;
- const port = original_url.getPortAuto();
-
- if (port == original_url.getDefaultPort()) {
- this.url = URL.parse(std.fmt.bufPrint(
- &url_buf.data,
- "{s}://{s}{s}",
- .{ original_url.displayProtocol(), original_url.displayHostname(), location },
- ) catch return error.RedirectURLTooLong);
+ this.url = URL.parse(url_buf.data[0..url_buf_len]);
+ this.redirect = url_buf;
} else {
- this.url = URL.parse(std.fmt.bufPrint(
- &url_buf.data,
- "{s}://{s}:{d}{s}",
- .{ original_url.displayProtocol(), original_url.displayHostname(), port, location },
- ) catch return error.RedirectURLTooLong);
- }
+ var url_buf = URLBufferPool.get(default_allocator);
+ const original_url = this.url;
+ const port = original_url.getPortAuto();
+
+ if (port == original_url.getDefaultPort()) {
+ this.url = URL.parse(std.fmt.bufPrint(
+ &url_buf.data,
+ "{s}://{s}{s}",
+ .{ original_url.displayProtocol(), original_url.displayHostname(), location },
+ ) catch return error.RedirectURLTooLong);
+ } else {
+ this.url = URL.parse(std.fmt.bufPrint(
+ &url_buf.data,
+ "{s}://{s}:{d}{s}",
+ .{ original_url.displayProtocol(), original_url.displayHostname(), port, location },
+ ) catch return error.RedirectURLTooLong);
+ }
- deferred_redirect.* = this.redirect;
- this.redirect = url_buf;
- }
+ deferred_redirect.* = this.redirect;
+ this.redirect = url_buf;
+ }
- // Note: RFC 1945 and RFC 2068 specify that the client is not allowed to change
- // the method on the redirected request. However, most existing user agent
- // implementations treat 302 as if it were a 303 response, performing a GET on
- // the Location field-value regardless of the original request method. The
- // status codes 303 and 307 have been added for servers that wish to make
- // unambiguously clear which kind of reaction is expected of the client.
- if (response.status_code == 302) {
- switch (this.method) {
- .GET, .HEAD => {},
- else => {
- this.method = .GET;
- },
+ // Note: RFC 1945 and RFC 2068 specify that the client is not allowed to change
+ // the method on the redirected request. However, most existing user agent
+ // implementations treat 302 as if it were a 303 response, performing a GET on
+ // the Location field-value regardless of the original request method. The
+ // status codes 303 and 307 have been added for servers that wish to make
+ // unambiguously clear which kind of reaction is expected of the client.
+ if (response.status_code == 302) {
+ switch (this.method) {
+ .GET, .HEAD => {},
+ else => {
+ this.method = .GET;
+ },
+ }
}
- }
- // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303
- if (response.status_code == 303 and this.method != .HEAD) {
- this.method = .GET;
- }
+ // https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/303
+ if (response.status_code == 303 and this.method != .HEAD) {
+ this.method = .GET;
+ }
- return error.Redirect;
- },
- else => {},
+ return error.Redirect;
+ },
+ else => {},
+ }
+ } else if (this.redirect_type == FetchRedirect.@"error") {
+ return error.UnexpectedRedirect;
+ } else if (this.redirect_type == FetchRedirect.manual) {
+ this.state.response_stage = if (this.state.transfer_encoding == .chunked) .body_chunk else .body;
+ return false;
}
}
this.state.response_stage = if (this.state.transfer_encoding == .chunked) .body_chunk else .body;
- if (is_redirect and !this.follow_redirects)
- return true;
-
return this.method.hasBody() and (this.state.body_size > 0 or this.state.transfer_encoding == .chunked);
}
diff --git a/src/install/install.zig b/src/install/install.zig
index 810bc4afb..0d0b8243d 100644
--- a/src/install/install.zig
+++ b/src/install/install.zig
@@ -340,6 +340,7 @@ const NetworkTask = struct {
this.package_manager.httpProxy(url),
null,
null,
+ HTTP.FetchRedirect.follow,
);
this.callback = .{
.package_manifest = .{
@@ -417,6 +418,7 @@ const NetworkTask = struct {
this.package_manager.httpProxy(url),
null,
null,
+ HTTP.FetchRedirect.follow,
);
this.callback = .{ .extract = tarball };
}
diff --git a/test/js/web/fetch/fetch.test.ts b/test/js/web/fetch/fetch.test.ts
index 9419629d1..c3327f37e 100644
--- a/test/js/web/fetch/fetch.test.ts
+++ b/test/js/web/fetch/fetch.test.ts
@@ -343,6 +343,27 @@ describe("fetch", () => {
expect(response.redirected).toBe(true);
});
+ it('redirect: "error" #2819', async () => {
+ startServer({
+ fetch(req) {
+ return new Response(null, {
+ status: 302,
+ headers: {
+ Location: "https://example.com",
+ },
+ });
+ },
+ });
+ try {
+ const response = await fetch(`http://${server.hostname}:${server.port}`, {
+ redirect: "error",
+ });
+ expect(response).toBeUndefined();
+ } catch (err: any) {
+ expect(err.code).toBe("UnexpectedRedirect");
+ }
+ });
+
it("provide body", async () => {
startServer({
fetch(req) {