mirror of
https://github.com/swift-server/async-http-client.git
synced 2026-05-03 07:32:29 +00:00
4b4d6605aa
* fail if we get part when state is endOrError * Prevent TaskHandler state change after `.endOrError` Motivation: Right now if task handler encounters an error, it changes state to `.endOrError`. We gate on that state to make sure that we do not process errors in the pipeline twice. Unfortunately, that state can be reset when we upload body or receive response parts. Modifications: Adds state validation before state is updated to a new value Adds a test Result: Fixes #297
1150 lines
43 KiB
Swift
1150 lines
43 KiB
Swift
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This source file is part of the AsyncHTTPClient open source project
|
|
//
|
|
// Copyright (c) 2018-2019 Apple Inc. and the AsyncHTTPClient project authors
|
|
// Licensed under Apache License v2.0
|
|
//
|
|
// See LICENSE.txt for license information
|
|
// See CONTRIBUTORS.txt for the list of AsyncHTTPClient project authors
|
|
//
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
import Foundation
|
|
import Logging
|
|
import NIO
|
|
import NIOConcurrencyHelpers
|
|
import NIOFoundationCompat
|
|
import NIOHTTP1
|
|
import NIOHTTPCompression
|
|
import NIOSSL
|
|
|
|
extension HTTPClient {
|
|
/// Represent request body.
|
|
public struct Body {
|
|
/// Chunk provider.
|
|
public struct StreamWriter {
|
|
let closure: (IOData) -> EventLoopFuture<Void>
|
|
|
|
/// Create new StreamWriter
|
|
///
|
|
/// - parameters:
|
|
/// - closure: function that will be called to write actual bytes to the channel.
|
|
public init(closure: @escaping (IOData) -> EventLoopFuture<Void>) {
|
|
self.closure = closure
|
|
}
|
|
|
|
/// Write data to server.
|
|
///
|
|
/// - parameters:
|
|
/// - data: `IOData` to write.
|
|
public func write(_ data: IOData) -> EventLoopFuture<Void> {
|
|
return self.closure(data)
|
|
}
|
|
}
|
|
|
|
/// Body size. Request validation will be failed with `HTTPClientErrors.contentLengthMissing` if nil,
|
|
/// unless `Trasfer-Encoding: chunked` header is set.
|
|
public var length: Int?
|
|
/// Body chunk provider.
|
|
public var stream: (StreamWriter) -> EventLoopFuture<Void>
|
|
|
|
/// Create and stream body using `ByteBuffer`.
|
|
///
|
|
/// - parameters:
|
|
/// - buffer: Body `ByteBuffer` representation.
|
|
public static func byteBuffer(_ buffer: ByteBuffer) -> Body {
|
|
return Body(length: buffer.readableBytes) { writer in
|
|
writer.write(.byteBuffer(buffer))
|
|
}
|
|
}
|
|
|
|
/// Create and stream body using `StreamWriter`.
|
|
///
|
|
/// - parameters:
|
|
/// - length: Body size. Request validation will be failed with `HTTPClientErrors.contentLengthMissing` if nil,
|
|
/// unless `Transfer-Encoding: chunked` header is set.
|
|
/// - stream: Body chunk provider.
|
|
public static func stream(length: Int? = nil, _ stream: @escaping (StreamWriter) -> EventLoopFuture<Void>) -> Body {
|
|
return Body(length: length, stream: stream)
|
|
}
|
|
|
|
/// Create and stream body using `Data`.
|
|
///
|
|
/// - parameters:
|
|
/// - data: Body `Data` representation.
|
|
public static func data(_ data: Data) -> Body {
|
|
return Body(length: data.count) { writer in
|
|
writer.write(.byteBuffer(ByteBuffer(bytes: data)))
|
|
}
|
|
}
|
|
|
|
/// Create and stream body using `String`.
|
|
///
|
|
/// - parameters:
|
|
/// - string: Body `String` representation.
|
|
public static func string(_ string: String) -> Body {
|
|
return Body(length: string.utf8.count) { writer in
|
|
writer.write(.byteBuffer(ByteBuffer(string: string)))
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Represent HTTP request.
|
|
public struct Request {
|
|
/// Represent kind of Request
|
|
enum Kind: Equatable {
|
|
enum UnixScheme: Equatable {
|
|
case baseURL
|
|
case http_unix
|
|
case https_unix
|
|
}
|
|
|
|
/// Remote host request.
|
|
case host
|
|
/// UNIX Domain Socket HTTP request.
|
|
case unixSocket(_ scheme: UnixScheme)
|
|
|
|
private static var hostRestrictedSchemes: Set = ["http", "https"]
|
|
private static var allSupportedSchemes: Set = ["http", "https", "unix", "http+unix", "https+unix"]
|
|
|
|
init(forScheme scheme: String) throws {
|
|
switch scheme {
|
|
case "http", "https": self = .host
|
|
case "unix": self = .unixSocket(.baseURL)
|
|
case "http+unix": self = .unixSocket(.http_unix)
|
|
case "https+unix": self = .unixSocket(.https_unix)
|
|
default:
|
|
throw HTTPClientError.unsupportedScheme(scheme)
|
|
}
|
|
}
|
|
|
|
func hostFromURL(_ url: URL) throws -> String {
|
|
switch self {
|
|
case .host:
|
|
guard let host = url.host else {
|
|
throw HTTPClientError.emptyHost
|
|
}
|
|
return host
|
|
case .unixSocket:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func socketPathFromURL(_ url: URL) throws -> String {
|
|
switch self {
|
|
case .unixSocket(.baseURL):
|
|
return url.baseURL?.path ?? url.path
|
|
case .unixSocket:
|
|
guard let socketPath = url.host else {
|
|
throw HTTPClientError.missingSocketPath
|
|
}
|
|
return socketPath
|
|
case .host:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func uriFromURL(_ url: URL) -> String {
|
|
switch self {
|
|
case .host:
|
|
return url.uri
|
|
case .unixSocket(.baseURL):
|
|
return url.baseURL != nil ? url.uri : "/"
|
|
case .unixSocket:
|
|
return url.uri
|
|
}
|
|
}
|
|
|
|
func supportsRedirects(to scheme: String?) -> Bool {
|
|
guard let scheme = scheme?.lowercased() else { return false }
|
|
|
|
switch self {
|
|
case .host:
|
|
return Kind.hostRestrictedSchemes.contains(scheme)
|
|
case .unixSocket:
|
|
return Kind.allSupportedSchemes.contains(scheme)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Request HTTP method, defaults to `GET`.
|
|
public let method: HTTPMethod
|
|
/// Remote URL.
|
|
public let url: URL
|
|
/// Remote HTTP scheme, resolved from `URL`.
|
|
public let scheme: String
|
|
/// Remote host, resolved from `URL`.
|
|
public let host: String
|
|
/// Socket path, resolved from `URL`.
|
|
let socketPath: String
|
|
/// URI composed of the path and query, resolved from `URL`.
|
|
let uri: String
|
|
/// Request custom HTTP Headers, defaults to no headers.
|
|
public var headers: HTTPHeaders
|
|
/// Request body, defaults to no body.
|
|
public var body: Body?
|
|
|
|
struct RedirectState {
|
|
var count: Int
|
|
var visited: Set<URL>?
|
|
}
|
|
|
|
var redirectState: RedirectState?
|
|
let kind: Kind
|
|
|
|
/// Create HTTP request.
|
|
///
|
|
/// - parameters:
|
|
/// - url: Remote `URL`.
|
|
/// - version: HTTP version.
|
|
/// - method: HTTP method.
|
|
/// - headers: Custom HTTP headers.
|
|
/// - body: Request body.
|
|
/// - throws:
|
|
/// - `invalidURL` if URL cannot be parsed.
|
|
/// - `emptyScheme` if URL does not contain HTTP scheme.
|
|
/// - `unsupportedScheme` if URL does contains unsupported HTTP scheme.
|
|
/// - `emptyHost` if URL does not contains a host.
|
|
public init(url: String, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws {
|
|
guard let url = URL(string: url) else {
|
|
throw HTTPClientError.invalidURL
|
|
}
|
|
|
|
try self.init(url: url, method: method, headers: headers, body: body)
|
|
}
|
|
|
|
/// Create an HTTP `Request`.
|
|
///
|
|
/// - parameters:
|
|
/// - url: Remote `URL`.
|
|
/// - method: HTTP method.
|
|
/// - headers: Custom HTTP headers.
|
|
/// - body: Request body.
|
|
/// - throws:
|
|
/// - `emptyScheme` if URL does not contain HTTP scheme.
|
|
/// - `unsupportedScheme` if URL does contains unsupported HTTP scheme.
|
|
/// - `emptyHost` if URL does not contains a host.
|
|
/// - `missingSocketPath` if URL does not contains a socketPath as an encoded host.
|
|
public init(url: URL, method: HTTPMethod = .GET, headers: HTTPHeaders = HTTPHeaders(), body: Body? = nil) throws {
|
|
guard let scheme = url.scheme?.lowercased() else {
|
|
throw HTTPClientError.emptyScheme
|
|
}
|
|
|
|
self.kind = try Kind(forScheme: scheme)
|
|
self.host = try self.kind.hostFromURL(url)
|
|
self.socketPath = try self.kind.socketPathFromURL(url)
|
|
self.uri = self.kind.uriFromURL(url)
|
|
|
|
self.redirectState = nil
|
|
self.url = url
|
|
self.method = method
|
|
self.scheme = scheme
|
|
self.headers = headers
|
|
self.body = body
|
|
}
|
|
|
|
/// Whether request will be executed using secure socket.
|
|
public var useTLS: Bool {
|
|
return self.scheme == "https" || self.scheme == "https+unix"
|
|
}
|
|
|
|
/// Resolved port.
|
|
public var port: Int {
|
|
return self.url.port ?? (self.useTLS ? 443 : 80)
|
|
}
|
|
}
|
|
|
|
/// Represent HTTP response.
|
|
public struct Response {
|
|
/// Remote host of the request.
|
|
public var host: String
|
|
/// Response HTTP status.
|
|
public var status: HTTPResponseStatus
|
|
/// Response HTTP version.
|
|
public var version: HTTPVersion
|
|
/// Reponse HTTP headers.
|
|
public var headers: HTTPHeaders
|
|
/// Response body.
|
|
public var body: ByteBuffer?
|
|
|
|
/// Create HTTP `Response`.
|
|
///
|
|
/// - parameters:
|
|
/// - host: Remote host of the request.
|
|
/// - status: Response HTTP status.
|
|
/// - headers: Reponse HTTP headers.
|
|
/// - body: Response body.
|
|
@available(*, deprecated, renamed: "init(host:status:version:headers:body:)")
|
|
public init(host: String, status: HTTPResponseStatus, headers: HTTPHeaders, body: ByteBuffer?) {
|
|
self.host = host
|
|
self.status = status
|
|
self.version = HTTPVersion(major: 1, minor: 1)
|
|
self.headers = headers
|
|
self.body = body
|
|
}
|
|
|
|
/// Create HTTP `Response`.
|
|
///
|
|
/// - parameters:
|
|
/// - host: Remote host of the request.
|
|
/// - status: Response HTTP status.
|
|
/// - version: Response HTTP version.
|
|
/// - headers: Reponse HTTP headers.
|
|
/// - body: Response body.
|
|
public init(host: String, status: HTTPResponseStatus, version: HTTPVersion, headers: HTTPHeaders, body: ByteBuffer?) {
|
|
self.host = host
|
|
self.status = status
|
|
self.version = version
|
|
self.headers = headers
|
|
self.body = body
|
|
}
|
|
}
|
|
|
|
/// HTTP authentication
|
|
public struct Authorization {
|
|
private enum Scheme {
|
|
case Basic(String)
|
|
case Bearer(String)
|
|
}
|
|
|
|
private let scheme: Scheme
|
|
|
|
private init(scheme: Scheme) {
|
|
self.scheme = scheme
|
|
}
|
|
|
|
public static func basic(username: String, password: String) -> HTTPClient.Authorization {
|
|
return .basic(credentials: Data("\(username):\(password)".utf8).base64EncodedString())
|
|
}
|
|
|
|
public static func basic(credentials: String) -> HTTPClient.Authorization {
|
|
return .init(scheme: .Basic(credentials))
|
|
}
|
|
|
|
public static func bearer(tokens: String) -> HTTPClient.Authorization {
|
|
return .init(scheme: .Bearer(tokens))
|
|
}
|
|
|
|
public var headerValue: String {
|
|
switch self.scheme {
|
|
case .Basic(let credentials):
|
|
return "Basic \(credentials)"
|
|
case .Bearer(let tokens):
|
|
return "Bearer \(tokens)"
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
public class ResponseAccumulator: HTTPClientResponseDelegate {
|
|
public typealias Response = HTTPClient.Response
|
|
|
|
enum State {
|
|
case idle
|
|
case head(HTTPResponseHead)
|
|
case body(HTTPResponseHead, ByteBuffer)
|
|
case end
|
|
case error(Error)
|
|
}
|
|
|
|
var state = State.idle
|
|
let request: HTTPClient.Request
|
|
|
|
public init(request: HTTPClient.Request) {
|
|
self.request = request
|
|
}
|
|
|
|
public func didReceiveHead(task: HTTPClient.Task<Response>, _ head: HTTPResponseHead) -> EventLoopFuture<Void> {
|
|
switch self.state {
|
|
case .idle:
|
|
self.state = .head(head)
|
|
case .head:
|
|
preconditionFailure("head already set")
|
|
case .body:
|
|
preconditionFailure("no head received before body")
|
|
case .end:
|
|
preconditionFailure("request already processed")
|
|
case .error:
|
|
break
|
|
}
|
|
return task.eventLoop.makeSucceededFuture(())
|
|
}
|
|
|
|
public func didReceiveBodyPart(task: HTTPClient.Task<Response>, _ part: ByteBuffer) -> EventLoopFuture<Void> {
|
|
switch self.state {
|
|
case .idle:
|
|
preconditionFailure("no head received before body")
|
|
case .head(let head):
|
|
self.state = .body(head, part)
|
|
case .body(let head, var body):
|
|
var part = part
|
|
body.writeBuffer(&part)
|
|
self.state = .body(head, body)
|
|
case .end:
|
|
preconditionFailure("request already processed")
|
|
case .error:
|
|
break
|
|
}
|
|
return task.eventLoop.makeSucceededFuture(())
|
|
}
|
|
|
|
public func didReceiveError(task: HTTPClient.Task<Response>, _ error: Error) {
|
|
self.state = .error(error)
|
|
}
|
|
|
|
public func didFinishRequest(task: HTTPClient.Task<Response>) throws -> Response {
|
|
switch self.state {
|
|
case .idle:
|
|
preconditionFailure("no head received before end")
|
|
case .head(let head):
|
|
return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: nil)
|
|
case .body(let head, let body):
|
|
return Response(host: self.request.host, status: head.status, version: head.version, headers: head.headers, body: body)
|
|
case .end:
|
|
preconditionFailure("request already processed")
|
|
case .error(let error):
|
|
throw error
|
|
}
|
|
}
|
|
}
|
|
|
|
/// `HTTPClientResponseDelegate` allows an implementation to receive notifications about request processing and to control how response parts are processed.
|
|
/// You can implement this protocol if you need fine-grained control over an HTTP request/response, for example, if you want to inspect the response
|
|
/// headers before deciding whether to accept a response body, or if you want to stream your request body. Pass an instance of your conforming
|
|
/// class to the `HTTPClient.execute()` method and this package will call each delegate method appropriately as the request takes place.
|
|
///
|
|
/// - note: This delegate is strongly held by the `HTTPTaskHandler`
|
|
/// for the duration of the `Request` processing and will be
|
|
/// released together with the `HTTPTaskHandler` when channel is closed.
|
|
/// Users of the library are not required to keep a reference to the
|
|
/// object that implements this protocol, but may do so if needed.
|
|
public protocol HTTPClientResponseDelegate: AnyObject {
|
|
associatedtype Response
|
|
|
|
/// Called when the request head is sent. Will be called once.
|
|
///
|
|
/// - parameters:
|
|
/// - task: Current request context.
|
|
/// - head: Request head.
|
|
func didSendRequestHead(task: HTTPClient.Task<Response>, _ head: HTTPRequestHead)
|
|
|
|
/// Called when a part of the request body is sent. Could be called zero or more times.
|
|
///
|
|
/// - parameters:
|
|
/// - task: Current request context.
|
|
/// - part: Request body `Part`.
|
|
func didSendRequestPart(task: HTTPClient.Task<Response>, _ part: IOData)
|
|
|
|
/// Called when the request is fully sent. Will be called once.
|
|
///
|
|
/// - parameters:
|
|
/// - task: Current request context.
|
|
func didSendRequest(task: HTTPClient.Task<Response>)
|
|
|
|
/// Called when response head is received. Will be called once.
|
|
/// You must return an `EventLoopFuture<Void>` that you complete when you have finished processing the body part.
|
|
/// You can create an already succeeded future by calling `task.eventLoop.makeSucceededFuture(())`.
|
|
///
|
|
/// - parameters:
|
|
/// - task: Current request context.
|
|
/// - head: Received reposonse head.
|
|
/// - returns: `EventLoopFuture` that will be used for backpressure.
|
|
func didReceiveHead(task: HTTPClient.Task<Response>, _ head: HTTPResponseHead) -> EventLoopFuture<Void>
|
|
|
|
/// Called when part of a response body is received. Could be called zero or more times.
|
|
/// You must return an `EventLoopFuture<Void>` that you complete when you have finished processing the body part.
|
|
/// You can create an already succeeded future by calling `task.eventLoop.makeSucceededFuture(())`.
|
|
///
|
|
/// - parameters:
|
|
/// - task: Current request context.
|
|
/// - buffer: Received body `Part`.
|
|
/// - returns: `EventLoopFuture` that will be used for backpressure.
|
|
func didReceiveBodyPart(task: HTTPClient.Task<Response>, _ buffer: ByteBuffer) -> EventLoopFuture<Void>
|
|
|
|
/// Called when error was thrown during request execution. Will be called zero or one time only. Request processing will be stopped after that.
|
|
///
|
|
/// - parameters:
|
|
/// - task: Current request context.
|
|
/// - error: Error that occured during response processing.
|
|
func didReceiveError(task: HTTPClient.Task<Response>, _ error: Error)
|
|
|
|
/// Called when the complete HTTP request is finished. You must return an instance of your `Response` associated type. Will be called once, except if an error occurred.
|
|
///
|
|
/// - parameters:
|
|
/// - task: Current request context.
|
|
/// - returns: Result of processing.
|
|
func didFinishRequest(task: HTTPClient.Task<Response>) throws -> Response
|
|
}
|
|
|
|
extension HTTPClientResponseDelegate {
|
|
public func didSendRequestHead(task: HTTPClient.Task<Response>, _ head: HTTPRequestHead) {}
|
|
|
|
public func didSendRequestPart(task: HTTPClient.Task<Response>, _ part: IOData) {}
|
|
|
|
public func didSendRequest(task: HTTPClient.Task<Response>) {}
|
|
|
|
public func didReceiveHead(task: HTTPClient.Task<Response>, _: HTTPResponseHead) -> EventLoopFuture<Void> {
|
|
return task.eventLoop.makeSucceededFuture(())
|
|
}
|
|
|
|
public func didReceiveBodyPart(task: HTTPClient.Task<Response>, _: ByteBuffer) -> EventLoopFuture<Void> {
|
|
return task.eventLoop.makeSucceededFuture(())
|
|
}
|
|
|
|
public func didReceiveError(task: HTTPClient.Task<Response>, _: Error) {}
|
|
}
|
|
|
|
extension URL {
|
|
var percentEncodedPath: String {
|
|
if self.path.isEmpty {
|
|
return "/"
|
|
}
|
|
return URLComponents(url: self, resolvingAgainstBaseURL: false)?.percentEncodedPath ?? self.path
|
|
}
|
|
|
|
var uri: String {
|
|
var uri = self.percentEncodedPath
|
|
|
|
if let query = self.query {
|
|
uri += "?" + query
|
|
}
|
|
|
|
return uri
|
|
}
|
|
|
|
func hasTheSameOrigin(as other: URL) -> Bool {
|
|
return self.host == other.host && self.scheme == other.scheme && self.port == other.port
|
|
}
|
|
|
|
/// Initializes a newly created HTTP URL connecting to a unix domain socket path. The socket path is encoded as the URL's host, replacing percent encoding invalid path characters, and will use the "http+unix" scheme.
|
|
/// - Parameters:
|
|
/// - socketPath: The path to the unix domain socket to connect to.
|
|
/// - uri: The URI path and query that will be sent to the server.
|
|
public init?(httpURLWithSocketPath socketPath: String, uri: String = "/") {
|
|
guard let host = socketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed) else { return nil }
|
|
var urlString: String
|
|
if uri.hasPrefix("/") {
|
|
urlString = "http+unix://\(host)\(uri)"
|
|
} else {
|
|
urlString = "http+unix://\(host)/\(uri)"
|
|
}
|
|
self.init(string: urlString)
|
|
}
|
|
|
|
/// Initializes a newly created HTTPS URL connecting to a unix domain socket path over TLS. The socket path is encoded as the URL's host, replacing percent encoding invalid path characters, and will use the "https+unix" scheme.
|
|
/// - Parameters:
|
|
/// - socketPath: The path to the unix domain socket to connect to.
|
|
/// - uri: The URI path and query that will be sent to the server.
|
|
public init?(httpsURLWithSocketPath socketPath: String, uri: String = "/") {
|
|
guard let host = socketPath.addingPercentEncoding(withAllowedCharacters: .urlHostAllowed) else { return nil }
|
|
var urlString: String
|
|
if uri.hasPrefix("/") {
|
|
urlString = "https+unix://\(host)\(uri)"
|
|
} else {
|
|
urlString = "https+unix://\(host)/\(uri)"
|
|
}
|
|
self.init(string: urlString)
|
|
}
|
|
}
|
|
|
|
extension HTTPClient {
|
|
/// Response execution context. Will be created by the library and could be used for obtaining
|
|
/// `EventLoopFuture<Response>` of the execution or cancellation of the execution.
|
|
public final class Task<Response> {
|
|
/// The `EventLoop` the delegate will be executed on.
|
|
public let eventLoop: EventLoop
|
|
|
|
let promise: EventLoopPromise<Response>
|
|
var completion: EventLoopFuture<Void>
|
|
var connection: Connection?
|
|
var cancelled: Bool
|
|
let lock: Lock
|
|
let logger: Logger // We are okay to store the logger here because a Task is for only one request.
|
|
|
|
init(eventLoop: EventLoop, logger: Logger) {
|
|
self.eventLoop = eventLoop
|
|
self.promise = eventLoop.makePromise()
|
|
self.completion = self.promise.futureResult.map { _ in }
|
|
self.cancelled = false
|
|
self.lock = Lock()
|
|
self.logger = logger
|
|
}
|
|
|
|
static func failedTask(eventLoop: EventLoop, error: Error, logger: Logger) -> Task<Response> {
|
|
let task = self.init(eventLoop: eventLoop, logger: logger)
|
|
task.promise.fail(error)
|
|
return task
|
|
}
|
|
|
|
/// `EventLoopFuture` for the response returned by this request.
|
|
public var futureResult: EventLoopFuture<Response> {
|
|
return self.promise.futureResult
|
|
}
|
|
|
|
/// Waits for execution of this request to complete.
|
|
///
|
|
/// - returns: The value of the `EventLoopFuture` when it completes.
|
|
/// - throws: The error value of the `EventLoopFuture` if it errors.
|
|
public func wait() throws -> Response {
|
|
return try self.promise.futureResult.wait()
|
|
}
|
|
|
|
/// Cancels the request execution.
|
|
public func cancel() {
|
|
let channel: Channel? = self.lock.withLock {
|
|
if !self.cancelled {
|
|
self.cancelled = true
|
|
return self.connection?.channel
|
|
} else {
|
|
return nil
|
|
}
|
|
}
|
|
channel?.triggerUserOutboundEvent(TaskCancelEvent(), promise: nil)
|
|
}
|
|
|
|
@discardableResult
|
|
func setConnection(_ connection: Connection) -> Connection {
|
|
return self.lock.withLock {
|
|
self.connection = connection
|
|
if self.cancelled {
|
|
connection.channel.triggerUserOutboundEvent(TaskCancelEvent(), promise: nil)
|
|
}
|
|
return connection
|
|
}
|
|
}
|
|
|
|
func succeed<Delegate: HTTPClientResponseDelegate>(promise: EventLoopPromise<Response>?,
|
|
with value: Response,
|
|
delegateType: Delegate.Type,
|
|
closing: Bool) {
|
|
self.releaseAssociatedConnection(delegateType: delegateType,
|
|
closing: closing).whenSuccess {
|
|
promise?.succeed(value)
|
|
}
|
|
}
|
|
|
|
func fail<Delegate: HTTPClientResponseDelegate>(with error: Error,
|
|
delegateType: Delegate.Type) {
|
|
if let connection = self.connection {
|
|
self.releaseAssociatedConnection(delegateType: delegateType, closing: true)
|
|
.whenSuccess {
|
|
self.promise.fail(error)
|
|
connection.channel.close(promise: nil)
|
|
}
|
|
}
|
|
}
|
|
|
|
func releaseAssociatedConnection<Delegate: HTTPClientResponseDelegate>(delegateType: Delegate.Type,
|
|
closing: Bool) -> EventLoopFuture<Void> {
|
|
if let connection = self.connection {
|
|
// remove read timeout handler
|
|
return connection.removeHandler(IdleStateHandler.self).flatMap {
|
|
connection.removeHandler(TaskHandler<Delegate>.self)
|
|
}.map {
|
|
connection.release(closing: closing, logger: self.logger)
|
|
}.flatMapError { error in
|
|
fatalError("Couldn't remove taskHandler: \(error)")
|
|
}
|
|
} else {
|
|
// TODO: This seems only reached in some internal unit test
|
|
// Maybe there could be a better handling in the future to make
|
|
// it an error outside of testing contexts
|
|
return self.eventLoop.makeSucceededFuture(())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
internal struct TaskCancelEvent {}
|
|
|
|
// MARK: - TaskHandler
|
|
|
|
internal class TaskHandler<Delegate: HTTPClientResponseDelegate>: RemovableChannelHandler {
|
|
enum State {
|
|
case idle
|
|
case bodySent
|
|
case sent
|
|
case head
|
|
case redirected(HTTPResponseHead, URL)
|
|
case body
|
|
case endOrError
|
|
}
|
|
|
|
let task: HTTPClient.Task<Delegate.Response>
|
|
let delegate: Delegate
|
|
let redirectHandler: RedirectHandler<Delegate.Response>?
|
|
let ignoreUncleanSSLShutdown: Bool
|
|
let logger: Logger // We are okay to store the logger here because a TaskHandler is just for one request.
|
|
|
|
var state: State = .idle
|
|
var expectedBodyLength: Int?
|
|
var actualBodyLength: Int = 0
|
|
var pendingRead = false
|
|
var mayRead = true
|
|
var closing = false {
|
|
didSet {
|
|
assert(self.closing || !oldValue,
|
|
"BUG in AsyncHTTPClient: TaskHandler.closing went from true (no conn reuse) to true (do reuse).")
|
|
}
|
|
}
|
|
|
|
let kind: HTTPClient.Request.Kind
|
|
|
|
init(task: HTTPClient.Task<Delegate.Response>,
|
|
kind: HTTPClient.Request.Kind,
|
|
delegate: Delegate,
|
|
redirectHandler: RedirectHandler<Delegate.Response>?,
|
|
ignoreUncleanSSLShutdown: Bool,
|
|
logger: Logger) {
|
|
self.task = task
|
|
self.delegate = delegate
|
|
self.redirectHandler = redirectHandler
|
|
self.ignoreUncleanSSLShutdown = ignoreUncleanSSLShutdown
|
|
self.kind = kind
|
|
self.logger = logger
|
|
}
|
|
}
|
|
|
|
// MARK: Delegate Callouts
|
|
|
|
extension TaskHandler {
|
|
func failTaskAndNotifyDelegate<Err: Error>(error: Err,
|
|
_ body: @escaping (HTTPClient.Task<Delegate.Response>, Err) -> Void) {
|
|
func doIt() {
|
|
body(self.task, error)
|
|
self.task.fail(with: error, delegateType: Delegate.self)
|
|
}
|
|
|
|
if self.task.eventLoop.inEventLoop {
|
|
doIt()
|
|
} else {
|
|
self.task.eventLoop.execute {
|
|
doIt()
|
|
}
|
|
}
|
|
}
|
|
|
|
func callOutToDelegateFireAndForget(_ body: @escaping (HTTPClient.Task<Delegate.Response>) -> Void) {
|
|
self.callOutToDelegateFireAndForget(value: ()) { (task, _: ()) in body(task) }
|
|
}
|
|
|
|
func callOutToDelegateFireAndForget<Value>(value: Value,
|
|
_ body: @escaping (HTTPClient.Task<Delegate.Response>, Value) -> Void) {
|
|
if self.task.eventLoop.inEventLoop {
|
|
body(self.task, value)
|
|
} else {
|
|
self.task.eventLoop.execute {
|
|
body(self.task, value)
|
|
}
|
|
}
|
|
}
|
|
|
|
func callOutToDelegate<Value>(value: Value,
|
|
channelEventLoop: EventLoop,
|
|
_ body: @escaping (HTTPClient.Task<Delegate.Response>, Value) -> EventLoopFuture<Void>) -> EventLoopFuture<Void> {
|
|
if self.task.eventLoop.inEventLoop {
|
|
return body(self.task, value).hop(to: channelEventLoop)
|
|
} else {
|
|
return self.task.eventLoop.submit {
|
|
body(self.task, value)
|
|
}.flatMap { $0 }.hop(to: channelEventLoop)
|
|
}
|
|
}
|
|
|
|
func callOutToDelegate<Response>(promise: EventLoopPromise<Response>? = nil,
|
|
_ body: @escaping (HTTPClient.Task<Delegate.Response>) throws -> Response) where Response == Delegate.Response {
|
|
func doIt() {
|
|
do {
|
|
let result = try body(self.task)
|
|
|
|
self.task.succeed(promise: promise,
|
|
with: result,
|
|
delegateType: Delegate.self,
|
|
closing: self.closing)
|
|
} catch {
|
|
self.task.fail(with: error, delegateType: Delegate.self)
|
|
}
|
|
}
|
|
|
|
if self.task.eventLoop.inEventLoop {
|
|
doIt()
|
|
} else {
|
|
self.task.eventLoop.submit {
|
|
doIt()
|
|
}.cascadeFailure(to: promise)
|
|
}
|
|
}
|
|
|
|
func callOutToDelegate<Response>(channelEventLoop: EventLoop,
|
|
_ body: @escaping (HTTPClient.Task<Delegate.Response>) throws -> Response) -> EventLoopFuture<Response> where Response == Delegate.Response {
|
|
let promise = channelEventLoop.makePromise(of: Response.self)
|
|
self.callOutToDelegate(promise: promise, body)
|
|
return promise.futureResult
|
|
}
|
|
}
|
|
|
|
// MARK: ChannelHandler implementation
|
|
|
|
extension TaskHandler: ChannelDuplexHandler {
|
|
typealias OutboundIn = HTTPClient.Request
|
|
typealias InboundIn = HTTPClientResponsePart
|
|
typealias OutboundOut = HTTPClientRequestPart
|
|
|
|
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
|
|
self.state = .idle
|
|
let request = self.unwrapOutboundIn(data)
|
|
|
|
var head = HTTPRequestHead(version: HTTPVersion(major: 1, minor: 1),
|
|
method: request.method,
|
|
uri: request.uri)
|
|
var headers = request.headers
|
|
|
|
if !request.headers.contains(name: "host") {
|
|
let port = request.port
|
|
var host = request.host
|
|
if !(port == 80 && request.scheme == "http"), !(port == 443 && request.scheme == "https") {
|
|
host += ":\(port)"
|
|
}
|
|
headers.add(name: "host", value: host)
|
|
}
|
|
|
|
do {
|
|
try headers.validate(method: request.method, body: request.body)
|
|
} catch {
|
|
self.errorCaught(context: context, error: error)
|
|
promise?.fail(error)
|
|
return
|
|
}
|
|
|
|
head.headers = headers
|
|
|
|
if head.headers[canonicalForm: "connection"].map({ $0.lowercased() }).contains("close") {
|
|
self.closing = true
|
|
}
|
|
// This assert can go away when (if ever!) the above `if` correctly handles other HTTP versions. For example
|
|
// in HTTP/1.0, we need to treat the absence of a 'connection: keep-alive' as a close too.
|
|
assert(head.version == HTTPVersion(major: 1, minor: 1),
|
|
"Sending a request in HTTP version \(head.version) which is unsupported by the above `if`")
|
|
|
|
let contentLengths = head.headers[canonicalForm: "content-length"]
|
|
assert(contentLengths.count <= 1)
|
|
|
|
self.expectedBodyLength = contentLengths.first.flatMap { Int($0) }
|
|
|
|
context.write(wrapOutboundOut(.head(head))).map {
|
|
self.callOutToDelegateFireAndForget(value: head, self.delegate.didSendRequestHead)
|
|
}.flatMap {
|
|
self.writeBody(request: request, context: context)
|
|
}.flatMap {
|
|
context.eventLoop.assertInEventLoop()
|
|
if case .endOrError = self.state {
|
|
return context.eventLoop.makeSucceededFuture(())
|
|
}
|
|
|
|
self.state = .bodySent
|
|
if let expectedBodyLength = self.expectedBodyLength, expectedBodyLength != self.actualBodyLength {
|
|
let error = HTTPClientError.bodyLengthMismatch
|
|
return context.eventLoop.makeFailedFuture(error)
|
|
}
|
|
return context.writeAndFlush(self.wrapOutboundOut(.end(nil)))
|
|
}.map {
|
|
context.eventLoop.assertInEventLoop()
|
|
if case .endOrError = self.state {
|
|
return
|
|
}
|
|
|
|
self.state = .sent
|
|
self.callOutToDelegateFireAndForget(self.delegate.didSendRequest)
|
|
}.flatMapErrorThrowing { error in
|
|
context.eventLoop.assertInEventLoop()
|
|
self.errorCaught(context: context, error: error)
|
|
throw error
|
|
}.cascade(to: promise)
|
|
}
|
|
|
|
private func writeBody(request: HTTPClient.Request, context: ChannelHandlerContext) -> EventLoopFuture<Void> {
|
|
guard let body = request.body else {
|
|
return context.eventLoop.makeSucceededFuture(())
|
|
}
|
|
|
|
let channel = context.channel
|
|
|
|
func doIt() -> EventLoopFuture<Void> {
|
|
return body.stream(HTTPClient.Body.StreamWriter { part in
|
|
let promise = self.task.eventLoop.makePromise(of: Void.self)
|
|
// All writes have to be switched to the channel EL if channel and task ELs differ
|
|
if channel.eventLoop.inEventLoop {
|
|
self.writeBodyPart(context: context, part: part, promise: promise)
|
|
} else {
|
|
channel.eventLoop.execute {
|
|
self.writeBodyPart(context: context, part: part, promise: promise)
|
|
}
|
|
}
|
|
|
|
return promise.futureResult.map {
|
|
self.callOutToDelegateFireAndForget(value: part, self.delegate.didSendRequestPart)
|
|
}
|
|
})
|
|
}
|
|
|
|
// Callout to the user to start body streaming should be on task EL
|
|
if self.task.eventLoop.inEventLoop {
|
|
return doIt()
|
|
} else {
|
|
return self.task.eventLoop.flatSubmit {
|
|
doIt()
|
|
}
|
|
}
|
|
}
|
|
|
|
private func writeBodyPart(context: ChannelHandlerContext, part: IOData, promise: EventLoopPromise<Void>) {
|
|
switch self.state {
|
|
case .idle:
|
|
if let limit = self.expectedBodyLength, self.actualBodyLength + part.readableBytes > limit {
|
|
let error = HTTPClientError.bodyLengthMismatch
|
|
self.errorCaught(context: context, error: error)
|
|
promise.fail(error)
|
|
return
|
|
}
|
|
self.actualBodyLength += part.readableBytes
|
|
context.writeAndFlush(self.wrapOutboundOut(.body(part)), promise: promise)
|
|
default:
|
|
let error = HTTPClientError.writeAfterRequestSent
|
|
self.errorCaught(context: context, error: error)
|
|
promise.fail(error)
|
|
}
|
|
}
|
|
|
|
public func read(context: ChannelHandlerContext) {
|
|
if self.mayRead {
|
|
self.pendingRead = false
|
|
context.read()
|
|
} else {
|
|
self.pendingRead = true
|
|
}
|
|
}
|
|
|
|
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
|
|
let response = self.unwrapInboundIn(data)
|
|
switch response {
|
|
case .head(let head):
|
|
if case .endOrError = self.state {
|
|
return
|
|
}
|
|
|
|
if !head.isKeepAlive {
|
|
self.closing = true
|
|
}
|
|
|
|
if let redirectURL = self.redirectHandler?.redirectTarget(status: head.status, headers: head.headers) {
|
|
self.state = .redirected(head, redirectURL)
|
|
} else {
|
|
self.state = .head
|
|
self.mayRead = false
|
|
self.callOutToDelegate(value: head, channelEventLoop: context.eventLoop, self.delegate.didReceiveHead)
|
|
.whenComplete { result in
|
|
self.handleBackpressureResult(context: context, result: result)
|
|
}
|
|
}
|
|
case .body(let body):
|
|
switch self.state {
|
|
case .redirected, .endOrError:
|
|
break
|
|
default:
|
|
self.state = .body
|
|
self.mayRead = false
|
|
self.callOutToDelegate(value: body, channelEventLoop: context.eventLoop, self.delegate.didReceiveBodyPart)
|
|
.whenComplete { result in
|
|
self.handleBackpressureResult(context: context, result: result)
|
|
}
|
|
}
|
|
case .end:
|
|
switch self.state {
|
|
case .endOrError:
|
|
break
|
|
case .redirected(let head, let redirectURL):
|
|
self.state = .endOrError
|
|
self.task.releaseAssociatedConnection(delegateType: Delegate.self, closing: self.closing).whenSuccess {
|
|
self.redirectHandler?.redirect(status: head.status, to: redirectURL, promise: self.task.promise)
|
|
}
|
|
default:
|
|
self.state = .endOrError
|
|
self.callOutToDelegate(promise: self.task.promise, self.delegate.didFinishRequest)
|
|
}
|
|
}
|
|
}
|
|
|
|
private func handleBackpressureResult(context: ChannelHandlerContext, result: Result<Void, Error>) {
|
|
context.eventLoop.assertInEventLoop()
|
|
switch result {
|
|
case .success:
|
|
self.mayRead = true
|
|
if self.pendingRead {
|
|
context.read()
|
|
}
|
|
case .failure(let error):
|
|
self.errorCaught(context: context, error: error)
|
|
}
|
|
}
|
|
|
|
func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) {
|
|
if (event as? IdleStateHandler.IdleStateEvent) == .read {
|
|
self.errorCaught(context: context, error: HTTPClientError.readTimeout)
|
|
} else {
|
|
context.fireUserInboundEventTriggered(event)
|
|
}
|
|
}
|
|
|
|
func triggerUserOutboundEvent(context: ChannelHandlerContext, event: Any, promise: EventLoopPromise<Void>?) {
|
|
if (event as? TaskCancelEvent) != nil {
|
|
self.errorCaught(context: context, error: HTTPClientError.cancelled)
|
|
promise?.succeed(())
|
|
} else {
|
|
context.triggerUserOutboundEvent(event, promise: promise)
|
|
}
|
|
}
|
|
|
|
func channelInactive(context: ChannelHandlerContext) {
|
|
switch self.state {
|
|
case .endOrError:
|
|
break
|
|
case .body, .head, .idle, .redirected, .sent, .bodySent:
|
|
self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed)
|
|
}
|
|
context.fireChannelInactive()
|
|
}
|
|
|
|
func errorCaught(context: ChannelHandlerContext, error: Error) {
|
|
switch error {
|
|
case NIOSSLError.uncleanShutdown:
|
|
switch self.state {
|
|
case .endOrError:
|
|
/// Some HTTP Servers can 'forget' to respond with CloseNotify when client is closing connection,
|
|
/// this could lead to incomplete SSL shutdown. But since request is already processed, we can ignore this error.
|
|
break
|
|
case .head where self.ignoreUncleanSSLShutdown,
|
|
.body where self.ignoreUncleanSSLShutdown:
|
|
/// We can also ignore this error like `.end`.
|
|
break
|
|
default:
|
|
self.state = .endOrError
|
|
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
|
|
}
|
|
default:
|
|
switch self.state {
|
|
case .idle, .bodySent, .sent, .head, .redirected, .body:
|
|
self.state = .endOrError
|
|
self.failTaskAndNotifyDelegate(error: error, self.delegate.didReceiveError)
|
|
case .endOrError:
|
|
// error was already handled
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func handlerAdded(context: ChannelHandlerContext) {
|
|
guard context.channel.isActive else {
|
|
self.errorCaught(context: context, error: HTTPClientError.remoteConnectionClosed)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// MARK: - RedirectHandler
|
|
|
|
internal struct RedirectHandler<ResponseType> {
|
|
let request: HTTPClient.Request
|
|
let execute: (HTTPClient.Request) -> HTTPClient.Task<ResponseType>
|
|
|
|
func redirectTarget(status: HTTPResponseStatus, headers: HTTPHeaders) -> URL? {
|
|
switch status {
|
|
case .movedPermanently, .found, .seeOther, .notModified, .useProxy, .temporaryRedirect, .permanentRedirect:
|
|
break
|
|
default:
|
|
return nil
|
|
}
|
|
|
|
guard let location = headers.first(name: "Location") else {
|
|
return nil
|
|
}
|
|
|
|
guard let url = URL(string: location, relativeTo: request.url) else {
|
|
return nil
|
|
}
|
|
|
|
guard self.request.kind.supportsRedirects(to: url.scheme) else {
|
|
return nil
|
|
}
|
|
|
|
if url.isFileURL {
|
|
return nil
|
|
}
|
|
|
|
return url.absoluteURL
|
|
}
|
|
|
|
func redirect(status: HTTPResponseStatus, to redirectURL: URL, promise: EventLoopPromise<ResponseType>) {
|
|
var nextState: HTTPClient.Request.RedirectState?
|
|
if var state = request.redirectState {
|
|
guard state.count > 0 else {
|
|
return promise.fail(HTTPClientError.redirectLimitReached)
|
|
}
|
|
|
|
state.count -= 1
|
|
|
|
if var visited = state.visited {
|
|
guard !visited.contains(redirectURL) else {
|
|
return promise.fail(HTTPClientError.redirectCycleDetected)
|
|
}
|
|
|
|
visited.insert(redirectURL)
|
|
state.visited = visited
|
|
}
|
|
|
|
nextState = state
|
|
}
|
|
|
|
let originalRequest = self.request
|
|
|
|
var convertToGet = false
|
|
if status == .seeOther, self.request.method != .HEAD {
|
|
convertToGet = true
|
|
} else if status == .movedPermanently || status == .found, self.request.method == .POST {
|
|
convertToGet = true
|
|
}
|
|
|
|
var method = originalRequest.method
|
|
var headers = originalRequest.headers
|
|
var body = originalRequest.body
|
|
|
|
if convertToGet {
|
|
method = .GET
|
|
body = nil
|
|
headers.remove(name: "Content-Length")
|
|
headers.remove(name: "Content-Type")
|
|
}
|
|
|
|
if !originalRequest.url.hasTheSameOrigin(as: redirectURL) {
|
|
headers.remove(name: "Origin")
|
|
headers.remove(name: "Cookie")
|
|
headers.remove(name: "Authorization")
|
|
headers.remove(name: "Proxy-Authorization")
|
|
}
|
|
|
|
do {
|
|
var newRequest = try HTTPClient.Request(url: redirectURL, method: method, headers: headers, body: body)
|
|
newRequest.redirectState = nextState
|
|
self.execute(newRequest).futureResult.whenComplete { result in
|
|
promise.futureResult.eventLoop.execute {
|
|
promise.completeWith(result)
|
|
}
|
|
}
|
|
} catch {
|
|
promise.fail(error)
|
|
}
|
|
}
|
|
}
|