Audit Publishers.Sequence for thread safety

This commit is contained in:
Sergej Jaskiewicz
2019-09-22 04:29:47 +03:00
committed by Sergej Jaskiewicz
parent 39dd9e40bf
commit 3990ec2afb
2 changed files with 99 additions and 36 deletions
@@ -29,7 +29,9 @@ extension Publishers {
where Failure == Downstream.Failure,
Elements.Element == Downstream.Input
{
if let inner = Inner(downstream: subscriber, sequence: sequence) {
var iterator = sequence.makeIterator()
if iterator.next() != nil {
let inner = Inner(downstream: subscriber, sequence: sequence)
subscriber.receive(subscription: inner)
} else {
subscriber.receive(subscription: Subscriptions.empty)
@@ -44,66 +46,89 @@ extension Publishers.Sequence {
private final class Inner<Downstream: Subscriber, Elements: Sequence, Failure>
: Subscription,
CustomStringConvertible,
CustomReflectable
CustomReflectable,
CustomPlaygroundDisplayConvertible
where Downstream.Input == Elements.Element,
Downstream.Failure == Failure
{
// NOTE: This class has been audited for thread-safety
typealias Iterator = Elements.Iterator
typealias Element = Elements.Element
private var _downstream: Downstream?
private var _sequence: Elements?
private var _iterator: Iterator?
private var _nextValue: Element?
private var sequence: Elements?
private var downstream: Downstream?
private var iterator: Iterator
private var next: Element?
private var pendingDemand = Subscribers.Demand.none
private var recursion = false
private var lock = Lock(recursive: false)
init?(downstream: Downstream, sequence: Elements) {
// Early exit if the sequence is empty
var iterator = sequence.makeIterator()
guard iterator.next() != nil else { return nil }
_downstream = downstream
_sequence = sequence
_iterator = sequence.makeIterator()
_nextValue = _iterator?.next()
fileprivate init(downstream: Downstream, sequence: Elements) {
self.sequence = sequence
self.downstream = downstream
self.iterator = sequence.makeIterator()
next = iterator.next()
}
var description: String {
return _sequence.map(String.init(describing:)) ?? "Sequence"
return sequence.map(String.init(describing:)) ?? "Sequence"
}
var customMirror: Mirror {
let children: CollectionOfOne<(label: String?, value: Any)> =
.init(("sequence", _sequence ?? [Element]()))
let children =
CollectionOfOne<Mirror.Child>(("sequence", sequence ?? [Element]()))
return Mirror(self, children: children)
}
var playgroundDescription: Any { return description }
func request(_ demand: Subscribers.Demand) {
lock.lock()
guard let downstream = self.downstream else {
lock.unlock()
return
}
pendingDemand += demand
if recursion {
lock.unlock()
return
}
guard let downstream = _downstream else { return }
while pendingDemand > 0 {
if let current = self.next {
pendingDemand -= 1
var demand = demand
while demand > 0 {
if let nextValue = _nextValue {
demand += downstream.receive(nextValue)
demand -= 1
// Combine calls next() while the lock is held.
// It is possible to engineer a custom Sequence that would cause
// a dedlock here, but it would be something insane.
let next = iterator.next()
recursion = true
lock.unlock()
let additionalDemand = downstream.receive(current)
lock.lock()
recursion = false
pendingDemand += additionalDemand
self.next = next
}
_nextValue = _iterator?.next()
if _nextValue == nil {
_downstream?.receive(completion: .finished)
cancel()
break
if next == nil {
self.downstream = nil
self.sequence = nil
lock.unlock()
downstream.receive(completion: .finished)
return
}
}
lock.unlock()
}
func cancel() {
_downstream = nil
_iterator = nil
_sequence = nil
lock.lock()
downstream = nil
sequence = nil
lock.unlock()
}
}
}
@@ -167,7 +167,7 @@ final class SequenceTests: XCTestCase {
}
func testPublishesCorrectValues() {
let sequence: Publishers.Sequence = (1...5).publisher
let sequence = makePublisher(1...5)
var history = [Int]()
_ = sequence.sink {
@@ -177,6 +177,44 @@ final class SequenceTests: XCTestCase {
XCTAssertEqual(history, [1, 2, 3, 4, 5])
}
func testRecursion() {
let sequence = makePublisher(1...5)
var history = [Int]()
var storedSubscription: Subscription?
let tracking = TrackingSubscriberBase<Int, Never>(
receiveSubscription: { subscription in
storedSubscription = subscription
subscription.request(.none) // Shouldn't crash
subscription.request(.max(1))
},
receiveValue: { value in
storedSubscription?.request(.max(1))
history.append(value)
return .none
}
)
sequence.subscribe(tracking)
XCTAssertEqual(history, [1, 2, 3, 4, 5])
}
func testReflection() throws {
func testCustomMirror(_ mirror: Mirror) -> Bool {
return mirror.children.count == 1 &&
mirror.children.first!.label == "sequence" &&
(mirror.children.first!.value as! ClosedRange<Int>) == 1...5
}
try testSubscriptionReflection(description: "1...5",
customMirror: testCustomMirror,
playgroundDescription: "1...5",
sut: makePublisher(1...5))
}
func testLifecycle() throws {
var deinitCounter = 0