|
| 1 | +//===----------------------------------------------------------------------===// |
| 2 | +// |
| 3 | +// This source file is part of the SwiftAWSLambdaRuntime open source project |
| 4 | +// |
| 5 | +// Copyright (c) 2020 Apple Inc. and the SwiftAWSLambdaRuntime project authors |
| 6 | +// Licensed under Apache License v2.0 |
| 7 | +// |
| 8 | +// See LICENSE.txt for license information |
| 9 | +// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors |
| 10 | +// |
| 11 | +// SPDX-License-Identifier: Apache-2.0 |
| 12 | +// |
| 13 | +//===----------------------------------------------------------------------===// |
| 14 | + |
| 15 | +import Dispatch |
| 16 | +import Logging |
| 17 | +import NIO |
| 18 | +import NIOConcurrencyHelpers |
| 19 | +import NIOHTTP1 |
| 20 | + |
| 21 | +// This functionality is designed for local testing hence beind a #if DEBUG flag. |
| 22 | +// For example: |
| 23 | +// |
| 24 | +// try Lambda.withLocalServer { |
| 25 | +// Lambda.run { (context: Lambda.Context, payload: String, callback: @escaping (Result<String, Error>) -> Void) in |
| 26 | +// callback(.success("Hello, \(payload)!")) |
| 27 | +// } |
| 28 | +// } |
| 29 | + |
| 30 | +#if DEBUG |
| 31 | +extension Lambda { |
| 32 | + /// Execute code in the context of a mock Lambda server. |
| 33 | + /// |
| 34 | + /// - parameters: |
| 35 | + /// - invocationEndpoint: The endpoint to post payloads to. |
| 36 | + /// - body: Code to run within the context of the mock server. Typically this would be a Lambda.run function call. |
| 37 | + /// |
| 38 | + /// - note: This API is designed stricly for local testing and is behind a DEBUG flag |
| 39 | + public static func withLocalServer(invocationEndpoint: String? = nil, _ body: @escaping () -> Void) throws { |
| 40 | + let server = LocalLambda.Server(invocationEndpoint: invocationEndpoint) |
| 41 | + try server.start().wait() |
| 42 | + defer { try! server.stop() } // FIXME: |
| 43 | + body() |
| 44 | + } |
| 45 | +} |
| 46 | + |
| 47 | +// MARK: - Local Mock Server |
| 48 | + |
| 49 | +private enum LocalLambda { |
| 50 | + struct Server { |
| 51 | + private let logger: Logger |
| 52 | + private let group: EventLoopGroup |
| 53 | + private let host: String |
| 54 | + private let port: Int |
| 55 | + private let invocationEndpoint: String |
| 56 | + |
| 57 | + public init(invocationEndpoint: String?) { |
| 58 | + let configuration = Lambda.Configuration() |
| 59 | + var logger = Logger(label: "LocalLambdaServer") |
| 60 | + logger.logLevel = configuration.general.logLevel |
| 61 | + self.logger = logger |
| 62 | + self.group = MultiThreadedEventLoopGroup(numberOfThreads: System.coreCount) |
| 63 | + self.host = configuration.runtimeEngine.ip |
| 64 | + self.port = configuration.runtimeEngine.port |
| 65 | + self.invocationEndpoint = invocationEndpoint ?? "/invoke" |
| 66 | + } |
| 67 | + |
| 68 | + func start() -> EventLoopFuture<Void> { |
| 69 | + let bootstrap = ServerBootstrap(group: group) |
| 70 | + .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1) |
| 71 | + .childChannelInitializer { channel in |
| 72 | + channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in |
| 73 | + channel.pipeline.addHandler(HTTPHandler(logger: self.logger, invocationEndpoint: self.invocationEndpoint)) |
| 74 | + } |
| 75 | + } |
| 76 | + return bootstrap.bind(host: self.host, port: self.port).flatMap { channel -> EventLoopFuture<Void> in |
| 77 | + guard channel.localAddress != nil else { |
| 78 | + return channel.eventLoop.makeFailedFuture(ServerError.cantBind) |
| 79 | + } |
| 80 | + self.logger.info("LocalLambdaServer started and listening on \(self.host):\(self.port), receiving payloads on \(self.invocationEndpoint)") |
| 81 | + return channel.eventLoop.makeSucceededFuture(()) |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + func stop() throws { |
| 86 | + try self.group.syncShutdownGracefully() |
| 87 | + } |
| 88 | + } |
| 89 | + |
| 90 | + final class HTTPHandler: ChannelInboundHandler { |
| 91 | + public typealias InboundIn = HTTPServerRequestPart |
| 92 | + public typealias OutboundOut = HTTPServerResponsePart |
| 93 | + |
| 94 | + private static let queueLock = Lock() |
| 95 | + private static var queue = [String: Pending]() |
| 96 | + |
| 97 | + private var processing = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>() |
| 98 | + |
| 99 | + private let logger: Logger |
| 100 | + private let invocationEndpoint: String |
| 101 | + |
| 102 | + init(logger: Logger, invocationEndpoint: String) { |
| 103 | + self.logger = logger |
| 104 | + self.invocationEndpoint = invocationEndpoint |
| 105 | + } |
| 106 | + |
| 107 | + func channelRead(context: ChannelHandlerContext, data: NIOAny) { |
| 108 | + let requestPart = unwrapInboundIn(data) |
| 109 | + |
| 110 | + switch requestPart { |
| 111 | + case .head(let head): |
| 112 | + self.processing.append((head: head, body: nil)) |
| 113 | + case .body(var buffer): |
| 114 | + var request = self.processing.removeFirst() |
| 115 | + if request.body == nil { |
| 116 | + request.body = buffer |
| 117 | + } else { |
| 118 | + request.body!.writeBuffer(&buffer) |
| 119 | + } |
| 120 | + self.processing.prepend(request) |
| 121 | + case .end: |
| 122 | + let request = self.processing.removeFirst() |
| 123 | + self.processRequest(context: context, request: request) |
| 124 | + } |
| 125 | + } |
| 126 | + |
| 127 | + func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) { |
| 128 | + if request.head.uri.hasSuffix(self.invocationEndpoint) { |
| 129 | + if let work = request.body { |
| 130 | + let requestId = "\(DispatchTime.now().uptimeNanoseconds)" // FIXME: |
| 131 | + let promise = context.eventLoop.makePromise(of: Response.self) |
| 132 | + promise.futureResult.whenComplete { result in |
| 133 | + switch result { |
| 134 | + case .success(let response): |
| 135 | + self.writeResponse(context: context, response: response) |
| 136 | + case .failure: |
| 137 | + self.writeResponse(context: context, response: .init(status: .internalServerError)) |
| 138 | + } |
| 139 | + } |
| 140 | + Self.queueLock.withLock { |
| 141 | + Self.queue[requestId] = Pending(requestId: requestId, request: work, responsePromise: promise) |
| 142 | + } |
| 143 | + } |
| 144 | + } else if request.head.uri.hasSuffix("/next") { |
| 145 | + switch (Self.queueLock.withLock { Self.queue.popFirst() }) { |
| 146 | + case .none: |
| 147 | + self.writeResponse(context: context, response: .init(status: .noContent)) |
| 148 | + case .some(let pending): |
| 149 | + var response = Response() |
| 150 | + response.body = pending.value.request |
| 151 | + // required headers |
| 152 | + response.headers = [ |
| 153 | + (AmazonHeaders.requestID, pending.key), |
| 154 | + (AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime"), |
| 155 | + (AmazonHeaders.traceID, "Root=\(Int16.random(in: Int16.min ... Int16.max));Parent=\(Int16.random(in: Int16.min ... Int16.max));Sampled=1"), |
| 156 | + (AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"), |
| 157 | + ] |
| 158 | + Self.queueLock.withLock { |
| 159 | + Self.queue[pending.key] = pending.value |
| 160 | + } |
| 161 | + self.writeResponse(context: context, response: response) |
| 162 | + } |
| 163 | + |
| 164 | + } else if request.head.uri.hasSuffix("/response") { |
| 165 | + let parts = request.head.uri.split(separator: "/") |
| 166 | + guard let requestId = parts.count > 2 ? String(parts[parts.count - 2]) : nil else { |
| 167 | + return self.writeResponse(context: context, response: .init(status: .badRequest)) |
| 168 | + } |
| 169 | + switch (Self.queueLock.withLock { Self.queue[requestId] }) { |
| 170 | + case .none: |
| 171 | + self.writeResponse(context: context, response: .init(status: .badRequest)) |
| 172 | + case .some(let pending): |
| 173 | + pending.responsePromise.succeed(.init(status: .ok, body: request.body)) |
| 174 | + self.writeResponse(context: context, response: .init(status: .accepted)) |
| 175 | + Self.queueLock.withLock { Self.queue[requestId] = nil } |
| 176 | + } |
| 177 | + } else { |
| 178 | + self.writeResponse(context: context, response: .init(status: .notFound)) |
| 179 | + } |
| 180 | + } |
| 181 | + |
| 182 | + func writeResponse(context: ChannelHandlerContext, response: Response) { |
| 183 | + var headers = HTTPHeaders(response.headers ?? []) |
| 184 | + headers.add(name: "content-length", value: "\(response.body?.readableBytes ?? 0)") |
| 185 | + let head = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: response.status, headers: headers) |
| 186 | + |
| 187 | + context.write(wrapOutboundOut(.head(head))).whenFailure { error in |
| 188 | + self.logger.error("\(self) write error \(error)") |
| 189 | + } |
| 190 | + |
| 191 | + if let buffer = response.body { |
| 192 | + context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in |
| 193 | + self.logger.error("\(self) write error \(error)") |
| 194 | + } |
| 195 | + } |
| 196 | + |
| 197 | + context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in |
| 198 | + if case .failure(let error) = result { |
| 199 | + self.logger.error("\(self) write error \(error)") |
| 200 | + } |
| 201 | + } |
| 202 | + } |
| 203 | + |
| 204 | + struct Response { |
| 205 | + var status: HTTPResponseStatus = .ok |
| 206 | + var headers: [(String, String)]? |
| 207 | + var body: ByteBuffer? |
| 208 | + } |
| 209 | + |
| 210 | + struct Pending { |
| 211 | + let requestId: String |
| 212 | + let request: ByteBuffer |
| 213 | + let responsePromise: EventLoopPromise<Response> |
| 214 | + } |
| 215 | + } |
| 216 | + |
| 217 | + enum ServerError: Error { |
| 218 | + case notReady |
| 219 | + case cantBind |
| 220 | + } |
| 221 | +} |
| 222 | +#endif |
0 commit comments