diff --git a/Sources/FluidAudio/ModelNames.swift b/Sources/FluidAudio/ModelNames.swift index fbb7d281..1784fd7a 100644 --- a/Sources/FluidAudio/ModelNames.swift +++ b/Sources/FluidAudio/ModelNames.swift @@ -667,7 +667,14 @@ public enum ModelNames { public static let mimiDecoder = "mimi_decoder" public static let mimiEncoder = "mimi_encoder" + /// Chunk-16 variant of `cond_step`. Same I/O schema as the chunk-1 + /// model except the conditioning input has sequence dim 16 instead + /// of 1. Used by the chunked prefill pipeline to amortize CoreML + /// dispatch overhead across longer text/voice prompts. + public static let condStepChunk16 = "cond_step_chunk16" + public static let condStepFile = condStep + ".mlmodelc" + public static let condStepChunk16File = condStepChunk16 + ".mlmodelc" public static let flowlmStepFile = flowlmStep + ".mlmodelc" public static let flowlmStepV2File = flowlmStepV2 + ".mlmodelc" public static let flowDecoderFile = flowDecoder + ".mlmodelc" diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsModelStore.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsModelStore.swift index 831e784a..9d649d2f 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsModelStore.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsModelStore.swift @@ -14,6 +14,7 @@ public actor PocketTtsModelStore { private let logger = AppLogger(subsystem: "com.fluidaudio.tts", category: "PocketTtsModelStore") private var condStepModel: MLModel? + private var condStepChunkModelStorage: MLModel? private var flowlmStepModel: MLModel? private var flowDecoderModel: MLModel? private var mimiDecoderModel: MLModel? @@ -22,11 +23,13 @@ public actor PocketTtsModelStore { private var voiceCache: [String: PocketTtsVoiceData] = [:] private var languageRootDirectory: URL? private var condLayerKeys: PocketTtsLayerKeys? + private var condStepChunkLayerKeysCache: PocketTtsLayerKeys? private var flowlmLayerKeys: PocketTtsLayerKeys? private var mimiDecoderKeysCache: PocketTtsMimiKeys? private let directory: URL? public let language: PocketTtsLanguage public let precision: PocketTtsPrecision + public let condStepMode: PocketTtsCondStepMode /// - Parameters: /// - language: Which upstream language pack to load. Defaults to @@ -38,14 +41,24 @@ public actor PocketTtsModelStore { /// `flowlm_step.mlmodelc` for `flowlm_stepv2.mlmodelc` from the /// same upstream `v2//` directory; the other three submodels /// stay at fp16. + /// - condStepMode: Which `cond_step` dispatch strategy the synthesizer + /// should use for KV cache prefill. Defaults to `.legacy` (per-token + /// dispatch — preserves the upstream behaviour). `.chunked(chunk: + /// 16)` additionally loads `cond_step_chunk16.mlmodelc` from the + /// same `v2//` directory and lets the synthesizer dispatch + /// prompt prefill in 16-token chunks plus a per-token tail. The + /// chunk-16 file is **not yet published on HuggingFace** — see + /// `PocketTtsCondStepMode.chunked` for placement details. public init( language: PocketTtsLanguage = .english, directory: URL? = nil, - precision: PocketTtsPrecision = .fp16 + precision: PocketTtsPrecision = .fp16, + condStepMode: PocketTtsCondStepMode = .legacy ) { self.language = language self.directory = directory self.precision = precision + self.condStepMode = condStepMode } /// Load all four CoreML models and the constants bundle. @@ -117,11 +130,63 @@ public actor PocketTtsModelStore { let elapsed = Date().timeIntervalSince(loadStart) logger.info("All PocketTTS models loaded in \(String(format: "%.2f", elapsed))s") + // Optionally load the chunked cond_step variant. The chunk-N model + // shares the K/V cache + position output schema of chunk-1, so the + // existing `.condStep` discovery kind applies — only the input + // sequence dim differs (1 → N). + if case .chunked(let chunk) = condStepMode { + try await loadCondStepChunkModel( + chunk: chunk, + languageRoot: languageRoot, + config: config, + expectedLayers: expectedLayers + ) + } + // Load constants constantsBundle = try PocketTtsConstantsLoader.load(from: languageRoot) logger.info("PocketTTS constants loaded") } + private func loadCondStepChunkModel( + chunk: Int, + languageRoot: URL, + config: MLModelConfiguration, + expectedLayers: Int + ) async throws { + // Only chunk-16 is supported initially. Reject other sizes loudly so + // callers don't silently fall back to a missing artifact. + guard chunk == 16 else { + throw PocketTTSError.modelNotFound( + "PocketTTS chunked cond_step only supports chunk=16 today (requested \(chunk))" + ) + } + + let file = ModelNames.PocketTTS.condStepChunk16File + let modelURL = languageRoot.appendingPathComponent(file) + guard FileManager.default.fileExists(atPath: modelURL.path) else { + throw PocketTTSError.modelNotFound( + "PocketTTS \(file) not found at \(modelURL.path). " + + "The chunked cond_step variant is not yet published on HuggingFace; " + + "place the compiled mlmodelc at this path manually to enable .chunked(chunk: 16)." + ) + } + + let chunkLoadStart = Date() + let model = try MLModel(contentsOf: modelURL, configuration: config) + condStepChunkModelStorage = model + condStepChunkLayerKeysCache = try PocketTtsLayerKeys.discover( + from: model, + kind: .condStep, + expectedLayers: expectedLayers, + modelName: ModelNames.PocketTTS.condStepChunk16 + ) + let chunkElapsed = Date().timeIntervalSince(chunkLoadStart) + logger.info( + "Loaded \(file) (chunk=\(chunk)) in \(String(format: "%.2f", chunkElapsed))s" + ) + } + /// The conditioning step model (KV cache prefill). public func condStep() throws -> MLModel { guard let model = condStepModel else { @@ -170,6 +235,44 @@ public actor PocketTtsModelStore { return keys } + /// The chunked cond_step model. Throws when the store was initialized + /// in `.legacy` mode (or when the chunk model file failed to load) — + /// callers should gate on `condStepChunkSize() != nil` first. + /// + /// Returns a non-optional `MLModel` to match the Sendable behaviour of + /// `condStep()`; `Optional` would require `MLModel` itself to + /// satisfy strict-mode Sendable (the `@preconcurrency import CoreML` + /// trick only covers the bare `MLModel`). + public func condStepChunkModel() throws -> MLModel { + guard let model = condStepChunkModelStorage else { + throw PocketTTSError.modelNotFound( + "PocketTTS chunked cond_step model not loaded (mode = \(condStepMode))" + ) + } + return model + } + + /// Discovered output names for the chunked cond_step model. Same + /// throwing semantics as `condStepChunkModel()`. + func condStepChunkLayerKeys() throws -> PocketTtsLayerKeys { + guard let keys = condStepChunkLayerKeysCache else { + throw PocketTTSError.modelNotFound( + "PocketTTS chunked cond_step layer keys not discovered (mode = \(condStepMode))" + ) + } + return keys + } + + /// The chunk size if the store was initialized in `.chunked` mode, + /// otherwise `nil`. Use this as the cheap gate before calling the + /// throwing chunk-model accessors. + public func condStepChunkSize() -> Int? { + if case .chunked(let chunk) = condStepMode { + return chunk + } + return nil + } + /// Discovered output names for the flowlm_step transformer model. func flowLMStepLayerKeys() throws -> PocketTtsLayerKeys { guard let keys = flowlmLayerKeys else { diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSession.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSession.swift index 138f3276..450ce819 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSession.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSession.swift @@ -61,6 +61,9 @@ public actor PocketTtsSession { private let condLayerKeys: PocketTtsLayerKeys private let flowlmLayerKeys: PocketTtsLayerKeys private let mimiKeys: PocketTtsMimiKeys + private let chunkCondModel: MLModel? + private let chunkCondLayerKeys: PocketTtsLayerKeys? + private let chunkSize: Int? // Persistent state private let voiceKVSnapshot: PocketTtsSynthesizer.KVCacheState @@ -88,7 +91,10 @@ public actor PocketTtsSession { mimiKeys: PocketTtsMimiKeys, bosEmb: MLMultiArray, temperature: Float, - seed: UInt64 + seed: UInt64, + chunkCondModel: MLModel? = nil, + chunkCondLayerKeys: PocketTtsLayerKeys? = nil, + chunkSize: Int? = nil ) { self.voiceKVSnapshot = voiceKVSnapshot self.mimiState = mimiState @@ -103,6 +109,9 @@ public actor PocketTtsSession { self.bosEmb = bosEmb self.temperature = temperature self.rng = SeededRNG(seed: seed) + self.chunkCondModel = chunkCondModel + self.chunkCondLayerKeys = chunkCondLayerKeys + self.chunkSize = chunkSize // Text queue channel let (textStream, textContinuation) = AsyncStream.makeStream(of: String.self) @@ -178,12 +187,26 @@ public actor PocketTtsSession { let tokenIds = constants.tokenizer.encode(normalizedChunk) let textEmbeddings = PocketTtsSynthesizer.embedTokens(tokenIds, constants: constants) - // Clone voice KV snapshot and prefill text tokens only + // Clone voice KV snapshot and prefill text tokens only. + // Uses hybrid chunk-N + chunk-1 dispatch when the session was + // created from a store opened in `.chunked` mode. var kvState = try PocketTtsSynthesizer.cloneKVCacheState(voiceKVSnapshot) - kvState = try await PocketTtsSynthesizer.prefillKVCacheText( - state: kvState, textEmbeddings: textEmbeddings, model: condModel, - layerKeys: condLayerKeys - ) + if let cm = chunkCondModel, let ck = chunkCondLayerKeys, let cs = chunkSize, cs > 1 { + kvState = try await PocketTtsSynthesizer.prefillKVCacheTextHybrid( + state: kvState, + textEmbeddings: textEmbeddings, + chunkModel: cm, + chunkLayerKeys: ck, + chunkSize: cs, + perTokenModel: condModel, + perTokenLayerKeys: condLayerKeys + ) + } else { + kvState = try await PocketTtsSynthesizer.prefillKVCacheText( + state: kvState, textEmbeddings: textEmbeddings, model: condModel, + layerKeys: condLayerKeys + ) + } // Generation loop let maxGenLen = PocketTtsSynthesizer.estimateMaxFrames(text: text) diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+Bench.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+Bench.swift new file mode 100644 index 00000000..d30fb953 --- /dev/null +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+Bench.swift @@ -0,0 +1,138 @@ +@preconcurrency import CoreML +import Foundation + +extension PocketTtsSynthesizer { + + /// Time `prefillKVCache` end-to-end for a single text + voice. Used by + /// `PocketTtsManager.benchmarkCondStepPrefill` and the `pocket-tts-cond-bench` + /// CLI subcommand to A/B compare the legacy chunk-1 and hybrid chunk-N + /// dispatch paths. + /// + /// Must be called within a `withModelStore` context. The store's + /// `condStepMode` decides whether the hybrid path runs. + static func benchmarkCondStepPrefill( + text: String, + voice: String, + warmup: Int, + iters: Int + ) async throws -> PocketTtsManager.CondStepPrefillBenchmarkResult { + let store = try currentModelStore() + let voiceData = try await store.voiceData(for: voice) + return try await benchmarkCondStepPrefill( + text: text, voiceData: voiceData, warmup: warmup, iters: iters + ) + } + + /// Same as the voice-name overload but takes pre-resolved voice data. + static func benchmarkCondStepPrefill( + text: String, + voiceData: PocketTtsVoiceData, + warmup: Int, + iters: Int + ) async throws -> PocketTtsManager.CondStepPrefillBenchmarkResult { + let store = try currentModelStore() + + let constants = try await store.constants() + let condModel = try await store.condStep() + let condLayerKeys = try await store.condStepLayerKeys() + + // Tokenize + embed once. The benchmark times only the prefill loop — + // tokenization is cheap and identical across iterations and across + // legacy vs chunked dispatch. + let (normalizedChunk, _) = normalizeText(text) + let tokenIds = constants.tokenizer.encode(normalizedChunk) + let textEmbeddings = embedTokens(tokenIds, constants: constants) + + // Resolve chunk resources only when the store is in chunked mode. + let chunkSize = await store.condStepChunkSize() + let chunkModel: MLModel? + let chunkLayerKeys: PocketTtsLayerKeys? + if chunkSize != nil { + chunkModel = try await store.condStepChunkModel() + chunkLayerKeys = try await store.condStepChunkLayerKeys() + } else { + chunkModel = nil + chunkLayerKeys = nil + } + + // Voice token count for reporting. Snapshot voices skip cond_step + // entirely on the voice side (`promptLength == 0` in that case). + let voiceTokens: Int = { + if voiceData.cacheSnapshot != nil { return 0 } + return voiceData.promptLength + }() + + // Warmup iterations (not recorded). Each runs a full prefill so the + // CoreML graph, MIL caches, and ANE/GPU pipelines are warm before + // we start sampling. + for _ in 0.. KVCacheState { + if let cm = chunkModel, let ck = chunkLayerKeys, let cs = chunkSize { + return try await prefillKVCache( + voiceData: voiceData, + textEmbeddings: textEmbeddings, + model: model, + layerKeys: layerKeys, + chunkModel: cm, + chunkLayerKeys: ck, + chunkSize: cs + ) + } + return try await prefillKVCache( + voiceData: voiceData, + textEmbeddings: textEmbeddings, + model: model, + layerKeys: layerKeys + ) + } +} diff --git a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift index f6c5de75..578fd070 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/Pipeline/PocketTtsSynthesizer+KVCache.swift @@ -119,6 +119,48 @@ extension PocketTtsSynthesizer { } } + /// Run the chunked conditioning step model on a `[1, chunkSize, 1024]` + /// conditioning tensor, updating the KV cache in place. + /// + /// Same I/O contract as `runCondStep` — the chunk-N CoreML graph reuses + /// the cond_step output names (cache + position per layer); only the + /// input sequence dim differs (1 → chunkSize). The model returns the + /// post-chunk K/V cache and the position counter advanced by `chunkSize`. + static func runCondStepChunk( + conditioning: MLMultiArray, + state: inout KVCacheState, + model: MLModel, + layerKeys: PocketTtsLayerKeys + ) async throws { + let layers = layerKeys.layerCount + var inputDict: [String: Any] = [ + "conditioning": conditioning + ] + + for i in 0.. KVCacheState { + var state = state + let dim = PocketTtsConstants.embeddingDim + let total = voiceData.promptLength + let nBig = total / chunkSize + let nOne = total % chunkSize + + var idx = 0 + for _ in 0.. KVCacheState { + var state = state + let dim = PocketTtsConstants.embeddingDim + let total = textEmbeddings.count + let nBig = total / chunkSize + let nOne = total % chunkSize + + var idx = 0 + for _ in 0.. KVCacheState { + let useHybrid = + chunkModel != nil && chunkLayerKeys != nil && chunkSize != nil + && (chunkSize ?? 0) > 1 + var state: KVCacheState if let snapshot = voiceData.cacheSnapshot { state = try kvCacheStateFromSnapshot(snapshot, layers: layerKeys.layerCount) } else { let emptyState = try emptyKVCacheState(layers: layerKeys.layerCount) - state = try await prefillKVCacheVoice( - state: emptyState, voiceData: voiceData, model: model, layerKeys: layerKeys + if useHybrid, let cm = chunkModel, let ck = chunkLayerKeys, let cs = chunkSize { + state = try await prefillKVCacheVoiceHybrid( + state: emptyState, + voiceData: voiceData, + chunkModel: cm, + chunkLayerKeys: ck, + chunkSize: cs, + perTokenModel: model, + perTokenLayerKeys: layerKeys + ) + } else { + state = try await prefillKVCacheVoice( + state: emptyState, voiceData: voiceData, model: model, layerKeys: layerKeys + ) + } + } + + if useHybrid, let cm = chunkModel, let ck = chunkLayerKeys, let cs = chunkSize { + state = try await prefillKVCacheTextHybrid( + state: state, + textEmbeddings: textEmbeddings, + chunkModel: cm, + chunkLayerKeys: ck, + chunkSize: cs, + perTokenModel: model, + perTokenLayerKeys: layerKeys + ) + } else { + state = try await prefillKVCacheText( + state: state, textEmbeddings: textEmbeddings, model: model, layerKeys: layerKeys ) } - state = try await prefillKVCacheText( - state: state, textEmbeddings: textEmbeddings, model: model, layerKeys: layerKeys - ) let finalPos = state.positions[0][0].floatValue logger.info("KV cache prefilled to position \(Int(finalPos))") @@ -297,6 +481,46 @@ extension PocketTtsSynthesizer { return array } + /// Create a `[1, count, 1024]` MLMultiArray from a flat audio-prompt + /// buffer laid out as `[token0_dim0..token0_dimN-1, token1_dim0..., ...]`. + private static func createConditioningChunkFromFlat( + source: [Float], startToken: Int, count: Int, dim: Int + ) throws -> MLMultiArray { + let array = try MLMultiArray( + shape: [1, NSNumber(value: count), NSNumber(value: dim)], + dataType: .float32 + ) + let total = count * dim + let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: total) + source.withUnsafeBufferPointer { buffer in + guard let base = buffer.baseAddress else { return } + ptr.update(from: base.advanced(by: startToken * dim), count: total) + } + return array + } + + /// Create a `[1, count, 1024]` MLMultiArray by concatenating + /// `count` token embeddings starting at `startToken` from a `[[Float]]` + /// embedding source. + private static func createConditioningChunkFromEmbeddings( + source: [[Float]], startToken: Int, count: Int, dim: Int + ) throws -> MLMultiArray { + let array = try MLMultiArray( + shape: [1, NSNumber(value: count), NSNumber(value: dim)], + dataType: .float32 + ) + let total = count * dim + let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: total) + for i in 0.. 1 { + let cm = try await store.condStepChunkModel() + let ck = try await store.condStepChunkLayerKeys() + voiceKVSnapshot = try await prefillKVCacheVoiceHybrid( + state: emptyState, + voiceData: voiceData, + chunkModel: cm, + chunkLayerKeys: ck, + chunkSize: cs, + perTokenModel: condModel, + perTokenLayerKeys: condLayerKeys + ) + } else { + voiceKVSnapshot = try await prefillKVCacheVoice( + state: emptyState, voiceData: voiceData, model: condModel, + layerKeys: condLayerKeys + ) + } } logger.info( "Session voice prefill at position \(Int(voiceKVSnapshot.positions[0][0].floatValue))" ) - let session = PocketTtsSession( - voiceKVSnapshot: voiceKVSnapshot, - mimiState: mimiState, - constants: constants, - condModel: condModel, - stepModel: stepModel, - flowModel: flowModel, - mimiModel: mimiModel, - condLayerKeys: condLayerKeys, - flowlmLayerKeys: flowlmLayerKeys, - mimiKeys: mimiKeys, - bosEmb: bosEmb, - temperature: temperature, - seed: seedValue - ) + let session: PocketTtsSession + if let cs = chunkSize { + let cm = try await store.condStepChunkModel() + let ck = try await store.condStepChunkLayerKeys() + session = PocketTtsSession( + voiceKVSnapshot: voiceKVSnapshot, + mimiState: mimiState, + constants: constants, + condModel: condModel, + stepModel: stepModel, + flowModel: flowModel, + mimiModel: mimiModel, + condLayerKeys: condLayerKeys, + flowlmLayerKeys: flowlmLayerKeys, + mimiKeys: mimiKeys, + bosEmb: bosEmb, + temperature: temperature, + seed: seedValue, + chunkCondModel: cm, + chunkCondLayerKeys: ck, + chunkSize: cs + ) + } else { + session = PocketTtsSession( + voiceKVSnapshot: voiceKVSnapshot, + mimiState: mimiState, + constants: constants, + condModel: condModel, + stepModel: stepModel, + flowModel: flowModel, + mimiModel: mimiModel, + condLayerKeys: condLayerKeys, + flowlmLayerKeys: flowlmLayerKeys, + mimiKeys: mimiKeys, + bosEmb: bosEmb, + temperature: temperature, + seed: seedValue + ) + } await session.start() return session } @@ -333,6 +405,9 @@ public struct PocketTtsSynthesizer { var rng: SeededRNG let chunkCount: Int let temperature: Float + let chunkCondModel: MLModel? + let chunkCondLayerKeys: PocketTtsLayerKeys? + let chunkSize: Int? init( constants: PocketTtsConstantsBundle, @@ -349,7 +424,10 @@ public struct PocketTtsSynthesizer { bosEmb: MLMultiArray, seedValue: UInt64, chunkCount: Int, - temperature: Float + temperature: Float, + chunkCondModel: MLModel? = nil, + chunkCondLayerKeys: PocketTtsLayerKeys? = nil, + chunkSize: Int? = nil ) { self.constants = constants self.voiceData = voiceData @@ -366,6 +444,9 @@ public struct PocketTtsSynthesizer { self.rng = SeededRNG(seed: seedValue) self.chunkCount = chunkCount self.temperature = temperature + self.chunkCondModel = chunkCondModel + self.chunkCondLayerKeys = chunkCondLayerKeys + self.chunkSize = chunkSize } /// Flow decode using actor-isolated RNG state. @@ -435,12 +516,25 @@ public struct PocketTtsSynthesizer { let textEmbeddings = PocketTtsSynthesizer.embedTokens( tokenIds, constants: constants) - var kvState = try await PocketTtsSynthesizer.prefillKVCache( - voiceData: voiceData, - textEmbeddings: textEmbeddings, - model: condModel, - layerKeys: condLayerKeys - ) + var kvState: KVCacheState + if let cm = chunkCondModel, let ck = chunkCondLayerKeys, let cs = chunkSize { + kvState = try await PocketTtsSynthesizer.prefillKVCache( + voiceData: voiceData, + textEmbeddings: textEmbeddings, + model: condModel, + layerKeys: condLayerKeys, + chunkModel: cm, + chunkLayerKeys: ck, + chunkSize: cs + ) + } else { + kvState = try await PocketTtsSynthesizer.prefillKVCache( + voiceData: voiceData, + textEmbeddings: textEmbeddings, + model: condModel, + layerKeys: condLayerKeys + ) + } let maxGenLen = PocketTtsSynthesizer.estimateMaxFrames(text: chunkText) var eosStep: Int? diff --git a/Sources/FluidAudio/TTS/PocketTTS/PocketTtsConstants.swift b/Sources/FluidAudio/TTS/PocketTTS/PocketTtsConstants.swift index 85a514dd..50b02426 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/PocketTtsConstants.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/PocketTtsConstants.swift @@ -135,3 +135,34 @@ public enum PocketTtsPrecision: Sendable, Hashable { case fp16 case int8 } + +/// Selects which `cond_step` dispatch strategy the PocketTTS pipeline uses +/// for prefilling the KV cache from voice tokens and text embeddings. +/// +/// `cond_step` is invoked once per token in the prompt during prefill — for +/// the default English voice (~125 frames) plus a typical sentence (~17–80 +/// text tokens) that is 140–200 CoreML calls before the first audio frame +/// is generated. Each call has a fixed overhead (Python/Swift interop, MIL +/// dispatch, kernel launch) that dominates wall time at small T. +/// +/// - `legacy`: dispatch the chunk-1 model once per token. This matches the +/// reference Python implementation and the upstream FluidAudio behaviour +/// prior to chunked prefill landing. Always works with the published +/// `cond_step.mlmodelc` artifact. +/// +/// - `chunked(chunk: N)`: hybrid dispatch — `T / N` calls to a chunk-N +/// variant of `cond_step` plus `T % N` chunk-1 tail calls. Same KV cache +/// schema as legacy (the chunk-N graph reuses the cond_step output names; +/// only the input sequence dim differs). Empirically 3–14× faster prefill +/// on Apple Silicon depending on T. +/// +/// Only `chunked(chunk: 16)` is supported initially. The chunk-16 model +/// file (`cond_step_chunk16.mlmodelc`) is **not yet published on +/// HuggingFace** — callers must place it manually under +/// `/v2//cond_step_chunk16.mlmodelc`. Loading the model +/// store in this mode without the file present throws +/// `PocketTTSError.modelNotFound` with the expected path. +public enum PocketTtsCondStepMode: Sendable, Equatable { + case legacy + case chunked(chunk: Int) +} diff --git a/Sources/FluidAudio/TTS/PocketTTS/PocketTtsManager.swift b/Sources/FluidAudio/TTS/PocketTTS/PocketTtsManager.swift index e95de9e2..c16395f9 100644 --- a/Sources/FluidAudio/TTS/PocketTTS/PocketTtsManager.swift +++ b/Sources/FluidAudio/TTS/PocketTTS/PocketTtsManager.swift @@ -40,12 +40,14 @@ public actor PocketTtsManager { defaultVoice: String = PocketTtsConstants.defaultVoice, language: PocketTtsLanguage = .english, directory: URL? = nil, - precision: PocketTtsPrecision = .fp16 + precision: PocketTtsPrecision = .fp16, + condStepMode: PocketTtsCondStepMode = .legacy ) { self.modelStore = PocketTtsModelStore( language: language, directory: directory, - precision: precision + precision: precision, + condStepMode: condStepMode ) self.defaultVoice = defaultVoice self.language = language @@ -334,6 +336,57 @@ public actor PocketTtsManager { isInitialized = false } + // MARK: - Benchmarks + + /// Result of a single `cond_step` prefill benchmark configuration. + public struct CondStepPrefillBenchmarkResult: Sendable { + /// Number of text tokens fed through `cond_step` (does not include voice prefill). + public let textTokens: Int + /// Number of voice tokens (only > 0 when the voice has no pre-baked snapshot). + public let voiceTokens: Int + /// Per-iteration wall-clock prefill durations in seconds. Excludes warmup. + public let durations: [TimeInterval] + /// Whether the chunked-N + chunk-1 hybrid dispatch was used. + public let usingChunked: Bool + /// Chunk size when `usingChunked == true`, else `nil`. + public let chunkSize: Int? + } + + /// Benchmark `cond_step` prefill for a single text on the given voice. + /// + /// Times only the `prefillKVCache` call — no flowlm_step, flow_decoder, + /// or mimi_decoder invocations. The first `warmup` iterations run but + /// are not recorded; the next `iters` iterations are timed and returned + /// in `durations`. + /// + /// Pipeline path: + /// - `.legacy` store mode: per-token chunk-1 dispatch (matches upstream behaviour). + /// - `.chunked(chunk: N)` store mode: hybrid (`T / N` chunk-N calls + `T % N` + /// chunk-1 tail calls). + /// + /// Use the same `text` and `voice` across two managers (one in each + /// mode) to A/B compare prefill latency end-to-end. + public func benchmarkCondStepPrefill( + text: String, + voice: String? = nil, + warmup: Int = 3, + iters: Int = 30 + ) async throws -> CondStepPrefillBenchmarkResult { + guard isInitialized else { + throw PocketTTSError.modelNotFound("PocketTTS model not initialized") + } + guard iters > 0 else { + throw PocketTTSError.processingFailed("benchmarkCondStepPrefill: iters must be > 0") + } + + let selectedVoice = voice ?? defaultVoice + return try await PocketTtsSynthesizer.withModelStore(modelStore) { + try await PocketTtsSynthesizer.benchmarkCondStepPrefill( + text: text, voice: selectedVoice, warmup: warmup, iters: iters + ) + } + } + // MARK: - Voice Cloning /// Check if voice cloning is available (mimi_encoder model present). diff --git a/Sources/FluidAudioCLI/Commands/PocketTtsCondBenchCommand.swift b/Sources/FluidAudioCLI/Commands/PocketTtsCondBenchCommand.swift new file mode 100644 index 00000000..c072658d --- /dev/null +++ b/Sources/FluidAudioCLI/Commands/PocketTtsCondBenchCommand.swift @@ -0,0 +1,239 @@ +import FluidAudio +import Foundation + +/// Benchmark `cond_step` prefill latency for two PocketTTS dispatch modes +/// (legacy chunk-1 vs hybrid chunk-N + chunk-1) on the same text and voice. +/// +/// Run via: +/// ``` +/// swift run -c release fluidaudio pocket-tts-cond-bench \ +/// --text "Hello world." --voice alba --iters 30 --warmup 3 +/// ``` +/// +/// Requires the chunk-N model file (`cond_step_chunk16.mlmodelc`) to be +/// placed manually under `/v2//` — the file is +/// not yet published on HuggingFace. See +/// `PocketTtsCondStepMode.chunked` for placement details. +public enum PocketTtsCondBenchCommand { + + private static let logger = AppLogger(category: "PocketTtsCondBench") + + public static func run(arguments: [String]) async { + var text = "Hello world, this is a test of the pocket TTS system." + var voice = "alba" + var languageRaw = "english" + var iters = 30 + var warmup = 3 + var chunk = 16 + var alsoSynth = false + + var i = 0 + while i < arguments.count { + let arg = arguments[i] + switch arg { + case "--text": + if i + 1 < arguments.count { + text = arguments[i + 1] + i += 1 + } + case "--voice": + if i + 1 < arguments.count { + voice = arguments[i + 1] + i += 1 + } + case "--language": + if i + 1 < arguments.count { + languageRaw = arguments[i + 1] + i += 1 + } + case "--iters": + if i + 1 < arguments.count, let n = Int(arguments[i + 1]), n > 0 { + iters = n + i += 1 + } + case "--warmup": + if i + 1 < arguments.count, let n = Int(arguments[i + 1]), n >= 0 { + warmup = n + i += 1 + } + case "--chunk": + if i + 1 < arguments.count, let n = Int(arguments[i + 1]), n > 0 { + chunk = n + i += 1 + } + case "--also-synth": + alsoSynth = true + case "--help", "-h": + printUsage() + return + default: + logger.warning("Unknown argument: \(arg)") + } + i += 1 + } + + guard let language = PocketTtsLanguage(rawValue: languageRaw) else { + let supported = PocketTtsLanguage.allCases.map { $0.rawValue }.joined(separator: ", ") + logger.error("Unknown language '\(languageRaw)'. Supported: \(supported)") + return + } + + do { + try await runBenchmark( + text: text, + voice: voice, + language: language, + iters: iters, + warmup: warmup, + chunk: chunk, + alsoSynth: alsoSynth + ) + } catch { + logger.error("pocket-tts-cond-bench failed: \(error)") + } + } + + private static func printUsage() { + let usage = """ + Usage: fluidaudio pocket-tts-cond-bench [options] + + Options: + --text Text to prefill (default: short test sentence) + --voice Voice id (default: alba) + --language Language pack id (default: english) + --iters Timed iterations per config (default: 30) + --warmup Warmup iterations per config, not recorded (default: 3) + --chunk Chunk size for the hybrid path (default: 16) + --also-synth Also synthesize one WAV per config to /tmp + """ + print(usage) + } + + private static func runBenchmark( + text: String, + voice: String, + language: PocketTtsLanguage, + iters: Int, + warmup: Int, + chunk: Int, + alsoSynth: Bool + ) async throws { + print("PocketTTS cond_step prefill benchmark") + print(" language: \(language.rawValue)") + print(" voice: \(voice)") + print(" text: \"\(text)\"") + print(" warmup: \(warmup) iters") + print(" iters: \(iters)") + print(" chunk: \(chunk)") + print("") + + // --- Legacy (chunk-1) --- + let legacyManager = PocketTtsManager( + defaultVoice: voice, language: language, condStepMode: .legacy + ) + print("Initializing legacy manager (chunk-1 dispatch)...") + let legacyInitStart = Date() + try await legacyManager.initialize() + let legacyInitElapsed = String(format: "%.2f", Date().timeIntervalSince(legacyInitStart)) + print(" done in \(legacyInitElapsed)s") + + print("Running legacy benchmark...") + let legacy = try await legacyManager.benchmarkCondStepPrefill( + text: text, voice: voice, warmup: warmup, iters: iters + ) + + // --- Hybrid (chunk-N + chunk-1) --- + let chunkedManager = PocketTtsManager( + defaultVoice: voice, language: language, condStepMode: .chunked(chunk: chunk) + ) + print("Initializing chunked manager (chunk-\(chunk) + chunk-1 hybrid)...") + let chunkedInitStart = Date() + do { + try await chunkedManager.initialize() + let chunkedInitElapsed = String(format: "%.2f", Date().timeIntervalSince(chunkedInitStart)) + print(" done in \(chunkedInitElapsed)s") + } catch { + logger.error( + "chunked manager init failed (is cond_step_chunk\(chunk).mlmodelc placed under the language root?): \(error)" + ) + throw error + } + + print("Running chunked benchmark...") + let chunked = try await chunkedManager.benchmarkCondStepPrefill( + text: text, voice: voice, warmup: warmup, iters: iters + ) + + printSummary(legacy: legacy, chunked: chunked) + + if alsoSynth { + print("") + print("Synthesizing one WAV per config...") + let legacyURL = URL(fileURLWithPath: "/tmp/pocket-tts-cond-bench-legacy.wav") + let chunkedURL = URL(fileURLWithPath: "/tmp/pocket-tts-cond-bench-chunked.wav") + try await legacyManager.synthesizeToFile(text: text, outputURL: legacyURL, voice: voice) + try await chunkedManager.synthesizeToFile(text: text, outputURL: chunkedURL, voice: voice) + print(" wrote \(legacyURL.path)") + print(" wrote \(chunkedURL.path)") + } + } + + private static func printSummary( + legacy: PocketTtsManager.CondStepPrefillBenchmarkResult, + chunked: PocketTtsManager.CondStepPrefillBenchmarkResult + ) { + let legacyMs = stats(legacy.durations) + let chunkedMs = stats(chunked.durations) + let speedup = legacyMs.median / max(chunkedMs.median, 1e-9) + + print("") + print( + "SUMMARY — text tokens=\(legacy.textTokens), voice tokens=\(legacy.voiceTokens) (legacy) / \(chunked.voiceTokens) (chunked)" + ) + print(" config | median | min | stdev | iters") + print(" ---------+-----------+-----------+-----------+-------") + print( + " legacy | \(fmt(legacyMs.median)) | \(fmt(legacyMs.min)) | \(fmt(legacyMs.stdev)) | \(legacy.durations.count)" + ) + let chunkLabel = chunked.chunkSize.map { "chunk-\($0)" } ?? "chunked" + print( + " \(pad(chunkLabel, 8)) | \(fmt(chunkedMs.median)) | \(fmt(chunkedMs.min)) | \(fmt(chunkedMs.stdev)) | \(chunked.durations.count)" + ) + print("") + let speedupStr = String(format: "%.2f", speedup) + print(" speedup (legacy median / chunked median): \(speedupStr)x") + } + + private static func fmt(_ seconds: TimeInterval) -> String { + // Render in milliseconds, fixed-width 8 chars (e.g. " 12.34 ms"). + let ms = seconds * 1000.0 + return String(format: "%6.2f ms", ms) + } + + private static func pad(_ s: String, _ width: Int) -> String { + if s.count >= width { return s } + return s + String(repeating: " ", count: width - s.count) + } + + private static func stats( + _ values: [TimeInterval] + ) -> ( + median: TimeInterval, min: TimeInterval, stdev: TimeInterval + ) { + guard !values.isEmpty else { return (0, 0, 0) } + let sorted = values.sorted() + let mid = sorted.count / 2 + let median: TimeInterval + if sorted.count.isMultiple(of: 2) { + median = (sorted[mid - 1] + sorted[mid]) / 2 + } else { + median = sorted[mid] + } + let minVal = sorted.first ?? 0 + let mean = values.reduce(0, +) / Double(values.count) + let variance = + values.reduce(0) { $0 + ($1 - mean) * ($1 - mean) } / Double(values.count) + let stdev = variance.squareRoot() + return (median, minVal, stdev) + } +} diff --git a/Sources/FluidAudioCLI/FluidAudioCLI.swift b/Sources/FluidAudioCLI/FluidAudioCLI.swift index 1b601209..b106a0ce 100644 --- a/Sources/FluidAudioCLI/FluidAudioCLI.swift +++ b/Sources/FluidAudioCLI/FluidAudioCLI.swift @@ -52,6 +52,8 @@ struct FluidAudioCLI { await TTSAsrVerifyCommand.run(arguments: Array(arguments.dropFirst(2))) case "tts-benchmark": await TtsBenchmarkCommand.run(arguments: Array(arguments.dropFirst(2))) + case "pocket-tts-cond-bench": + await PocketTtsCondBenchCommand.run(arguments: Array(arguments.dropFirst(2))) case "minimax-corpus": await MinimaxCorpusCommand.run(arguments: Array(arguments.dropFirst(2))) case "diarization-benchmark": @@ -121,6 +123,7 @@ struct FluidAudioCLI { magpie Magpie TTS Multilingual 357M (experimental, ~0.04 RTFx — slow, needs perf work) tts-asr-verify Batch TTS→ASR roundtrip WER verification tts-benchmark Quantitative TTS benchmark (latency, quality, compute-unit sweep) + pocket-tts-cond-bench Benchmark PocketTTS cond_step prefill: legacy chunk-1 vs hybrid chunk-16 minimax-corpus Fetch MiniMax TTS Multilingual Test Set into Benchmarks/tts/corpus/minimax parakeet-eou Run Parakeet EOU Streaming ASR on a single file ctc-earnings-benchmark Run CTC keyword spotting benchmark on Earnings22 diff --git a/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsCondStepHybridTests.swift b/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsCondStepHybridTests.swift new file mode 100644 index 00000000..dd5abae0 --- /dev/null +++ b/Tests/FluidAudioTests/TTS/PocketTTS/PocketTtsCondStepHybridTests.swift @@ -0,0 +1,114 @@ +@preconcurrency import CoreML +import Foundation +import XCTest + +@testable import FluidAudio + +/// Tests for the chunked `cond_step` dispatch mode added alongside the +/// existing chunk-1 (legacy) pipeline. +/// +/// These tests are pure-logic: they verify the model store's mode plumbing +/// and accessor behaviour without requiring any downloaded CoreML artifacts. +/// End-to-end parity vs the legacy path is validated by the +/// `pocket-tts-cond-bench` CLI subcommand (which requires the chunk-16 +/// `mlmodelc` file to be placed manually under the language root). +final class PocketTtsCondStepHybridTests: XCTestCase { + + // MARK: - PocketTtsCondStepMode + + func testCondStepModeEquatable() { + XCTAssertEqual(PocketTtsCondStepMode.legacy, .legacy) + XCTAssertEqual(PocketTtsCondStepMode.chunked(chunk: 16), .chunked(chunk: 16)) + XCTAssertNotEqual(PocketTtsCondStepMode.legacy, .chunked(chunk: 16)) + XCTAssertNotEqual( + PocketTtsCondStepMode.chunked(chunk: 16), + PocketTtsCondStepMode.chunked(chunk: 32) + ) + } + + // MARK: - PocketTtsModelStore.condStepMode plumbing + + func testStoreDefaultModeIsLegacy() async { + let store = PocketTtsModelStore(language: .english) + let mode = await store.condStepMode + XCTAssertEqual(mode, .legacy) + let chunkSize = await store.condStepChunkSize() + XCTAssertNil(chunkSize, "legacy mode must not expose a chunk size") + } + + func testStoreChunkedModeExposesChunkSize() async { + let store = PocketTtsModelStore( + language: .english, condStepMode: .chunked(chunk: 16) + ) + let mode = await store.condStepMode + XCTAssertEqual(mode, .chunked(chunk: 16)) + let chunkSize = await store.condStepChunkSize() + XCTAssertEqual(chunkSize, 16) + } + + // MARK: - Accessor error semantics + + func testCondStepChunkModelThrowsWhenNotLoaded() async { + // Legacy store never loads the chunk model — accessor must throw + // a clean `modelNotFound` instead of returning a stale value. + let store = PocketTtsModelStore(language: .english) + do { + _ = try await store.condStepChunkModel() + XCTFail("expected condStepChunkModel() to throw in legacy mode") + } catch let error as PocketTTSError { + if case .modelNotFound = error { + // Expected + } else { + XCTFail("expected .modelNotFound, got \(error)") + } + } catch { + XCTFail("expected PocketTTSError, got \(type(of: error)): \(error)") + } + } + + func testCondStepChunkModelThrowsWhenChunkedButUnloaded() async { + // Chunked mode without `loadIfNeeded()` — accessor must throw + // (we never run the loader because that would require network + + // the unpublished chunk-16 file). + let store = PocketTtsModelStore( + language: .english, condStepMode: .chunked(chunk: 16) + ) + do { + _ = try await store.condStepChunkModel() + XCTFail("expected condStepChunkModel() to throw before load") + } catch let error as PocketTTSError { + if case .modelNotFound = error { + // Expected + } else { + XCTFail("expected .modelNotFound, got \(error)") + } + } catch { + XCTFail("expected PocketTTSError, got \(type(of: error)): \(error)") + } + } + + // MARK: - PocketTtsManager init plumbing + + func testManagerAcceptsCondStepMode() async { + // Smoke test: verify the manager init compiles + threads the mode + // through to the underlying store. We don't initialize() because + // that would require network access. + let manager = PocketTtsManager( + defaultVoice: "alba", + language: .english, + condStepMode: .chunked(chunk: 16) + ) + let isAvailable = await manager.isAvailable + XCTAssertFalse(isAvailable, "manager should not be available before initialize()") + } + + // MARK: - ModelNames + + func testCondStepChunk16FilenameMatchesConvention() { + // The chunk-16 mlmodelc file lives next to cond_step.mlmodelc under + // the v2// language root. Mismatching the filename here would + // surface as a confusing modelNotFound at load time. + XCTAssertEqual(ModelNames.PocketTTS.condStepChunk16, "cond_step_chunk16") + XCTAssertEqual(ModelNames.PocketTTS.condStepChunk16File, "cond_step_chunk16.mlmodelc") + } +}