diff options
| -rw-r--r-- | src/bun.js/webcore/response.zig | 14 | ||||
| -rw-r--r-- | src/http_client_async.zig | 9 | ||||
| -rw-r--r-- | test/bun.js/fetch.test.js | 42 | 
3 files changed, 64 insertions, 1 deletions
| diff --git a/src/bun.js/webcore/response.zig b/src/bun.js/webcore/response.zig index 6744df60f..d68d4b6d2 100644 --- a/src/bun.js/webcore/response.zig +++ b/src/bun.js/webcore/response.zig @@ -732,6 +732,11 @@ pub const Fetch = struct {                      fetch_tasklet,                  ),              ); + +            if (!fetch_options.follow_redirects) { +                fetch_tasklet.http.?.client.remaining_redirect_count = 0; +            } +              fetch_tasklet.http.?.client.disable_timeout = fetch_options.disable_timeout;              fetch_tasklet.http.?.client.verbose = fetch_options.verbose;              fetch_tasklet.http.?.client.disable_keepalive = fetch_options.disable_keepalive; @@ -747,6 +752,7 @@ pub const Fetch = struct {              disable_keepalive: bool,              url: ZigURL,              verbose: bool = false, +            follow_redirects: bool = true,          };          pub fn queue( @@ -807,6 +813,7 @@ pub const Fetch = struct {          var disable_timeout = false;          var disable_keepalive = false;          var verbose = false; +        var follow_redirects = true;          if (first_arg.as(Request)) |request| {              url = ZigURL.parse(getAllocator(ctx).dupe(u8, request.url) catch unreachable);              method = request.method; @@ -865,6 +872,12 @@ pub const Fetch = struct {                          }                      } +                    if (options.get(ctx, "redirect")) |redirect_value| { +                        if (redirect_value.getZigString(globalThis).eqlComptime("manual")) { +                            follow_redirects = false; +                        } +                    } +                      if (options.get(ctx, "keepalive")) |keepalive_value| {                          if (keepalive_value.isBoolean()) {                              disable_keepalive = !keepalive_value.asBoolean(); @@ -906,6 +919,7 @@ pub const Fetch = struct {                  .timeout = std.time.ns_per_hour,                  .disable_keepalive = disable_keepalive,                  .disable_timeout = disable_timeout, +                .follow_redirects = follow_redirects,                  .verbose = verbose,              },              JSC.JSValue.fromRef(deferred_promise), diff --git a/src/http_client_async.zig b/src/http_client_async.zig index 5f6c68cf0..11150ef50 100644 --- a/src/http_client_async.zig +++ b/src/http_client_async.zig @@ -750,6 +750,7 @@ allocator: std.mem.Allocator,  verbose: bool = Environment.isTest,  remaining_redirect_count: i8 = default_redirect_count,  allow_retry: bool = false, +follow_redirects: bool = true,  redirect: ?*URLBufferPool.Node = null,  timeout: usize = 0,  progress_node: ?*std.Progress.Node = null, @@ -1208,6 +1209,7 @@ pub fn buildRequest(this: *HTTPClient, body_len: usize) picohttp.Request {  pub fn doRedirect(this: *HTTPClient) void {      var body_out_str = this.state.body_out_str.?;      this.remaining_redirect_count -|= 1; +    std.debug.assert(this.follow_redirects);      if (this.remaining_redirect_count == 0) {          this.fail(error.TooManyRedirects); @@ -1963,7 +1965,9 @@ pub fn handleResponseMetadata(          this.state.pending_response.status_code = 304;      } -    if (location.len > 0 and this.remaining_redirect_count > 0) { +    const is_redirect = this.state.pending_response.status_code >= 300 and this.state.pending_response.status_code <= 399; + +    if (location.len > 0 and this.remaining_redirect_count > 0 and this.follow_redirects) {          switch (this.state.pending_response.status_code) {              302, 301, 307, 308, 303 => {                  if (strings.indexOf(location, "://")) |i| { @@ -2058,6 +2062,9 @@ pub fn handleResponseMetadata(      this.state.response_stage = if (this.state.transfer_encoding == .chunked) .body_chunk else .body; +    if (is_redirect and !this.follow_redirects) +        return true; +      return this.method.hasBody() and (this.state.body_size > 0 or this.state.transfer_encoding == .chunked);  } diff --git a/test/bun.js/fetch.test.js b/test/bun.js/fetch.test.js index a703c955a..519d6bbdd 100644 --- a/test/bun.js/fetch.test.js +++ b/test/bun.js/fetch.test.js @@ -60,6 +60,48 @@ describe("fetch", () => {        expect(exampleFixture).toBe(text);      });    } + +  it(`"redirect: "manual"`, async () => { +    const server = Bun.serve({ +      port: 4082, +      fetch(req) { +        return new Response(null, { +          status: 302, +          headers: { +            Location: "https://example.com", +          }, +        }); +      }, +    }); +    const response = await fetch(`http://${server.hostname}:${server.port}`, { +      redirect: "manual", +    }); +    expect(response.status).toBe(302); +    expect(response.headers.get("location")).toBe("https://example.com"); +    expect(response.redirected).toBe(true); +    server.stop(); +  }); + +  it(`"redirect: "follow"`, async () => { +    const server = Bun.serve({ +      port: 4083, +      fetch(req) { +        return new Response(null, { +          status: 302, +          headers: { +            Location: "https://example.com", +          }, +        }); +      }, +    }); +    const response = await fetch(`http://${server.hostname}:${server.port}`, { +      redirect: "follow", +    }); +    expect(response.status).toBe(200); +    expect(response.headers.get("location")).toBe(null); +    expect(response.redirected).toBe(true); +    server.stop(); +  });  });  it("simultaneous HTTPS fetch", async () => { | 
