From 3990ec2afb2fd281c2c50908462d2900f108bb65 Mon Sep 17 00:00:00 2001 From: Sergej Jaskiewicz Date: Sun, 22 Sep 2019 04:29:47 +0300 Subject: [PATCH] Audit Publishers.Sequence for thread safety --- .../Publishers/Publishers.Sequence.swift | 95 ++++++++++++------- .../PublisherTests/SequenceTests.swift | 40 +++++++- 2 files changed, 99 insertions(+), 36 deletions(-) diff --git a/Sources/OpenCombine/Publishers/Publishers.Sequence.swift b/Sources/OpenCombine/Publishers/Publishers.Sequence.swift index 3b86409..b9474c8 100644 --- a/Sources/OpenCombine/Publishers/Publishers.Sequence.swift +++ b/Sources/OpenCombine/Publishers/Publishers.Sequence.swift @@ -29,7 +29,9 @@ extension Publishers { where Failure == Downstream.Failure, Elements.Element == Downstream.Input { - if let inner = Inner(downstream: subscriber, sequence: sequence) { + var iterator = sequence.makeIterator() + if iterator.next() != nil { + let inner = Inner(downstream: subscriber, sequence: sequence) subscriber.receive(subscription: inner) } else { subscriber.receive(subscription: Subscriptions.empty) @@ -44,66 +46,89 @@ extension Publishers.Sequence { private final class Inner : Subscription, CustomStringConvertible, - CustomReflectable + CustomReflectable, + CustomPlaygroundDisplayConvertible where Downstream.Input == Elements.Element, Downstream.Failure == Failure { + // NOTE: This class has been audited for thread-safety + typealias Iterator = Elements.Iterator typealias Element = Elements.Element - private var _downstream: Downstream? - private var _sequence: Elements? - private var _iterator: Iterator? - private var _nextValue: Element? + private var sequence: Elements? + private var downstream: Downstream? + private var iterator: Iterator + private var next: Element? + private var pendingDemand = Subscribers.Demand.none + private var recursion = false + private var lock = Lock(recursive: false) - init?(downstream: Downstream, sequence: Elements) { - - // Early exit if the sequence is empty - var iterator = sequence.makeIterator() - guard iterator.next() != nil else { return nil } - - _downstream = downstream - _sequence = sequence - _iterator = sequence.makeIterator() - _nextValue = _iterator?.next() + fileprivate init(downstream: Downstream, sequence: Elements) { + self.sequence = sequence + self.downstream = downstream + self.iterator = sequence.makeIterator() + next = iterator.next() } var description: String { - return _sequence.map(String.init(describing:)) ?? "Sequence" + return sequence.map(String.init(describing:)) ?? "Sequence" } var customMirror: Mirror { - let children: CollectionOfOne<(label: String?, value: Any)> = - .init(("sequence", _sequence ?? [Element]())) + let children = + CollectionOfOne(("sequence", sequence ?? [Element]())) return Mirror(self, children: children) } + var playgroundDescription: Any { return description } + func request(_ demand: Subscribers.Demand) { + lock.lock() + guard let downstream = self.downstream else { + lock.unlock() + return + } + pendingDemand += demand + if recursion { + lock.unlock() + return + } - guard let downstream = _downstream else { return } + while pendingDemand > 0 { + if let current = self.next { + pendingDemand -= 1 - var demand = demand - - while demand > 0 { - if let nextValue = _nextValue { - demand += downstream.receive(nextValue) - demand -= 1 + // Combine calls next() while the lock is held. + // It is possible to engineer a custom Sequence that would cause + // a dedlock here, but it would be something insane. + let next = iterator.next() + recursion = true + lock.unlock() + let additionalDemand = downstream.receive(current) + lock.lock() + recursion = false + pendingDemand += additionalDemand + self.next = next } - _nextValue = _iterator?.next() - - if _nextValue == nil { - _downstream?.receive(completion: .finished) - cancel() - break + if next == nil { + self.downstream = nil + self.sequence = nil + lock.unlock() + downstream.receive(completion: .finished) + return } } + + lock.unlock() } func cancel() { - _downstream = nil - _iterator = nil - _sequence = nil + lock.lock() + downstream = nil + sequence = nil + lock.unlock() } } } diff --git a/Tests/OpenCombineTests/PublisherTests/SequenceTests.swift b/Tests/OpenCombineTests/PublisherTests/SequenceTests.swift index 0d9adb0..615a9f1 100644 --- a/Tests/OpenCombineTests/PublisherTests/SequenceTests.swift +++ b/Tests/OpenCombineTests/PublisherTests/SequenceTests.swift @@ -167,7 +167,7 @@ final class SequenceTests: XCTestCase { } func testPublishesCorrectValues() { - let sequence: Publishers.Sequence = (1...5).publisher + let sequence = makePublisher(1...5) var history = [Int]() _ = sequence.sink { @@ -177,6 +177,44 @@ final class SequenceTests: XCTestCase { XCTAssertEqual(history, [1, 2, 3, 4, 5]) } + func testRecursion() { + let sequence = makePublisher(1...5) + + var history = [Int]() + var storedSubscription: Subscription? + + let tracking = TrackingSubscriberBase( + receiveSubscription: { subscription in + storedSubscription = subscription + subscription.request(.none) // Shouldn't crash + subscription.request(.max(1)) + }, + receiveValue: { value in + storedSubscription?.request(.max(1)) + history.append(value) + return .none + } + ) + + sequence.subscribe(tracking) + + XCTAssertEqual(history, [1, 2, 3, 4, 5]) + } + + func testReflection() throws { + + func testCustomMirror(_ mirror: Mirror) -> Bool { + return mirror.children.count == 1 && + mirror.children.first!.label == "sequence" && + (mirror.children.first!.value as! ClosedRange) == 1...5 + } + + try testSubscriptionReflection(description: "1...5", + customMirror: testCustomMirror, + playgroundDescription: "1...5", + sut: makePublisher(1...5)) + } + func testLifecycle() throws { var deinitCounter = 0