From ec68898de69ba5290ce35025859ab2f8392ffcdd Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Sun, 12 Jan 2025 10:20:11 +0000 Subject: [PATCH 1/4] Add cancellation handling for nextInvocation() --- Sources/AWSLambdaRuntimeCore/Lambda.swift | 38 ++++++++------- .../LambdaRuntimeClient.swift | 46 +++++++++++++++++-- .../LambdaRuntimeClientTests.swift | 20 ++++++++ 3 files changed, 84 insertions(+), 20 deletions(-) diff --git a/Sources/AWSLambdaRuntimeCore/Lambda.swift b/Sources/AWSLambdaRuntimeCore/Lambda.swift index 3ba90e9c..4634fca0 100644 --- a/Sources/AWSLambdaRuntimeCore/Lambda.swift +++ b/Sources/AWSLambdaRuntimeCore/Lambda.swift @@ -37,25 +37,31 @@ public enum Lambda { ) async throws where Handler: StreamingLambdaHandler { var handler = handler - while !Task.isCancelled { - let (invocation, writer) = try await runtimeClient.nextInvocation() + do { + while !Task.isCancelled { + let (invocation, writer) = try await runtimeClient.nextInvocation() - do { - try await handler.handle( - invocation.event, - responseWriter: writer, - context: LambdaContext( - requestID: invocation.metadata.requestID, - traceID: invocation.metadata.traceID, - invokedFunctionARN: invocation.metadata.invokedFunctionARN, - deadline: DispatchWallTime(millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch), - logger: logger + do { + try await handler.handle( + invocation.event, + responseWriter: writer, + context: LambdaContext( + requestID: invocation.metadata.requestID, + traceID: invocation.metadata.traceID, + invokedFunctionARN: invocation.metadata.invokedFunctionARN, + deadline: DispatchWallTime( + millisSinceEpoch: invocation.metadata.deadlineInMillisSinceEpoch + ), + logger: logger + ) ) - ) - } catch { - try await writer.reportError(error) - continue + } catch { + try await writer.reportError(error) + continue + } } + } catch is CancellationError { + // don't allow cancellation error to propagate further } } diff --git a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift index 228dc471..a264f1dd 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift @@ -410,6 +410,17 @@ private protocol LambdaChannelHandlerDelegate { func connectionErrorHappened(_ error: any Error, channel: any Channel) } +struct UnsafeContext: @unchecked Sendable { + private let _context: ChannelHandlerContext + var context: ChannelHandlerContext { + self._context.eventLoop.preconditionInEventLoop() + return _context + } + init(_ context: ChannelHandlerContext) { + self._context = context + } +} + private final class LambdaChannelHandler { let nextInvocationPath = Consts.invocationURLPrefix + Consts.getNextInvocationURLSuffix @@ -469,10 +480,37 @@ private final class LambdaChannelHandler func nextInvocation(isolation: isolated (any Actor)? = #isolation) async throws -> Invocation { switch self.state { case .connected(let context, .idle): - return try await withCheckedThrowingContinuation { - (continuation: CheckedContinuation) in - self.state = .connected(context, .waitingForNextInvocation(continuation)) - self.sendNextRequest(context: context) + return try await withTaskCancellationHandler { + try Task.checkCancellation() + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in + self.state = .connected(context, .waitingForNextInvocation(continuation)) + + let unsafeContext = UnsafeContext(context) + context.eventLoop.execute { [nextInvocationPath, defaultHeaders] in + // Send next request. The function `sendNextRequest` requires `self` which is not + // Sendable so just inlined the code instead + let httpRequest = HTTPRequestHead( + version: .http1_1, + method: .GET, + uri: nextInvocationPath, + headers: defaultHeaders + ) + let context = unsafeContext.context + context.write(Self.wrapOutboundOut(.head(httpRequest)), promise: nil) + context.write(Self.wrapOutboundOut(.end(nil)), promise: nil) + context.flush() + } + } + } onCancel: { + switch self.state { + case .connected(_, .waitingForNextInvocation(let continuation)): + continuation.resume(throwing: CancellationError()) + case .connected(_, .idle): + break + default: + fatalError("Invalid state: \(self.state)") + } } case .connected(_, .sendingResponse), diff --git a/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift b/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift index e779b931..c9679c6e 100644 --- a/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift +++ b/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift @@ -86,4 +86,24 @@ struct LambdaRuntimeClientTests { } } } + + @Test + func testCancellation() async throws { + try await LambdaRuntimeClient.withRuntimeClient( + configuration: .init(ip: "127.0.0.1", port: 7000), + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + while true { + _ = try await runtimeClient.nextInvocation() + } + } + // wait a small amount to ensure we are waiting for continuation + try await Task.sleep(for: .milliseconds(100)) + group.cancelAll() + } + } + } } From 467d0124561f4e0593329d8d18d48f1508594df5 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 21 Feb 2025 08:16:41 +0000 Subject: [PATCH 2/4] Use NIOLoopBound --- .../LambdaRuntimeClient.swift | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift index a264f1dd..495627ca 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift @@ -410,17 +410,6 @@ private protocol LambdaChannelHandlerDelegate { func connectionErrorHappened(_ error: any Error, channel: any Channel) } -struct UnsafeContext: @unchecked Sendable { - private let _context: ChannelHandlerContext - var context: ChannelHandlerContext { - self._context.eventLoop.preconditionInEventLoop() - return _context - } - init(_ context: ChannelHandlerContext) { - self._context = context - } -} - private final class LambdaChannelHandler { let nextInvocationPath = Consts.invocationURLPrefix + Consts.getNextInvocationURLSuffix @@ -486,7 +475,7 @@ private final class LambdaChannelHandler (continuation: CheckedContinuation) in self.state = .connected(context, .waitingForNextInvocation(continuation)) - let unsafeContext = UnsafeContext(context) + let unsafeContext = NIOLoopBound(context, eventLoop: context.eventLoop) context.eventLoop.execute { [nextInvocationPath, defaultHeaders] in // Send next request. The function `sendNextRequest` requires `self` which is not // Sendable so just inlined the code instead @@ -496,7 +485,7 @@ private final class LambdaChannelHandler uri: nextInvocationPath, headers: defaultHeaders ) - let context = unsafeContext.context + let context = unsafeContext.value context.write(Self.wrapOutboundOut(.head(httpRequest)), promise: nil) context.write(Self.wrapOutboundOut(.end(nil)), promise: nil) context.flush() From 27ac22da9be4904792c67b150208558d22ceb814 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 21 Feb 2025 09:21:34 +0000 Subject: [PATCH 3/4] Move cancellation to LambdaRuntimeClient.nextInvocation --- .../LambdaRuntimeClient.swift | 75 ++++++++----------- .../LambdaRuntimeClientTests.swift | 56 +++++++++++--- 2 files changed, 74 insertions(+), 57 deletions(-) diff --git a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift index 495627ca..809a6a0e 100644 --- a/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift +++ b/Sources/AWSLambdaRuntimeCore/LambdaRuntimeClient.swift @@ -145,22 +145,28 @@ final actor LambdaRuntimeClient: LambdaRuntimeClientProtocol { } func nextInvocation() async throws -> (Invocation, Writer) { - switch self.lambdaState { - case .idle: - self.lambdaState = .waitingForNextInvocation - let handler = try await self.makeOrGetConnection() - let invocation = try await handler.nextInvocation() - guard case .waitingForNextInvocation = self.lambdaState else { + try await withTaskCancellationHandler { + switch self.lambdaState { + case .idle: + self.lambdaState = .waitingForNextInvocation + let handler = try await self.makeOrGetConnection() + let invocation = try await handler.nextInvocation() + guard case .waitingForNextInvocation = self.lambdaState else { + fatalError("Invalid state: \(self.lambdaState)") + } + self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID) + return (invocation, Writer(runtimeClient: self)) + + case .waitingForNextInvocation, + .waitingForResponse, + .sendingResponse, + .sentResponse: fatalError("Invalid state: \(self.lambdaState)") } - self.lambdaState = .waitingForResponse(requestID: invocation.metadata.requestID) - return (invocation, Writer(runtimeClient: self)) - - case .waitingForNextInvocation, - .waitingForResponse, - .sendingResponse, - .sentResponse: - fatalError("Invalid state: \(self.lambdaState)") + } onCancel: { + Task { + await self.close() + } } } @@ -469,37 +475,10 @@ private final class LambdaChannelHandler func nextInvocation(isolation: isolated (any Actor)? = #isolation) async throws -> Invocation { switch self.state { case .connected(let context, .idle): - return try await withTaskCancellationHandler { - try Task.checkCancellation() - return try await withCheckedThrowingContinuation { - (continuation: CheckedContinuation) in - self.state = .connected(context, .waitingForNextInvocation(continuation)) - - let unsafeContext = NIOLoopBound(context, eventLoop: context.eventLoop) - context.eventLoop.execute { [nextInvocationPath, defaultHeaders] in - // Send next request. The function `sendNextRequest` requires `self` which is not - // Sendable so just inlined the code instead - let httpRequest = HTTPRequestHead( - version: .http1_1, - method: .GET, - uri: nextInvocationPath, - headers: defaultHeaders - ) - let context = unsafeContext.value - context.write(Self.wrapOutboundOut(.head(httpRequest)), promise: nil) - context.write(Self.wrapOutboundOut(.end(nil)), promise: nil) - context.flush() - } - } - } onCancel: { - switch self.state { - case .connected(_, .waitingForNextInvocation(let continuation)): - continuation.resume(throwing: CancellationError()) - case .connected(_, .idle): - break - default: - fatalError("Invalid state: \(self.state)") - } + return try await withCheckedThrowingContinuation { + (continuation: CheckedContinuation) in + self.state = .connected(context, .waitingForNextInvocation(continuation)) + self.sendNextRequest(context: context) } case .connected(_, .sendingResponse), @@ -846,6 +825,12 @@ extension LambdaChannelHandler: ChannelInboundHandler { func channelInactive(context: ChannelHandlerContext) { // fail any pending responses with last error or assume peer disconnected + switch self.state { + case .connected(_, .waitingForNextInvocation(let continuation)): + continuation.resume(throwing: self.lastError ?? ChannelError.ioOnClosedChannel) + default: + break + } // we don't need to forward channelInactive to the delegate, as the delegate observes the // closeFuture diff --git a/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift b/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift index c9679c6e..623c04f4 100644 --- a/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift +++ b/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift @@ -89,20 +89,52 @@ struct LambdaRuntimeClientTests { @Test func testCancellation() async throws { - try await LambdaRuntimeClient.withRuntimeClient( - configuration: .init(ip: "127.0.0.1", port: 7000), - eventLoop: NIOSingletons.posixEventLoopGroup.next(), - logger: self.logger - ) { runtimeClient in - try await withThrowingTaskGroup(of: Void.self) { group in - group.addTask { - while true { - _ = try await runtimeClient.nextInvocation() + struct HappyBehavior: LambdaServerBehavior { + let requestId = UUID().uuidString + let event = "hello" + + func getInvocation() -> GetInvocationResult { + .success((self.requestId, self.event)) + } + + func processResponse(requestId: String, response: String?) -> Result { + #expect(self.requestId == requestId) + #expect(self.event == response) + return .success(()) + } + + func processError(requestId: String, error: ErrorResponse) -> Result { + Issue.record("should not report error") + return .failure(.internalServerError) + } + + func processInitError(error: ErrorResponse) -> Result { + Issue.record("should not report init error") + return .failure(.internalServerError) + } + } + + try await withMockServer(behaviour: HappyBehavior()) { port in + try await LambdaRuntimeClient.withRuntimeClient( + configuration: .init(ip: "127.0.0.1", port: port), + eventLoop: NIOSingletons.posixEventLoopGroup.next(), + logger: self.logger + ) { runtimeClient in + try await withThrowingTaskGroup(of: Void.self) { group in + group.addTask { + while true { + print("Waiting") + let (_, writer) = try await runtimeClient.nextInvocation() + try await Task { + try await writer.write(ByteBuffer(string: "hello")) + try await writer.finish() + }.value + } } + // wait a small amount to ensure we are waiting for continuation + try await Task.sleep(for: .milliseconds(100)) + group.cancelAll() } - // wait a small amount to ensure we are waiting for continuation - try await Task.sleep(for: .milliseconds(100)) - group.cancelAll() } } } From a1ee6c1135867a6f14ece1f020f384fb09677628 Mon Sep 17 00:00:00 2001 From: Adam Fowler Date: Fri, 21 Feb 2025 09:28:18 +0000 Subject: [PATCH 4/4] Add comment to test --- Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift b/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift index 623c04f4..5d430a0f 100644 --- a/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift +++ b/Tests/AWSLambdaRuntimeCoreTests/LambdaRuntimeClientTests.swift @@ -123,8 +123,8 @@ struct LambdaRuntimeClientTests { try await withThrowingTaskGroup(of: Void.self) { group in group.addTask { while true { - print("Waiting") let (_, writer) = try await runtimeClient.nextInvocation() + // Wrap this is a task so cancellation isn't propagated to the write calls try await Task { try await writer.write(ByteBuffer(string: "hello")) try await writer.finish()