362 lines
12 KiB
Swift
362 lines
12 KiB
Swift
// This source file is part of the Swift.org Server APIs open source project
|
|
//
|
|
// Copyright (c) 2017 Swift Server API project authors
|
|
// Licensed under Apache License v2.0 with Runtime Library Exception
|
|
//
|
|
// See http://swift.org/LICENSE.txt for license information
|
|
//
|
|
|
|
import Foundation
|
|
import Dispatch
|
|
import ServerSecurity
|
|
|
|
///:nodoc:
|
|
public enum PoCSocketError: Error {
|
|
case SocketOSError(errno: Int32)
|
|
case InvalidSocketError
|
|
case InvalidReadLengthError
|
|
case InvalidWriteLengthError
|
|
case InvalidBufferError
|
|
}
|
|
|
|
/// Simple Wrapper around the `socket(2)` functions we need for Proof of Concept testing
|
|
/// Intentionally a thin layer over `recv(2)`/`send(2)` so uses the same argument types.
|
|
/// Note that no method names here are the same as any system call names.
|
|
/// This is because we expect the caller might need functionality we haven't implemented here.
|
|
internal class PoCSocket: ConnectionDelegate {
|
|
|
|
/// hold the file descriptor for the socket supplied by the OS. `-1` is invalid socket
|
|
internal var socketfd: Int32 = -1
|
|
|
|
/// The TCP port the server is actually listening on. Set after system call completes
|
|
internal var listeningPort: Int32 = -1
|
|
|
|
/// Track state between `listen(2)` and `shutdown(2)`
|
|
private let _isListeningLock = DispatchSemaphore(value: 1)
|
|
private var _isListening: Bool = false
|
|
internal private(set) var isListening: Bool {
|
|
get {
|
|
_isListeningLock.wait()
|
|
defer {
|
|
_isListeningLock.signal()
|
|
}
|
|
return _isListening
|
|
}
|
|
set {
|
|
_isListeningLock.wait()
|
|
defer {
|
|
_isListeningLock.signal()
|
|
}
|
|
_isListening = newValue
|
|
}
|
|
}
|
|
|
|
/// Track state between `accept(2)/bind(2)` and `close(2)`
|
|
private let _isConnectedLock = DispatchSemaphore(value: 1)
|
|
private var _isConnected: Bool = false
|
|
internal private(set) var isConnected: Bool {
|
|
get {
|
|
_isConnectedLock.wait()
|
|
defer {
|
|
_isConnectedLock.signal()
|
|
}
|
|
return _isConnected
|
|
}
|
|
set {
|
|
_isConnectedLock.wait()
|
|
defer {
|
|
_isConnectedLock.signal()
|
|
}
|
|
_isConnected = newValue
|
|
}
|
|
}
|
|
|
|
/// track whether a shutdown is in progress so we can suppress error messages
|
|
private let _isShuttingDownLock = DispatchSemaphore(value: 1)
|
|
private var _isShuttingDown: Bool = false
|
|
private var isShuttingDown: Bool {
|
|
get {
|
|
_isShuttingDownLock.wait()
|
|
defer {
|
|
_isShuttingDownLock.signal()
|
|
}
|
|
return _isShuttingDown
|
|
}
|
|
set {
|
|
_isShuttingDownLock.wait()
|
|
defer {
|
|
_isShuttingDownLock.signal()
|
|
}
|
|
_isShuttingDown = newValue
|
|
}
|
|
}
|
|
|
|
/// Delegate that provides the TLS implementation
|
|
public var TLSdelegate: TLSServiceDelegate? = nil
|
|
/// Return the file descriptor as a connection endpoint for ConnectionDelegate.
|
|
public var endpoint: ConnectionType {
|
|
get {
|
|
return ConnectionType.socket(self.socketfd)
|
|
}
|
|
}
|
|
|
|
/// track whether a the socket has already been closed.
|
|
private let _hasClosedLock = DispatchSemaphore(value: 1)
|
|
private var _hasClosed: Bool = false
|
|
private var hasClosed: Bool {
|
|
get {
|
|
_hasClosedLock.wait()
|
|
defer {
|
|
_hasClosedLock.signal()
|
|
}
|
|
return _hasClosed
|
|
}
|
|
set {
|
|
_hasClosedLock.wait()
|
|
defer {
|
|
_hasClosedLock.signal()
|
|
}
|
|
_hasClosed = newValue
|
|
}
|
|
}
|
|
|
|
/// Call recv(2) with buffer allocated by our caller and return the output
|
|
///
|
|
/// - Parameters:
|
|
/// - readBuffer: Buffer to read into. Note this needs to be `inout` because we're modfying it and we want Swift4+'s ownership checks to make sure no one else is at the same time
|
|
/// - maxLength: Max length that can be read. Buffer *must* be at least this big!!!
|
|
/// - Returns: Number of bytes read or -1 on failure as per `recv(2)`
|
|
/// - Throws: PoCSocketError if sanity checks fail
|
|
internal func socketRead(into readBuffer: inout UnsafeMutablePointer<Int8>, maxLength: Int) throws -> Int {
|
|
if maxLength <= 0 || maxLength > Int(Int32.max) {
|
|
throw PoCSocketError.InvalidReadLengthError
|
|
}
|
|
if socketfd <= 0 {
|
|
throw PoCSocketError.InvalidSocketError
|
|
}
|
|
|
|
//Make sure no one passed a nil pointer to us
|
|
let readBufferPointer: UnsafeMutablePointer<Int8>! = readBuffer
|
|
if readBufferPointer == nil {
|
|
throw PoCSocketError.InvalidBufferError
|
|
}
|
|
|
|
//Make sure data isn't re-used
|
|
readBuffer.initialize(to: 0x0, count: maxLength)
|
|
|
|
let read: Int
|
|
if let tls = self.TLSdelegate {
|
|
// HTTPS
|
|
read = try tls.willReceive(into: readBuffer, bufSize: maxLength)
|
|
} else {
|
|
// HTTP
|
|
read = recv(self.socketfd, readBuffer, maxLength, Int32(0))
|
|
}
|
|
//Leave this as a local variable to facilitate Setting a Watchpoint in lldb
|
|
return read
|
|
}
|
|
|
|
/// Pass buffer passed into to us into send(2).
|
|
///
|
|
/// - Parameters:
|
|
/// - buffer: buffer containing data to write.
|
|
/// - bufSize: number of bytes to write. Buffer must be this long
|
|
/// - Returns: number of bytes written or -1. See `send(2)`
|
|
/// - Throws: PoCSocketError if sanity checks fail
|
|
@discardableResult internal func socketWrite(from buffer: UnsafeRawPointer, bufSize: Int) throws -> Int {
|
|
if socketfd <= 0 {
|
|
throw PoCSocketError.InvalidSocketError
|
|
}
|
|
if bufSize < 0 || bufSize > Int(Int32.max) {
|
|
throw PoCSocketError.InvalidWriteLengthError
|
|
}
|
|
|
|
// Make sure we weren't handed a nil buffer
|
|
let writeBufferPointer: UnsafeRawPointer! = buffer
|
|
if writeBufferPointer == nil {
|
|
throw PoCSocketError.InvalidBufferError
|
|
}
|
|
|
|
let sent: Int
|
|
if let tls = self.TLSdelegate {
|
|
// HTTPS
|
|
sent = try tls.willSend(buffer: buffer, bufSize: Int(bufSize))
|
|
} else {
|
|
// HTTP
|
|
sent = send(self.socketfd, buffer, Int(bufSize), Int32(0))
|
|
}
|
|
//Leave this as a local variable to facilitate Setting a Watchpoint in lldb
|
|
return sent
|
|
}
|
|
|
|
/// Calls `shutdown(2)` and `close(2)` on a socket
|
|
internal func shutdownAndClose() {
|
|
self.isShuttingDown = true
|
|
if let tls = self.TLSdelegate {
|
|
tls.willDestroy()
|
|
}
|
|
if socketfd < 1 {
|
|
//Nothing to do. Maybe it was closed already
|
|
return
|
|
}
|
|
if hasClosed {
|
|
//Nothing to do. It was closed already
|
|
return
|
|
}
|
|
if self.isListening || self.isConnected {
|
|
_ = shutdown(self.socketfd, Int32(SHUT_RDWR))
|
|
self.isListening = false
|
|
}
|
|
self.isConnected = false
|
|
close(self.socketfd)
|
|
self.hasClosed = true
|
|
}
|
|
|
|
/// Thin wrapper around `accept(2)`
|
|
///
|
|
/// - Returns: PoCSocket object for newly connected socket or nil if we've been told to shutdown
|
|
/// - Throws: PoCSocketError on sanity check fails or if accept fails after several retries
|
|
internal func acceptClientConnection() throws -> PoCSocket? {
|
|
if socketfd <= 0 || !isListening {
|
|
throw PoCSocketError.InvalidSocketError
|
|
}
|
|
|
|
let retVal = PoCSocket()
|
|
|
|
var maxRetryCount = 100
|
|
|
|
var acceptFD: Int32 = -1
|
|
repeat {
|
|
var acceptAddr = sockaddr_in()
|
|
var addrSize = socklen_t(MemoryLayout<sockaddr_in>.size)
|
|
|
|
acceptFD = withUnsafeMutablePointer(to: &acceptAddr) { pointer in
|
|
return accept(self.socketfd, UnsafeMutableRawPointer(pointer).assumingMemoryBound(to: sockaddr.self), &addrSize)
|
|
}
|
|
if acceptFD < 0 && errno != EINTR {
|
|
//fail
|
|
if (isShuttingDown) {
|
|
return nil
|
|
}
|
|
maxRetryCount = maxRetryCount - 1
|
|
print("Could not accept on socket \(socketfd). Error is \(errno). Will retry.")
|
|
}
|
|
}
|
|
while acceptFD < 0 && maxRetryCount > 0
|
|
|
|
if acceptFD < 0 {
|
|
throw PoCSocketError.SocketOSError(errno: errno)
|
|
}
|
|
|
|
retVal.isConnected = true
|
|
retVal.socketfd = acceptFD
|
|
|
|
// TLS delegate does post accept handling and verification
|
|
if let tls = self.TLSdelegate {
|
|
try tls.didAccept(connection: retVal)
|
|
}
|
|
|
|
return retVal
|
|
}
|
|
|
|
/// call `bind(2)` and `listen(2)`
|
|
///
|
|
/// - Parameters:
|
|
/// - port: `sin_port` value, see `bind(2)`
|
|
/// - maxBacklogSize: backlog argument to `listen(2)`
|
|
/// - Throws: PoCSocketError
|
|
internal func bindAndListen(on port: Int = 0, maxBacklogSize: Int32 = 100) throws {
|
|
#if os(Linux)
|
|
socketfd = socket(Int32(AF_INET), Int32(SOCK_STREAM.rawValue), Int32(IPPROTO_TCP))
|
|
#else
|
|
socketfd = socket(Int32(AF_INET), Int32(SOCK_STREAM), Int32(IPPROTO_TCP))
|
|
#endif
|
|
|
|
if socketfd <= 0 {
|
|
throw PoCSocketError.InvalidSocketError
|
|
|
|
}
|
|
// Initialize delegate
|
|
if let tls = self.TLSdelegate {
|
|
try tls.didCreateServer()
|
|
}
|
|
|
|
var on: Int32 = 1
|
|
// Allow address reuse
|
|
if setsockopt(self.socketfd, SOL_SOCKET, SO_REUSEADDR, &on, socklen_t(MemoryLayout<Int32>.size)) < 0 {
|
|
throw PoCSocketError.SocketOSError(errno: errno)
|
|
}
|
|
|
|
// Allow port reuse
|
|
if setsockopt(self.socketfd, SOL_SOCKET, SO_REUSEPORT, &on, socklen_t(MemoryLayout<Int32>.size)) < 0 {
|
|
throw PoCSocketError.SocketOSError(errno: errno)
|
|
}
|
|
|
|
#if os(Linux)
|
|
var addr = sockaddr_in(
|
|
sin_family: sa_family_t(AF_INET),
|
|
sin_port: htons(UInt16(port)),
|
|
sin_addr: in_addr(s_addr: in_addr_t(0)),
|
|
sin_zero: (0, 0, 0, 0, 0, 0, 0, 0))
|
|
#else
|
|
var addr = sockaddr_in(
|
|
sin_len: UInt8(MemoryLayout<sockaddr_in>.stride),
|
|
sin_family: UInt8(AF_INET),
|
|
sin_port: (Int(OSHostByteOrder()) != OSLittleEndian ? UInt16(port) : _OSSwapInt16(UInt16(port))),
|
|
sin_addr: in_addr(s_addr: in_addr_t(0)),
|
|
sin_zero: (0, 0, 0, 0, 0, 0, 0, 0))
|
|
#endif
|
|
|
|
_ = withUnsafePointer(to: &addr) {
|
|
bind(self.socketfd, UnsafePointer<sockaddr>(OpaquePointer($0)), socklen_t(MemoryLayout<sockaddr_in>.size))
|
|
}
|
|
|
|
_ = listen(self.socketfd, maxBacklogSize)
|
|
|
|
isListening = true
|
|
|
|
var addr_in = sockaddr_in()
|
|
|
|
listeningPort = try withUnsafePointer(to: &addr_in) { pointer in
|
|
var len = socklen_t(MemoryLayout<sockaddr_in>.size)
|
|
if getsockname(socketfd, UnsafeMutablePointer(OpaquePointer(pointer)), &len) != 0 {
|
|
throw PoCSocketError.SocketOSError(errno: errno)
|
|
}
|
|
#if os(Linux)
|
|
return Int32(ntohs(addr_in.sin_port))
|
|
#else
|
|
return Int32(Int(OSHostByteOrder()) != OSLittleEndian ? addr_in.sin_port.littleEndian : addr_in.sin_port.bigEndian)
|
|
#endif
|
|
}
|
|
}
|
|
|
|
/// Check to see if socket is being used
|
|
///
|
|
/// - Returns: whether socket is listening or connected
|
|
internal func isOpen() -> Bool {
|
|
return isListening || isConnected
|
|
}
|
|
|
|
/// Sets the socket to Blocking or non-blocking mode.
|
|
///
|
|
/// - Parameter mode: true for blocking, false for nonBlocking
|
|
/// - Returns: `fcntl(2)` flags
|
|
/// - Throws: PoCSocketError if `fcntl` fails
|
|
@discardableResult internal func setBlocking(mode: Bool) throws -> Int32 {
|
|
let flags = fcntl(self.socketfd, F_GETFL)
|
|
if flags < 0 {
|
|
//Failed
|
|
throw PoCSocketError.SocketOSError(errno: errno)
|
|
}
|
|
|
|
let newFlags = mode ? flags & ~O_NONBLOCK : flags | O_NONBLOCK
|
|
|
|
let result = fcntl(self.socketfd, F_SETFL, newFlags)
|
|
if result < 0 {
|
|
//Failed
|
|
throw PoCSocketError.SocketOSError(errno: errno)
|
|
}
|
|
return result
|
|
}
|
|
}
|