Skip to content

RFC: add debug functionality to test with mock server #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions Package.swift
Original file line number Diff line number Diff line change
@@ -36,12 +36,11 @@ let package = Package(
.product(name: "NIO", package: "swift-nio"),
]),
.testTarget(name: "AWSLambdaTestingTests", dependencies: ["AWSLambdaTesting"]),
// samples
.target(name: "StringSample", dependencies: ["AWSLambdaRuntime"]),
.target(name: "CodableSample", dependencies: ["AWSLambdaRuntime"]),
// perf tests
// for perf testing
.target(name: "MockServer", dependencies: [
.product(name: "NIOHTTP1", package: "swift-nio"),
]),
.target(name: "StringSample", dependencies: ["AWSLambdaRuntime"]),
.target(name: "CodableSample", dependencies: ["AWSLambdaRuntime"]),
]
)
268 changes: 268 additions & 0 deletions Sources/AWSLambdaRuntime/Lambda+LocalServer.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftAWSLambdaRuntime open source project
//
// Copyright (c) 2020 Apple Inc. and the SwiftAWSLambdaRuntime project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftAWSLambdaRuntime project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

#if DEBUG
import Dispatch
import Logging
import NIO
import NIOConcurrencyHelpers
import NIOHTTP1

// This functionality is designed for local testing hence beind a #if DEBUG flag.
// For example:
//
// try Lambda.withLocalServer {
// Lambda.run { (context: Lambda.Context, payload: String, callback: @escaping (Result<String, Error>) -> Void) in
// callback(.success("Hello, \(payload)!"))
// }
// }
extension Lambda {
/// Execute code in the context of a mock Lambda server.
///
/// - parameters:
/// - invocationEndpoint: The endpoint to post payloads to.
/// - body: Code to run within the context of the mock server. Typically this would be a Lambda.run function call.
///
/// - note: This API is designed stricly for local testing and is behind a DEBUG flag
public static func withLocalServer(invocationEndpoint: String? = nil, _ body: @escaping () -> Void) throws {
let server = LocalLambda.Server(invocationEndpoint: invocationEndpoint)
try server.start().wait()
defer { try! server.stop() } // FIXME:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀

body()
}
}

// MARK: - Local Mock Server

private enum LocalLambda {
struct Server {
Copy link
Contributor Author

@tomerd tomerd May 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should call this MockServer to make it even clearer?

private let logger: Logger
private let group: EventLoopGroup
private let host: String
private let port: Int
private let invocationEndpoint: String

public init(invocationEndpoint: String?) {
let configuration = Lambda.Configuration()
var logger = Logger(label: "LocalLambdaServer")
logger.logLevel = configuration.general.logLevel
self.logger = logger
self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1)
self.host = configuration.runtimeEngine.ip
self.port = configuration.runtimeEngine.port
self.invocationEndpoint = invocationEndpoint ?? "/invoke"
}

func start() -> EventLoopFuture<Void> {
let bootstrap = ServerBootstrap(group: group)
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.childChannelInitializer { channel in
channel.pipeline.configureHTTPServerPipeline(withErrorHandling: true).flatMap { _ in
channel.pipeline.addHandler(HTTPHandler(logger: self.logger, invocationEndpoint: self.invocationEndpoint))
}
}
return bootstrap.bind(host: self.host, port: self.port).flatMap { channel -> EventLoopFuture<Void> in
guard channel.localAddress != nil else {
return channel.eventLoop.makeFailedFuture(ServerError.cantBind)
}
self.logger.info("LocalLambdaServer started and listening on \(self.host):\(self.port), receiving payloads on \(self.invocationEndpoint)")
return channel.eventLoop.makeSucceededFuture(())
}
}

func stop() throws {
try self.group.syncShutdownGracefully()
}
}

final class HTTPHandler: ChannelInboundHandler {
public typealias InboundIn = HTTPServerRequestPart
public typealias OutboundOut = HTTPServerResponsePart

private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>()

private static var invocations = CircularBuffer<Invocation>()
private static var invocationState = InvocationState.waitingForLambdaRequest

private let logger: Logger
private let invocationEndpoint: String

init(logger: Logger, invocationEndpoint: String) {
self.logger = logger
self.invocationEndpoint = invocationEndpoint
}

func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let requestPart = unwrapInboundIn(data)

switch requestPart {
case .head(let head):
self.pending.append((head: head, body: nil))
case .body(var buffer):
var request = self.pending.removeFirst()
if request.body == nil {
request.body = buffer
} else {
request.body!.writeBuffer(&buffer)
}
self.pending.prepend(request)
case .end:
let request = self.pending.removeFirst()
self.processRequest(context: context, request: request)
}
}

func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) {
switch (request.head.method, request.head.uri) {
// this endpoint is called by the client invoking the lambda
case (.POST, let url) where url.hasSuffix(self.invocationEndpoint):
guard let work = request.body else {
return self.writeResponse(context: context, response: .init(status: .badRequest))
}
let requestID = "\(DispatchTime.now().uptimeNanoseconds)" // FIXME:
let promise = context.eventLoop.makePromise(of: Response.self)
promise.futureResult.whenComplete { result in
switch result {
case .failure(let error):
self.logger.error("invocation error: \(error)")
self.writeResponse(context: context, response: .init(status: .internalServerError))
case .success(let response):
self.writeResponse(context: context, response: response)
}
}
let invocation = Invocation(requestID: requestID, request: work, responsePromise: promise)
switch Self.invocationState {
case .waitingForInvocation(let promise):
promise.succeed(invocation)
case .waitingForLambdaRequest, .waitingForLambdaResponse:
Self.invocations.append(invocation)
}
// /next endpoint is called by the lambda polling for work
case (.GET, let url) where url.hasSuffix(Consts.requestWorkURLSuffix):
// check if our server is in the correct state
guard case .waitingForLambdaRequest = Self.invocationState else {
self.logger.error("invalid invocation state \(Self.invocationState)")
self.writeResponse(context: context, response: .init(status: .unprocessableEntity))
return
}

// pop the first task from the queue
switch Self.invocations.popFirst() {
case .none:
// if there is nothing in the queue,
// create a promise that we can fullfill when we get a new task
let promise = context.eventLoop.makePromise(of: Invocation.self)
promise.futureResult.whenComplete { result in
switch result {
case .failure(let error):
self.logger.error("invocation error: \(error)")
self.writeResponse(context: context, response: .init(status: .internalServerError))
case .success(let invocation):
Self.invocationState = .waitingForLambdaResponse(invocation)
self.writeResponse(context: context, response: invocation.makeResponse())
}
}
Self.invocationState = .waitingForInvocation(promise)
case .some(let invocation):
// if there is a task pending, we can immediatly respond with it.
Self.invocationState = .waitingForLambdaResponse(invocation)
self.writeResponse(context: context, response: invocation.makeResponse())
}
// :requestID/response endpoint is called by the lambda posting the response
case (.POST, let url) where url.hasSuffix(Consts.postResponseURLSuffix):
let parts = request.head.uri.split(separator: "/")
guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else {
// the request is malformed, since we were expecting a requestId in the path
return self.writeResponse(context: context, response: .init(status: .badRequest))
}
guard case .waitingForLambdaResponse(let invocation) = Self.invocationState else {
// a response was send, but we did not expect to receive one
self.logger.error("invalid invocation state \(Self.invocationState)")
return self.writeResponse(context: context, response: .init(status: .unprocessableEntity))
}
guard requestID == invocation.requestID else {
// the request's requestId is not matching the one we are expecting
self.logger.error("invalid invocation state request ID \(requestID) does not match expected \(invocation.requestID)")
return self.writeResponse(context: context, response: .init(status: .badRequest))
}

invocation.responsePromise.succeed(.init(status: .ok, body: request.body))
self.writeResponse(context: context, response: .init(status: .accepted))
Self.invocationState = .waitingForLambdaRequest
// unknown call
default:
self.writeResponse(context: context, response: .init(status: .notFound))
}
}

func writeResponse(context: ChannelHandlerContext, response: Response) {
var headers = HTTPHeaders(response.headers ?? [])
headers.add(name: "content-length", value: "\(response.body?.readableBytes ?? 0)")
let head = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), status: response.status, headers: headers)

context.write(wrapOutboundOut(.head(head))).whenFailure { error in
self.logger.error("\(self) write error \(error)")
}

if let buffer = response.body {
context.write(wrapOutboundOut(.body(.byteBuffer(buffer)))).whenFailure { error in
self.logger.error("\(self) write error \(error)")
}
}

context.writeAndFlush(wrapOutboundOut(.end(nil))).whenComplete { result in
if case .failure(let error) = result {
self.logger.error("\(self) write error \(error)")
}
}
}

struct Response {
var status: HTTPResponseStatus = .ok
var headers: [(String, String)]?
var body: ByteBuffer?
}

struct Invocation {
let requestID: String
let request: ByteBuffer
let responsePromise: EventLoopPromise<Response>

func makeResponse() -> Response {
var response = Response()
response.body = self.request
// required headers
response.headers = [
(AmazonHeaders.requestID, self.requestID),
(AmazonHeaders.invokedFunctionARN, "arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime"),
(AmazonHeaders.traceID, "Root=\(Int16.random(in: Int16.min ... Int16.max));Parent=\(Int16.random(in: Int16.min ... Int16.max));Sampled=1"),
(AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"),
]
return response
}
}

enum InvocationState {
case waitingForInvocation(EventLoopPromise<Invocation>)
case waitingForLambdaRequest
case waitingForLambdaResponse(Invocation)
}
}

enum ServerError: Error {
case notReady
case cantBind
}
}
#endif
6 changes: 5 additions & 1 deletion Sources/AWSLambdaRuntime/LambdaContext.swift
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@ import NIO
extension Lambda {
/// Lambda runtime context.
/// The Lambda runtime generates and passes the `Context` to the Lambda handler as an argument.
public final class Context {
public final class Context: CustomDebugStringConvertible {
/// The request ID, which identifies the request that triggered the function invocation.
public let requestId: String

@@ -85,5 +85,9 @@ extension Lambda {
let remaining = deadline - now
return .milliseconds(remaining)
}

public var debugDescription: String {
"\(Self.self)(requestId: \(self.requestId), traceId: \(self.traceId), invokedFunctionArn: \(self.invokedFunctionArn), cognitoIdentity: \(self.cognitoIdentity ?? "nil"), clientContext: \(self.clientContext ?? "nil"), deadline: \(self.deadline))"
}
}
}
20 changes: 19 additions & 1 deletion Sources/AWSLambdaTesting/Lambda+Testing.swift
Original file line number Diff line number Diff line change
@@ -12,9 +12,27 @@
//
//===----------------------------------------------------------------------===//

// this is designed to only work for testing
// This functionality is designed to help with Lambda unit testing with XCTest
// #if filter required for release builds which do not support @testable import
// @testable is used to access of internal functions
// For exmaple:
//
// func test() {
// struct MyLambda: EventLoopLambdaHandler {
// typealias In = String
// typealias Out = String
//
// func handle(context: Lambda.Context, payload: String) -> EventLoopFuture<String> {
// return context.eventLoop.makeSucceededFuture("echo" + payload)
// }
// }
//
// let input = UUID().uuidString
// var result: String?
// XCTAssertNoThrow(result = try Lambda.test(MyLambda(), with: input))
// XCTAssertEqual(result, "echo" + input)
// }

#if DEBUG
@testable import AWSLambdaRuntime
import Dispatch
2 changes: 1 addition & 1 deletion docker/docker-compose.1804.53.yaml
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@ services:
image: swift-aws-lambda:18.04-5.3
build:
args:
base_image: "swiftlang/swift:nightly-bionic"
base_image: "swiftlang/swift:nightly-5.3-bionic"

test:
image: swift-aws-lambda:18.04-5.3