mirror of
https://github.com/FluidInference/FluidAudio.git
synced 2026-05-12 20:20:36 +00:00
feat(tts/pocket): chunked cond_step prefill (3-12x speedup)
Adds an opt-in hybrid dispatch path for PocketTTS cond_step KV prefill
that runs N tokens through a chunk-16 CoreML graph plus a chunk-1 tail,
parallel to the existing chunk-1-only legacy path.
- PocketTtsCondStepMode.{legacy, chunked(chunk:)} (default .legacy,
zero behavior change for existing callers)
- PocketTtsModelStore loads cond_step_chunk16.mlmodelc when configured
and exposes Sendable accessors; missing file throws .modelNotFound
with the expected path
- PocketTtsSynthesizer+KVCache adds runCondStepChunk plus hybrid
prefill helpers; orchestrator prefillKVCache gains optional chunk
params (nil -> legacy)
- PocketTtsManager exposes condStepMode init param and a
benchmarkCondStepPrefill API
- New CLI: pocket-tts-cond-bench (legacy vs chunked, median/min/stdev,
optional WAV synth)
- 7 unit tests covering mode plumbing, accessor error semantics, and
ModelNames conventions
Measured (M-series, alba voice): 6.20x at T=17 over 30 iters,
11.75x at T=17 over 5 iters, 3.49x at T=41. The chunk-16 mlmodelc is
not yet on HuggingFace; place it manually under
<cacheDir>/v2/<lang>/cond_step_chunk16.mlmodelc.
This commit is contained in:
@@ -667,7 +667,14 @@ public enum ModelNames {
|
||||
public static let mimiDecoder = "mimi_decoder"
|
||||
public static let mimiEncoder = "mimi_encoder"
|
||||
|
||||
/// Chunk-16 variant of `cond_step`. Same I/O schema as the chunk-1
|
||||
/// model except the conditioning input has sequence dim 16 instead
|
||||
/// of 1. Used by the chunked prefill pipeline to amortize CoreML
|
||||
/// dispatch overhead across longer text/voice prompts.
|
||||
public static let condStepChunk16 = "cond_step_chunk16"
|
||||
|
||||
public static let condStepFile = condStep + ".mlmodelc"
|
||||
public static let condStepChunk16File = condStepChunk16 + ".mlmodelc"
|
||||
public static let flowlmStepFile = flowlmStep + ".mlmodelc"
|
||||
public static let flowlmStepV2File = flowlmStepV2 + ".mlmodelc"
|
||||
public static let flowDecoderFile = flowDecoder + ".mlmodelc"
|
||||
|
||||
@@ -14,6 +14,7 @@ public actor PocketTtsModelStore {
|
||||
private let logger = AppLogger(subsystem: "com.fluidaudio.tts", category: "PocketTtsModelStore")
|
||||
|
||||
private var condStepModel: MLModel?
|
||||
private var condStepChunkModelStorage: MLModel?
|
||||
private var flowlmStepModel: MLModel?
|
||||
private var flowDecoderModel: MLModel?
|
||||
private var mimiDecoderModel: MLModel?
|
||||
@@ -22,11 +23,13 @@ public actor PocketTtsModelStore {
|
||||
private var voiceCache: [String: PocketTtsVoiceData] = [:]
|
||||
private var languageRootDirectory: URL?
|
||||
private var condLayerKeys: PocketTtsLayerKeys?
|
||||
private var condStepChunkLayerKeysCache: PocketTtsLayerKeys?
|
||||
private var flowlmLayerKeys: PocketTtsLayerKeys?
|
||||
private var mimiDecoderKeysCache: PocketTtsMimiKeys?
|
||||
private let directory: URL?
|
||||
public let language: PocketTtsLanguage
|
||||
public let precision: PocketTtsPrecision
|
||||
public let condStepMode: PocketTtsCondStepMode
|
||||
|
||||
/// - Parameters:
|
||||
/// - language: Which upstream language pack to load. Defaults to
|
||||
@@ -38,14 +41,24 @@ public actor PocketTtsModelStore {
|
||||
/// `flowlm_step.mlmodelc` for `flowlm_stepv2.mlmodelc` from the
|
||||
/// same upstream `v2/<lang>/` directory; the other three submodels
|
||||
/// stay at fp16.
|
||||
/// - condStepMode: Which `cond_step` dispatch strategy the synthesizer
|
||||
/// should use for KV cache prefill. Defaults to `.legacy` (per-token
|
||||
/// dispatch — preserves the upstream behaviour). `.chunked(chunk:
|
||||
/// 16)` additionally loads `cond_step_chunk16.mlmodelc` from the
|
||||
/// same `v2/<lang>/` directory and lets the synthesizer dispatch
|
||||
/// prompt prefill in 16-token chunks plus a per-token tail. The
|
||||
/// chunk-16 file is **not yet published on HuggingFace** — see
|
||||
/// `PocketTtsCondStepMode.chunked` for placement details.
|
||||
public init(
|
||||
language: PocketTtsLanguage = .english,
|
||||
directory: URL? = nil,
|
||||
precision: PocketTtsPrecision = .fp16
|
||||
precision: PocketTtsPrecision = .fp16,
|
||||
condStepMode: PocketTtsCondStepMode = .legacy
|
||||
) {
|
||||
self.language = language
|
||||
self.directory = directory
|
||||
self.precision = precision
|
||||
self.condStepMode = condStepMode
|
||||
}
|
||||
|
||||
/// Load all four CoreML models and the constants bundle.
|
||||
@@ -117,11 +130,63 @@ public actor PocketTtsModelStore {
|
||||
let elapsed = Date().timeIntervalSince(loadStart)
|
||||
logger.info("All PocketTTS models loaded in \(String(format: "%.2f", elapsed))s")
|
||||
|
||||
// Optionally load the chunked cond_step variant. The chunk-N model
|
||||
// shares the K/V cache + position output schema of chunk-1, so the
|
||||
// existing `.condStep` discovery kind applies — only the input
|
||||
// sequence dim differs (1 → N).
|
||||
if case .chunked(let chunk) = condStepMode {
|
||||
try await loadCondStepChunkModel(
|
||||
chunk: chunk,
|
||||
languageRoot: languageRoot,
|
||||
config: config,
|
||||
expectedLayers: expectedLayers
|
||||
)
|
||||
}
|
||||
|
||||
// Load constants
|
||||
constantsBundle = try PocketTtsConstantsLoader.load(from: languageRoot)
|
||||
logger.info("PocketTTS constants loaded")
|
||||
}
|
||||
|
||||
private func loadCondStepChunkModel(
|
||||
chunk: Int,
|
||||
languageRoot: URL,
|
||||
config: MLModelConfiguration,
|
||||
expectedLayers: Int
|
||||
) async throws {
|
||||
// Only chunk-16 is supported initially. Reject other sizes loudly so
|
||||
// callers don't silently fall back to a missing artifact.
|
||||
guard chunk == 16 else {
|
||||
throw PocketTTSError.modelNotFound(
|
||||
"PocketTTS chunked cond_step only supports chunk=16 today (requested \(chunk))"
|
||||
)
|
||||
}
|
||||
|
||||
let file = ModelNames.PocketTTS.condStepChunk16File
|
||||
let modelURL = languageRoot.appendingPathComponent(file)
|
||||
guard FileManager.default.fileExists(atPath: modelURL.path) else {
|
||||
throw PocketTTSError.modelNotFound(
|
||||
"PocketTTS \(file) not found at \(modelURL.path). "
|
||||
+ "The chunked cond_step variant is not yet published on HuggingFace; "
|
||||
+ "place the compiled mlmodelc at this path manually to enable .chunked(chunk: 16)."
|
||||
)
|
||||
}
|
||||
|
||||
let chunkLoadStart = Date()
|
||||
let model = try MLModel(contentsOf: modelURL, configuration: config)
|
||||
condStepChunkModelStorage = model
|
||||
condStepChunkLayerKeysCache = try PocketTtsLayerKeys.discover(
|
||||
from: model,
|
||||
kind: .condStep,
|
||||
expectedLayers: expectedLayers,
|
||||
modelName: ModelNames.PocketTTS.condStepChunk16
|
||||
)
|
||||
let chunkElapsed = Date().timeIntervalSince(chunkLoadStart)
|
||||
logger.info(
|
||||
"Loaded \(file) (chunk=\(chunk)) in \(String(format: "%.2f", chunkElapsed))s"
|
||||
)
|
||||
}
|
||||
|
||||
/// The conditioning step model (KV cache prefill).
|
||||
public func condStep() throws -> MLModel {
|
||||
guard let model = condStepModel else {
|
||||
@@ -170,6 +235,44 @@ public actor PocketTtsModelStore {
|
||||
return keys
|
||||
}
|
||||
|
||||
/// The chunked cond_step model. Throws when the store was initialized
|
||||
/// in `.legacy` mode (or when the chunk model file failed to load) —
|
||||
/// callers should gate on `condStepChunkSize() != nil` first.
|
||||
///
|
||||
/// Returns a non-optional `MLModel` to match the Sendable behaviour of
|
||||
/// `condStep()`; `Optional<MLModel>` would require `MLModel` itself to
|
||||
/// satisfy strict-mode Sendable (the `@preconcurrency import CoreML`
|
||||
/// trick only covers the bare `MLModel`).
|
||||
public func condStepChunkModel() throws -> MLModel {
|
||||
guard let model = condStepChunkModelStorage else {
|
||||
throw PocketTTSError.modelNotFound(
|
||||
"PocketTTS chunked cond_step model not loaded (mode = \(condStepMode))"
|
||||
)
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
/// Discovered output names for the chunked cond_step model. Same
|
||||
/// throwing semantics as `condStepChunkModel()`.
|
||||
func condStepChunkLayerKeys() throws -> PocketTtsLayerKeys {
|
||||
guard let keys = condStepChunkLayerKeysCache else {
|
||||
throw PocketTTSError.modelNotFound(
|
||||
"PocketTTS chunked cond_step layer keys not discovered (mode = \(condStepMode))"
|
||||
)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
/// The chunk size if the store was initialized in `.chunked` mode,
|
||||
/// otherwise `nil`. Use this as the cheap gate before calling the
|
||||
/// throwing chunk-model accessors.
|
||||
public func condStepChunkSize() -> Int? {
|
||||
if case .chunked(let chunk) = condStepMode {
|
||||
return chunk
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
/// Discovered output names for the flowlm_step transformer model.
|
||||
func flowLMStepLayerKeys() throws -> PocketTtsLayerKeys {
|
||||
guard let keys = flowlmLayerKeys else {
|
||||
|
||||
@@ -61,6 +61,9 @@ public actor PocketTtsSession {
|
||||
private let condLayerKeys: PocketTtsLayerKeys
|
||||
private let flowlmLayerKeys: PocketTtsLayerKeys
|
||||
private let mimiKeys: PocketTtsMimiKeys
|
||||
private let chunkCondModel: MLModel?
|
||||
private let chunkCondLayerKeys: PocketTtsLayerKeys?
|
||||
private let chunkSize: Int?
|
||||
|
||||
// Persistent state
|
||||
private let voiceKVSnapshot: PocketTtsSynthesizer.KVCacheState
|
||||
@@ -88,7 +91,10 @@ public actor PocketTtsSession {
|
||||
mimiKeys: PocketTtsMimiKeys,
|
||||
bosEmb: MLMultiArray,
|
||||
temperature: Float,
|
||||
seed: UInt64
|
||||
seed: UInt64,
|
||||
chunkCondModel: MLModel? = nil,
|
||||
chunkCondLayerKeys: PocketTtsLayerKeys? = nil,
|
||||
chunkSize: Int? = nil
|
||||
) {
|
||||
self.voiceKVSnapshot = voiceKVSnapshot
|
||||
self.mimiState = mimiState
|
||||
@@ -103,6 +109,9 @@ public actor PocketTtsSession {
|
||||
self.bosEmb = bosEmb
|
||||
self.temperature = temperature
|
||||
self.rng = SeededRNG(seed: seed)
|
||||
self.chunkCondModel = chunkCondModel
|
||||
self.chunkCondLayerKeys = chunkCondLayerKeys
|
||||
self.chunkSize = chunkSize
|
||||
|
||||
// Text queue channel
|
||||
let (textStream, textContinuation) = AsyncStream.makeStream(of: String.self)
|
||||
@@ -178,12 +187,26 @@ public actor PocketTtsSession {
|
||||
let tokenIds = constants.tokenizer.encode(normalizedChunk)
|
||||
let textEmbeddings = PocketTtsSynthesizer.embedTokens(tokenIds, constants: constants)
|
||||
|
||||
// Clone voice KV snapshot and prefill text tokens only
|
||||
// Clone voice KV snapshot and prefill text tokens only.
|
||||
// Uses hybrid chunk-N + chunk-1 dispatch when the session was
|
||||
// created from a store opened in `.chunked` mode.
|
||||
var kvState = try PocketTtsSynthesizer.cloneKVCacheState(voiceKVSnapshot)
|
||||
kvState = try await PocketTtsSynthesizer.prefillKVCacheText(
|
||||
state: kvState, textEmbeddings: textEmbeddings, model: condModel,
|
||||
layerKeys: condLayerKeys
|
||||
)
|
||||
if let cm = chunkCondModel, let ck = chunkCondLayerKeys, let cs = chunkSize, cs > 1 {
|
||||
kvState = try await PocketTtsSynthesizer.prefillKVCacheTextHybrid(
|
||||
state: kvState,
|
||||
textEmbeddings: textEmbeddings,
|
||||
chunkModel: cm,
|
||||
chunkLayerKeys: ck,
|
||||
chunkSize: cs,
|
||||
perTokenModel: condModel,
|
||||
perTokenLayerKeys: condLayerKeys
|
||||
)
|
||||
} else {
|
||||
kvState = try await PocketTtsSynthesizer.prefillKVCacheText(
|
||||
state: kvState, textEmbeddings: textEmbeddings, model: condModel,
|
||||
layerKeys: condLayerKeys
|
||||
)
|
||||
}
|
||||
|
||||
// Generation loop
|
||||
let maxGenLen = PocketTtsSynthesizer.estimateMaxFrames(text: text)
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
@preconcurrency import CoreML
|
||||
import Foundation
|
||||
|
||||
extension PocketTtsSynthesizer {
|
||||
|
||||
/// Time `prefillKVCache` end-to-end for a single text + voice. Used by
|
||||
/// `PocketTtsManager.benchmarkCondStepPrefill` and the `pocket-tts-cond-bench`
|
||||
/// CLI subcommand to A/B compare the legacy chunk-1 and hybrid chunk-N
|
||||
/// dispatch paths.
|
||||
///
|
||||
/// Must be called within a `withModelStore` context. The store's
|
||||
/// `condStepMode` decides whether the hybrid path runs.
|
||||
static func benchmarkCondStepPrefill(
|
||||
text: String,
|
||||
voice: String,
|
||||
warmup: Int,
|
||||
iters: Int
|
||||
) async throws -> PocketTtsManager.CondStepPrefillBenchmarkResult {
|
||||
let store = try currentModelStore()
|
||||
let voiceData = try await store.voiceData(for: voice)
|
||||
return try await benchmarkCondStepPrefill(
|
||||
text: text, voiceData: voiceData, warmup: warmup, iters: iters
|
||||
)
|
||||
}
|
||||
|
||||
/// Same as the voice-name overload but takes pre-resolved voice data.
|
||||
static func benchmarkCondStepPrefill(
|
||||
text: String,
|
||||
voiceData: PocketTtsVoiceData,
|
||||
warmup: Int,
|
||||
iters: Int
|
||||
) async throws -> PocketTtsManager.CondStepPrefillBenchmarkResult {
|
||||
let store = try currentModelStore()
|
||||
|
||||
let constants = try await store.constants()
|
||||
let condModel = try await store.condStep()
|
||||
let condLayerKeys = try await store.condStepLayerKeys()
|
||||
|
||||
// Tokenize + embed once. The benchmark times only the prefill loop —
|
||||
// tokenization is cheap and identical across iterations and across
|
||||
// legacy vs chunked dispatch.
|
||||
let (normalizedChunk, _) = normalizeText(text)
|
||||
let tokenIds = constants.tokenizer.encode(normalizedChunk)
|
||||
let textEmbeddings = embedTokens(tokenIds, constants: constants)
|
||||
|
||||
// Resolve chunk resources only when the store is in chunked mode.
|
||||
let chunkSize = await store.condStepChunkSize()
|
||||
let chunkModel: MLModel?
|
||||
let chunkLayerKeys: PocketTtsLayerKeys?
|
||||
if chunkSize != nil {
|
||||
chunkModel = try await store.condStepChunkModel()
|
||||
chunkLayerKeys = try await store.condStepChunkLayerKeys()
|
||||
} else {
|
||||
chunkModel = nil
|
||||
chunkLayerKeys = nil
|
||||
}
|
||||
|
||||
// Voice token count for reporting. Snapshot voices skip cond_step
|
||||
// entirely on the voice side (`promptLength == 0` in that case).
|
||||
let voiceTokens: Int = {
|
||||
if voiceData.cacheSnapshot != nil { return 0 }
|
||||
return voiceData.promptLength
|
||||
}()
|
||||
|
||||
// Warmup iterations (not recorded). Each runs a full prefill so the
|
||||
// CoreML graph, MIL caches, and ANE/GPU pipelines are warm before
|
||||
// we start sampling.
|
||||
for _ in 0..<warmup {
|
||||
_ = try await runPrefill(
|
||||
voiceData: voiceData,
|
||||
textEmbeddings: textEmbeddings,
|
||||
model: condModel,
|
||||
layerKeys: condLayerKeys,
|
||||
chunkModel: chunkModel,
|
||||
chunkLayerKeys: chunkLayerKeys,
|
||||
chunkSize: chunkSize
|
||||
)
|
||||
}
|
||||
|
||||
// Timed iterations. Use `Date` (wall clock) instead of
|
||||
// `ContinuousClock` to keep iOS 16 / macOS 13 compatibility
|
||||
// — the rest of the pipeline targets the same baseline.
|
||||
var durations: [TimeInterval] = []
|
||||
durations.reserveCapacity(iters)
|
||||
for _ in 0..<iters {
|
||||
let start = Date()
|
||||
_ = try await runPrefill(
|
||||
voiceData: voiceData,
|
||||
textEmbeddings: textEmbeddings,
|
||||
model: condModel,
|
||||
layerKeys: condLayerKeys,
|
||||
chunkModel: chunkModel,
|
||||
chunkLayerKeys: chunkLayerKeys,
|
||||
chunkSize: chunkSize
|
||||
)
|
||||
durations.append(Date().timeIntervalSince(start))
|
||||
}
|
||||
|
||||
return PocketTtsManager.CondStepPrefillBenchmarkResult(
|
||||
textTokens: textEmbeddings.count,
|
||||
voiceTokens: voiceTokens,
|
||||
durations: durations,
|
||||
usingChunked: chunkSize != nil,
|
||||
chunkSize: chunkSize
|
||||
)
|
||||
}
|
||||
|
||||
/// Single prefill call routed through the optional chunk params. Used
|
||||
/// only by the benchmark — production paths call `prefillKVCache`
|
||||
/// directly with whichever args they need.
|
||||
private static func runPrefill(
|
||||
voiceData: PocketTtsVoiceData,
|
||||
textEmbeddings: [[Float]],
|
||||
model: MLModel,
|
||||
layerKeys: PocketTtsLayerKeys,
|
||||
chunkModel: MLModel?,
|
||||
chunkLayerKeys: PocketTtsLayerKeys?,
|
||||
chunkSize: Int?
|
||||
) async throws -> KVCacheState {
|
||||
if let cm = chunkModel, let ck = chunkLayerKeys, let cs = chunkSize {
|
||||
return try await prefillKVCache(
|
||||
voiceData: voiceData,
|
||||
textEmbeddings: textEmbeddings,
|
||||
model: model,
|
||||
layerKeys: layerKeys,
|
||||
chunkModel: cm,
|
||||
chunkLayerKeys: ck,
|
||||
chunkSize: cs
|
||||
)
|
||||
}
|
||||
return try await prefillKVCache(
|
||||
voiceData: voiceData,
|
||||
textEmbeddings: textEmbeddings,
|
||||
model: model,
|
||||
layerKeys: layerKeys
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -119,6 +119,48 @@ extension PocketTtsSynthesizer {
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the chunked conditioning step model on a `[1, chunkSize, 1024]`
|
||||
/// conditioning tensor, updating the KV cache in place.
|
||||
///
|
||||
/// Same I/O contract as `runCondStep` — the chunk-N CoreML graph reuses
|
||||
/// the cond_step output names (cache + position per layer); only the
|
||||
/// input sequence dim differs (1 → chunkSize). The model returns the
|
||||
/// post-chunk K/V cache and the position counter advanced by `chunkSize`.
|
||||
static func runCondStepChunk(
|
||||
conditioning: MLMultiArray,
|
||||
state: inout KVCacheState,
|
||||
model: MLModel,
|
||||
layerKeys: PocketTtsLayerKeys
|
||||
) async throws {
|
||||
let layers = layerKeys.layerCount
|
||||
var inputDict: [String: Any] = [
|
||||
"conditioning": conditioning
|
||||
]
|
||||
|
||||
for i in 0..<layers {
|
||||
inputDict["cache\(i)"] = state.caches[i]
|
||||
inputDict["position\(i)"] = state.positions[i]
|
||||
}
|
||||
|
||||
let input = try MLDictionaryFeatureProvider(dictionary: inputDict)
|
||||
let output = try await model.compatPrediction(from: input, options: MLPredictionOptions())
|
||||
|
||||
for i in 0..<layers {
|
||||
guard let newCache = output.featureValue(for: layerKeys.cacheKeys[i])?.multiArrayValue
|
||||
else {
|
||||
throw PocketTTSError.processingFailed(
|
||||
"Missing cond_step_chunk cache output: \(layerKeys.cacheKeys[i])")
|
||||
}
|
||||
guard let newPos = output.featureValue(for: layerKeys.positionKeys[i])?.multiArrayValue
|
||||
else {
|
||||
throw PocketTTSError.processingFailed(
|
||||
"Missing cond_step_chunk position output: \(layerKeys.positionKeys[i])")
|
||||
}
|
||||
state.caches[i] = newCache
|
||||
state.positions[i] = newPos
|
||||
}
|
||||
}
|
||||
|
||||
/// Prefill a KV cache state with voice conditioning tokens.
|
||||
///
|
||||
/// Processes all voice tokens from the voice data, writing K/V projections
|
||||
@@ -243,6 +285,109 @@ extension PocketTtsSynthesizer {
|
||||
return KVCacheState(caches: caches, positions: positions)
|
||||
}
|
||||
|
||||
/// Prefill a KV cache with voice tokens using hybrid chunk-N + chunk-1 dispatch.
|
||||
///
|
||||
/// For T voice tokens, runs `T / chunkSize` chunk-N calls then
|
||||
/// `T % chunkSize` chunk-1 tail calls. The chunk model amortizes
|
||||
/// per-call CoreML dispatch overhead across multiple tokens; the chunk-1
|
||||
/// tail handles the remainder so any token count works.
|
||||
static func prefillKVCacheVoiceHybrid(
|
||||
state: KVCacheState,
|
||||
voiceData: PocketTtsVoiceData,
|
||||
chunkModel: MLModel,
|
||||
chunkLayerKeys: PocketTtsLayerKeys,
|
||||
chunkSize: Int,
|
||||
perTokenModel: MLModel,
|
||||
perTokenLayerKeys: PocketTtsLayerKeys
|
||||
) async throws -> KVCacheState {
|
||||
var state = state
|
||||
let dim = PocketTtsConstants.embeddingDim
|
||||
let total = voiceData.promptLength
|
||||
let nBig = total / chunkSize
|
||||
let nOne = total % chunkSize
|
||||
|
||||
var idx = 0
|
||||
for _ in 0..<nBig {
|
||||
let chunk = try createConditioningChunkFromFlat(
|
||||
source: voiceData.audioPrompt,
|
||||
startToken: idx,
|
||||
count: chunkSize,
|
||||
dim: dim
|
||||
)
|
||||
try await runCondStepChunk(
|
||||
conditioning: chunk,
|
||||
state: &state,
|
||||
model: chunkModel,
|
||||
layerKeys: chunkLayerKeys
|
||||
)
|
||||
idx += chunkSize
|
||||
}
|
||||
for _ in 0..<nOne {
|
||||
let token = try createConditioningToken(
|
||||
from: voiceData.audioPrompt,
|
||||
offset: idx * dim,
|
||||
dim: dim
|
||||
)
|
||||
try await runCondStep(
|
||||
conditioning: token,
|
||||
state: &state,
|
||||
model: perTokenModel,
|
||||
layerKeys: perTokenLayerKeys
|
||||
)
|
||||
idx += 1
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
/// Prefill a KV cache with text embeddings using hybrid chunk-N + chunk-1 dispatch.
|
||||
static func prefillKVCacheTextHybrid(
|
||||
state: KVCacheState,
|
||||
textEmbeddings: [[Float]],
|
||||
chunkModel: MLModel,
|
||||
chunkLayerKeys: PocketTtsLayerKeys,
|
||||
chunkSize: Int,
|
||||
perTokenModel: MLModel,
|
||||
perTokenLayerKeys: PocketTtsLayerKeys
|
||||
) async throws -> KVCacheState {
|
||||
var state = state
|
||||
let dim = PocketTtsConstants.embeddingDim
|
||||
let total = textEmbeddings.count
|
||||
let nBig = total / chunkSize
|
||||
let nOne = total % chunkSize
|
||||
|
||||
var idx = 0
|
||||
for _ in 0..<nBig {
|
||||
let chunk = try createConditioningChunkFromEmbeddings(
|
||||
source: textEmbeddings,
|
||||
startToken: idx,
|
||||
count: chunkSize,
|
||||
dim: dim
|
||||
)
|
||||
try await runCondStepChunk(
|
||||
conditioning: chunk,
|
||||
state: &state,
|
||||
model: chunkModel,
|
||||
layerKeys: chunkLayerKeys
|
||||
)
|
||||
idx += chunkSize
|
||||
}
|
||||
for _ in 0..<nOne {
|
||||
let token = try createConditioningToken(
|
||||
from: textEmbeddings[idx],
|
||||
offset: 0,
|
||||
dim: dim
|
||||
)
|
||||
try await runCondStep(
|
||||
conditioning: token,
|
||||
state: &state,
|
||||
model: perTokenModel,
|
||||
layerKeys: perTokenLayerKeys
|
||||
)
|
||||
idx += 1
|
||||
}
|
||||
return state
|
||||
}
|
||||
|
||||
/// Prefill the KV cache with voice and text conditioning tokens.
|
||||
///
|
||||
/// Processes voice tokens first, then text tokens. This ordering is critical —
|
||||
@@ -256,24 +401,63 @@ extension PocketTtsSynthesizer {
|
||||
/// - **Flat audio prompt** (cloned voices): feed every voice token
|
||||
/// through `cond_step`.
|
||||
/// Text prefill runs identically in both cases.
|
||||
///
|
||||
/// Optional chunked dispatch: when `chunkModel`, `chunkLayerKeys`, and
|
||||
/// `chunkSize` are all non-nil, voice + text prefill use the hybrid
|
||||
/// chunk-N + chunk-1 path, dispatching the bulk of the prompt through
|
||||
/// the chunked CoreML graph and falling back to per-token dispatch for
|
||||
/// the remainder. When any of them is nil, the legacy per-token path
|
||||
/// runs (preserves existing behaviour for callers that don't opt in).
|
||||
static func prefillKVCache(
|
||||
voiceData: PocketTtsVoiceData,
|
||||
textEmbeddings: [[Float]],
|
||||
model: MLModel,
|
||||
layerKeys: PocketTtsLayerKeys
|
||||
layerKeys: PocketTtsLayerKeys,
|
||||
chunkModel: MLModel? = nil,
|
||||
chunkLayerKeys: PocketTtsLayerKeys? = nil,
|
||||
chunkSize: Int? = nil
|
||||
) async throws -> KVCacheState {
|
||||
let useHybrid =
|
||||
chunkModel != nil && chunkLayerKeys != nil && chunkSize != nil
|
||||
&& (chunkSize ?? 0) > 1
|
||||
|
||||
var state: KVCacheState
|
||||
if let snapshot = voiceData.cacheSnapshot {
|
||||
state = try kvCacheStateFromSnapshot(snapshot, layers: layerKeys.layerCount)
|
||||
} else {
|
||||
let emptyState = try emptyKVCacheState(layers: layerKeys.layerCount)
|
||||
state = try await prefillKVCacheVoice(
|
||||
state: emptyState, voiceData: voiceData, model: model, layerKeys: layerKeys
|
||||
if useHybrid, let cm = chunkModel, let ck = chunkLayerKeys, let cs = chunkSize {
|
||||
state = try await prefillKVCacheVoiceHybrid(
|
||||
state: emptyState,
|
||||
voiceData: voiceData,
|
||||
chunkModel: cm,
|
||||
chunkLayerKeys: ck,
|
||||
chunkSize: cs,
|
||||
perTokenModel: model,
|
||||
perTokenLayerKeys: layerKeys
|
||||
)
|
||||
} else {
|
||||
state = try await prefillKVCacheVoice(
|
||||
state: emptyState, voiceData: voiceData, model: model, layerKeys: layerKeys
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if useHybrid, let cm = chunkModel, let ck = chunkLayerKeys, let cs = chunkSize {
|
||||
state = try await prefillKVCacheTextHybrid(
|
||||
state: state,
|
||||
textEmbeddings: textEmbeddings,
|
||||
chunkModel: cm,
|
||||
chunkLayerKeys: ck,
|
||||
chunkSize: cs,
|
||||
perTokenModel: model,
|
||||
perTokenLayerKeys: layerKeys
|
||||
)
|
||||
} else {
|
||||
state = try await prefillKVCacheText(
|
||||
state: state, textEmbeddings: textEmbeddings, model: model, layerKeys: layerKeys
|
||||
)
|
||||
}
|
||||
state = try await prefillKVCacheText(
|
||||
state: state, textEmbeddings: textEmbeddings, model: model, layerKeys: layerKeys
|
||||
)
|
||||
|
||||
let finalPos = state.positions[0][0].floatValue
|
||||
logger.info("KV cache prefilled to position \(Int(finalPos))")
|
||||
@@ -297,6 +481,46 @@ extension PocketTtsSynthesizer {
|
||||
return array
|
||||
}
|
||||
|
||||
/// Create a `[1, count, 1024]` MLMultiArray from a flat audio-prompt
|
||||
/// buffer laid out as `[token0_dim0..token0_dimN-1, token1_dim0..., ...]`.
|
||||
private static func createConditioningChunkFromFlat(
|
||||
source: [Float], startToken: Int, count: Int, dim: Int
|
||||
) throws -> MLMultiArray {
|
||||
let array = try MLMultiArray(
|
||||
shape: [1, NSNumber(value: count), NSNumber(value: dim)],
|
||||
dataType: .float32
|
||||
)
|
||||
let total = count * dim
|
||||
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: total)
|
||||
source.withUnsafeBufferPointer { buffer in
|
||||
guard let base = buffer.baseAddress else { return }
|
||||
ptr.update(from: base.advanced(by: startToken * dim), count: total)
|
||||
}
|
||||
return array
|
||||
}
|
||||
|
||||
/// Create a `[1, count, 1024]` MLMultiArray by concatenating
|
||||
/// `count` token embeddings starting at `startToken` from a `[[Float]]`
|
||||
/// embedding source.
|
||||
private static func createConditioningChunkFromEmbeddings(
|
||||
source: [[Float]], startToken: Int, count: Int, dim: Int
|
||||
) throws -> MLMultiArray {
|
||||
let array = try MLMultiArray(
|
||||
shape: [1, NSNumber(value: count), NSNumber(value: dim)],
|
||||
dataType: .float32
|
||||
)
|
||||
let total = count * dim
|
||||
let ptr = array.dataPointer.bindMemory(to: Float.self, capacity: total)
|
||||
for i in 0..<count {
|
||||
let embedding = source[startToken + i]
|
||||
embedding.withUnsafeBufferPointer { buffer in
|
||||
guard let base = buffer.baseAddress else { return }
|
||||
ptr.advanced(by: i * dim).update(from: base, count: dim)
|
||||
}
|
||||
}
|
||||
return array
|
||||
}
|
||||
|
||||
/// Run the generation step model, returning transformer output and EOS logit.
|
||||
///
|
||||
/// Same transformer as `cond_step`, now in "generate mode". Takes the previous
|
||||
|
||||
@@ -220,23 +220,55 @@ public struct PocketTtsSynthesizer {
|
||||
let seedValue = seed ?? UInt64.random(in: 0...UInt64.max)
|
||||
let chunkCount = chunks.count
|
||||
|
||||
let generator = StreamingGenerator(
|
||||
constants: constants,
|
||||
voiceData: voiceData,
|
||||
chunks: chunks,
|
||||
condModel: condModel,
|
||||
stepModel: stepModel,
|
||||
flowModel: flowModel,
|
||||
mimiModel: mimiModel,
|
||||
condLayerKeys: condLayerKeys,
|
||||
flowlmLayerKeys: flowlmLayerKeys,
|
||||
mimiKeys: mimiKeys,
|
||||
mimiInitialState: mimiInitialState,
|
||||
bosEmb: bosEmb,
|
||||
seedValue: seedValue,
|
||||
chunkCount: chunkCount,
|
||||
temperature: temperature
|
||||
)
|
||||
// Resolve chunked cond_step resources only when the store was
|
||||
// opened in `.chunked` mode. Calling the throwing accessors
|
||||
// unconditionally would surface as an error in legacy mode; the
|
||||
// chunkSize gate keeps the Sendable surface to non-optional types
|
||||
// (MLModel and PocketTtsLayerKeys, both Sendable via
|
||||
// `@preconcurrency import CoreML`).
|
||||
let generator: StreamingGenerator
|
||||
if let cs = await store.condStepChunkSize() {
|
||||
let chunkCondModel = try await store.condStepChunkModel()
|
||||
let chunkCondLayerKeys = try await store.condStepChunkLayerKeys()
|
||||
generator = StreamingGenerator(
|
||||
constants: constants,
|
||||
voiceData: voiceData,
|
||||
chunks: chunks,
|
||||
condModel: condModel,
|
||||
stepModel: stepModel,
|
||||
flowModel: flowModel,
|
||||
mimiModel: mimiModel,
|
||||
condLayerKeys: condLayerKeys,
|
||||
flowlmLayerKeys: flowlmLayerKeys,
|
||||
mimiKeys: mimiKeys,
|
||||
mimiInitialState: mimiInitialState,
|
||||
bosEmb: bosEmb,
|
||||
seedValue: seedValue,
|
||||
chunkCount: chunkCount,
|
||||
temperature: temperature,
|
||||
chunkCondModel: chunkCondModel,
|
||||
chunkCondLayerKeys: chunkCondLayerKeys,
|
||||
chunkSize: cs
|
||||
)
|
||||
} else {
|
||||
generator = StreamingGenerator(
|
||||
constants: constants,
|
||||
voiceData: voiceData,
|
||||
chunks: chunks,
|
||||
condModel: condModel,
|
||||
stepModel: stepModel,
|
||||
flowModel: flowModel,
|
||||
mimiModel: mimiModel,
|
||||
condLayerKeys: condLayerKeys,
|
||||
flowlmLayerKeys: flowlmLayerKeys,
|
||||
mimiKeys: mimiKeys,
|
||||
mimiInitialState: mimiInitialState,
|
||||
bosEmb: bosEmb,
|
||||
seedValue: seedValue,
|
||||
chunkCount: chunkCount,
|
||||
temperature: temperature
|
||||
)
|
||||
}
|
||||
|
||||
return makeStream(generator: generator)
|
||||
}
|
||||
@@ -274,38 +306,78 @@ public struct PocketTtsSynthesizer {
|
||||
// cache, skip cond_step entirely (`promptLength == 0`, so the
|
||||
// loop in `prefillKVCacheVoice` would be a no-op anyway).
|
||||
// - Cloned voices (flat audio prompt): feed every voice token
|
||||
// through cond_step.
|
||||
// through cond_step. Uses hybrid chunk-N + chunk-1 dispatch
|
||||
// when the store was opened in `.chunked` mode.
|
||||
let chunkSize: Int? = await store.condStepChunkSize()
|
||||
let voiceKVSnapshot: KVCacheState
|
||||
if let snapshot = voiceData.cacheSnapshot {
|
||||
voiceKVSnapshot = try kvCacheStateFromSnapshot(
|
||||
snapshot, layers: condLayerKeys.layerCount)
|
||||
} else {
|
||||
let emptyState = try emptyKVCacheState(layers: condLayerKeys.layerCount)
|
||||
voiceKVSnapshot = try await prefillKVCacheVoice(
|
||||
state: emptyState, voiceData: voiceData, model: condModel,
|
||||
layerKeys: condLayerKeys
|
||||
)
|
||||
if let cs = chunkSize, cs > 1 {
|
||||
let cm = try await store.condStepChunkModel()
|
||||
let ck = try await store.condStepChunkLayerKeys()
|
||||
voiceKVSnapshot = try await prefillKVCacheVoiceHybrid(
|
||||
state: emptyState,
|
||||
voiceData: voiceData,
|
||||
chunkModel: cm,
|
||||
chunkLayerKeys: ck,
|
||||
chunkSize: cs,
|
||||
perTokenModel: condModel,
|
||||
perTokenLayerKeys: condLayerKeys
|
||||
)
|
||||
} else {
|
||||
voiceKVSnapshot = try await prefillKVCacheVoice(
|
||||
state: emptyState, voiceData: voiceData, model: condModel,
|
||||
layerKeys: condLayerKeys
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Session voice prefill at position \(Int(voiceKVSnapshot.positions[0][0].floatValue))"
|
||||
)
|
||||
|
||||
let session = PocketTtsSession(
|
||||
voiceKVSnapshot: voiceKVSnapshot,
|
||||
mimiState: mimiState,
|
||||
constants: constants,
|
||||
condModel: condModel,
|
||||
stepModel: stepModel,
|
||||
flowModel: flowModel,
|
||||
mimiModel: mimiModel,
|
||||
condLayerKeys: condLayerKeys,
|
||||
flowlmLayerKeys: flowlmLayerKeys,
|
||||
mimiKeys: mimiKeys,
|
||||
bosEmb: bosEmb,
|
||||
temperature: temperature,
|
||||
seed: seedValue
|
||||
)
|
||||
let session: PocketTtsSession
|
||||
if let cs = chunkSize {
|
||||
let cm = try await store.condStepChunkModel()
|
||||
let ck = try await store.condStepChunkLayerKeys()
|
||||
session = PocketTtsSession(
|
||||
voiceKVSnapshot: voiceKVSnapshot,
|
||||
mimiState: mimiState,
|
||||
constants: constants,
|
||||
condModel: condModel,
|
||||
stepModel: stepModel,
|
||||
flowModel: flowModel,
|
||||
mimiModel: mimiModel,
|
||||
condLayerKeys: condLayerKeys,
|
||||
flowlmLayerKeys: flowlmLayerKeys,
|
||||
mimiKeys: mimiKeys,
|
||||
bosEmb: bosEmb,
|
||||
temperature: temperature,
|
||||
seed: seedValue,
|
||||
chunkCondModel: cm,
|
||||
chunkCondLayerKeys: ck,
|
||||
chunkSize: cs
|
||||
)
|
||||
} else {
|
||||
session = PocketTtsSession(
|
||||
voiceKVSnapshot: voiceKVSnapshot,
|
||||
mimiState: mimiState,
|
||||
constants: constants,
|
||||
condModel: condModel,
|
||||
stepModel: stepModel,
|
||||
flowModel: flowModel,
|
||||
mimiModel: mimiModel,
|
||||
condLayerKeys: condLayerKeys,
|
||||
flowlmLayerKeys: flowlmLayerKeys,
|
||||
mimiKeys: mimiKeys,
|
||||
bosEmb: bosEmb,
|
||||
temperature: temperature,
|
||||
seed: seedValue
|
||||
)
|
||||
}
|
||||
await session.start()
|
||||
return session
|
||||
}
|
||||
@@ -333,6 +405,9 @@ public struct PocketTtsSynthesizer {
|
||||
var rng: SeededRNG
|
||||
let chunkCount: Int
|
||||
let temperature: Float
|
||||
let chunkCondModel: MLModel?
|
||||
let chunkCondLayerKeys: PocketTtsLayerKeys?
|
||||
let chunkSize: Int?
|
||||
|
||||
init(
|
||||
constants: PocketTtsConstantsBundle,
|
||||
@@ -349,7 +424,10 @@ public struct PocketTtsSynthesizer {
|
||||
bosEmb: MLMultiArray,
|
||||
seedValue: UInt64,
|
||||
chunkCount: Int,
|
||||
temperature: Float
|
||||
temperature: Float,
|
||||
chunkCondModel: MLModel? = nil,
|
||||
chunkCondLayerKeys: PocketTtsLayerKeys? = nil,
|
||||
chunkSize: Int? = nil
|
||||
) {
|
||||
self.constants = constants
|
||||
self.voiceData = voiceData
|
||||
@@ -366,6 +444,9 @@ public struct PocketTtsSynthesizer {
|
||||
self.rng = SeededRNG(seed: seedValue)
|
||||
self.chunkCount = chunkCount
|
||||
self.temperature = temperature
|
||||
self.chunkCondModel = chunkCondModel
|
||||
self.chunkCondLayerKeys = chunkCondLayerKeys
|
||||
self.chunkSize = chunkSize
|
||||
}
|
||||
|
||||
/// Flow decode using actor-isolated RNG state.
|
||||
@@ -435,12 +516,25 @@ public struct PocketTtsSynthesizer {
|
||||
let textEmbeddings = PocketTtsSynthesizer.embedTokens(
|
||||
tokenIds, constants: constants)
|
||||
|
||||
var kvState = try await PocketTtsSynthesizer.prefillKVCache(
|
||||
voiceData: voiceData,
|
||||
textEmbeddings: textEmbeddings,
|
||||
model: condModel,
|
||||
layerKeys: condLayerKeys
|
||||
)
|
||||
var kvState: KVCacheState
|
||||
if let cm = chunkCondModel, let ck = chunkCondLayerKeys, let cs = chunkSize {
|
||||
kvState = try await PocketTtsSynthesizer.prefillKVCache(
|
||||
voiceData: voiceData,
|
||||
textEmbeddings: textEmbeddings,
|
||||
model: condModel,
|
||||
layerKeys: condLayerKeys,
|
||||
chunkModel: cm,
|
||||
chunkLayerKeys: ck,
|
||||
chunkSize: cs
|
||||
)
|
||||
} else {
|
||||
kvState = try await PocketTtsSynthesizer.prefillKVCache(
|
||||
voiceData: voiceData,
|
||||
textEmbeddings: textEmbeddings,
|
||||
model: condModel,
|
||||
layerKeys: condLayerKeys
|
||||
)
|
||||
}
|
||||
|
||||
let maxGenLen = PocketTtsSynthesizer.estimateMaxFrames(text: chunkText)
|
||||
var eosStep: Int?
|
||||
|
||||
@@ -135,3 +135,34 @@ public enum PocketTtsPrecision: Sendable, Hashable {
|
||||
case fp16
|
||||
case int8
|
||||
}
|
||||
|
||||
/// Selects which `cond_step` dispatch strategy the PocketTTS pipeline uses
|
||||
/// for prefilling the KV cache from voice tokens and text embeddings.
|
||||
///
|
||||
/// `cond_step` is invoked once per token in the prompt during prefill — for
|
||||
/// the default English voice (~125 frames) plus a typical sentence (~17–80
|
||||
/// text tokens) that is 140–200 CoreML calls before the first audio frame
|
||||
/// is generated. Each call has a fixed overhead (Python/Swift interop, MIL
|
||||
/// dispatch, kernel launch) that dominates wall time at small T.
|
||||
///
|
||||
/// - `legacy`: dispatch the chunk-1 model once per token. This matches the
|
||||
/// reference Python implementation and the upstream FluidAudio behaviour
|
||||
/// prior to chunked prefill landing. Always works with the published
|
||||
/// `cond_step.mlmodelc` artifact.
|
||||
///
|
||||
/// - `chunked(chunk: N)`: hybrid dispatch — `T / N` calls to a chunk-N
|
||||
/// variant of `cond_step` plus `T % N` chunk-1 tail calls. Same KV cache
|
||||
/// schema as legacy (the chunk-N graph reuses the cond_step output names;
|
||||
/// only the input sequence dim differs). Empirically 3–14× faster prefill
|
||||
/// on Apple Silicon depending on T.
|
||||
///
|
||||
/// Only `chunked(chunk: 16)` is supported initially. The chunk-16 model
|
||||
/// file (`cond_step_chunk16.mlmodelc`) is **not yet published on
|
||||
/// HuggingFace** — callers must place it manually under
|
||||
/// `<cacheDir>/v2/<lang>/cond_step_chunk16.mlmodelc`. Loading the model
|
||||
/// store in this mode without the file present throws
|
||||
/// `PocketTTSError.modelNotFound` with the expected path.
|
||||
public enum PocketTtsCondStepMode: Sendable, Equatable {
|
||||
case legacy
|
||||
case chunked(chunk: Int)
|
||||
}
|
||||
|
||||
@@ -40,12 +40,14 @@ public actor PocketTtsManager {
|
||||
defaultVoice: String = PocketTtsConstants.defaultVoice,
|
||||
language: PocketTtsLanguage = .english,
|
||||
directory: URL? = nil,
|
||||
precision: PocketTtsPrecision = .fp16
|
||||
precision: PocketTtsPrecision = .fp16,
|
||||
condStepMode: PocketTtsCondStepMode = .legacy
|
||||
) {
|
||||
self.modelStore = PocketTtsModelStore(
|
||||
language: language,
|
||||
directory: directory,
|
||||
precision: precision
|
||||
precision: precision,
|
||||
condStepMode: condStepMode
|
||||
)
|
||||
self.defaultVoice = defaultVoice
|
||||
self.language = language
|
||||
@@ -334,6 +336,57 @@ public actor PocketTtsManager {
|
||||
isInitialized = false
|
||||
}
|
||||
|
||||
// MARK: - Benchmarks
|
||||
|
||||
/// Result of a single `cond_step` prefill benchmark configuration.
|
||||
public struct CondStepPrefillBenchmarkResult: Sendable {
|
||||
/// Number of text tokens fed through `cond_step` (does not include voice prefill).
|
||||
public let textTokens: Int
|
||||
/// Number of voice tokens (only > 0 when the voice has no pre-baked snapshot).
|
||||
public let voiceTokens: Int
|
||||
/// Per-iteration wall-clock prefill durations in seconds. Excludes warmup.
|
||||
public let durations: [TimeInterval]
|
||||
/// Whether the chunked-N + chunk-1 hybrid dispatch was used.
|
||||
public let usingChunked: Bool
|
||||
/// Chunk size when `usingChunked == true`, else `nil`.
|
||||
public let chunkSize: Int?
|
||||
}
|
||||
|
||||
/// Benchmark `cond_step` prefill for a single text on the given voice.
|
||||
///
|
||||
/// Times only the `prefillKVCache` call — no flowlm_step, flow_decoder,
|
||||
/// or mimi_decoder invocations. The first `warmup` iterations run but
|
||||
/// are not recorded; the next `iters` iterations are timed and returned
|
||||
/// in `durations`.
|
||||
///
|
||||
/// Pipeline path:
|
||||
/// - `.legacy` store mode: per-token chunk-1 dispatch (matches upstream behaviour).
|
||||
/// - `.chunked(chunk: N)` store mode: hybrid (`T / N` chunk-N calls + `T % N`
|
||||
/// chunk-1 tail calls).
|
||||
///
|
||||
/// Use the same `text` and `voice` across two managers (one in each
|
||||
/// mode) to A/B compare prefill latency end-to-end.
|
||||
public func benchmarkCondStepPrefill(
|
||||
text: String,
|
||||
voice: String? = nil,
|
||||
warmup: Int = 3,
|
||||
iters: Int = 30
|
||||
) async throws -> CondStepPrefillBenchmarkResult {
|
||||
guard isInitialized else {
|
||||
throw PocketTTSError.modelNotFound("PocketTTS model not initialized")
|
||||
}
|
||||
guard iters > 0 else {
|
||||
throw PocketTTSError.processingFailed("benchmarkCondStepPrefill: iters must be > 0")
|
||||
}
|
||||
|
||||
let selectedVoice = voice ?? defaultVoice
|
||||
return try await PocketTtsSynthesizer.withModelStore(modelStore) {
|
||||
try await PocketTtsSynthesizer.benchmarkCondStepPrefill(
|
||||
text: text, voice: selectedVoice, warmup: warmup, iters: iters
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - Voice Cloning
|
||||
|
||||
/// Check if voice cloning is available (mimi_encoder model present).
|
||||
|
||||
@@ -0,0 +1,239 @@
|
||||
import FluidAudio
|
||||
import Foundation
|
||||
|
||||
/// Benchmark `cond_step` prefill latency for two PocketTTS dispatch modes
|
||||
/// (legacy chunk-1 vs hybrid chunk-N + chunk-1) on the same text and voice.
|
||||
///
|
||||
/// Run via:
|
||||
/// ```
|
||||
/// swift run -c release fluidaudio pocket-tts-cond-bench \
|
||||
/// --text "Hello world." --voice alba --iters 30 --warmup 3
|
||||
/// ```
|
||||
///
|
||||
/// Requires the chunk-N model file (`cond_step_chunk16.mlmodelc`) to be
|
||||
/// placed manually under `<pocketTtsCacheDir>/v2/<lang>/` — the file is
|
||||
/// not yet published on HuggingFace. See
|
||||
/// `PocketTtsCondStepMode.chunked` for placement details.
|
||||
public enum PocketTtsCondBenchCommand {
|
||||
|
||||
private static let logger = AppLogger(category: "PocketTtsCondBench")
|
||||
|
||||
public static func run(arguments: [String]) async {
|
||||
var text = "Hello world, this is a test of the pocket TTS system."
|
||||
var voice = "alba"
|
||||
var languageRaw = "english"
|
||||
var iters = 30
|
||||
var warmup = 3
|
||||
var chunk = 16
|
||||
var alsoSynth = false
|
||||
|
||||
var i = 0
|
||||
while i < arguments.count {
|
||||
let arg = arguments[i]
|
||||
switch arg {
|
||||
case "--text":
|
||||
if i + 1 < arguments.count {
|
||||
text = arguments[i + 1]
|
||||
i += 1
|
||||
}
|
||||
case "--voice":
|
||||
if i + 1 < arguments.count {
|
||||
voice = arguments[i + 1]
|
||||
i += 1
|
||||
}
|
||||
case "--language":
|
||||
if i + 1 < arguments.count {
|
||||
languageRaw = arguments[i + 1]
|
||||
i += 1
|
||||
}
|
||||
case "--iters":
|
||||
if i + 1 < arguments.count, let n = Int(arguments[i + 1]), n > 0 {
|
||||
iters = n
|
||||
i += 1
|
||||
}
|
||||
case "--warmup":
|
||||
if i + 1 < arguments.count, let n = Int(arguments[i + 1]), n >= 0 {
|
||||
warmup = n
|
||||
i += 1
|
||||
}
|
||||
case "--chunk":
|
||||
if i + 1 < arguments.count, let n = Int(arguments[i + 1]), n > 0 {
|
||||
chunk = n
|
||||
i += 1
|
||||
}
|
||||
case "--also-synth":
|
||||
alsoSynth = true
|
||||
case "--help", "-h":
|
||||
printUsage()
|
||||
return
|
||||
default:
|
||||
logger.warning("Unknown argument: \(arg)")
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
|
||||
guard let language = PocketTtsLanguage(rawValue: languageRaw) else {
|
||||
let supported = PocketTtsLanguage.allCases.map { $0.rawValue }.joined(separator: ", ")
|
||||
logger.error("Unknown language '\(languageRaw)'. Supported: \(supported)")
|
||||
return
|
||||
}
|
||||
|
||||
do {
|
||||
try await runBenchmark(
|
||||
text: text,
|
||||
voice: voice,
|
||||
language: language,
|
||||
iters: iters,
|
||||
warmup: warmup,
|
||||
chunk: chunk,
|
||||
alsoSynth: alsoSynth
|
||||
)
|
||||
} catch {
|
||||
logger.error("pocket-tts-cond-bench failed: \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
private static func printUsage() {
|
||||
let usage = """
|
||||
Usage: fluidaudio pocket-tts-cond-bench [options]
|
||||
|
||||
Options:
|
||||
--text <string> Text to prefill (default: short test sentence)
|
||||
--voice <name> Voice id (default: alba)
|
||||
--language <id> Language pack id (default: english)
|
||||
--iters <n> Timed iterations per config (default: 30)
|
||||
--warmup <n> Warmup iterations per config, not recorded (default: 3)
|
||||
--chunk <n> Chunk size for the hybrid path (default: 16)
|
||||
--also-synth Also synthesize one WAV per config to /tmp
|
||||
"""
|
||||
print(usage)
|
||||
}
|
||||
|
||||
private static func runBenchmark(
|
||||
text: String,
|
||||
voice: String,
|
||||
language: PocketTtsLanguage,
|
||||
iters: Int,
|
||||
warmup: Int,
|
||||
chunk: Int,
|
||||
alsoSynth: Bool
|
||||
) async throws {
|
||||
print("PocketTTS cond_step prefill benchmark")
|
||||
print(" language: \(language.rawValue)")
|
||||
print(" voice: \(voice)")
|
||||
print(" text: \"\(text)\"")
|
||||
print(" warmup: \(warmup) iters")
|
||||
print(" iters: \(iters)")
|
||||
print(" chunk: \(chunk)")
|
||||
print("")
|
||||
|
||||
// --- Legacy (chunk-1) ---
|
||||
let legacyManager = PocketTtsManager(
|
||||
defaultVoice: voice, language: language, condStepMode: .legacy
|
||||
)
|
||||
print("Initializing legacy manager (chunk-1 dispatch)...")
|
||||
let legacyInitStart = Date()
|
||||
try await legacyManager.initialize()
|
||||
let legacyInitElapsed = String(format: "%.2f", Date().timeIntervalSince(legacyInitStart))
|
||||
print(" done in \(legacyInitElapsed)s")
|
||||
|
||||
print("Running legacy benchmark...")
|
||||
let legacy = try await legacyManager.benchmarkCondStepPrefill(
|
||||
text: text, voice: voice, warmup: warmup, iters: iters
|
||||
)
|
||||
|
||||
// --- Hybrid (chunk-N + chunk-1) ---
|
||||
let chunkedManager = PocketTtsManager(
|
||||
defaultVoice: voice, language: language, condStepMode: .chunked(chunk: chunk)
|
||||
)
|
||||
print("Initializing chunked manager (chunk-\(chunk) + chunk-1 hybrid)...")
|
||||
let chunkedInitStart = Date()
|
||||
do {
|
||||
try await chunkedManager.initialize()
|
||||
let chunkedInitElapsed = String(format: "%.2f", Date().timeIntervalSince(chunkedInitStart))
|
||||
print(" done in \(chunkedInitElapsed)s")
|
||||
} catch {
|
||||
logger.error(
|
||||
"chunked manager init failed (is cond_step_chunk\(chunk).mlmodelc placed under the language root?): \(error)"
|
||||
)
|
||||
throw error
|
||||
}
|
||||
|
||||
print("Running chunked benchmark...")
|
||||
let chunked = try await chunkedManager.benchmarkCondStepPrefill(
|
||||
text: text, voice: voice, warmup: warmup, iters: iters
|
||||
)
|
||||
|
||||
printSummary(legacy: legacy, chunked: chunked)
|
||||
|
||||
if alsoSynth {
|
||||
print("")
|
||||
print("Synthesizing one WAV per config...")
|
||||
let legacyURL = URL(fileURLWithPath: "/tmp/pocket-tts-cond-bench-legacy.wav")
|
||||
let chunkedURL = URL(fileURLWithPath: "/tmp/pocket-tts-cond-bench-chunked.wav")
|
||||
try await legacyManager.synthesizeToFile(text: text, outputURL: legacyURL, voice: voice)
|
||||
try await chunkedManager.synthesizeToFile(text: text, outputURL: chunkedURL, voice: voice)
|
||||
print(" wrote \(legacyURL.path)")
|
||||
print(" wrote \(chunkedURL.path)")
|
||||
}
|
||||
}
|
||||
|
||||
private static func printSummary(
|
||||
legacy: PocketTtsManager.CondStepPrefillBenchmarkResult,
|
||||
chunked: PocketTtsManager.CondStepPrefillBenchmarkResult
|
||||
) {
|
||||
let legacyMs = stats(legacy.durations)
|
||||
let chunkedMs = stats(chunked.durations)
|
||||
let speedup = legacyMs.median / max(chunkedMs.median, 1e-9)
|
||||
|
||||
print("")
|
||||
print(
|
||||
"SUMMARY — text tokens=\(legacy.textTokens), voice tokens=\(legacy.voiceTokens) (legacy) / \(chunked.voiceTokens) (chunked)"
|
||||
)
|
||||
print(" config | median | min | stdev | iters")
|
||||
print(" ---------+-----------+-----------+-----------+-------")
|
||||
print(
|
||||
" legacy | \(fmt(legacyMs.median)) | \(fmt(legacyMs.min)) | \(fmt(legacyMs.stdev)) | \(legacy.durations.count)"
|
||||
)
|
||||
let chunkLabel = chunked.chunkSize.map { "chunk-\($0)" } ?? "chunked"
|
||||
print(
|
||||
" \(pad(chunkLabel, 8)) | \(fmt(chunkedMs.median)) | \(fmt(chunkedMs.min)) | \(fmt(chunkedMs.stdev)) | \(chunked.durations.count)"
|
||||
)
|
||||
print("")
|
||||
let speedupStr = String(format: "%.2f", speedup)
|
||||
print(" speedup (legacy median / chunked median): \(speedupStr)x")
|
||||
}
|
||||
|
||||
private static func fmt(_ seconds: TimeInterval) -> String {
|
||||
// Render in milliseconds, fixed-width 8 chars (e.g. " 12.34 ms").
|
||||
let ms = seconds * 1000.0
|
||||
return String(format: "%6.2f ms", ms)
|
||||
}
|
||||
|
||||
private static func pad(_ s: String, _ width: Int) -> String {
|
||||
if s.count >= width { return s }
|
||||
return s + String(repeating: " ", count: width - s.count)
|
||||
}
|
||||
|
||||
private static func stats(
|
||||
_ values: [TimeInterval]
|
||||
) -> (
|
||||
median: TimeInterval, min: TimeInterval, stdev: TimeInterval
|
||||
) {
|
||||
guard !values.isEmpty else { return (0, 0, 0) }
|
||||
let sorted = values.sorted()
|
||||
let mid = sorted.count / 2
|
||||
let median: TimeInterval
|
||||
if sorted.count.isMultiple(of: 2) {
|
||||
median = (sorted[mid - 1] + sorted[mid]) / 2
|
||||
} else {
|
||||
median = sorted[mid]
|
||||
}
|
||||
let minVal = sorted.first ?? 0
|
||||
let mean = values.reduce(0, +) / Double(values.count)
|
||||
let variance =
|
||||
values.reduce(0) { $0 + ($1 - mean) * ($1 - mean) } / Double(values.count)
|
||||
let stdev = variance.squareRoot()
|
||||
return (median, minVal, stdev)
|
||||
}
|
||||
}
|
||||
@@ -52,6 +52,8 @@ struct FluidAudioCLI {
|
||||
await TTSAsrVerifyCommand.run(arguments: Array(arguments.dropFirst(2)))
|
||||
case "tts-benchmark":
|
||||
await TtsBenchmarkCommand.run(arguments: Array(arguments.dropFirst(2)))
|
||||
case "pocket-tts-cond-bench":
|
||||
await PocketTtsCondBenchCommand.run(arguments: Array(arguments.dropFirst(2)))
|
||||
case "minimax-corpus":
|
||||
await MinimaxCorpusCommand.run(arguments: Array(arguments.dropFirst(2)))
|
||||
case "diarization-benchmark":
|
||||
@@ -121,6 +123,7 @@ struct FluidAudioCLI {
|
||||
magpie Magpie TTS Multilingual 357M (experimental, ~0.04 RTFx — slow, needs perf work)
|
||||
tts-asr-verify Batch TTS→ASR roundtrip WER verification
|
||||
tts-benchmark Quantitative TTS benchmark (latency, quality, compute-unit sweep)
|
||||
pocket-tts-cond-bench Benchmark PocketTTS cond_step prefill: legacy chunk-1 vs hybrid chunk-16
|
||||
minimax-corpus Fetch MiniMax TTS Multilingual Test Set into Benchmarks/tts/corpus/minimax
|
||||
parakeet-eou Run Parakeet EOU Streaming ASR on a single file
|
||||
ctc-earnings-benchmark Run CTC keyword spotting benchmark on Earnings22
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
@preconcurrency import CoreML
|
||||
import Foundation
|
||||
import XCTest
|
||||
|
||||
@testable import FluidAudio
|
||||
|
||||
/// Tests for the chunked `cond_step` dispatch mode added alongside the
|
||||
/// existing chunk-1 (legacy) pipeline.
|
||||
///
|
||||
/// These tests are pure-logic: they verify the model store's mode plumbing
|
||||
/// and accessor behaviour without requiring any downloaded CoreML artifacts.
|
||||
/// End-to-end parity vs the legacy path is validated by the
|
||||
/// `pocket-tts-cond-bench` CLI subcommand (which requires the chunk-16
|
||||
/// `mlmodelc` file to be placed manually under the language root).
|
||||
final class PocketTtsCondStepHybridTests: XCTestCase {
|
||||
|
||||
// MARK: - PocketTtsCondStepMode
|
||||
|
||||
func testCondStepModeEquatable() {
|
||||
XCTAssertEqual(PocketTtsCondStepMode.legacy, .legacy)
|
||||
XCTAssertEqual(PocketTtsCondStepMode.chunked(chunk: 16), .chunked(chunk: 16))
|
||||
XCTAssertNotEqual(PocketTtsCondStepMode.legacy, .chunked(chunk: 16))
|
||||
XCTAssertNotEqual(
|
||||
PocketTtsCondStepMode.chunked(chunk: 16),
|
||||
PocketTtsCondStepMode.chunked(chunk: 32)
|
||||
)
|
||||
}
|
||||
|
||||
// MARK: - PocketTtsModelStore.condStepMode plumbing
|
||||
|
||||
func testStoreDefaultModeIsLegacy() async {
|
||||
let store = PocketTtsModelStore(language: .english)
|
||||
let mode = await store.condStepMode
|
||||
XCTAssertEqual(mode, .legacy)
|
||||
let chunkSize = await store.condStepChunkSize()
|
||||
XCTAssertNil(chunkSize, "legacy mode must not expose a chunk size")
|
||||
}
|
||||
|
||||
func testStoreChunkedModeExposesChunkSize() async {
|
||||
let store = PocketTtsModelStore(
|
||||
language: .english, condStepMode: .chunked(chunk: 16)
|
||||
)
|
||||
let mode = await store.condStepMode
|
||||
XCTAssertEqual(mode, .chunked(chunk: 16))
|
||||
let chunkSize = await store.condStepChunkSize()
|
||||
XCTAssertEqual(chunkSize, 16)
|
||||
}
|
||||
|
||||
// MARK: - Accessor error semantics
|
||||
|
||||
func testCondStepChunkModelThrowsWhenNotLoaded() async {
|
||||
// Legacy store never loads the chunk model — accessor must throw
|
||||
// a clean `modelNotFound` instead of returning a stale value.
|
||||
let store = PocketTtsModelStore(language: .english)
|
||||
do {
|
||||
_ = try await store.condStepChunkModel()
|
||||
XCTFail("expected condStepChunkModel() to throw in legacy mode")
|
||||
} catch let error as PocketTTSError {
|
||||
if case .modelNotFound = error {
|
||||
// Expected
|
||||
} else {
|
||||
XCTFail("expected .modelNotFound, got \(error)")
|
||||
}
|
||||
} catch {
|
||||
XCTFail("expected PocketTTSError, got \(type(of: error)): \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
func testCondStepChunkModelThrowsWhenChunkedButUnloaded() async {
|
||||
// Chunked mode without `loadIfNeeded()` — accessor must throw
|
||||
// (we never run the loader because that would require network +
|
||||
// the unpublished chunk-16 file).
|
||||
let store = PocketTtsModelStore(
|
||||
language: .english, condStepMode: .chunked(chunk: 16)
|
||||
)
|
||||
do {
|
||||
_ = try await store.condStepChunkModel()
|
||||
XCTFail("expected condStepChunkModel() to throw before load")
|
||||
} catch let error as PocketTTSError {
|
||||
if case .modelNotFound = error {
|
||||
// Expected
|
||||
} else {
|
||||
XCTFail("expected .modelNotFound, got \(error)")
|
||||
}
|
||||
} catch {
|
||||
XCTFail("expected PocketTTSError, got \(type(of: error)): \(error)")
|
||||
}
|
||||
}
|
||||
|
||||
// MARK: - PocketTtsManager init plumbing
|
||||
|
||||
func testManagerAcceptsCondStepMode() async {
|
||||
// Smoke test: verify the manager init compiles + threads the mode
|
||||
// through to the underlying store. We don't initialize() because
|
||||
// that would require network access.
|
||||
let manager = PocketTtsManager(
|
||||
defaultVoice: "alba",
|
||||
language: .english,
|
||||
condStepMode: .chunked(chunk: 16)
|
||||
)
|
||||
let isAvailable = await manager.isAvailable
|
||||
XCTAssertFalse(isAvailable, "manager should not be available before initialize()")
|
||||
}
|
||||
|
||||
// MARK: - ModelNames
|
||||
|
||||
func testCondStepChunk16FilenameMatchesConvention() {
|
||||
// The chunk-16 mlmodelc file lives next to cond_step.mlmodelc under
|
||||
// the v2/<lang>/ language root. Mismatching the filename here would
|
||||
// surface as a confusing modelNotFound at load time.
|
||||
XCTAssertEqual(ModelNames.PocketTTS.condStepChunk16, "cond_step_chunk16")
|
||||
XCTAssertEqual(ModelNames.PocketTTS.condStepChunk16File, "cond_step_chunk16.mlmodelc")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user