refactor (#13)

motivation: simpler concurrency, better tests

changes:
* ensure http client is called in a single-threaded manner and remove locks 
* refactor test
* make mock server more robust
This commit is contained in:
tomer doron
2020-03-08 12:41:09 -07:00
committed by GitHub
parent b8a6466577
commit 08f75f51c7
14 changed files with 195 additions and 147 deletions
+14 -13
View File
@@ -64,13 +64,12 @@ internal final class HTTPHandler: ChannelInboundHandler {
private let mode: Mode
private let keepAlive: Bool
private var requestHead: HTTPRequestHead!
private var requestBody: ByteBuffer?
private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>()
public init(logger: Logger, keepAlive: Bool, mode: Mode) {
self.logger = logger
self.mode = mode
self.keepAlive = keepAlive
self.mode = mode
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
@@ -78,27 +77,29 @@ internal final class HTTPHandler: ChannelInboundHandler {
switch requestPart {
case .head(let head):
self.requestHead = head
self.requestBody?.clear()
self.pending.append((head: head, body: nil))
case .body(var buffer):
if self.requestBody == nil {
self.requestBody = buffer
var request = self.pending.removeFirst()
if request.body == nil {
request.body = buffer
} else {
self.requestBody!.writeBuffer(&buffer)
request.body!.writeBuffer(&buffer)
}
self.pending.prepend(request)
case .end:
self.processRequest(context: context)
let request = self.pending.removeFirst()
self.processRequest(context: context, request: request)
}
}
func processRequest(context: ChannelHandlerContext) {
self.logger.debug("\(self) processing \(self.requestHead.uri)")
func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) {
self.logger.debug("\(self) processing \(request.head.uri)")
var responseStatus: HTTPResponseStatus
var responseBody: String?
var responseHeaders: [(String, String)]?
if self.requestHead.uri.hasSuffix("/next") {
if request.head.uri.hasSuffix("/next") {
let requestId = UUID().uuidString
responseStatus = .ok
switch self.mode {
@@ -108,7 +109,7 @@ internal final class HTTPHandler: ChannelInboundHandler {
responseBody = "{ \"body\": \"\(requestId)\" }"
}
responseHeaders = [(AmazonHeaders.requestID, requestId)]
} else if self.requestHead.uri.hasSuffix("/response") {
} else if request.head.uri.hasSuffix("/response") {
responseStatus = .accepted
} else {
responseStatus = .notFound
+54 -48
View File
@@ -17,13 +17,15 @@ import NIOConcurrencyHelpers
import NIOHTTP1
/// A barebone HTTP client to interact with AWS Runtime Engine which is an HTTP server.
internal class HTTPClient {
/// Note that Lambda Runtime API dictate that only one requests runs at a time.
/// This means we can avoid locks and other concurrency concern we would otherwise need to build into the client
internal final class HTTPClient {
private let eventLoop: EventLoop
private let configuration: Lambda.Configuration.RuntimeEngine
private let targetHost: String
private var state = State.disconnected
private let lock = Lock()
private let executing = NIOAtomic.makeAtomic(value: false)
init(eventLoop: EventLoop, configuration: Lambda.Configuration.RuntimeEngine) {
self.eventLoop = eventLoop
@@ -46,38 +48,34 @@ internal class HTTPClient {
timeout: timeout ?? self.configuration.requestTimeout))
}
private func execute(_ request: Request) -> EventLoopFuture<Response> {
self.lock.lock()
// TODO: cap reconnect attempt
private func execute(_ request: Request, validate: Bool = true) -> EventLoopFuture<Response> {
precondition(!validate || self.executing.compareAndExchange(expected: false, desired: true), "expecting single request at a time")
switch self.state {
case .disconnected:
return self.connect().flatMap { channel -> EventLoopFuture<Response> in
self.state = .connected(channel)
return self.execute(request, validate: false)
}
case .connected(let channel):
guard channel.isActive else {
// attempt to reconnect
self.state = .disconnected
self.lock.unlock()
return self.execute(request)
return self.execute(request, validate: false)
}
self.lock.unlock()
let promise = channel.eventLoop.makePromise(of: Response.self)
promise.futureResult.whenComplete { _ in
precondition(self.executing.compareAndExchange(expected: true, desired: false), "invalid execution state")
}
let wrapper = HTTPRequestWrapper(request: request, promise: promise)
return channel.writeAndFlush(wrapper).flatMap {
promise.futureResult
}
case .disconnected:
return self.connect().flatMap {
self.lock.unlock()
return self.execute(request)
}
default:
preconditionFailure("invalid state \(self.state)")
channel.writeAndFlush(wrapper).cascadeFailure(to: promise)
return promise.futureResult
}
}
private func connect() -> EventLoopFuture<Void> {
guard case .disconnected = self.state else {
preconditionFailure("invalid state \(self.state)")
}
self.state = .connecting
let bootstrap = ClientBootstrap(group: eventLoop)
private func connect() -> EventLoopFuture<Channel> {
let bootstrap = ClientBootstrap(group: self.eventLoop)
.channelInitializer { channel in
channel.pipeline.addHTTPClientHandlers().flatMap {
channel.pipeline.addHandlers([HTTPHandler(keepAlive: self.configuration.keepAlive),
@@ -88,9 +86,7 @@ internal class HTTPClient {
do {
// connect directly via socket address to avoid happy eyeballs (perf)
let address = try SocketAddress(ipAddress: self.configuration.ip, port: self.configuration.port)
return bootstrap.connect(to: address).flatMapThrowing { channel in
self.state = .connected(channel)
}
return bootstrap.connect(to: address)
} catch {
return self.eventLoop.makeFailedFuture(error)
}
@@ -126,13 +122,12 @@ internal class HTTPClient {
}
private enum State {
case connecting
case connected(Channel)
case disconnected
case connected(Channel)
}
}
private class HTTPHandler: ChannelDuplexHandler {
private final class HTTPHandler: ChannelDuplexHandler {
typealias OutboundIn = HTTPClient.Request
typealias InboundOut = HTTPClient.Response
typealias InboundIn = HTTPClientResponsePart
@@ -207,15 +202,15 @@ private class HTTPHandler: ChannelDuplexHandler {
}
}
private class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
// no need in locks since we validate only one request can run at a time
private final class UnaryHandler: ChannelDuplexHandler {
typealias OutboundIn = HTTPRequestWrapper
typealias InboundIn = HTTPClient.Response
typealias OutboundOut = HTTPClient.Request
private let keepAlive: Bool
private let lock = Lock()
private var pendingResponses = CircularBuffer<(EventLoopPromise<HTTPClient.Response>, Scheduled<Void>?)>()
private var pending: (promise: EventLoopPromise<HTTPClient.Response>, timeout: Scheduled<Void>?)?
private var lastError: Error?
init(keepAlive: Bool) {
@@ -223,47 +218,58 @@ private class UnaryHandler: ChannelInboundHandler, ChannelOutboundHandler {
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
guard self.pending == nil else {
preconditionFailure("invalid state, outstanding request")
}
let wrapper = unwrapOutboundIn(data)
let timeoutTask = wrapper.request.timeout.map {
context.eventLoop.scheduleTask(in: $0) {
if (self.lock.withLock { !self.pendingResponses.isEmpty }) {
self.errorCaught(context: context, error: HTTPClient.Errors.timeout)
if self.pending != nil {
context.pipeline.fireErrorCaught(HTTPClient.Errors.timeout)
}
}
}
self.lock.withLockVoid { pendingResponses.append((wrapper.promise, timeoutTask)) }
self.pending = (promise: wrapper.promise, timeout: timeoutTask)
context.writeAndFlush(wrapOutboundOut(wrapper.request), promise: promise)
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let response = unwrapInboundIn(data)
if let pending = (self.lock.withLock { self.pendingResponses.popFirst() }) {
let serverKeepAlive = response.headers["connection"].first?.lowercased() == "keep-alive"
let future = self.keepAlive && serverKeepAlive ? context.eventLoop.makeSucceededFuture(()) : context.channel.close()
future.whenComplete { _ in
pending.1?.cancel()
pending.0.succeed(response)
guard let pending = self.pending else {
preconditionFailure("invalid state, no pending request")
}
let serverKeepAlive = response.headers.first(name: "connection")?.lowercased() == "keep-alive"
if !self.keepAlive || !serverKeepAlive {
pending.promise.futureResult.whenComplete { _ in
_ = context.channel.close()
}
}
self.completeWith(.success(response))
}
func errorCaught(context: ChannelHandlerContext, error: Error) {
// pending responses will fail with lastError in channelInactive since we are calling context.close
self.lock.withLockVoid { self.lastError = error }
self.lastError = error
context.channel.close(promise: nil)
}
func channelInactive(context: ChannelHandlerContext) {
// fail any pending responses with last error or assume peer disconnected
self.failPendingResponses(self.lock.withLock { self.lastError } ?? HTTPClient.Errors.connectionResetByPeer)
if self.pending != nil {
let error = self.lastError ?? HTTPClient.Errors.connectionResetByPeer
self.completeWith(.failure(error))
}
context.fireChannelInactive()
}
private func failPendingResponses(_ error: Error) {
while let pending = (self.lock.withLock { pendingResponses.popFirst() }) {
pending.1?.cancel()
pending.0.fail(error)
private func completeWith(_ result: Result<HTTPClient.Response, Error>) {
guard let pending = self.pending else {
preconditionFailure("invalid state, no pending request")
}
self.pending = nil
self.lastError = nil
pending.timeout?.cancel()
pending.promise.completeWith(result)
}
}
+3 -3
View File
@@ -105,7 +105,7 @@ public enum Lambda {
}
set {
self.stateLock.withLockVoid {
precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(_state)")
precondition(newValue.rawValue > _state.rawValue, "invalid state \(newValue) after \(self._state)")
self._state = newValue
}
}
@@ -124,12 +124,12 @@ public enum Lambda {
}
func stop() {
self.logger.info("lambda lifecycle stopping")
self.logger.debug("lambda lifecycle stopping")
self.state = .stopping
}
func shutdown() {
self.logger.info("lambda lifecycle shutdown")
self.logger.debug("lambda lifecycle shutdown")
self.state = .shutdown
}
+2 -2
View File
@@ -36,7 +36,7 @@ internal struct LambdaRunner {
///
/// - Returns: An `EventLoopFuture<Void>` fulfilled with the outcome of the initialization.
func initialize(logger: Logger) -> EventLoopFuture<Void> {
logger.info("initializing lambda")
logger.debug("initializing lambda")
// We need to use `flatMap` instead of `whenFailure` to ensure we complete reporting the result before stopping.
return self.lambdaHandler.initialize(eventLoop: self.eventLoop,
lifecycleId: self.lifecycleId,
@@ -69,7 +69,7 @@ internal struct LambdaRunner {
}
}.always { result in
// we are done!
logger.log(level: result.successful ? .info : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")")
logger.log(level: result.successful ? .debug : .warning, "lambda invocation sequence completed \(result.successful ? "successfully" : "with failure")")
}
}
}
@@ -145,7 +145,7 @@ internal struct JsonCodecError: Error, Equatable {
}
static func == (lhs: JsonCodecError, rhs: JsonCodecError) -> Bool {
return lhs.cause.localizedDescription == rhs.cause.localizedDescription
return String(describing: lhs.cause) == String(describing: rhs.cause)
}
}
@@ -25,7 +25,7 @@ import XCTest
extension CodableLambdaTest {
static var allTests: [(String, (CodableLambdaTest) -> () throws -> Void)] {
return [
("testSuceess", testSuceess),
("testSuccess", testSuccess),
("testFailure", testFailure),
("testClosureSuccess", testClosureSuccess),
("testClosureFailure", testClosureFailure),
@@ -16,39 +16,47 @@
import XCTest
class CodableLambdaTest: XCTestCase {
func testSuceess() throws {
func testSuccess() {
let server = MockLambdaServer(behavior: GoodBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let maxTimes = Int.random(in: 1 ... 10)
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
let result = Lambda.run(handler: CodableEchoHandler(), configuration: configuration)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
}
func testFailure() throws {
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
func testFailure() {
let server = MockLambdaServer(behavior: BadBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let result = Lambda.run(handler: CodableEchoHandler())
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
}
func testClosureSuccess() throws {
func testClosureSuccess() {
let server = MockLambdaServer(behavior: GoodBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let maxTimes = Int.random(in: 1 ... 10)
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
let result = Lambda.run(configuration: configuration) { (_, payload: Request, callback) in
callback(.success(Response(requestId: payload.requestId)))
}
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
}
func testClosureFailure() throws {
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
func testClosureFailure() {
let server = MockLambdaServer(behavior: BadBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let result: LambdaLifecycleResult = Lambda.run { (_, payload: Request, callback) in
callback(.success(Response(requestId: payload.requestId)))
}
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
}
}
@@ -25,7 +25,7 @@ import XCTest
extension StringLambdaTest {
static var allTests: [(String, (StringLambdaTest) -> () throws -> Void)] {
return [
("testSuceess", testSuceess),
("testSuccess", testSuccess),
("testFailure", testFailure),
("testClosureSuccess", testClosureSuccess),
("testClosureFailure", testClosureFailure),
@@ -16,39 +16,47 @@
import XCTest
class StringLambdaTest: XCTestCase {
func testSuceess() throws {
func testSuccess() {
let server = MockLambdaServer(behavior: GoodBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let maxTimes = Int.random(in: 1 ... 10)
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
let result = Lambda.run(handler: StringEchoHandler(), configuration: configuration)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
}
func testFailure() throws {
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
func testFailure() {
let server = MockLambdaServer(behavior: BadBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let result = Lambda.run(handler: StringEchoHandler())
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
}
func testClosureSuccess() throws {
func testClosureSuccess() {
let server = MockLambdaServer(behavior: GoodBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let maxTimes = Int.random(in: 1 ... 10)
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
let result = Lambda.run(configuration: configuration) { (_, payload: String, callback) in
callback(.success(payload))
}
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
}
func testClosureFailure() throws {
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
func testClosureFailure() {
let server = MockLambdaServer(behavior: BadBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let result: LambdaLifecycleResult = Lambda.run { (_, payload: String, callback) in
callback(.success(payload))
}
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
}
}
@@ -12,6 +12,8 @@
//
//===----------------------------------------------------------------------===//
import Logging
import NIO
@testable import SwiftAwsLambda
import XCTest
@@ -25,7 +25,7 @@ import XCTest
extension LambdaTest {
static var allTests: [(String, (LambdaTest) -> () throws -> Void)] {
return [
("testSuceess", testSuceess),
("testSuccess", testSuccess),
("testFailure", testFailure),
("testInitFailure", testInitFailure),
("testInitFailureAndReportErrorFailure", testInitFailureAndReportErrorFailure),
+56 -35
View File
@@ -17,57 +17,69 @@ import NIO
import XCTest
class LambdaTest: XCTestCase {
func testSuceess() throws {
func testSuccess() {
let server = MockLambdaServer(behavior: GoodBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let maxTimes = Int.random(in: 10 ... 20)
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
let handler = EchoHandler()
let result = Lambda.run(handler: handler, configuration: configuration)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
XCTAssertEqual(handler.initializeCalls, 1)
}
func testFailure() throws {
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
func testFailure() {
let server = MockLambdaServer(behavior: BadBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let result = Lambda.run(handler: EchoHandler())
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
}
func testInitFailure() throws {
let server = try MockLambdaServer(behavior: GoodBehaviourWhenInitFails()).start().wait()
func testInitFailure() {
let server = MockLambdaServer(behavior: GoodBehaviourWhenInitFails())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let handler = FailedInitializerHandler("kaboom")
let result = Lambda.run(handler: handler)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: FailedInitializerHandler.Error(description: "kaboom"))
}
func testInitFailureAndReportErrorFailure() throws {
let server = try MockLambdaServer(behavior: BadBehaviourWhenInitFails()).start().wait()
func testInitFailureAndReportErrorFailure() {
let server = MockLambdaServer(behavior: BadBehaviourWhenInitFails())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let handler = FailedInitializerHandler("kaboom")
let result = Lambda.run(handler: handler)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: FailedInitializerHandler.Error(description: "kaboom"))
}
func testClosureSuccess() throws {
func testClosureSuccess() {
let server = MockLambdaServer(behavior: GoodBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let maxTimes = Int.random(in: 10 ... 20)
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
let server = try MockLambdaServer(behavior: GoodBehavior()).start().wait()
let result = Lambda.run(configuration: configuration) { (_, payload: [UInt8], callback: LambdaCallback) in
callback(.success(payload))
}
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
}
func testClosureFailure() throws {
let server = try MockLambdaServer(behavior: BadBehavior()).start().wait()
func testClosureFailure() {
let server = MockLambdaServer(behavior: BadBehavior())
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let result: LambdaLifecycleResult = Lambda.run { (_, payload: [UInt8], callback: LambdaCallback) in
callback(.success(payload))
}
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.badStatusCode(.internalServerError))
}
@@ -94,49 +106,58 @@ class LambdaTest: XCTestCase {
try eventLoopGroup.syncShutdownGracefully()
}
func testTimeout() throws {
func testTimeout() {
let timeout: Int64 = 100
let server = MockLambdaServer(behavior: GoodBehavior(requestId: "timeout", payload: "\(timeout * 2)"))
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: 1),
runtimeEngine: .init(requestTimeout: .milliseconds(timeout)))
let server = try MockLambdaServer(behavior: GoodBehavior(requestId: "timeout", payload: "\(timeout * 2)")).start().wait()
let result = Lambda.run(handler: EchoHandler(), configuration: configuration)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.upstreamError("timeout"))
}
func testDisconnect() throws {
func testDisconnect() {
let server = MockLambdaServer(behavior: GoodBehavior(requestId: "disconnect"))
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: 1))
let server = try MockLambdaServer(behavior: GoodBehavior(requestId: "disconnect")).start().wait()
let result = Lambda.run(handler: EchoHandler(), configuration: configuration)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shouldFailWithError: LambdaRuntimeClientError.upstreamError("connectionResetByPeer"))
}
func testBigPayload() throws {
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: 1))
func testBigPayload() {
let payload = String(repeating: "*", count: 104_448)
let server = try MockLambdaServer(behavior: GoodBehavior(payload: payload)).start().wait()
let server = MockLambdaServer(behavior: GoodBehavior(payload: payload))
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: 1))
let result = Lambda.run(handler: EchoHandler(), configuration: configuration)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: 1)
}
func testKeepAliveServer() throws {
func testKeepAliveServer() {
let server = MockLambdaServer(behavior: GoodBehavior(), keepAlive: true)
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let maxTimes = 10
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
let server = try MockLambdaServer(behavior: GoodBehavior(), keepAlive: true).start().wait()
let result = Lambda.run(handler: EchoHandler(), configuration: configuration)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
}
func testNoKeepAliveServer() throws {
func testNoKeepAliveServer() {
let server = MockLambdaServer(behavior: GoodBehavior(), keepAlive: false)
XCTAssertNoThrow(try server.start().wait())
defer { XCTAssertNoThrow(try server.stop().wait()) }
let maxTimes = 10
let configuration = Lambda.Configuration(lifecycle: .init(maxTimes: maxTimes))
let server = try MockLambdaServer(behavior: GoodBehavior(), keepAlive: false).start().wait()
let result = Lambda.run(handler: EchoHandler(), configuration: configuration)
try server.stop().wait()
assertLambdaLifecycleResult(result: result, shoudHaveRun: maxTimes)
}
}
@@ -79,8 +79,7 @@ internal final class HTTPHandler: ChannelInboundHandler {
private let keepAlive: Bool
private let behavior: LambdaServerBehavior
private var requestHead: HTTPRequestHead!
private var requestBody: ByteBuffer?
private var pending = CircularBuffer<(head: HTTPRequestHead, body: ByteBuffer?)>()
public init(logger: Logger, keepAlive: Bool, behavior: LambdaServerBehavior) {
self.logger = logger
@@ -93,23 +92,25 @@ internal final class HTTPHandler: ChannelInboundHandler {
switch requestPart {
case .head(let head):
self.requestHead = head
self.requestBody?.clear()
self.pending.append((head: head, body: nil))
case .body(var buffer):
if self.requestBody == nil {
self.requestBody = buffer
var request = self.pending.removeFirst()
if request.body == nil {
request.body = buffer
} else {
self.requestBody!.writeBuffer(&buffer)
request.body!.writeBuffer(&buffer)
}
self.pending.prepend(request)
case .end:
self.processRequest(context: context)
let request = self.pending.removeFirst()
self.processRequest(context: context, request: request)
}
}
func processRequest(context: ChannelHandlerContext) {
self.logger.info("\(self) processing \(self.requestHead.uri)")
func processRequest(context: ChannelHandlerContext, request: (head: HTTPRequestHead, body: ByteBuffer?)) {
self.logger.info("\(self) processing \(request.head.uri)")
let requestBody = self.requestBody.flatMap { (buffer: ByteBuffer) -> String? in
let requestBody = request.body.flatMap { (buffer: ByteBuffer) -> String? in
var buffer = buffer
return buffer.readString(length: buffer.readableBytes)
}
@@ -119,7 +120,7 @@ internal final class HTTPHandler: ChannelInboundHandler {
var responseHeaders: [(String, String)]?
// Handle post-init-error first to avoid matching the less specific post-error suffix.
if self.requestHead.uri.hasSuffix(Consts.postInitErrorURL) {
if request.head.uri.hasSuffix(Consts.postInitErrorURL) {
guard let json = requestBody, let error = ErrorResponse.fromJson(json) else {
return self.writeResponse(context: context, status: .badRequest)
}
@@ -129,7 +130,7 @@ internal final class HTTPHandler: ChannelInboundHandler {
case .failure(let error):
responseStatus = .init(statusCode: error.rawValue)
}
} else if self.requestHead.uri.hasSuffix(Consts.requestWorkURLSuffix) {
} else if request.head.uri.hasSuffix(Consts.requestWorkURLSuffix) {
switch self.behavior.getWork() {
case .success(let (requestId, result)):
if requestId == "timeout" {
@@ -143,8 +144,8 @@ internal final class HTTPHandler: ChannelInboundHandler {
case .failure(let error):
responseStatus = .init(statusCode: error.rawValue)
}
} else if self.requestHead.uri.hasSuffix(Consts.postResponseURLSuffix) {
guard let requestId = requestHead.uri.split(separator: "/").dropFirst(3).first, let response = requestBody else {
} else if request.head.uri.hasSuffix(Consts.postResponseURLSuffix) {
guard let requestId = request.head.uri.split(separator: "/").dropFirst(3).first, let response = requestBody else {
return self.writeResponse(context: context, status: .badRequest)
}
switch self.behavior.processResponse(requestId: String(requestId), response: response) {
@@ -153,8 +154,8 @@ internal final class HTTPHandler: ChannelInboundHandler {
case .failure(let error):
responseStatus = .init(statusCode: error.rawValue)
}
} else if self.requestHead.uri.hasSuffix(Consts.postErrorURLSuffix) {
guard let requestId = requestHead.uri.split(separator: "/").dropFirst(3).first,
} else if request.head.uri.hasSuffix(Consts.postErrorURLSuffix) {
guard let requestId = request.head.uri.split(separator: "/").dropFirst(3).first,
let json = requestBody,
let error = ErrorResponse.fromJson(json)
else {
@@ -169,6 +170,7 @@ internal final class HTTPHandler: ChannelInboundHandler {
} else {
responseStatus = .notFound
}
self.logger.info("\(self) responding to \(request.head.uri)")
self.writeResponse(context: context, status: responseStatus, headers: responseHeaders, body: responseBody)
}
+1 -1
View File
@@ -30,7 +30,7 @@ func runLambda(behavior: LambdaServerBehavior, handler: LambdaHandler) throws {
}.wait()
}
class EchoHandler: LambdaHandler {
final class EchoHandler: LambdaHandler {
var initializeCalls = 0
func initialize(callback: @escaping LambdaInitCallBack) {