Skip to content

Commit df87a86

Browse files
Lukasadnadoba
andauthored
Limit max recursion depth delivering body parts (#611)
Motivation When receiving certain patterns of response body parts, we can end up recursing almost indefinitely to deliver them to the application. This can lead to crashes, so we might politely describe it as "sub-optimal". Modifications Keep track of our stack depth and avoid creating too many stack frames. Added some unit tests. Result We no longer explode when handling bodies with lots of tiny parts. Co-authored-by: David Nadoba <[email protected]>
1 parent 46d1c76 commit df87a86

File tree

6 files changed

+81
-3
lines changed

6 files changed

+81
-3
lines changed

Sources/AsyncHTTPClient/RequestBag.swift

+38-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ import NIOHTTP1
1919
import NIOSSL
2020

2121
final class RequestBag<Delegate: HTTPClientResponseDelegate> {
22+
/// Defends against the call stack getting too large when consuming body parts.
23+
///
24+
/// If the response body comes in lots of tiny chunks, we'll deliver those tiny chunks to users
25+
/// one at a time.
26+
private static var maxConsumeBodyPartStackDepth: Int {
27+
50
28+
}
29+
2230
let task: HTTPClient.Task<Delegate.Response>
2331
var eventLoop: EventLoop {
2432
self.task.eventLoop
@@ -30,6 +38,9 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
3038
// the request state is synchronized on the task eventLoop
3139
private var state: StateMachine
3240

41+
// the consume body part stack depth is synchronized on the task event loop.
42+
private var consumeBodyPartStackDepth: Int
43+
3344
// MARK: HTTPClientTask properties
3445

3546
var logger: Logger {
@@ -55,6 +66,7 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
5566
self.eventLoopPreference = eventLoopPreference
5667
self.task = task
5768
self.state = .init(redirectHandler: redirectHandler)
69+
self.consumeBodyPartStackDepth = 0
5870
self.request = request
5971
self.connectionDeadline = connectionDeadline
6072
self.requestOptions = requestOptions
@@ -290,16 +302,39 @@ final class RequestBag<Delegate: HTTPClientResponseDelegate> {
290302
private func consumeMoreBodyData0(resultOfPreviousConsume result: Result<Void, Error>) {
291303
self.task.eventLoop.assertInEventLoop()
292304

305+
// We get defensive here about the maximum stack depth. It's possible for the `didReceiveBodyPart`
306+
// future to be returned to us completed. If it is, we will recurse back into this method. To
307+
// break that recursion we have a max stack depth which we increment and decrement in this method:
308+
// if it gets too large, instead of recurring we'll insert an `eventLoop.execute`, which will
309+
// manually break the recursion and unwind the stack.
310+
//
311+
// Note that we don't bother starting this at the various other call sites that _begin_ stacks
312+
// that risk ending up in this loop. That's because we don't need an accurate count: our limit is
313+
// a best-effort target anyway, one stack frame here or there does not put us at risk. We're just
314+
// trying to prevent ourselves looping out of control.
315+
self.consumeBodyPartStackDepth += 1
316+
defer {
317+
self.consumeBodyPartStackDepth -= 1
318+
assert(self.consumeBodyPartStackDepth >= 0)
319+
}
320+
293321
let consumptionAction = self.state.consumeMoreBodyData(resultOfPreviousConsume: result)
294322

295323
switch consumptionAction {
296324
case .consume(let byteBuffer):
297325
self.delegate.didReceiveBodyPart(task: self.task, byteBuffer)
298326
.hop(to: self.task.eventLoop)
299-
.whenComplete {
300-
switch $0 {
327+
.whenComplete { result in
328+
switch result {
301329
case .success:
302-
self.consumeMoreBodyData0(resultOfPreviousConsume: $0)
330+
if self.consumeBodyPartStackDepth < Self.maxConsumeBodyPartStackDepth {
331+
self.consumeMoreBodyData0(resultOfPreviousConsume: result)
332+
} else {
333+
// We need to unwind the stack, let's take a break.
334+
self.task.eventLoop.execute {
335+
self.consumeMoreBodyData0(resultOfPreviousConsume: result)
336+
}
337+
}
303338
case .failure(let error):
304339
self.fail(error)
305340
}

Tests/AsyncHTTPClientTests/HTTP2ClientTests+XCTest.swift

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ extension HTTP2ClientTests {
3737
("testH2CanHandleRequestsThatHaveAlreadyHitTheDeadline", testH2CanHandleRequestsThatHaveAlreadyHitTheDeadline),
3838
("testStressCancelingRunningRequestFromDifferentThreads", testStressCancelingRunningRequestFromDifferentThreads),
3939
("testPlatformConnectErrorIsForwardedOnTimeout", testPlatformConnectErrorIsForwardedOnTimeout),
40+
("testMassiveDownload", testMassiveDownload),
4041
]
4142
}
4243
}

Tests/AsyncHTTPClientTests/HTTP2ClientTests.swift

+13
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,19 @@ class HTTP2ClientTests: XCTestCase {
432432
)
433433
}
434434
}
435+
436+
func testMassiveDownload() {
437+
let bin = HTTPBin(.http2(compress: false))
438+
defer { XCTAssertNoThrow(try bin.shutdown()) }
439+
let client = self.makeDefaultHTTPClient()
440+
defer { XCTAssertNoThrow(try client.syncShutdown()) }
441+
var response: HTTPClient.Response?
442+
XCTAssertNoThrow(response = try client.get(url: "https://localhost:\(bin.port)/mega-chunked").wait())
443+
444+
XCTAssertEqual(.ok, response?.status)
445+
XCTAssertEqual(response?.version, .http2)
446+
XCTAssertEqual(response?.body?.readableBytes, 10_000)
447+
}
435448
}
436449

437450
private final class HeadReceivedCallback: HTTPClientResponseDelegate {

Tests/AsyncHTTPClientTests/HTTPClientTestUtils.swift

+19
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,22 @@ internal final class HTTPBinHandler: ChannelInboundHandler {
745745
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
746746
}
747747

748+
func writeManyChunks(context: ChannelHandlerContext) {
749+
// This tests receiving a lot of tiny chunks: they must all be sent in a single flush or the test doesn't work.
750+
let headers = HTTPHeaders([("Transfer-Encoding", "chunked")])
751+
752+
context.write(self.wrapOutboundOut(.head(HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: .ok, headers: headers))), promise: nil)
753+
let message = ByteBuffer(integer: UInt8(ascii: "a"))
754+
755+
// This number (10k) is load-bearing and a bit magic: it has been experimentally verified as being sufficient to blow the stack
756+
// in the old implementation on all testing platforms. Please don't change it without good reason.
757+
for _ in 0..<10_000 {
758+
context.write(wrapOutboundOut(.body(.byteBuffer(message))), promise: nil)
759+
}
760+
761+
context.writeAndFlush(self.wrapOutboundOut(.end(nil)), promise: nil)
762+
}
763+
748764
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
749765
self.isServingRequest = true
750766
switch self.unwrapInboundIn(data) {
@@ -863,6 +879,9 @@ internal final class HTTPBinHandler: ChannelInboundHandler {
863879
case "/chunked":
864880
self.writeChunked(context: context)
865881
return
882+
case "/mega-chunked":
883+
self.writeManyChunks(context: context)
884+
return
866885
case "/close-on-response":
867886
var headers = self.responseHeaders
868887
headers.replaceOrAdd(name: "connection", value: "close")

Tests/AsyncHTTPClientTests/HTTPClientTests+XCTest.swift

+1
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ extension HTTPClientTests {
142142
("testRequestSpecificTLS", testRequestSpecificTLS),
143143
("testConnectionPoolSizeConfigValueIsRespected", testConnectionPoolSizeConfigValueIsRespected),
144144
("testRequestWithHeaderTransferEncodingIdentityDoesNotFail", testRequestWithHeaderTransferEncodingIdentityDoesNotFail),
145+
("testMassiveDownload", testMassiveDownload),
145146
]
146147
}
147148
}

Tests/AsyncHTTPClientTests/HTTPClientTests.swift

+9
Original file line numberDiff line numberDiff line change
@@ -3454,4 +3454,13 @@ class HTTPClientTests: XCTestCase {
34543454

34553455
XCTAssertNoThrow(try client.execute(request: request).wait())
34563456
}
3457+
3458+
func testMassiveDownload() {
3459+
var response: HTTPClient.Response?
3460+
XCTAssertNoThrow(response = try self.defaultClient.get(url: "\(self.defaultHTTPBinURLPrefix)mega-chunked").wait())
3461+
3462+
XCTAssertEqual(.ok, response?.status)
3463+
XCTAssertEqual(response?.version, .http1_1)
3464+
XCTAssertEqual(response?.body?.readableBytes, 10_000)
3465+
}
34573466
}

0 commit comments

Comments
 (0)