//===----------------------------------------------------------------------===// // // This source file is part of the SwiftNIO open source project // // Copyright (c) 2017-2021 Apple Inc. and the SwiftNIO project authors // Licensed under Apache License v2.0 // // See LICENSE.txt for license information // See CONTRIBUTORS.txt for the list of SwiftNIO project authors // // SPDX-License-Identifier: Apache-2.0 // //===----------------------------------------------------------------------===// import NIOCore /// Errors thrown by the NIO websocket module. public enum NIOWebSocketError: Error { /// The frame being sent is larger than the configured maximum /// acceptable frame size case invalidFrameLength /// A control frame may not be fragmented. case fragmentedControlFrame /// A control frame may not have a length more than 125 bytes. case multiByteControlFrameLength } extension WebSocketErrorCode { init(_ error: NIOWebSocketError) { switch error { case .invalidFrameLength: self = .messageTooLarge case .fragmentedControlFrame, .multiByteControlFrameLength: self = .protocolError } } } extension ByteBuffer { /// Applies the WebSocket unmasking operation. /// /// - Parameters: /// - maskingKey: The masking key. /// - indexOffset: An integer offset to apply to the index into the masking key. /// This is used when masking multiple "contiguous" byte buffers, to ensure that /// the masking key is applied uniformly to the collection rather than from the /// start each time. public mutating func webSocketUnmask(_ maskingKey: WebSocketMaskingKey, indexOffset: Int = 0) { /// Shhhh: secretly unmasking and masking are the same operation! webSocketMask(maskingKey, indexOffset: indexOffset) } /// Applies the websocket masking operation. /// /// - Parameters: /// - maskingKey: The masking key. /// - indexOffset: An integer offset to apply to the index into the masking key. /// This is used when masking multiple "contiguous" byte buffers, to ensure that /// the masking key is applied uniformly to the collection rather than from the /// start each time. public mutating func webSocketMask(_ maskingKey: WebSocketMaskingKey, indexOffset: Int = 0) { self.withUnsafeMutableReadableBytes { for (index, byte) in $0.enumerated() { $0[index] = byte ^ maskingKey[(index + indexOffset) % 4] } } } } /// The current state of the frame decoder. enum DecoderState { /// Waiting for a frame. case idle /// The initial frame byte has been received, but the length byte /// has not. case firstByteReceived(firstByte: UInt8) /// The length byte indicates that we need to wait for the length word, and we're /// currently waiting for it. case waitingForLengthWord(firstByte: UInt8, masked: Bool) /// The length byte indicates that we need to wait for the length qword, and /// we're currently waiting for it. case waitingForLengthQWord(firstByte: UInt8, masked: Bool) /// The mask bit indicates we are expecting a mask key. case waitingForMask(firstByte: UInt8, length: Int) /// All the header data is complete, we are waiting for the application data. case waitingForData(firstByte: UInt8, length: Int, maskingKey: WebSocketMaskingKey?) } enum ParseResult { case insufficientData case continueParsing case result(WebSocketFrame) } /// An incremental websocket frame parser. /// /// This parser attempts to parse a websocket frame incrementally, keeping as much parsing state around as possible to ensure that /// we don't repeatedly partially parse the data. struct WSParser { /// The current state of the decoder during incremental parse. var state: DecoderState = .idle mutating func parseStep(_ buffer: inout ByteBuffer) -> ParseResult { switch self.state { case .idle: // This is a new buffer. We want to find the first octet and save it off. guard let firstByte = buffer.readInteger(as: UInt8.self) else { return .insufficientData } self.state = .firstByteReceived(firstByte: firstByte) return .continueParsing case .firstByteReceived(let firstByte): // Now we're looking for the length. We begin by finding the length byte to see if we // need any more data. guard let lengthByte = buffer.readInteger(as: UInt8.self) else { return .insufficientData } let masked = (lengthByte & 0x80) != 0 switch (lengthByte & 0x7F, masked) { case (126, _): self.state = .waitingForLengthWord(firstByte: firstByte, masked: masked) case (127, _): self.state = .waitingForLengthQWord(firstByte: firstByte, masked: masked) case (let len, true): assert(len <= 125) self.state = .waitingForMask(firstByte: firstByte, length: Int(len)) case (let len, false): assert(len <= 125) self.state = .waitingForData(firstByte: firstByte, length: Int(len), maskingKey: nil) } return .continueParsing case .waitingForLengthWord(let firstByte, let masked): // We've got a one-word length here. guard let lengthWord = buffer.readInteger(as: UInt16.self) else { return .insufficientData } if masked { self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthWord)) } else { self.state = .waitingForData(firstByte: firstByte, length: Int(lengthWord), maskingKey: nil) } return .continueParsing case .waitingForLengthQWord(let firstByte, let masked): // We've got a qword of length here. guard let lengthQWord = buffer.readInteger(as: UInt64.self) else { return .insufficientData } if masked { self.state = .waitingForMask(firstByte: firstByte, length: Int(lengthQWord)) } else { self.state = .waitingForData(firstByte: firstByte, length: Int(lengthQWord), maskingKey: nil) } return .continueParsing case .waitingForMask(let firstByte, let length): // We're waiting for the masking key. guard let maskingKey = buffer.readInteger(as: UInt32.self) else { return .insufficientData } self.state = .waitingForData( firstByte: firstByte, length: length, maskingKey: WebSocketMaskingKey(networkRepresentation: maskingKey) ) return .continueParsing case .waitingForData(let firstByte, let length, let maskingKey): guard let data = buffer.readSlice(length: length) else { return .insufficientData } let frame = WebSocketFrame(firstByte: firstByte, maskKey: maskingKey, applicationData: data) self.state = .idle return .result(frame) } } /// Apply a number of validations to the incremental state, ensuring that the frame we're /// receiving is valid. func validateState(maxFrameSize: Int) throws { switch self.state { case .waitingForMask(let firstByte, let length), .waitingForData(let firstByte, let length, _): if length > maxFrameSize { throw NIOWebSocketError.invalidFrameLength } let isControlFrame = (firstByte & 0x08) != 0 let isFragment = (firstByte & 0x80) == 0 if isControlFrame && isFragment { throw NIOWebSocketError.fragmentedControlFrame } if isControlFrame && length > 125 { throw NIOWebSocketError.multiByteControlFrameLength } case .idle, .firstByteReceived, .waitingForLengthWord, .waitingForLengthQWord: // No validation necessary in this state as we have no length to validate. break } } } /// An inbound `ChannelHandler` that deserializes websocket frames into a structured /// format for further processing. /// /// This decoder has limited enforcement of compliance to RFC 6455. In particular, to guarantee /// that the decoder can handle arbitrary extensions, only normative MUST/MUST NOTs that do not /// relate to extensions (e.g. the requirement that control frames not have lengths larger than /// 125 bytes) are enforced by this decoder. /// /// This decoder does not have any support for decoding extensions. If you wish to support /// extensions, you should implement a message-to-message decoder that performs the appropriate /// frame transformation as needed. All the frame data is assumed to be application data by this /// parser. public final class WebSocketFrameDecoder: ByteToMessageDecoder { public typealias InboundIn = ByteBuffer public typealias InboundOut = WebSocketFrame public typealias OutboundOut = WebSocketFrame /// The maximum frame size the decoder is willing to tolerate from the remote peer. let maxFrameSize: Int /// Our parser state. private var parser = WSParser() /// Construct a new `WebSocketFrameDecoder` /// /// - Parameters: /// - maxFrameSize: The maximum frame size the decoder is willing to tolerate from the /// remote peer. WebSockets in principle allows frame sizes up to `2**64` bytes, but /// this is an objectively unreasonable maximum value (on AMD64 systems it is not /// possible to even allocate a buffer large enough to handle this size), so we /// set a lower one. The default value is the same as the default HTTP/2 max frame /// size, `2**14` bytes. Users may override this to any value up to `UInt32.max`. /// Users are strongly encouraged not to increase this value unless they absolutely /// must, as the decoder will not produce partial frames, meaning that it will hold /// on to data until the *entire* body is received. public init(maxFrameSize: Int = 1 << 14) { precondition(maxFrameSize <= UInt32.max, "invalid overlarge max frame size") self.maxFrameSize = maxFrameSize } public func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { // Even though the calling code will loop around calling us in `decode`, we can't quite // rely on that: sometimes we have zero-length elements to parse, and the caller doesn't // guarantee to call us with zero-length bytes. while true { switch parser.parseStep(&buffer) { case .result(let frame): context.fireChannelRead(WebSocketFrameDecoder.wrapInboundOut(frame)) return .continue case .continueParsing: try self.parser.validateState(maxFrameSize: self.maxFrameSize) // loop again, might be 'waiting' for 0 bytes case .insufficientData: return .needMoreData } } } } @available(*, unavailable) extension WebSocketFrameDecoder: Sendable {}