mirror of
https://github.com/FluidInference/FluidAudio.git
synced 2026-05-12 20:20:36 +00:00
chore: consolidate Python scripts into Scripts/ (#344)
## Summary - Move `Benchmarks/nemo` to `Scripts/nemo_ami_benchmark` - Move `Tools/voice_cloning` to `Scripts/voice_cloning` - Remove now-empty `Benchmarks/` and `Tools/` top-level directories Consolidates standalone Python utilities into a single `Scripts/` directory to reduce top-level clutter. ## Test plan - [x] Verify files moved correctly (no content changes) <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/fluidinference/fluidaudio/pull/344" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
This commit is contained in:
@@ -0,0 +1,171 @@
|
||||
# NeMo Sortformer AMI Benchmark
|
||||
|
||||
This directory contains tools for comparing the Swift/CoreML Sortformer implementation against NVIDIA's original NeMo Sortformer model.
|
||||
|
||||
## Overview
|
||||
|
||||
The `nemo_ami_benchmark.py` script runs NVIDIA's Sortformer model on the AMI SDM dataset to provide a baseline comparison for the Swift/CoreML implementation.
|
||||
|
||||
## Requirements
|
||||
|
||||
### Python Environment
|
||||
|
||||
```bash
|
||||
# Create virtual environment with Python 3.10+
|
||||
python3.10 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install torch torchaudio torchcodec
|
||||
pip install nemo_toolkit[asr] pyannote.metrics
|
||||
```
|
||||
|
||||
### HuggingFace Authentication
|
||||
|
||||
The NVIDIA Sortformer model is gated and requires HuggingFace authentication:
|
||||
|
||||
1. Create an account at [huggingface.co](https://huggingface.co)
|
||||
2. Accept the model license at [nvidia/diar_sortformer_4spk-v2.1](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1)
|
||||
3. Create an access token at [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens)
|
||||
|
||||
### AMI Dataset
|
||||
|
||||
Download the AMI SDM test set audio files and RTTM ground truth:
|
||||
|
||||
```bash
|
||||
# Audio files should be in:
|
||||
~/FluidAudioDatasets/ami_official/sdm/
|
||||
|
||||
# RTTM files should be in:
|
||||
~/FluidAudioDatasets/ami_official/rttm/
|
||||
```
|
||||
|
||||
RTTM files can be downloaded from [pyannote AMI diarization setup](https://github.com/pyannote/AMI-diarization-setup).
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
# Run on single file
|
||||
HF_TOKEN="your_token" python nemo_ami_benchmark.py --single-file ES2004a --device cpu
|
||||
|
||||
# Run on all 16 AMI test meetings
|
||||
HF_TOKEN="your_token" python nemo_ami_benchmark.py --device cpu
|
||||
|
||||
# Save results to JSON
|
||||
HF_TOKEN="your_token" python nemo_ami_benchmark.py --output results.json
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
| Option | Description | Default |
|
||||
|--------|-------------|---------|
|
||||
| `--audio-dir` | Path to AMI audio files | `~/FluidAudioDatasets/ami_official/sdm` |
|
||||
| `--rttm-dir` | Path to RTTM ground truth files | `~/FluidAudioDatasets/ami_official/rttm` |
|
||||
| `--output`, `-o` | Output JSON file path | None |
|
||||
| `--single-file` | Run on single meeting (e.g., ES2004a) | All 16 meetings |
|
||||
| `--device` | Device to use (cpu, cuda, mps) | mps if available, else cpu |
|
||||
| `--batch` | Use batch mode instead of streaming | False |
|
||||
| `--model-path` | Path to local .nemo model file | Downloads from HuggingFace |
|
||||
|
||||
## Configuration Settings
|
||||
|
||||
### Model Configuration
|
||||
|
||||
| Parameter | Value | Description |
|
||||
|-----------|-------|-------------|
|
||||
| Model | `nvidia/diar_sortformer_4spk-v1` | NVIDIA Sortformer 4-speaker model |
|
||||
| Sample Rate | 16000 Hz | Audio sample rate |
|
||||
| Frame Duration | 80 ms | Duration per output frame |
|
||||
| Num Speakers | 4 | Maximum number of speakers |
|
||||
|
||||
### High-Latency Streaming Config
|
||||
|
||||
These settings match the Swift `SortformerConfig.nvidiaHighLatency`:
|
||||
|
||||
| Parameter | Value | Description |
|
||||
|-----------|-------|-------------|
|
||||
| Chunk Length | 48 frames | Core chunk length in encoder frames |
|
||||
| Left Context | 56 frames | Left context in encoder frames |
|
||||
| Right Context | 56 frames | Right context in encoder frames |
|
||||
| Subsampling Factor | 8 | Mel frames per encoder frame |
|
||||
| **Total Context** | **30.4 seconds** | (48 + 56 + 56) * 8 * 10ms |
|
||||
|
||||
### Post-Processing Config
|
||||
|
||||
| Parameter | Value | Description |
|
||||
|-----------|-------|-------------|
|
||||
| Onset Threshold | 0.5 | Threshold for speaker activity detection |
|
||||
| Offset Threshold | 0.5 | Threshold for speaker activity end |
|
||||
|
||||
## AMI Test Meetings
|
||||
|
||||
The benchmark runs on 16 AMI SDM test meetings:
|
||||
|
||||
| Series | Meetings |
|
||||
|--------|----------|
|
||||
| EN2002 | EN2002a, EN2002b, EN2002c, EN2002d |
|
||||
| ES2004 | ES2004a, ES2004b, ES2004c, ES2004d |
|
||||
| IS1009 | IS1009a, IS1009b, IS1009c, IS1009d |
|
||||
| TS3003 | TS3003a, TS3003b, TS3003c, TS3003d |
|
||||
|
||||
## Output Metrics
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| DER | Diarization Error Rate (Miss + FA + SE) |
|
||||
| Miss % | Missed speech (false negatives) |
|
||||
| FA % | False alarm (false positives) |
|
||||
| SE % | Speaker error (wrong speaker assigned) |
|
||||
| Speakers | Detected / Ground truth speaker count |
|
||||
| RTFx | Real-time factor (audio duration / processing time) |
|
||||
|
||||
## Example Output
|
||||
|
||||
```
|
||||
================================================================================
|
||||
NEMO SORTFORMER AMI BENCHMARK
|
||||
================================================================================
|
||||
Device: cpu
|
||||
Mode: Streaming (30.4s chunks)
|
||||
Audio dir: /Users/user/FluidAudioDatasets/ami_official/sdm
|
||||
RTTM dir: /Users/user/FluidAudioDatasets/ami_official/rttm
|
||||
Meetings: 1
|
||||
|
||||
Loading Sortformer model...
|
||||
Model loaded in 2.35s
|
||||
|
||||
----------------------------------------------------------------------
|
||||
Meeting DER % Miss % FA % SE % Speakers RTFx
|
||||
----------------------------------------------------------------------
|
||||
ES2004a 34.0% 30.7% 0.9% 2.3% 4/ 4 0.2x
|
||||
----------------------------------------------------------------------
|
||||
AVERAGE 34.0% 30.7% 0.9% 2.3% - 0.2x
|
||||
======================================================================
|
||||
```
|
||||
|
||||
## Comparison with Swift/CoreML
|
||||
|
||||
| Metric | NeMo Python (CPU) | Swift/CoreML (ANE) |
|
||||
|--------|-------------------|---------------------|
|
||||
| DER | 34.0% | 32.3% |
|
||||
| Miss Rate | 30.7% | ~29% |
|
||||
| False Alarm | 0.9% | ~1% |
|
||||
| Speaker Error | 2.3% | ~2% |
|
||||
| RTFx | 0.2x | ~5x |
|
||||
|
||||
The Swift/CoreML implementation achieves comparable accuracy while being significantly faster due to Apple Neural Engine acceleration.
|
||||
|
||||
## Notes
|
||||
|
||||
- CPU inference is slow (~0.2x real-time). Use CUDA for faster inference if available.
|
||||
- MPS (Apple Silicon GPU) may have memory issues with long audio files.
|
||||
- The NeMo model runs in batch mode; Swift implements true streaming chunking on top.
|
||||
|
||||
## References
|
||||
|
||||
- [NVIDIA Sortformer Model](https://huggingface.co/nvidia/diar_sortformer_4spk-v1)
|
||||
- [AMI Corpus](https://groups.inf.ed.ac.uk/ami/corpus/)
|
||||
- [pyannote AMI Diarization Setup](https://github.com/pyannote/AMI-diarization-setup)
|
||||
- [NeMo Toolkit](https://github.com/NVIDIA/NeMo)
|
||||
+665
@@ -0,0 +1,665 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
NeMo Sortformer Benchmark
|
||||
|
||||
Benchmarks the NeMo streaming Sortformer model on the same files as:
|
||||
- SortformerBenchmark.swift
|
||||
- single_file.py
|
||||
|
||||
Uses streaming parameters:
|
||||
- chunk_len = 340
|
||||
- left_context = 1
|
||||
- right_context = 40
|
||||
- fifo_len = 40
|
||||
- spkcache_len = 188
|
||||
- spkcache_update_period = 300
|
||||
"""
|
||||
import os
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
import json
|
||||
import time
|
||||
import argparse
|
||||
import urllib.request
|
||||
from pathlib import Path
|
||||
from itertools import permutations
|
||||
import numpy as np
|
||||
import torch
|
||||
from nemo.collections.asr.models import SortformerEncLabelModel
|
||||
|
||||
|
||||
# ============================================================
|
||||
# AMI RTTM Download
|
||||
# ============================================================
|
||||
# pyannote AMI-diarization-setup repository
|
||||
AMI_RTTM_URL = "https://raw.githubusercontent.com/pyannote/AMI-diarization-setup/main/only_words/rttms/test"
|
||||
|
||||
def download_ami_rttm(meeting_name: str, output_dir: Path) -> str:
|
||||
"""Download AMI RTTM file from pyannote AMI-diarization-setup repository."""
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = output_dir / f"{meeting_name}.rttm"
|
||||
|
||||
if output_path.exists():
|
||||
return str(output_path)
|
||||
|
||||
# Files in pyannote repo are named {meeting}.rttm (not {meeting}.Mix-Headset.rttm)
|
||||
url = f"{AMI_RTTM_URL}/{meeting_name}.rttm"
|
||||
try:
|
||||
print(f" Downloading RTTM from {url}...")
|
||||
urllib.request.urlretrieve(url, output_path)
|
||||
return str(output_path)
|
||||
except Exception as e:
|
||||
print(f" Failed to download RTTM: {e}")
|
||||
return None
|
||||
|
||||
# ============================================================
|
||||
# Benchmark Configuration
|
||||
# ============================================================
|
||||
STREAMING_CONFIG = {
|
||||
'chunk_len': 340,
|
||||
'chunk_left_context': 1,
|
||||
'chunk_right_context': 40,
|
||||
'fifo_len': 40,
|
||||
'spkcache_len': 188,
|
||||
'spkcache_update_period': 300,
|
||||
}
|
||||
|
||||
FRAME_SHIFT = 0.08 # 80ms per frame (matches Swift)
|
||||
SAMPLE_RATE = 16000
|
||||
NUM_SPEAKERS = 4
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Data Paths (matches SortformerBenchmark.swift)
|
||||
# ============================================================
|
||||
def get_home_dir():
|
||||
return Path.home()
|
||||
|
||||
|
||||
def get_audio_path(meeting_name: str, dataset: str) -> str:
|
||||
"""Get audio file path for a meeting."""
|
||||
home = get_home_dir()
|
||||
|
||||
if dataset == "ami":
|
||||
return str(home / f"FluidAudioDatasets/ami_official/sdm/{meeting_name}.Mix-Headset.wav")
|
||||
elif dataset == "voxconverse":
|
||||
return str(home / f"FluidAudioDatasets/voxconverse/voxconverse_test_wav/{meeting_name}.wav")
|
||||
elif dataset == "callhome":
|
||||
return str(home / f"FluidAudioDatasets/callhome_eng/{meeting_name}.wav")
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {dataset}")
|
||||
|
||||
|
||||
def get_rttm_path(meeting_name: str, dataset: str, auto_download: bool = True) -> str:
|
||||
"""Get RTTM ground truth path for a meeting."""
|
||||
home = get_home_dir()
|
||||
script_dir = Path(__file__).parent
|
||||
|
||||
if dataset == "ami":
|
||||
# First try local RTTMs in cache
|
||||
cache_dir = script_dir / "rttm_cache" / "ami"
|
||||
cached_rttm = cache_dir / f"{meeting_name}.rttm"
|
||||
if cached_rttm.exists():
|
||||
return str(cached_rttm)
|
||||
|
||||
# Try local project RTTM
|
||||
local_rttm = script_dir / f"Streaming-Sortformer-Conversion/{meeting_name}.rttm"
|
||||
if local_rttm.exists():
|
||||
return str(local_rttm)
|
||||
|
||||
# Try dataset RTTM
|
||||
dataset_rttm = home / f"FluidAudioDatasets/ami_official/rttm/{meeting_name}.rttm"
|
||||
if dataset_rttm.exists():
|
||||
return str(dataset_rttm)
|
||||
|
||||
# Auto-download if enabled
|
||||
if auto_download:
|
||||
downloaded = download_ami_rttm(meeting_name, cache_dir)
|
||||
if downloaded:
|
||||
return downloaded
|
||||
|
||||
return str(cached_rttm) # Return path even if not downloaded (will fail later)
|
||||
|
||||
elif dataset == "voxconverse":
|
||||
return str(home / f"FluidAudioDatasets/voxconverse/rttm_repo/test/{meeting_name}.rttm")
|
||||
elif dataset == "callhome":
|
||||
return str(home / f"FluidAudioDatasets/callhome_eng/rttm/{meeting_name}.rttm")
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {dataset}")
|
||||
|
||||
|
||||
def get_ami_files(max_files: int = None) -> list:
|
||||
"""Get list of AMI test set meetings (matches Swift benchmark)."""
|
||||
# Official AMI SDM test set (16 meetings) - matches NeMo evaluation
|
||||
all_meetings = [
|
||||
"EN2002a", "EN2002b", "EN2002c", "EN2002d",
|
||||
"ES2004a", "ES2004b", "ES2004c", "ES2004d",
|
||||
"IS1009a", "IS1009b", "IS1009c", "IS1009d",
|
||||
"TS3003a", "TS3003b", "TS3003c", "TS3003d",
|
||||
]
|
||||
|
||||
available = []
|
||||
for meeting in all_meetings:
|
||||
if Path(get_audio_path(meeting, "ami")).exists():
|
||||
available.append(meeting)
|
||||
|
||||
if max_files:
|
||||
return available[:max_files]
|
||||
return available
|
||||
|
||||
|
||||
def get_voxconverse_files(max_files: int = None) -> list:
|
||||
"""Get list of VoxConverse test files."""
|
||||
home = get_home_dir()
|
||||
vox_dir = home / "FluidAudioDatasets/voxconverse/voxconverse_test_wav"
|
||||
|
||||
if not vox_dir.exists():
|
||||
return []
|
||||
|
||||
available = []
|
||||
for wav_file in sorted(vox_dir.glob("*.wav")):
|
||||
name = wav_file.stem
|
||||
rttm_path = home / f"FluidAudioDatasets/voxconverse/rttm_repo/test/{name}.rttm"
|
||||
if rttm_path.exists():
|
||||
available.append(name)
|
||||
|
||||
if max_files:
|
||||
return available[:max_files]
|
||||
return available
|
||||
|
||||
|
||||
def get_callhome_files(max_files: int = None) -> list:
|
||||
"""Get list of CALLHOME files."""
|
||||
home = get_home_dir()
|
||||
callhome_dir = home / "FluidAudioDatasets/callhome_eng"
|
||||
|
||||
if not callhome_dir.exists():
|
||||
return []
|
||||
|
||||
available = []
|
||||
for wav_file in sorted(callhome_dir.glob("*.wav")):
|
||||
name = wav_file.stem
|
||||
rttm_path = callhome_dir / f"rttm/{name}.rttm"
|
||||
if rttm_path.exists():
|
||||
available.append(name)
|
||||
|
||||
if max_files:
|
||||
return available[:max_files]
|
||||
return available
|
||||
|
||||
|
||||
# ============================================================
|
||||
# RTTM Ground Truth Loading
|
||||
# ============================================================
|
||||
def load_rttm(rttm_path: str) -> list:
|
||||
"""
|
||||
Load RTTM file and return list of segments.
|
||||
Format: SPEAKER <file> 1 <start> <duration> <NA> <NA> <speaker_id> <NA> <NA>
|
||||
"""
|
||||
if not Path(rttm_path).exists():
|
||||
return []
|
||||
|
||||
segments = []
|
||||
with open(rttm_path, 'r') as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) < 8 or parts[0] != "SPEAKER":
|
||||
continue
|
||||
|
||||
try:
|
||||
start_time = float(parts[3])
|
||||
duration = float(parts[4])
|
||||
speaker_id = parts[7]
|
||||
end_time = start_time + duration
|
||||
|
||||
segments.append({
|
||||
'speaker_id': speaker_id,
|
||||
'start': start_time,
|
||||
'end': end_time,
|
||||
})
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
speakers = set(s['speaker_id'] for s in segments)
|
||||
print(f" [RTTM] Loaded {len(segments)} segments, speakers: {sorted(speakers)}")
|
||||
return segments
|
||||
|
||||
|
||||
# ============================================================
|
||||
# DER Calculation (matches Swift implementation)
|
||||
# ============================================================
|
||||
def calculate_der(predictions: np.ndarray, ground_truth: list,
|
||||
threshold: float = 0.5, frame_shift: float = 0.08) -> dict:
|
||||
"""
|
||||
Calculate DER using simple frame-level binary comparison.
|
||||
This matches the NeMo/Swift evaluation approach.
|
||||
|
||||
Args:
|
||||
predictions: [num_frames, num_speakers] probability array
|
||||
ground_truth: List of RTTM segments with 'speaker_id', 'start', 'end'
|
||||
threshold: Speaker activity threshold
|
||||
frame_shift: Time per frame in seconds
|
||||
|
||||
Returns:
|
||||
dict with 'der', 'miss', 'fa', 'se' percentages
|
||||
"""
|
||||
num_frames = predictions.shape[0]
|
||||
num_speakers = predictions.shape[1]
|
||||
|
||||
# Create reference binary matrix [num_frames, num_speakers]
|
||||
ref_binary = np.zeros((num_frames, num_speakers), dtype=np.float32)
|
||||
|
||||
# Map ground truth speakers to indices
|
||||
speaker_labels = sorted(set(s['speaker_id'] for s in ground_truth))
|
||||
speaker_map = {label: idx for idx, label in enumerate(speaker_labels) if idx < num_speakers}
|
||||
|
||||
# Fill reference binary from ground truth segments
|
||||
for segment in ground_truth:
|
||||
spk_id = segment['speaker_id']
|
||||
if spk_id not in speaker_map:
|
||||
continue
|
||||
spk_idx = speaker_map[spk_id]
|
||||
start_frame = max(0, min(int(segment['start'] / frame_shift), num_frames))
|
||||
end_frame = max(0, min(int(segment['end'] / frame_shift), num_frames))
|
||||
ref_binary[start_frame:end_frame, spk_idx] = 1.0
|
||||
|
||||
# Create prediction binary matrix
|
||||
pred_binary = (predictions > threshold).astype(np.float32)
|
||||
|
||||
# Try all permutations to find best DER
|
||||
best_der = float('inf')
|
||||
best_miss = 0
|
||||
best_fa = 0
|
||||
best_se = 0
|
||||
|
||||
for perm in permutations(range(num_speakers)):
|
||||
miss_frames = 0
|
||||
fa_frames = 0
|
||||
se_frames = 0
|
||||
total_ref_speech = 0
|
||||
|
||||
for frame in range(num_frames):
|
||||
ref_speech = ref_binary[frame].any()
|
||||
pred_speech_permuted = any(pred_binary[frame, perm[spk]] > 0 for spk in range(num_speakers))
|
||||
|
||||
if ref_speech:
|
||||
total_ref_speech += 1
|
||||
|
||||
if ref_speech and not pred_speech_permuted:
|
||||
miss_frames += 1
|
||||
elif not ref_speech and pred_speech_permuted:
|
||||
fa_frames += 1
|
||||
elif ref_speech and pred_speech_permuted:
|
||||
# Calculate speaker error
|
||||
ref_spks = set(spk for spk in range(num_speakers) if ref_binary[frame, spk] > 0)
|
||||
pred_spks = set(spk for spk in range(num_speakers) if pred_binary[frame, perm[spk]] > 0)
|
||||
sym_diff = ref_spks.symmetric_difference(pred_spks)
|
||||
se_frames += len(sym_diff) / 2.0
|
||||
|
||||
if total_ref_speech > 0:
|
||||
der = (miss_frames + fa_frames + se_frames) / total_ref_speech * 100
|
||||
if der < best_der:
|
||||
best_der = der
|
||||
best_miss = miss_frames / total_ref_speech * 100
|
||||
best_fa = fa_frames / total_ref_speech * 100
|
||||
best_se = se_frames / total_ref_speech * 100
|
||||
|
||||
return {
|
||||
'der': best_der,
|
||||
'miss': best_miss,
|
||||
'fa': best_fa,
|
||||
'se': best_se,
|
||||
}
|
||||
|
||||
|
||||
# ============================================================
|
||||
# NeMo Sortformer Inference
|
||||
# ============================================================
|
||||
def run_inference(model, audio_path: str) -> tuple:
|
||||
"""
|
||||
Run NeMo Sortformer streaming inference on an audio file.
|
||||
|
||||
Returns:
|
||||
(predictions, duration, processing_time)
|
||||
- predictions: [num_frames, num_speakers] probability array
|
||||
- duration: Audio duration in seconds
|
||||
- processing_time: Inference time in seconds
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
# Run inference
|
||||
predicted_segments, predicted_probs = model.diarize(
|
||||
audio=audio_path,
|
||||
batch_size=1,
|
||||
include_tensor_outputs=True
|
||||
)
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Process output probabilities
|
||||
probs = predicted_probs[0].squeeze().cpu().numpy() # [num_frames, num_speakers]
|
||||
|
||||
# Calculate duration from number of frames
|
||||
num_frames = probs.shape[0]
|
||||
duration = num_frames * FRAME_SHIFT
|
||||
|
||||
return probs, duration, processing_time
|
||||
|
||||
|
||||
def process_audio_file(model, audio_path: str, threshold: float, verbose: bool) -> dict:
|
||||
"""Process a single audio file without ground truth (inference only)."""
|
||||
if not Path(audio_path).exists():
|
||||
print(f"❌ Audio file not found: {audio_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
print(f" Running inference on {audio_path}...")
|
||||
probs, duration, processing_time = run_inference(model, audio_path)
|
||||
|
||||
rtfx = duration / processing_time
|
||||
|
||||
# Print probability statistics
|
||||
min_val = probs.min()
|
||||
max_val = probs.max()
|
||||
mean_val = probs.mean()
|
||||
above_05 = (probs > 0.5).sum()
|
||||
total_vals = probs.size
|
||||
|
||||
print(f" Audio duration: {duration:.2f}s")
|
||||
print(f" Processing time: {processing_time:.2f}s")
|
||||
print(f" RTFx: {rtfx:.1f}x")
|
||||
print(f" Prob stats: min={min_val:.3f}, max={max_val:.3f}, mean={mean_val:.3f}")
|
||||
print(f" Activity: {above_05}/{total_vals} values ({above_05/total_vals*100:.1f}%) above 0.5")
|
||||
|
||||
# Count detected speakers
|
||||
detected_speakers = sum(1 for spk in range(probs.shape[1]) if (probs[:, spk] > threshold).any())
|
||||
print(f" Detected speakers: {detected_speakers}")
|
||||
|
||||
return {
|
||||
'file': audio_path,
|
||||
'duration': duration,
|
||||
'processing_time': processing_time,
|
||||
'rtfx': rtfx,
|
||||
'num_frames': probs.shape[0],
|
||||
'detected_speakers': detected_speakers,
|
||||
'prob_min': float(min_val),
|
||||
'prob_max': float(max_val),
|
||||
'prob_mean': float(mean_val),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"❌ Error processing {audio_path}: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
def process_meeting(model, meeting_name: str, dataset: str, threshold: float, verbose: bool) -> dict:
|
||||
"""Process a single meeting and return benchmark results."""
|
||||
audio_path = get_audio_path(meeting_name, dataset)
|
||||
rttm_path = get_rttm_path(meeting_name, dataset, auto_download=True)
|
||||
|
||||
if not Path(audio_path).exists():
|
||||
print(f"❌ Audio file not found: {audio_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Run inference
|
||||
print(f" Running inference on {audio_path}...")
|
||||
probs, duration, processing_time = run_inference(model, audio_path)
|
||||
|
||||
rtfx = duration / processing_time
|
||||
|
||||
# Print probability statistics
|
||||
min_val = probs.min()
|
||||
max_val = probs.max()
|
||||
mean_val = probs.mean()
|
||||
above_05 = (probs > 0.5).sum()
|
||||
total_vals = probs.size
|
||||
|
||||
print(f" Prob stats: min={min_val:.3f}, max={max_val:.3f}, mean={mean_val:.3f}")
|
||||
print(f" Activity: {above_05}/{total_vals} values ({above_05/total_vals*100:.1f}%) above 0.5")
|
||||
|
||||
# Load ground truth
|
||||
ground_truth = load_rttm(rttm_path)
|
||||
if not ground_truth:
|
||||
print(f"⚠️ No ground truth found for {meeting_name}")
|
||||
return None
|
||||
|
||||
# Calculate DER
|
||||
metrics = calculate_der(probs, ground_truth, threshold=threshold, frame_shift=FRAME_SHIFT)
|
||||
|
||||
# Count speakers
|
||||
detected_speakers = sum(1 for spk in range(probs.shape[1]) if (probs[:, spk] > threshold).any())
|
||||
gt_speakers = len(set(s['speaker_id'] for s in ground_truth))
|
||||
|
||||
return {
|
||||
'meeting': meeting_name,
|
||||
'der': metrics['der'],
|
||||
'miss': metrics['miss'],
|
||||
'fa': metrics['fa'],
|
||||
'se': metrics['se'],
|
||||
'rtfx': rtfx,
|
||||
'processing_time': processing_time,
|
||||
'duration': duration,
|
||||
'num_frames': probs.shape[0],
|
||||
'detected_speakers': detected_speakers,
|
||||
'gt_speakers': gt_speakers,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
print(f"❌ Error processing {meeting_name}: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Main Benchmark
|
||||
# ============================================================
|
||||
def run_benchmark(args):
|
||||
"""Run the full benchmark."""
|
||||
print("🚀 Starting NeMo Sortformer Benchmark")
|
||||
print(f" Dataset: {args.dataset}")
|
||||
print(f" Threshold: {args.threshold}")
|
||||
print(f" Device: {args.device}")
|
||||
print()
|
||||
|
||||
# Load model
|
||||
print("🔧 Loading NeMo Sortformer model...")
|
||||
model_load_start = time.time()
|
||||
|
||||
device = torch.device(args.device)
|
||||
model = SortformerEncLabelModel.from_pretrained(
|
||||
"nvidia/diar_streaming_sortformer_4spk-v2.1",
|
||||
map_location=device
|
||||
)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
# Apply streaming configuration
|
||||
modules = model.sortformer_modules
|
||||
modules.chunk_len = STREAMING_CONFIG['chunk_len']
|
||||
modules.chunk_left_context = STREAMING_CONFIG['chunk_left_context']
|
||||
modules.chunk_right_context = STREAMING_CONFIG['chunk_right_context']
|
||||
modules.fifo_len = STREAMING_CONFIG['fifo_len']
|
||||
modules.spkcache_len = STREAMING_CONFIG['spkcache_len']
|
||||
modules.spkcache_update_period = STREAMING_CONFIG['spkcache_update_period']
|
||||
|
||||
# Validate streaming parameters
|
||||
modules._check_streaming_parameters()
|
||||
|
||||
model_load_time = time.time() - model_load_start
|
||||
print(f"✅ Model loaded in {model_load_time:.2f}s")
|
||||
print(f" chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}")
|
||||
print(f" fifo_len={modules.fifo_len}, spkcache_len={modules.spkcache_len}, update_period={modules.spkcache_update_period}")
|
||||
print()
|
||||
|
||||
# Get files to process
|
||||
if args.single_file:
|
||||
files_to_process = [args.single_file]
|
||||
else:
|
||||
if args.dataset == "ami":
|
||||
files_to_process = get_ami_files(args.max_files)
|
||||
elif args.dataset == "voxconverse":
|
||||
files_to_process = get_voxconverse_files(args.max_files)
|
||||
elif args.dataset == "callhome":
|
||||
files_to_process = get_callhome_files(args.max_files)
|
||||
else:
|
||||
print(f"❌ Unknown dataset: {args.dataset}")
|
||||
return
|
||||
|
||||
if not files_to_process:
|
||||
print("❌ No files found to process")
|
||||
return
|
||||
|
||||
print(f"📂 Processing {len(files_to_process)} file(s)")
|
||||
print()
|
||||
|
||||
# Process each file
|
||||
all_results = []
|
||||
|
||||
for i, meeting in enumerate(files_to_process):
|
||||
print("=" * 60)
|
||||
print(f"[{i+1}/{len(files_to_process)}] Processing: {meeting}")
|
||||
print("=" * 60)
|
||||
|
||||
result = process_meeting(model, meeting, args.dataset, args.threshold, args.verbose)
|
||||
|
||||
if result:
|
||||
all_results.append(result)
|
||||
print(f"📊 Results for {meeting}:")
|
||||
print(f" DER: {result['der']:.1f}%")
|
||||
print(f" RTFx: {result['rtfx']:.1f}x")
|
||||
print(f" Speakers: {result['detected_speakers']} detected / {result['gt_speakers']} truth")
|
||||
print()
|
||||
|
||||
# Print final summary
|
||||
if all_results:
|
||||
print_summary(all_results)
|
||||
|
||||
# Save results
|
||||
if args.output:
|
||||
with open(args.output, 'w') as f:
|
||||
json.dump(all_results, f, indent=2)
|
||||
print(f"💾 Results saved to: {args.output}")
|
||||
|
||||
|
||||
def print_summary(results: list):
|
||||
"""Print benchmark summary."""
|
||||
print()
|
||||
print("=" * 80)
|
||||
print("NEMO SORTFORMER BENCHMARK SUMMARY")
|
||||
print("=" * 80)
|
||||
|
||||
print("📋 Results Sorted by DER:")
|
||||
print("-" * 70)
|
||||
print(f"{'Meeting':<14} {'DER %':>8} {'Miss %':>8} {'FA %':>8} {'SE %':>8} {'Speakers':>10} {'RTFx':>8}")
|
||||
print("-" * 70)
|
||||
|
||||
for result in sorted(results, key=lambda x: x['der']):
|
||||
speaker_info = f"{result['detected_speakers']}/{result['gt_speakers']}"
|
||||
print(f"{result['meeting']:<14} {result['der']:>8.1f} {result['miss']:>8.1f} {result['fa']:>8.1f} {result['se']:>8.1f} {speaker_info:>10} {result['rtfx']:>8.1f}")
|
||||
|
||||
print("-" * 70)
|
||||
|
||||
# Calculate averages
|
||||
n = len(results)
|
||||
avg_der = sum(r['der'] for r in results) / n
|
||||
avg_miss = sum(r['miss'] for r in results) / n
|
||||
avg_fa = sum(r['fa'] for r in results) / n
|
||||
avg_se = sum(r['se'] for r in results) / n
|
||||
avg_rtfx = sum(r['rtfx'] for r in results) / n
|
||||
|
||||
print(f"{'AVERAGE':<14} {avg_der:>8.1f} {avg_miss:>8.1f} {avg_fa:>8.1f} {avg_se:>8.1f} {'-':>10} {avg_rtfx:>8.1f}")
|
||||
print("=" * 70)
|
||||
|
||||
print()
|
||||
print("✅ Target Check:")
|
||||
if avg_der < 15:
|
||||
print(f" ✅ DER < 15% (achieved: {avg_der:.1f}%)")
|
||||
elif avg_der < 20:
|
||||
print(f" 🟡 DER < 20% (achieved: {avg_der:.1f}%)")
|
||||
else:
|
||||
print(f" ❌ DER > 20% (achieved: {avg_der:.1f}%)")
|
||||
|
||||
if avg_rtfx > 1:
|
||||
print(f" ✅ RTFx > 1x (achieved: {avg_rtfx:.1f}x)")
|
||||
else:
|
||||
print(f" ❌ RTFx < 1x (achieved: {avg_rtfx:.1f}x)")
|
||||
|
||||
|
||||
def run_single_audio(args):
|
||||
"""Run inference on a single audio file without ground truth."""
|
||||
print("🚀 Starting NeMo Sortformer Inference")
|
||||
print(f" Audio: {args.audio}")
|
||||
print(f" Threshold: {args.threshold}")
|
||||
print(f" Device: {args.device}")
|
||||
print()
|
||||
|
||||
# Load model
|
||||
print("🔧 Loading NeMo Sortformer model...")
|
||||
model_load_start = time.time()
|
||||
|
||||
device = torch.device(args.device)
|
||||
model = SortformerEncLabelModel.from_pretrained(
|
||||
"nvidia/diar_streaming_sortformer_4spk-v2.1",
|
||||
map_location=device
|
||||
)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
# Apply streaming configuration
|
||||
modules = model.sortformer_modules
|
||||
modules.chunk_len = STREAMING_CONFIG['chunk_len']
|
||||
modules.chunk_left_context = STREAMING_CONFIG['chunk_left_context']
|
||||
modules.chunk_right_context = STREAMING_CONFIG['chunk_right_context']
|
||||
modules.fifo_len = STREAMING_CONFIG['fifo_len']
|
||||
modules.spkcache_len = STREAMING_CONFIG['spkcache_len']
|
||||
modules.spkcache_update_period = STREAMING_CONFIG['spkcache_update_period']
|
||||
modules._check_streaming_parameters()
|
||||
|
||||
model_load_time = time.time() - model_load_start
|
||||
print(f"✅ Model loaded in {model_load_time:.2f}s")
|
||||
print(f" chunk_len={modules.chunk_len}, left_ctx={modules.chunk_left_context}, right_ctx={modules.chunk_right_context}")
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
result = process_audio_file(model, args.audio, args.threshold, args.verbose)
|
||||
print("=" * 60)
|
||||
|
||||
if result and args.output:
|
||||
with open(args.output, 'w') as f:
|
||||
json.dump(result, f, indent=2)
|
||||
print(f"💾 Results saved to: {args.output}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="NeMo Sortformer Benchmark")
|
||||
parser.add_argument("--dataset", choices=["ami", "voxconverse", "callhome"],
|
||||
default="ami", help="Dataset to benchmark on")
|
||||
parser.add_argument("--single-file", type=str, default=None,
|
||||
help="Process a specific meeting (e.g., ES2004a)")
|
||||
parser.add_argument("--audio", type=str, default=None,
|
||||
help="Process a single audio file (no ground truth, inference only)")
|
||||
parser.add_argument("--max-files", type=int, default=None,
|
||||
help="Maximum number of files to process")
|
||||
parser.add_argument("--threshold", type=float, default=0.5,
|
||||
help="Speaker activity threshold")
|
||||
parser.add_argument("--device", type=str, default="cpu",
|
||||
help="Device to run on (cpu, cuda, mps)")
|
||||
parser.add_argument("--output", type=str, default=None,
|
||||
help="Output JSON file for results")
|
||||
parser.add_argument("--verbose", action="store_true",
|
||||
help="Enable verbose output")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.audio:
|
||||
run_single_audio(args)
|
||||
else:
|
||||
run_benchmark(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Executable
+273
@@ -0,0 +1,273 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
FluidAudio Benchmark Suite
|
||||
|
||||
Runs ASR, VAD, and Diarization benchmarks and saves results to JSON.
|
||||
Compare results against Documentation/Benchmarks.md baselines.
|
||||
|
||||
Usage:
|
||||
python run_benchmarks.py # Run all benchmarks
|
||||
python run_benchmarks.py --quick # Quick smoke test
|
||||
python run_benchmarks.py --asr-only # ASR benchmark only
|
||||
python run_benchmarks.py --vad-only # VAD benchmark only
|
||||
python run_benchmarks.py --diar-only # Diarization only
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
# Baseline values from Documentation/Benchmarks.md
|
||||
BASELINES = {
|
||||
"asr": {
|
||||
"wer_percent": 5.8,
|
||||
"rtfx_min": 200, # M4 Pro: ~210x
|
||||
"description": "LibriSpeech test-clean, Parakeet TDT 0.6B"
|
||||
},
|
||||
"vad": {
|
||||
"f1_percent": 85.0,
|
||||
"rtfx_min": 500,
|
||||
"description": "VOiCES dataset, Silero VAD"
|
||||
},
|
||||
"diarization": {
|
||||
"der_percent": 17.7,
|
||||
"rtfx_min": 1.0,
|
||||
"description": "AMI SDM, pyannote-based"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def run_command(cmd: list[str], output_file: Path | None = None) -> tuple[int, str]:
|
||||
"""Run a command and optionally save output."""
|
||||
print(f"Running: {' '.join(cmd)}")
|
||||
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
output = result.stdout + result.stderr
|
||||
|
||||
if output_file:
|
||||
output_file.write_text(output)
|
||||
|
||||
return result.returncode, output
|
||||
|
||||
|
||||
def build_release() -> bool:
|
||||
"""Build the project in release mode."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Building release...")
|
||||
print("=" * 60)
|
||||
|
||||
returncode, _ = run_command(["swift", "build", "-c", "release"])
|
||||
|
||||
if returncode != 0:
|
||||
print("ERROR: Build failed!")
|
||||
return False
|
||||
|
||||
print("Build successful.")
|
||||
return True
|
||||
|
||||
|
||||
def run_asr_benchmark(output_dir: Path, quick: bool = False) -> dict | None:
|
||||
"""Run ASR benchmark on LibriSpeech test-clean."""
|
||||
print("\n" + "=" * 60)
|
||||
print("ASR Benchmark (LibriSpeech test-clean)")
|
||||
print("=" * 60)
|
||||
|
||||
max_files = "100" if quick else "all"
|
||||
output_json = output_dir / f"asr_results.json"
|
||||
|
||||
cmd = [
|
||||
"swift", "run", "-c", "release", "fluidaudio", "asr-benchmark",
|
||||
"--subset", "test-clean",
|
||||
"--max-files", max_files,
|
||||
"--output", str(output_json)
|
||||
]
|
||||
|
||||
returncode, output = run_command(cmd, output_dir / "asr_log.txt")
|
||||
|
||||
if returncode != 0:
|
||||
print(f"ERROR: ASR benchmark failed!")
|
||||
return None
|
||||
|
||||
if output_json.exists():
|
||||
return json.loads(output_json.read_text())
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def run_vad_benchmark(output_dir: Path, quick: bool = False) -> dict | None:
|
||||
"""Run VAD benchmark."""
|
||||
print("\n" + "=" * 60)
|
||||
print("VAD Benchmark")
|
||||
print("=" * 60)
|
||||
|
||||
dataset = "mini50" if quick else "voices-subset"
|
||||
output_json = output_dir / f"vad_results.json"
|
||||
|
||||
cmd = [
|
||||
"swift", "run", "-c", "release", "fluidaudio", "vad-benchmark",
|
||||
"--dataset", dataset,
|
||||
"--all-files",
|
||||
"--threshold", "0.5",
|
||||
"--output", str(output_json)
|
||||
]
|
||||
|
||||
returncode, output = run_command(cmd, output_dir / "vad_log.txt")
|
||||
|
||||
if returncode != 0:
|
||||
print(f"ERROR: VAD benchmark failed!")
|
||||
return None
|
||||
|
||||
if output_json.exists():
|
||||
return json.loads(output_json.read_text())
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def run_diarization_benchmark(output_dir: Path, quick: bool = False) -> dict | None:
|
||||
"""Run diarization benchmark on AMI SDM."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Diarization Benchmark (AMI SDM)")
|
||||
print("=" * 60)
|
||||
|
||||
output_json = output_dir / f"diarization_results.json"
|
||||
|
||||
cmd = [
|
||||
"swift", "run", "-c", "release", "fluidaudio", "diarization-benchmark",
|
||||
"--auto-download",
|
||||
"--output", str(output_json)
|
||||
]
|
||||
|
||||
if quick:
|
||||
cmd.extend(["--single-file", "ES2004a"])
|
||||
|
||||
returncode, output = run_command(cmd, output_dir / "diarization_log.txt")
|
||||
|
||||
if returncode != 0:
|
||||
print(f"ERROR: Diarization benchmark failed!")
|
||||
return None
|
||||
|
||||
if output_json.exists():
|
||||
return json.loads(output_json.read_text())
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def compare_results(results: dict) -> None:
|
||||
"""Compare results against baselines."""
|
||||
print("\n" + "=" * 60)
|
||||
print("Results vs Baselines (Documentation/Benchmarks.md)")
|
||||
print("=" * 60)
|
||||
|
||||
if "asr" in results and results["asr"]:
|
||||
asr = results["asr"]
|
||||
baseline = BASELINES["asr"]
|
||||
wer = asr.get("wer", asr.get("average_wer", 0)) * 100
|
||||
rtfx = asr.get("rtfx", asr.get("median_rtfx", 0))
|
||||
|
||||
wer_status = "✓" if wer <= baseline["wer_percent"] * 1.1 else "✗"
|
||||
rtfx_status = "✓" if rtfx >= baseline["rtfx_min"] * 0.8 else "✗"
|
||||
|
||||
print(f"\nASR ({baseline['description']}):")
|
||||
print(f" WER: {wer:.1f}% (baseline: {baseline['wer_percent']}%) {wer_status}")
|
||||
print(f" RTFx: {rtfx:.1f}x (baseline: {baseline['rtfx_min']}x+) {rtfx_status}")
|
||||
|
||||
if "vad" in results and results["vad"]:
|
||||
vad = results["vad"]
|
||||
baseline = BASELINES["vad"]
|
||||
f1 = vad.get("f1_score", 0)
|
||||
rtfx = vad.get("rtfx", 0)
|
||||
|
||||
f1_status = "✓" if f1 >= baseline["f1_percent"] * 0.9 else "✗"
|
||||
rtfx_status = "✓" if rtfx >= baseline["rtfx_min"] * 0.5 else "✗"
|
||||
|
||||
print(f"\nVAD ({baseline['description']}):")
|
||||
print(f" F1: {f1:.1f}% (baseline: {baseline['f1_percent']}%+) {f1_status}")
|
||||
print(f" RTFx: {rtfx:.1f}x (baseline: {baseline['rtfx_min']}x+) {rtfx_status}")
|
||||
|
||||
if "diarization" in results and results["diarization"]:
|
||||
diar = results["diarization"]
|
||||
baseline = BASELINES["diarization"]
|
||||
der = diar.get("der", diar.get("average_der", 0)) * 100
|
||||
rtfx = diar.get("rtfx", diar.get("average_rtfx", 0))
|
||||
|
||||
der_status = "✓" if der <= baseline["der_percent"] * 1.2 else "✗"
|
||||
rtfx_status = "✓" if rtfx >= baseline["rtfx_min"] else "✗"
|
||||
|
||||
print(f"\nDiarization ({baseline['description']}):")
|
||||
print(f" DER: {der:.1f}% (baseline: {baseline['der_percent']}%) {der_status}")
|
||||
print(f" RTFx: {rtfx:.1f}x (baseline: {baseline['rtfx_min']}x+) {rtfx_status}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="FluidAudio Benchmark Suite")
|
||||
parser.add_argument("--quick", action="store_true", help="Quick smoke test with smaller datasets")
|
||||
parser.add_argument("--asr-only", action="store_true", help="Run ASR benchmark only")
|
||||
parser.add_argument("--vad-only", action="store_true", help="Run VAD benchmark only")
|
||||
parser.add_argument("--diar-only", action="store_true", help="Run diarization benchmark only")
|
||||
parser.add_argument("--output-dir", type=str, help="Output directory for results")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Determine which benchmarks to run
|
||||
run_all = not (args.asr_only or args.vad_only or args.diar_only)
|
||||
|
||||
# Setup output directory
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
if args.output_dir:
|
||||
output_dir = Path(args.output_dir)
|
||||
else:
|
||||
output_dir = Path("benchmark-results") / timestamp
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print("=" * 60)
|
||||
print("FluidAudio Benchmark Suite")
|
||||
print("=" * 60)
|
||||
print(f"Mode: {'Quick' if args.quick else 'Full'}")
|
||||
print(f"Output: {output_dir}")
|
||||
print(f"Time: {timestamp}")
|
||||
|
||||
# Build first
|
||||
if not build_release():
|
||||
sys.exit(1)
|
||||
|
||||
results = {}
|
||||
|
||||
# Run benchmarks
|
||||
if run_all or args.asr_only:
|
||||
results["asr"] = run_asr_benchmark(output_dir, args.quick)
|
||||
|
||||
if run_all or args.vad_only:
|
||||
results["vad"] = run_vad_benchmark(output_dir, args.quick)
|
||||
|
||||
if run_all or args.diar_only:
|
||||
results["diarization"] = run_diarization_benchmark(output_dir, args.quick)
|
||||
|
||||
# Save combined results
|
||||
combined_output = output_dir / "benchmark_results.json"
|
||||
combined_output.write_text(json.dumps({
|
||||
"timestamp": timestamp,
|
||||
"mode": "quick" if args.quick else "full",
|
||||
"baselines": BASELINES,
|
||||
"results": results
|
||||
}, indent=2))
|
||||
|
||||
# Compare against baselines
|
||||
compare_results(results)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Benchmark complete!")
|
||||
print("=" * 60)
|
||||
print(f"Results saved to: {combined_output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,77 @@
|
||||
# Voice Cloning Evaluation Scripts
|
||||
|
||||
Tools for evaluating PocketTTS voice cloning quality using spectral similarity.
|
||||
|
||||
## evaluate_voice.py
|
||||
|
||||
Compares a reference voice sample with synthesized TTS output using mel-spectrogram and MFCC similarity metrics. No neural network required.
|
||||
|
||||
### Install
|
||||
|
||||
```bash
|
||||
pip install librosa numpy
|
||||
# Or minimal (scipy fallback):
|
||||
pip install scipy numpy
|
||||
|
||||
# Optional for plotting:
|
||||
pip install matplotlib
|
||||
```
|
||||
|
||||
### Usage
|
||||
|
||||
```bash
|
||||
# Basic comparison
|
||||
python evaluate_voice.py reference.wav synthesized.wav
|
||||
|
||||
# With visualization
|
||||
python evaluate_voice.py reference.wav synthesized.wav --plot
|
||||
|
||||
# JSON output
|
||||
python evaluate_voice.py reference.wav synthesized.wav --json
|
||||
```
|
||||
|
||||
### Metrics
|
||||
|
||||
| Metric | Description |
|
||||
|--------|-------------|
|
||||
| Mel Similarity | Cosine similarity of mean mel spectrum (voice timbre) |
|
||||
| MFCC Similarity | Cosine similarity of mean MFCCs (voice characteristics) |
|
||||
| MFCC Std Similarity | Similarity of MFCC dynamics |
|
||||
| Combined Score | Weighted average (0.4 mel + 0.4 mfcc + 0.2 mfcc_std) |
|
||||
|
||||
### Quality Thresholds
|
||||
|
||||
| Score | Quality | Meaning |
|
||||
|-------|---------|---------|
|
||||
| 0.90+ | Excellent | Very close spectral match |
|
||||
| 0.80+ | Good | Similar voice characteristics |
|
||||
| 0.70+ | Fair | Some similarity |
|
||||
| <0.70 | Poor | Different spectral characteristics |
|
||||
|
||||
### Example Workflow
|
||||
|
||||
```bash
|
||||
# 1. Clone a voice using FluidAudio CLI
|
||||
fluidaudio tts "Hello, this is a test." --backend pocket --clone-voice speaker.wav -o output.wav
|
||||
|
||||
# 2. Evaluate the result
|
||||
python Tools/voice_cloning/evaluate_voice.py speaker.wav output.wav --plot
|
||||
```
|
||||
|
||||
### Output Example
|
||||
|
||||
```
|
||||
Reference: speaker.wav
|
||||
Synthesized: output.wav
|
||||
|
||||
Reference duration: 5.23s
|
||||
Synthesized duration: 2.15s
|
||||
|
||||
Computing spectral similarity...
|
||||
|
||||
Mel Similarity: 0.9234
|
||||
MFCC Similarity: 0.8876
|
||||
MFCC Std Similarity: 0.8543
|
||||
Combined Score: 0.8951
|
||||
Quality: Good
|
||||
```
|
||||
Executable
+296
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Evaluate voice cloning quality using spectral similarity.
|
||||
|
||||
Compares a reference voice sample with synthesized TTS output using
|
||||
mel-spectrogram cosine similarity - no neural network required.
|
||||
|
||||
Requirements:
|
||||
pip install librosa numpy scipy
|
||||
|
||||
Usage:
|
||||
python evaluate_voice.py reference.wav synthesized.wav
|
||||
python evaluate_voice.py reference.wav synthesized.wav --plot
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SAMPLE_RATE = 24000 # PocketTTS native sample rate
|
||||
|
||||
|
||||
def load_audio(path: Path) -> np.ndarray:
|
||||
"""Load audio and resample to target sample rate."""
|
||||
try:
|
||||
import librosa
|
||||
audio, _ = librosa.load(str(path), sr=SAMPLE_RATE, mono=True)
|
||||
return audio
|
||||
except ImportError:
|
||||
from scipy.io import wavfile
|
||||
from scipy import signal
|
||||
sr, audio = wavfile.read(str(path))
|
||||
if audio.dtype == np.int16:
|
||||
audio = audio.astype(np.float32) / 32768.0
|
||||
elif audio.dtype == np.int32:
|
||||
audio = audio.astype(np.float32) / 2147483648.0
|
||||
if len(audio.shape) > 1:
|
||||
audio = audio.mean(axis=1)
|
||||
if sr != SAMPLE_RATE:
|
||||
num_samples = int(len(audio) * SAMPLE_RATE / sr)
|
||||
audio = signal.resample(audio, num_samples)
|
||||
return audio.astype(np.float32)
|
||||
|
||||
|
||||
def compute_mel_spectrogram(audio: np.ndarray, n_mels: int = 80, n_fft: int = 1024,
|
||||
hop_length: int = 256) -> np.ndarray:
|
||||
"""Compute mel spectrogram."""
|
||||
try:
|
||||
import librosa
|
||||
mel = librosa.feature.melspectrogram(
|
||||
y=audio, sr=SAMPLE_RATE, n_mels=n_mels,
|
||||
n_fft=n_fft, hop_length=hop_length
|
||||
)
|
||||
return librosa.power_to_db(mel, ref=np.max)
|
||||
except ImportError:
|
||||
# Fallback using scipy
|
||||
from scipy import signal
|
||||
from scipy.fftpack import dct
|
||||
|
||||
# Simple STFT
|
||||
_, _, Sxx = signal.spectrogram(audio, fs=SAMPLE_RATE, nperseg=n_fft,
|
||||
noverlap=n_fft - hop_length)
|
||||
# Approximate mel scaling (simplified)
|
||||
mel_basis = np.zeros((n_mels, Sxx.shape[0]))
|
||||
for i in range(n_mels):
|
||||
center = int(Sxx.shape[0] * (i + 1) / (n_mels + 1))
|
||||
width = max(1, Sxx.shape[0] // (n_mels * 2))
|
||||
mel_basis[i, max(0, center-width):min(Sxx.shape[0], center+width)] = 1
|
||||
mel_basis = mel_basis / (mel_basis.sum(axis=1, keepdims=True) + 1e-8)
|
||||
mel = np.dot(mel_basis, Sxx)
|
||||
return 10 * np.log10(mel + 1e-10)
|
||||
|
||||
|
||||
def compute_mfcc(audio: np.ndarray, n_mfcc: int = 13) -> np.ndarray:
|
||||
"""Compute MFCCs."""
|
||||
try:
|
||||
import librosa
|
||||
return librosa.feature.mfcc(y=audio, sr=SAMPLE_RATE, n_mfcc=n_mfcc)
|
||||
except ImportError:
|
||||
mel = compute_mel_spectrogram(audio)
|
||||
from scipy.fftpack import dct
|
||||
return dct(mel, type=2, axis=0, norm='ortho')[:n_mfcc]
|
||||
|
||||
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
a_flat = a.flatten()
|
||||
b_flat = b.flatten()
|
||||
# Truncate to same length
|
||||
min_len = min(len(a_flat), len(b_flat))
|
||||
a_flat = a_flat[:min_len]
|
||||
b_flat = b_flat[:min_len]
|
||||
|
||||
norm_a = np.linalg.norm(a_flat)
|
||||
norm_b = np.linalg.norm(b_flat)
|
||||
if norm_a == 0 or norm_b == 0:
|
||||
return 0.0
|
||||
return float(np.dot(a_flat, b_flat) / (norm_a * norm_b))
|
||||
|
||||
|
||||
def compute_spectral_similarity(ref_audio: np.ndarray, syn_audio: np.ndarray) -> dict:
|
||||
"""Compute spectral similarity metrics."""
|
||||
# Compute mel spectrograms
|
||||
ref_mel = compute_mel_spectrogram(ref_audio)
|
||||
syn_mel = compute_mel_spectrogram(syn_audio)
|
||||
|
||||
# Compute mean mel vectors (voice timbre signature)
|
||||
ref_mel_mean = ref_mel.mean(axis=1)
|
||||
syn_mel_mean = syn_mel.mean(axis=1)
|
||||
mel_similarity = cosine_similarity(ref_mel_mean, syn_mel_mean)
|
||||
|
||||
# Compute MFCCs
|
||||
ref_mfcc = compute_mfcc(ref_audio)
|
||||
syn_mfcc = compute_mfcc(syn_audio)
|
||||
|
||||
# MFCC mean (captures voice characteristics)
|
||||
ref_mfcc_mean = ref_mfcc.mean(axis=1)
|
||||
syn_mfcc_mean = syn_mfcc.mean(axis=1)
|
||||
mfcc_similarity = cosine_similarity(ref_mfcc_mean, syn_mfcc_mean)
|
||||
|
||||
# MFCC std (captures dynamics)
|
||||
ref_mfcc_std = ref_mfcc.std(axis=1)
|
||||
syn_mfcc_std = syn_mfcc.std(axis=1)
|
||||
mfcc_std_similarity = cosine_similarity(ref_mfcc_std, syn_mfcc_std)
|
||||
|
||||
return {
|
||||
'mel_similarity': mel_similarity,
|
||||
'mfcc_similarity': mfcc_similarity,
|
||||
'mfcc_std_similarity': mfcc_std_similarity,
|
||||
}
|
||||
|
||||
|
||||
def evaluate_voice_cloning(
|
||||
reference_path: Path,
|
||||
synthesized_path: Path,
|
||||
plot: bool = False
|
||||
) -> dict:
|
||||
"""Evaluate voice cloning quality using spectral similarity."""
|
||||
logger.info(f"Reference: {reference_path}")
|
||||
logger.info(f"Synthesized: {synthesized_path}")
|
||||
logger.info("")
|
||||
|
||||
# Load audio
|
||||
ref_audio = load_audio(reference_path)
|
||||
syn_audio = load_audio(synthesized_path)
|
||||
|
||||
logger.info(f"Reference duration: {len(ref_audio) / SAMPLE_RATE:.2f}s")
|
||||
logger.info(f"Synthesized duration: {len(syn_audio) / SAMPLE_RATE:.2f}s")
|
||||
logger.info("")
|
||||
|
||||
# Compute spectral similarity
|
||||
logger.info("Computing spectral similarity...")
|
||||
metrics = compute_spectral_similarity(ref_audio, syn_audio)
|
||||
|
||||
# Combined score (weighted average)
|
||||
combined = (
|
||||
0.4 * metrics['mel_similarity'] +
|
||||
0.4 * metrics['mfcc_similarity'] +
|
||||
0.2 * metrics['mfcc_std_similarity']
|
||||
)
|
||||
metrics['combined_similarity'] = combined
|
||||
|
||||
logger.info("")
|
||||
logger.info(f" Mel Similarity: {metrics['mel_similarity']:.4f}")
|
||||
logger.info(f" MFCC Similarity: {metrics['mfcc_similarity']:.4f}")
|
||||
logger.info(f" MFCC Std Similarity: {metrics['mfcc_std_similarity']:.4f}")
|
||||
logger.info(f" Combined Score: {combined:.4f}")
|
||||
|
||||
# Quality interpretation
|
||||
if combined >= 0.90:
|
||||
quality = "Excellent"
|
||||
elif combined >= 0.80:
|
||||
quality = "Good"
|
||||
elif combined >= 0.70:
|
||||
quality = "Fair"
|
||||
else:
|
||||
quality = "Poor"
|
||||
|
||||
metrics['quality'] = quality
|
||||
logger.info(f" Quality: {quality}")
|
||||
|
||||
# Plot if requested
|
||||
if plot:
|
||||
plot_spectrograms(ref_audio, syn_audio, reference_path.stem, synthesized_path.stem)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def plot_spectrograms(ref_audio: np.ndarray, syn_audio: np.ndarray,
|
||||
ref_name: str, syn_name: str):
|
||||
"""Visualize mel spectrograms."""
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
logger.warning("matplotlib not installed, skipping plot")
|
||||
return
|
||||
|
||||
ref_mel = compute_mel_spectrogram(ref_audio)
|
||||
syn_mel = compute_mel_spectrogram(syn_audio)
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(14, 8))
|
||||
|
||||
# Reference mel spectrogram
|
||||
im0 = axes[0, 0].imshow(ref_mel, aspect='auto', origin='lower', cmap='magma')
|
||||
axes[0, 0].set_title(f'Reference: {ref_name}')
|
||||
axes[0, 0].set_ylabel('Mel bin')
|
||||
plt.colorbar(im0, ax=axes[0, 0], format='%+2.0f dB')
|
||||
|
||||
# Synthesized mel spectrogram
|
||||
im1 = axes[0, 1].imshow(syn_mel, aspect='auto', origin='lower', cmap='magma')
|
||||
axes[0, 1].set_title(f'Synthesized: {syn_name}')
|
||||
axes[0, 1].set_ylabel('Mel bin')
|
||||
plt.colorbar(im1, ax=axes[0, 1], format='%+2.0f dB')
|
||||
|
||||
# Mean mel comparison
|
||||
ref_mel_mean = ref_mel.mean(axis=1)
|
||||
syn_mel_mean = syn_mel.mean(axis=1)
|
||||
axes[1, 0].plot(ref_mel_mean, label='Reference', alpha=0.8)
|
||||
axes[1, 0].plot(syn_mel_mean, label='Synthesized', alpha=0.8)
|
||||
axes[1, 0].set_xlabel('Mel bin')
|
||||
axes[1, 0].set_ylabel('Mean energy (dB)')
|
||||
axes[1, 0].set_title('Mean Mel Spectrum (Voice Timbre)')
|
||||
axes[1, 0].legend()
|
||||
axes[1, 0].grid(True, alpha=0.3)
|
||||
|
||||
# MFCC comparison
|
||||
ref_mfcc = compute_mfcc(ref_audio).mean(axis=1)
|
||||
syn_mfcc = compute_mfcc(syn_audio).mean(axis=1)
|
||||
x = np.arange(len(ref_mfcc))
|
||||
width = 0.35
|
||||
axes[1, 1].bar(x - width/2, ref_mfcc, width, label='Reference', alpha=0.8)
|
||||
axes[1, 1].bar(x + width/2, syn_mfcc, width, label='Synthesized', alpha=0.8)
|
||||
axes[1, 1].set_xlabel('MFCC coefficient')
|
||||
axes[1, 1].set_ylabel('Value')
|
||||
axes[1, 1].set_title('Mean MFCCs')
|
||||
axes[1, 1].legend()
|
||||
axes[1, 1].grid(True, alpha=0.3)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig('spectral_comparison.png', dpi=150)
|
||||
logger.info("\nSaved comparison plot to: spectral_comparison.png")
|
||||
plt.show()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Evaluate voice cloning using spectral similarity",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Spectral Similarity Thresholds:
|
||||
0.90+ Excellent - Very close spectral match
|
||||
0.80+ Good - Similar voice characteristics
|
||||
0.70+ Fair - Some similarity
|
||||
<0.70 Poor - Different spectral characteristics
|
||||
|
||||
Metrics:
|
||||
- Mel Similarity: Cosine similarity of mean mel spectrum (timbre)
|
||||
- MFCC Similarity: Cosine similarity of mean MFCCs (voice characteristics)
|
||||
- MFCC Std Similarity: Similarity of MFCC dynamics
|
||||
|
||||
Requirements:
|
||||
pip install librosa numpy
|
||||
# Or minimal: pip install scipy numpy
|
||||
|
||||
Examples:
|
||||
python evaluate_voice.py original_speaker.wav tts_output.wav
|
||||
python evaluate_voice.py reference.wav synthesized.wav --plot
|
||||
"""
|
||||
)
|
||||
parser.add_argument("reference", type=Path, help="Reference voice audio file")
|
||||
parser.add_argument("synthesized", type=Path, help="Synthesized TTS audio file")
|
||||
parser.add_argument("--plot", action="store_true", help="Show spectrogram comparison plots")
|
||||
parser.add_argument("--json", action="store_true", help="Output metrics as JSON")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.reference.exists():
|
||||
logger.error(f"Reference file not found: {args.reference}")
|
||||
sys.exit(1)
|
||||
if not args.synthesized.exists():
|
||||
logger.error(f"Synthesized file not found: {args.synthesized}")
|
||||
sys.exit(1)
|
||||
|
||||
metrics = evaluate_voice_cloning(args.reference, args.synthesized, plot=args.plot)
|
||||
|
||||
if args.json:
|
||||
import json
|
||||
print(json.dumps(metrics, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user