diff --git a/Package.swift b/Package.swift index eaa2dd8..7b1c06d 100644 --- a/Package.swift +++ b/Package.swift @@ -23,9 +23,14 @@ let package = Package( ], targets: [ .target( - name: "LambdaRuntime", - dependencies: ["AsyncHTTPClient", "NIO", "NIOHTTP1", "NIOFoundationCompat", "Logging", "Base64Kit"] - ), + name: "LambdaRuntime", dependencies: [ + "AsyncHTTPClient", + "NIO", + "NIOHTTP1", + "NIOFoundationCompat", + "Logging", + "Base64Kit" + ]), .target( name: "LambdaRuntimeTestUtils", dependencies: ["NIOHTTP1", "LambdaRuntime"] @@ -33,6 +38,7 @@ let package = Package( .testTarget(name: "LambdaRuntimeTests", dependencies: [ "LambdaRuntime", "LambdaRuntimeTestUtils", + "Base64Kit", "NIOTestUtils", "Logging", ]) diff --git a/Sources/LambdaRuntime/Context.swift b/Sources/LambdaRuntime/Context.swift index d100afc..49f26c6 100644 --- a/Sources/LambdaRuntime/Context.swift +++ b/Sources/LambdaRuntime/Context.swift @@ -12,7 +12,7 @@ public class Context { public let traceId : String public let requestId : String - public let logger : Logger + public let logger : Logger public let eventLoop : EventLoop public let deadlineDate: Date diff --git a/Sources/LambdaRuntime/Events/ALB.swift b/Sources/LambdaRuntime/Events/ALB.swift index 00446ea..50ec93f 100644 --- a/Sources/LambdaRuntime/Events/ALB.swift +++ b/Sources/LambdaRuntime/Events/ALB.swift @@ -8,7 +8,7 @@ import NIOHTTP1 public struct ALB { /// ALBTargetGroupRequest contains data originating from the ALB Lambda target group integration - public struct TargetGroupRequest { + public struct TargetGroupRequest: DecodableBody { /// ALBTargetGroupRequestContext contains the information to identify the load balancer invoking the lambda public struct Context: Codable { @@ -21,7 +21,7 @@ public struct ALB { public let headers: HTTPHeaders public let requestContext: Context public let isBase64Encoded: Bool - public let body: String + public let body: String? } /// ELBContext contains the information to identify the ARN invoking the lambda @@ -143,7 +143,9 @@ extension ALB.TargetGroupRequest: Decodable { self.requestContext = try container.decode(Context.self, forKey: .requestContext) self.isBase64Encoded = try container.decode(Bool.self, forKey: .isBase64Encoded) - self.body = try container.decode(String.self, forKey: .body) + + let body = try container.decode(String.self, forKey: .body) + self.body = body != "" ? body : nil } } diff --git a/Sources/LambdaRuntime/Events/APIGateway.swift b/Sources/LambdaRuntime/Events/APIGateway.swift index 4804a65..7d28ab9 100644 --- a/Sources/LambdaRuntime/Events/APIGateway.swift +++ b/Sources/LambdaRuntime/Events/APIGateway.swift @@ -9,7 +9,7 @@ import Base64Kit public struct APIGateway { /// APIGatewayRequest contains data coming from the API Gateway - public struct Request { + public struct Request: DecodableBody { public struct Context: Codable { @@ -162,18 +162,9 @@ extension APIGateway.Request: Decodable { extension APIGateway.Request { + @available(*, deprecated, renamed: "decodeBody(_:decoder:)") public func payload(_ type: Payload.Type, decoder: JSONDecoder = JSONDecoder()) throws -> Payload { - let body = self.body ?? "" - - let capacity = body.lengthOfBytes(using: .utf8) - - // TBD: I am pretty sure, we don't need this buffer copy here. - // Access the strings buffer directly to get to the data. - var buffer = ByteBufferAllocator().buffer(capacity: capacity) - buffer.setString(body, at: 0) - buffer.moveWriterIndex(to: capacity) - - return try decoder.decode(Payload.self, from: buffer) + return try self.decodeBody(Payload.self, decoder: decoder) } } diff --git a/Sources/LambdaRuntime/Events/DecodableBody.swift b/Sources/LambdaRuntime/Events/DecodableBody.swift new file mode 100644 index 0000000..adb431e --- /dev/null +++ b/Sources/LambdaRuntime/Events/DecodableBody.swift @@ -0,0 +1,41 @@ +import Foundation +import NIO +import NIOFoundationCompat + +protocol DecodableBody { + + var body: String? { get } + var isBase64Encoded: Bool { get } + +} + +extension DecodableBody { + + var isBase64Encoded: Bool { + return false + } + +} + +extension DecodableBody { + + func decodeBody(_ type: T.Type, decoder: JSONDecoder = JSONDecoder()) throws -> T { + + // I would really like to not use Foundation.Data at all, but well + // the NIOFoundationCompat just creates an internal Data as well. + // So let's save one malloc and copy and just use Data. + let payload = self.body ?? "" + + let data: Data + if self.isBase64Encoded { + let bytes = try payload.base64decoded() + data = Data(bytes) + } + else { + // TBD: Can this ever fail? I wouldn't think so... + data = payload.data(using: .utf8)! + } + + return try decoder.decode(T.self, from: data) + } +} diff --git a/Sources/LambdaRuntime/Events/SNS.swift b/Sources/LambdaRuntime/Events/SNS.swift index 1156ad0..5132392 100644 --- a/Sources/LambdaRuntime/Events/SNS.swift +++ b/Sources/LambdaRuntime/Events/SNS.swift @@ -100,20 +100,15 @@ extension SNS.Message: Decodable { } -extension SNS.Message { +extension SNS.Message: DecodableBody { + public var body: String? { + return self.message != "" ? self.message : nil + } + + @available(*, deprecated, renamed: "decodeBody(_:decoder:)") public func payload(decoder: JSONDecoder = JSONDecoder()) throws -> Payload { - let body = self.message - - let capacity = body.lengthOfBytes(using: .utf8) - - // TBD: I am pretty sure, we don't need this buffer copy here. - // Access the strings buffer directly to get to the data. - var buffer = ByteBufferAllocator().buffer(capacity: capacity) - buffer.setString(body, at: 0) - buffer.moveWriterIndex(to: capacity) - - return try decoder.decode(Payload.self, from: buffer) + return try self.decodeBody(Payload.self, decoder: decoder) } } diff --git a/Sources/LambdaRuntime/Events/SQS.swift b/Sources/LambdaRuntime/Events/SQS.swift index 3c6ba4e..3a1ecf8 100644 --- a/Sources/LambdaRuntime/Events/SQS.swift +++ b/Sources/LambdaRuntime/Events/SQS.swift @@ -4,7 +4,7 @@ import NIO /// https://github.com/aws/aws-lambda-go/blob/master/events/sqs.go public struct SQS { - public struct Event: Codable { + public struct Event: Decodable { public let records: [Message] enum CodingKeys: String, CodingKey { @@ -12,7 +12,7 @@ public struct SQS { } } - public struct Message: Codable { + public struct Message: DecodableBody { /// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_MessageAttributeValue.html public enum Attribute { @@ -23,7 +23,7 @@ public struct SQS { public let messageId : String public let receiptHandle : String - public let body : String + public let body : String? public let md5OfBody : String public let md5OfMessageAttributes : String? public let attributes : [String: String] @@ -31,20 +31,41 @@ public struct SQS { public let eventSourceArn : String public let eventSource : String public let awsRegion : String + } +} + +extension SQS.Message: Decodable { + + enum CodingKeys: String, CodingKey { + case messageId + case receiptHandle + case body + case md5OfBody + case md5OfMessageAttributes + case attributes + case messageAttributes + case eventSourceArn = "eventSourceARN" + case eventSource + case awsRegion + } + + public init(from decoder: Decoder) throws { + + let container = try decoder.container(keyedBy: CodingKeys.self) + self.messageId = try container.decode(String.self, forKey: .messageId) + self.receiptHandle = try container.decode(String.self, forKey: .receiptHandle) + self.md5OfBody = try container.decode(String.self, forKey: .md5OfBody) + self.md5OfMessageAttributes = try container.decodeIfPresent(String.self, forKey: .md5OfMessageAttributes) + self.attributes = try container.decode([String: String].self, forKey: .attributes) + self.messageAttributes = try container.decode([String: Attribute].self, forKey: .messageAttributes) + self.eventSourceArn = try container.decode(String.self, forKey: .eventSourceArn) + self.eventSource = try container.decode(String.self, forKey: .eventSource) + self.awsRegion = try container.decode(String.self, forKey: .awsRegion) - enum CodingKeys: String, CodingKey { - case messageId - case receiptHandle - case body - case md5OfBody - case md5OfMessageAttributes - case attributes - case messageAttributes - case eventSourceArn = "eventSourceARN" - case eventSource - case awsRegion - } + let body = try container.decode(String?.self, forKey: .body) + self.body = body != "" ? body : nil } + } extension SQS.Message.Attribute: Equatable { } diff --git a/Tests/LambdaRuntimeTests/Events/ALBTests.swift b/Tests/LambdaRuntimeTests/Events/ALBTests.swift index 1a0d51b..a05008a 100644 --- a/Tests/LambdaRuntimeTests/Events/ALBTests.swift +++ b/Tests/LambdaRuntimeTests/Events/ALBTests.swift @@ -40,7 +40,7 @@ class ALBTests: XCTestCase { let event = try decoder.decode(ALB.TargetGroupRequest.self, from: data) XCTAssertEqual(event.httpMethod, .GET) - XCTAssertEqual(event.body, "") + XCTAssertEqual(event.body, nil) XCTAssertEqual(event.isBase64Encoded, false) XCTAssertEqual(event.headers.count, 11) XCTAssertEqual(event.path, "/") diff --git a/Tests/LambdaRuntimeTests/Events/APIGatewayTests.swift b/Tests/LambdaRuntimeTests/Events/APIGatewayTests.swift index e8fd769..5f104ea 100644 --- a/Tests/LambdaRuntimeTests/Events/APIGatewayTests.swift +++ b/Tests/LambdaRuntimeTests/Events/APIGatewayTests.swift @@ -3,8 +3,8 @@ import XCTest import NIO import NIOHTTP1 import NIOFoundationCompat -import LambdaRuntimeTestUtils @testable import LambdaRuntime +import LambdaRuntimeTestUtils class APIGatewayTests: XCTestCase { @@ -85,7 +85,7 @@ class APIGatewayTests: XCTestCase { XCTAssertEqual(request.path, "/todos") XCTAssertEqual(request.httpMethod, .POST) - let todo = try request.payload(Todo.self) + let todo = try request.decodeBody(Todo.self) XCTAssertEqual(todo.title, "a todo") } catch { diff --git a/Tests/LambdaRuntimeTests/Events/DecodableBodyTests.swift b/Tests/LambdaRuntimeTests/Events/DecodableBodyTests.swift new file mode 100644 index 0000000..561e0e8 --- /dev/null +++ b/Tests/LambdaRuntimeTests/Events/DecodableBodyTests.swift @@ -0,0 +1,57 @@ +import Foundation +import XCTest +import Base64Kit +@testable import LambdaRuntime + +class DecodableBodyTests: XCTestCase { + + struct TestEvent: DecodableBody { + let body: String? + let isBase64Encoded: Bool + } + + struct TestPayload: Codable { + let hello: String + } + + func testSimplePayloadFromEvent() { + do { + let event = TestEvent(body: "{\"hello\":\"world\"}", isBase64Encoded: false) + let payload = try event.decodeBody(TestPayload.self) + + XCTAssertEqual(payload.hello, "world") + } + catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testBase64PayloadFromEvent() { + do { + let event = TestEvent(body: "eyJoZWxsbyI6IndvcmxkIn0=", isBase64Encoded: true) + let payload = try event.decodeBody(TestPayload.self) + + XCTAssertEqual(payload.hello, "world") + } + catch { + XCTFail("Unexpected error: \(error)") + } + } + + func testNoDataFromEvent() { + do { + let event = TestEvent(body: "", isBase64Encoded: false) + _ = try event.decodeBody(TestPayload.self) + + XCTFail("Did not expect to reach this point") + } + catch DecodingError.dataCorrupted(_) { + return // expected error + } + catch { + XCTFail("Unexpected error: \(error)") + } + + } + +} diff --git a/Tests/LambdaRuntimeTests/Events/SNSTests.swift b/Tests/LambdaRuntimeTests/Events/SNSTests.swift index 04198d6..a852344 100644 --- a/Tests/LambdaRuntimeTests/Events/SNSTests.swift +++ b/Tests/LambdaRuntimeTests/Events/SNSTests.swift @@ -73,7 +73,7 @@ class SNSTests: XCTestCase { XCTAssertEqual(record.sns.messageAttributes["binary"], .binary(binaryBuffer)) XCTAssertEqual(record.sns.messageAttributes["string"], .string("abc123")) - let payload: TestStruct = try record.sns.payload() + let payload = try record.sns.decodeBody(TestStruct.self) XCTAssertEqual(payload.hello, "world") } catch { diff --git a/Tests/LambdaRuntimeTests/Runtime+CodableTests.swift b/Tests/LambdaRuntimeTests/Runtime+CodableTests.swift index 4406409..93f6686 100644 --- a/Tests/LambdaRuntimeTests/Runtime+CodableTests.swift +++ b/Tests/LambdaRuntimeTests/Runtime+CodableTests.swift @@ -2,8 +2,8 @@ import Foundation import XCTest import NIO import NIOHTTP1 -import LambdaRuntimeTestUtils @testable import LambdaRuntime +import LambdaRuntimeTestUtils class RuntimeCodableTests: XCTestCase {