mirror of
https://github.com/devnen/Kitten-TTS-Server.git
synced 2026-05-12 18:00:34 +00:00
d9b2d3462c
Support all KittenTTS models (Nano/Micro/Mini across v0.1, v0.2, v0.8) with automatic download from HuggingFace, async background loading with progress tracking, cancellation support, and named voices for every model. - Model selector dropdown in Web UI with real-time status and progress modal - Async model loading in background thread keeps server responsive during downloads - Cancel in-progress downloads when switching to a different model - Named voices: Amber, Felix, Clara, Marcus, Ivy, Oscar, Nora, Reed (v0.1/v0.2) and Bella, Jasper, Luna, Bruno, Rosie, Hugo, Kiki, Leo (v0.8) - Voice dropdown auto-updates when switching models - New API endpoints: /api/model-info, /api/model-registry, /api/model-status, /api/cancel-loading - Models cached in project-local model_cache directory - ONNX2 model type support with voice alias resolution - Voice embedding shape handling for v0.8 (400,256) -> (1,256) slicing
755 lines
27 KiB
Python
755 lines
27 KiB
Python
# File: server.py
|
|
# Main FastAPI application for the TTS Server.
|
|
# Handles API requests for text-to-speech generation, UI serving,
|
|
# configuration management, and file uploads.
|
|
|
|
import os
|
|
import io
|
|
import logging
|
|
import logging.handlers # For RotatingFileHandler
|
|
import shutil
|
|
import time
|
|
import uuid
|
|
import yaml # For loading presets
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional, List, Dict, Any, Literal
|
|
import webbrowser # For automatic browser opening
|
|
import threading # For automatic browser opening
|
|
|
|
from fastapi import (
|
|
FastAPI,
|
|
HTTPException,
|
|
Request,
|
|
File,
|
|
UploadFile,
|
|
Form,
|
|
BackgroundTasks,
|
|
)
|
|
from fastapi.responses import (
|
|
HTMLResponse,
|
|
JSONResponse,
|
|
StreamingResponse,
|
|
FileResponse,
|
|
)
|
|
from fastapi.staticfiles import StaticFiles
|
|
# from fastapi.templating import Jinja2Templates # Not used, serving static HTML
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
# --- Internal Project Imports ---
|
|
from config import (
|
|
config_manager,
|
|
get_host,
|
|
get_port,
|
|
get_log_file_path,
|
|
get_output_path,
|
|
get_ui_title,
|
|
get_gen_default_speed,
|
|
get_gen_default_language,
|
|
get_audio_sample_rate,
|
|
get_full_config_for_template,
|
|
get_audio_output_format,
|
|
)
|
|
|
|
import engine # TTS Engine interface
|
|
from models import ( # Pydantic models
|
|
CustomTTSRequest,
|
|
ErrorResponse,
|
|
UpdateStatusResponse,
|
|
)
|
|
import utils # Utility functions
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
class OpenAISpeechRequest(BaseModel):
|
|
model: str
|
|
input_: str = Field(..., alias="input")
|
|
voice: str
|
|
response_format: Literal["wav", "opus", "mp3"] = "wav" # Add "mp3"
|
|
speed: float = 1.0
|
|
seed: Optional[int] = None
|
|
|
|
|
|
# --- Logging Configuration ---
|
|
log_file_path_obj = get_log_file_path()
|
|
log_file_max_size_mb = config_manager.get_int("server.log_file_max_size_mb", 10)
|
|
log_backup_count = config_manager.get_int("server.log_file_backup_count", 5)
|
|
|
|
log_file_path_obj.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
handlers=[
|
|
logging.handlers.RotatingFileHandler(
|
|
str(log_file_path_obj),
|
|
maxBytes=log_file_max_size_mb * 1024 * 1024,
|
|
backupCount=log_backup_count,
|
|
encoding="utf-8",
|
|
),
|
|
logging.StreamHandler(),
|
|
],
|
|
)
|
|
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
|
logging.getLogger("watchfiles").setLevel(logging.WARNING)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# --- Global Variables & Application Setup ---
|
|
startup_complete_event = threading.Event() # For coordinating browser opening
|
|
|
|
|
|
def _delayed_browser_open(host: str, port: int):
|
|
"""
|
|
Waits for the startup_complete_event, then opens the web browser
|
|
to the server's main page after a short delay.
|
|
"""
|
|
try:
|
|
startup_complete_event.wait(timeout=30)
|
|
if not startup_complete_event.is_set():
|
|
logger.warning(
|
|
"Server startup did not signal completion within timeout. Browser will not be opened automatically."
|
|
)
|
|
return
|
|
|
|
time.sleep(1.5)
|
|
display_host = "localhost" if host == "0.0.0.0" else host
|
|
browser_url = f"http://{display_host}:{port}/"
|
|
logger.info(f"Attempting to open web browser to: {browser_url}")
|
|
webbrowser.open(browser_url)
|
|
except Exception as e:
|
|
logger.error(f"Failed to open browser automatically: {e}", exc_info=True)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
"""Manages application startup and shutdown events."""
|
|
logger.info("TTS Server: Initializing application...")
|
|
try:
|
|
logger.info(f"Configuration loaded. Log file at: {get_log_file_path()}")
|
|
|
|
paths_to_ensure = [
|
|
get_output_path(),
|
|
Path("ui"),
|
|
config_manager.get_path(
|
|
"paths.model_cache", "./model_cache", ensure_absolute=True
|
|
),
|
|
]
|
|
for p in paths_to_ensure:
|
|
p.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not engine.load_model():
|
|
logger.critical(
|
|
"CRITICAL: TTS Model failed to load on startup. Server might not function correctly."
|
|
)
|
|
else:
|
|
logger.info("TTS Model loaded successfully via engine.")
|
|
host_address = get_host()
|
|
server_port = get_port()
|
|
browser_thread = threading.Thread(
|
|
target=lambda: _delayed_browser_open(host_address, server_port),
|
|
daemon=True,
|
|
)
|
|
browser_thread.start()
|
|
|
|
logger.info("Application startup sequence complete.")
|
|
startup_complete_event.set()
|
|
yield
|
|
except Exception as e_startup:
|
|
logger.error(
|
|
f"FATAL ERROR during application startup: {e_startup}", exc_info=True
|
|
)
|
|
startup_complete_event.set()
|
|
yield
|
|
finally:
|
|
logger.info("TTS Server: Application shutdown sequence initiated...")
|
|
logger.info("TTS Server: Application shutdown complete.")
|
|
|
|
|
|
# --- FastAPI Application Instance ---
|
|
app = FastAPI(
|
|
title=get_ui_title(),
|
|
description="Text-to-Speech server with advanced UI and API capabilities.",
|
|
version="3.0.0", # Multi-model support
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# --- CORS Middleware ---
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*", "null"],
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "OPTIONS"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# --- Static Files and HTML Templates ---
|
|
ui_static_path = Path(__file__).parent / "ui"
|
|
if ui_static_path.is_dir():
|
|
app.mount("/ui", StaticFiles(directory=ui_static_path), name="ui_static_assets")
|
|
else:
|
|
logger.warning(
|
|
f"UI static assets directory not found at '{ui_static_path}'. UI may not load correctly."
|
|
)
|
|
|
|
# This will serve files from 'ui_static_path/vendor' when requests come to '/vendor/*'
|
|
if (ui_static_path / "vendor").is_dir():
|
|
app.mount(
|
|
"/vendor", StaticFiles(directory=ui_static_path / "vendor"), name="vendor_files"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Vendor directory not found at '{ui_static_path}' /vendor. Wavesurfer might not load."
|
|
)
|
|
|
|
|
|
@app.get("/styles.css", include_in_schema=False)
|
|
async def get_main_styles():
|
|
styles_file = ui_static_path / "styles.css"
|
|
if styles_file.is_file():
|
|
return FileResponse(styles_file)
|
|
raise HTTPException(status_code=404, detail="styles.css not found")
|
|
|
|
|
|
@app.get("/script.js", include_in_schema=False)
|
|
async def get_main_script():
|
|
script_file = ui_static_path / "script.js"
|
|
if script_file.is_file():
|
|
return FileResponse(script_file)
|
|
raise HTTPException(status_code=404, detail="script.js not found")
|
|
|
|
|
|
outputs_static_path = get_output_path(ensure_absolute=True)
|
|
try:
|
|
app.mount(
|
|
"/outputs",
|
|
StaticFiles(directory=str(outputs_static_path)),
|
|
name="generated_outputs",
|
|
)
|
|
except RuntimeError as e_mount_outputs:
|
|
logger.error(
|
|
f"Failed to mount /outputs directory '{outputs_static_path}': {e_mount_outputs}. "
|
|
"Output files may not be accessible via URL."
|
|
)
|
|
|
|
# templates removed - serving index.html as static file
|
|
|
|
# --- API Endpoints ---
|
|
|
|
|
|
# --- Main UI Route ---
|
|
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
|
async def get_web_ui(request: Request):
|
|
"""Serves the main web interface (index.html)."""
|
|
logger.info("Request received for main UI page ('/').")
|
|
try:
|
|
index_path = ui_static_path / "index.html"
|
|
if index_path.is_file():
|
|
return FileResponse(index_path, media_type="text/html")
|
|
return HTMLResponse(
|
|
"<html><body><h1>Not Found</h1><p>index.html not found.</p></body></html>",
|
|
status_code=404,
|
|
)
|
|
except Exception as e_render:
|
|
logger.error(f"Error rendering main UI page: {e_render}", exc_info=True)
|
|
return HTMLResponse(
|
|
"<html><body><h1>Internal Server Error</h1><p>Could not load the TTS interface. "
|
|
"Please check server logs for more details.</p></body></html>",
|
|
status_code=500,
|
|
)
|
|
|
|
|
|
# --- API Endpoint for Model Information ---
|
|
@app.get("/api/model-info", tags=["Model Information"])
|
|
async def get_model_info_endpoint():
|
|
"""Returns detailed information about the currently loaded TTS model."""
|
|
logger.debug("Request received for /api/model-info")
|
|
try:
|
|
return engine.get_model_info()
|
|
except Exception as e:
|
|
logger.error(f"Error getting model info: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail="Failed to retrieve model information")
|
|
|
|
|
|
@app.get("/api/model-registry", tags=["Model Information"])
|
|
async def get_model_registry_endpoint():
|
|
"""Returns the full list of available models for the UI dropdown."""
|
|
return engine.get_model_registry()
|
|
|
|
|
|
@app.get("/api/model-status", tags=["Model Information"])
|
|
async def get_model_status_endpoint():
|
|
"""Returns the current download/loading progress for model switching."""
|
|
return engine.get_download_status()
|
|
|
|
|
|
# --- API Endpoint for Initial UI Data ---
|
|
@app.get("/api/ui/initial-data", tags=["UI Helpers"])
|
|
async def get_ui_initial_data():
|
|
"""
|
|
Provides all necessary initial data for the UI to render,
|
|
including configuration, file lists, presets, and model information.
|
|
"""
|
|
logger.info("Request received for /api/ui/initial-data.")
|
|
try:
|
|
full_config = get_full_config_for_template()
|
|
|
|
# Get model information for UI
|
|
model_info = engine.get_model_info()
|
|
model_registry = engine.get_model_registry()
|
|
|
|
loaded_presets = []
|
|
presets_file = ui_static_path / "presets.yaml"
|
|
if presets_file.exists():
|
|
with open(presets_file, "r", encoding="utf-8") as f:
|
|
yaml_content = yaml.safe_load(f)
|
|
if isinstance(yaml_content, list):
|
|
loaded_presets = yaml_content
|
|
else:
|
|
logger.warning(
|
|
f"Invalid format in {presets_file}. Expected a list, got {type(yaml_content)}."
|
|
)
|
|
else:
|
|
logger.info(
|
|
f"Presets file not found: {presets_file}. No presets will be loaded for initial data."
|
|
)
|
|
|
|
initial_gen_result_placeholder = {
|
|
"outputUrl": None,
|
|
"filename": None,
|
|
"genTime": None,
|
|
"submittedVoice": None,
|
|
}
|
|
|
|
return {
|
|
"config": full_config,
|
|
"presets": loaded_presets,
|
|
"initial_gen_result": initial_gen_result_placeholder,
|
|
"model_info": model_info,
|
|
"model_registry": model_registry,
|
|
"available_voices": engine.get_available_voices(),
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Error preparing initial UI data for API: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500, detail="Failed to load initial data for UI."
|
|
)
|
|
|
|
|
|
# --- Configuration Management API Endpoints ---
|
|
@app.post("/save_settings", response_model=UpdateStatusResponse, tags=["Configuration"])
|
|
async def save_settings_endpoint(request: Request):
|
|
"""
|
|
Saves partial configuration updates to the config.yaml file.
|
|
Merges the update with the current configuration.
|
|
"""
|
|
logger.info("Request received for /save_settings.")
|
|
try:
|
|
partial_update = await request.json()
|
|
if not isinstance(partial_update, dict):
|
|
raise ValueError("Request body must be a JSON object for /save_settings.")
|
|
logger.debug(f"Received partial config data to save: {partial_update}")
|
|
|
|
if config_manager.update_and_save(partial_update):
|
|
restart_needed = any(
|
|
key in partial_update
|
|
for key in ["server", "tts_engine", "paths", "model"]
|
|
)
|
|
message = "Settings saved successfully."
|
|
if restart_needed:
|
|
message += " A server restart may be required for some changes to take full effect."
|
|
return UpdateStatusResponse(message=message, restart_needed=restart_needed)
|
|
else:
|
|
logger.error(
|
|
"Failed to save configuration via config_manager.update_and_save."
|
|
)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail="Failed to save configuration file due to an internal error.",
|
|
)
|
|
except ValueError as ve:
|
|
logger.error(f"Invalid data format for /save_settings: {ve}")
|
|
raise HTTPException(status_code=400, detail=f"Invalid request data: {str(ve)}")
|
|
except Exception as e:
|
|
logger.error(f"Error processing /save_settings request: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Internal server error during settings save: {str(e)}",
|
|
)
|
|
|
|
|
|
@app.post(
|
|
"/reset_settings", response_model=UpdateStatusResponse, tags=["Configuration"]
|
|
)
|
|
async def reset_settings_endpoint():
|
|
"""Resets the configuration in config.yaml back to hardcoded defaults."""
|
|
logger.warning("Request received to reset all configurations to default values.")
|
|
try:
|
|
if config_manager.reset_and_save():
|
|
logger.info("Configuration successfully reset to defaults and saved.")
|
|
return UpdateStatusResponse(
|
|
message="Configuration reset to defaults. Please reload the page. A server restart may be beneficial.",
|
|
restart_needed=True,
|
|
)
|
|
else:
|
|
logger.error("Failed to reset and save configuration via config_manager.")
|
|
raise HTTPException(
|
|
status_code=500, detail="Failed to reset and save configuration file."
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error processing /reset_settings request: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Internal server error during settings reset: {str(e)}",
|
|
)
|
|
|
|
|
|
@app.post(
|
|
"/restart_server", response_model=UpdateStatusResponse, tags=["Configuration"]
|
|
)
|
|
async def restart_server_endpoint():
|
|
"""
|
|
Triggers an async hot-swap of the TTS model engine.
|
|
Returns immediately while the model downloads and loads in the background.
|
|
The UI polls /api/model-status to track progress.
|
|
"""
|
|
logger.info("Request received for /restart_server (Async Model Hot-Swap).")
|
|
|
|
try:
|
|
engine.reload_model_async()
|
|
return UpdateStatusResponse(
|
|
message="Model reload initiated in background. Poll /api/model-status for progress.",
|
|
restart_needed=False,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error initiating model hot-swap: {e}", exc_info=True)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to initiate model reload: {str(e)}",
|
|
)
|
|
|
|
|
|
@app.post("/api/cancel-loading", tags=["Configuration"])
|
|
async def cancel_loading_endpoint():
|
|
"""Cancels any in-progress model loading."""
|
|
logger.info("Request received for /api/cancel-loading.")
|
|
cancelled = engine.cancel_loading()
|
|
if cancelled:
|
|
return {"message": "Model loading cancellation requested."}
|
|
return {"message": "No model loading in progress."}
|
|
|
|
|
|
@app.post("/api/unload", tags=["Configuration"])
|
|
async def unload_model_endpoint():
|
|
"""
|
|
Unloads the TTS model and releases all resources.
|
|
The model will need to be reloaded (via /restart_server) before TTS requests can be processed.
|
|
"""
|
|
logger.info("Request received for /api/unload (Model Unload).")
|
|
try:
|
|
success = engine.unload_model()
|
|
if success:
|
|
return {"message": "Model unloaded successfully. Resources released."}
|
|
else:
|
|
raise HTTPException(status_code=500, detail="Failed to unload model.")
|
|
except Exception as e:
|
|
logger.error(f"Error during model unload: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# --- TTS Generation Endpoint ---
|
|
|
|
|
|
@app.post(
|
|
"/tts",
|
|
tags=["TTS Generation"],
|
|
summary="Generate speech with custom parameters",
|
|
responses={
|
|
200: {
|
|
"content": {"audio/wav": {}, "audio/opus": {}},
|
|
"description": "Successful audio generation.",
|
|
},
|
|
400: {
|
|
"model": ErrorResponse,
|
|
"description": "Invalid request parameters or input.",
|
|
},
|
|
500: {
|
|
"model": ErrorResponse,
|
|
"description": "Internal server error during generation.",
|
|
},
|
|
503: {
|
|
"model": ErrorResponse,
|
|
"description": "TTS engine not available or model not loaded.",
|
|
},
|
|
},
|
|
)
|
|
async def custom_tts_endpoint(
|
|
request: CustomTTSRequest, background_tasks: BackgroundTasks
|
|
):
|
|
"""
|
|
Generates speech audio from text using specified parameters.
|
|
Returns audio as a stream (WAV or Opus).
|
|
"""
|
|
perf_monitor = utils.PerformanceMonitor(
|
|
enabled=config_manager.get_bool("server.enable_performance_monitor", False)
|
|
)
|
|
perf_monitor.record("TTS request received")
|
|
|
|
if not engine.MODEL_LOADED:
|
|
logger.error("TTS request failed: Model not loaded.")
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="TTS engine model is not currently loaded or available.",
|
|
)
|
|
|
|
logger.info(
|
|
f"Received /tts request: voice='{request.voice}', format='{request.output_format}'"
|
|
)
|
|
logger.debug(
|
|
f"TTS params: speed={request.speed}, split={request.split_text}, chunk_size={request.chunk_size}"
|
|
)
|
|
logger.debug(f"Input text (first 100 chars): '{request.text[:100]}...'")
|
|
|
|
perf_monitor.record("Parameters resolved")
|
|
|
|
all_audio_segments_np: List[np.ndarray] = []
|
|
final_output_sample_rate = get_audio_sample_rate()
|
|
engine_output_sample_rate: Optional[int] = None
|
|
|
|
if request.split_text and len(request.text) > (
|
|
request.chunk_size * 1.5 if request.chunk_size else 120 * 1.5
|
|
):
|
|
chunk_size_to_use = (
|
|
request.chunk_size if request.chunk_size is not None else 120
|
|
)
|
|
logger.info(f"Splitting text into chunks of size ~{chunk_size_to_use}.")
|
|
text_chunks = utils.chunk_text_by_sentences(request.text, chunk_size_to_use)
|
|
perf_monitor.record(f"Text split into {len(text_chunks)} chunks")
|
|
else:
|
|
text_chunks = [request.text]
|
|
logger.info(
|
|
"Processing text as a single chunk (splitting not enabled or text too short)."
|
|
)
|
|
|
|
if not text_chunks:
|
|
raise HTTPException(
|
|
status_code=400, detail="Text processing resulted in no usable chunks."
|
|
)
|
|
|
|
for i, chunk in enumerate(text_chunks):
|
|
logger.info(f"Synthesizing chunk {i+1}/{len(text_chunks)}...")
|
|
try:
|
|
chunk_audio_np, chunk_sr_from_engine = engine.synthesize(
|
|
text=chunk,
|
|
voice=request.voice,
|
|
speed=(
|
|
request.speed
|
|
if request.speed is not None
|
|
else get_gen_default_speed()
|
|
),
|
|
)
|
|
perf_monitor.record(f"Engine synthesized chunk {i+1}")
|
|
|
|
if chunk_audio_np is None or chunk_sr_from_engine is None:
|
|
error_detail = f"TTS engine failed to synthesize audio for chunk {i+1}."
|
|
logger.error(error_detail)
|
|
raise HTTPException(status_code=500, detail=error_detail)
|
|
|
|
if engine_output_sample_rate is None:
|
|
engine_output_sample_rate = chunk_sr_from_engine
|
|
elif engine_output_sample_rate != chunk_sr_from_engine:
|
|
logger.warning(
|
|
f"Inconsistent sample rate from engine: chunk {i+1} ({chunk_sr_from_engine}Hz) "
|
|
f"differs from previous ({engine_output_sample_rate}Hz). Using first chunk's SR."
|
|
)
|
|
|
|
# The speed factor is now handled by the engine directly, so no post-processing for speed is needed here.
|
|
|
|
all_audio_segments_np.append(chunk_audio_np)
|
|
|
|
except HTTPException as http_exc:
|
|
raise http_exc
|
|
except Exception as e_chunk:
|
|
error_detail = f"Error processing audio chunk {i+1}: {str(e_chunk)}"
|
|
logger.error(error_detail, exc_info=True)
|
|
raise HTTPException(status_code=500, detail=error_detail)
|
|
|
|
if not all_audio_segments_np:
|
|
logger.error("No audio segments were successfully generated.")
|
|
raise HTTPException(
|
|
status_code=500, detail="Audio generation resulted in no output."
|
|
)
|
|
|
|
if engine_output_sample_rate is None:
|
|
logger.error("Engine output sample rate could not be determined.")
|
|
raise HTTPException(
|
|
status_code=500, detail="Failed to determine engine sample rate."
|
|
)
|
|
|
|
try:
|
|
if len(all_audio_segments_np) > 1:
|
|
# Add silence between chunks for natural pauses
|
|
silence_duration_ms = 200 # silence between chunks
|
|
silence_samples = int(
|
|
silence_duration_ms / 1000 * engine_output_sample_rate
|
|
)
|
|
silence_array = np.zeros(silence_samples, dtype=np.float32)
|
|
|
|
# Apply crossfade and add silence between chunks
|
|
crossfade_samples = int(0.01 * engine_output_sample_rate) # 10ms crossfade
|
|
|
|
merged_audio = []
|
|
for i, chunk in enumerate(all_audio_segments_np):
|
|
if i == 0:
|
|
merged_audio.append(chunk)
|
|
else:
|
|
# Add silence gap between chunks
|
|
merged_audio.append(silence_array)
|
|
|
|
# Then add the next chunk with optional crossfade
|
|
if (
|
|
len(merged_audio[-2]) >= crossfade_samples
|
|
and len(chunk) >= crossfade_samples
|
|
):
|
|
# Apply fade out to end of previous audio (before silence)
|
|
fade_out = np.linspace(1, 0, crossfade_samples)
|
|
merged_audio[-2][-crossfade_samples:] *= fade_out
|
|
|
|
# Apply fade in to start of current chunk
|
|
fade_in = np.linspace(0, 1, crossfade_samples)
|
|
chunk_copy = chunk.copy()
|
|
chunk_copy[:crossfade_samples] *= fade_in
|
|
merged_audio.append(chunk_copy)
|
|
else:
|
|
merged_audio.append(chunk)
|
|
|
|
final_audio_np = np.concatenate(merged_audio)
|
|
logger.debug(
|
|
f"Added {silence_duration_ms}ms silence between {len(all_audio_segments_np)} chunks"
|
|
)
|
|
else:
|
|
final_audio_np = all_audio_segments_np[0]
|
|
|
|
perf_monitor.record("All audio chunks processed and concatenated")
|
|
|
|
except ValueError as e_concat:
|
|
logger.error(f"Audio concatenation failed: {e_concat}", exc_info=True)
|
|
for idx, seg in enumerate(all_audio_segments_np):
|
|
logger.error(f"Segment {idx} shape: {seg.shape}, dtype: {seg.dtype}")
|
|
raise HTTPException(
|
|
status_code=500, detail=f"Audio concatenation error: {e_concat}"
|
|
)
|
|
|
|
output_format_str = (
|
|
request.output_format if request.output_format else get_audio_output_format()
|
|
)
|
|
|
|
encoded_audio_bytes = utils.encode_audio(
|
|
audio_array=final_audio_np,
|
|
sample_rate=engine_output_sample_rate,
|
|
output_format=output_format_str,
|
|
target_sample_rate=final_output_sample_rate,
|
|
)
|
|
perf_monitor.record(
|
|
f"Final audio encoded to {output_format_str} (target SR: {final_output_sample_rate}Hz from engine SR: {engine_output_sample_rate}Hz)"
|
|
)
|
|
|
|
if encoded_audio_bytes is None or len(encoded_audio_bytes) < 100:
|
|
logger.error(
|
|
f"Failed to encode final audio to format: {output_format_str} or output is too small ({len(encoded_audio_bytes or b'')} bytes)."
|
|
)
|
|
raise HTTPException(
|
|
status_code=500,
|
|
detail=f"Failed to encode audio to {output_format_str} or generated invalid audio.",
|
|
)
|
|
|
|
media_type = f"audio/{output_format_str}"
|
|
timestamp_str = time.strftime("%Y%m%d_%H%M%S")
|
|
suggested_filename_base = f"tts_output_{timestamp_str}"
|
|
download_filename = utils.sanitize_filename(
|
|
f"{suggested_filename_base}.{output_format_str}"
|
|
)
|
|
headers = {"Content-Disposition": f'attachment; filename="{download_filename}"'}
|
|
|
|
logger.info(
|
|
f"Successfully generated audio: {download_filename}, {len(encoded_audio_bytes)} bytes, type {media_type}."
|
|
)
|
|
logger.debug(perf_monitor.report())
|
|
|
|
return StreamingResponse(
|
|
io.BytesIO(encoded_audio_bytes), media_type=media_type, headers=headers
|
|
)
|
|
|
|
|
|
@app.post("/v1/audio/speech", tags=["OpenAI Compatible"])
|
|
async def openai_speech_endpoint(request: OpenAISpeechRequest):
|
|
# Check if the TTS model is loaded
|
|
if not engine.MODEL_LOADED:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail="TTS engine model is not currently loaded or available.",
|
|
)
|
|
|
|
try:
|
|
# Synthesize the audio
|
|
audio_np, sr = engine.synthesize(
|
|
text=request.input_,
|
|
voice=request.voice,
|
|
speed=request.speed,
|
|
)
|
|
|
|
if audio_np is None or sr is None:
|
|
raise HTTPException(
|
|
status_code=500, detail="TTS engine failed to synthesize audio."
|
|
)
|
|
|
|
# Ensure it's 1D
|
|
if audio_np.ndim == 2:
|
|
audio_np = audio_np.squeeze()
|
|
|
|
# Encode the audio to the requested format
|
|
encoded_audio = utils.encode_audio(
|
|
audio_array=audio_np,
|
|
sample_rate=sr,
|
|
output_format=request.response_format,
|
|
target_sample_rate=get_audio_sample_rate(),
|
|
)
|
|
|
|
if encoded_audio is None:
|
|
raise HTTPException(status_code=500, detail="Failed to encode audio.")
|
|
|
|
# Determine the media type
|
|
media_type = f"audio/{request.response_format}"
|
|
|
|
# Return the streaming response
|
|
return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in openai_speech_endpoint: {e}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
# --- Main Execution ---
|
|
if __name__ == "__main__":
|
|
server_host = get_host()
|
|
server_port = get_port()
|
|
|
|
logger.info(f"Starting TTS Server directly on http://{server_host}:{server_port}")
|
|
logger.info(
|
|
f"API documentation will be available at http://{server_host}:{server_port}/docs"
|
|
)
|
|
logger.info(f"Web UI will be available at http://{server_host}:{server_port}/")
|
|
|
|
import uvicorn
|
|
|
|
uvicorn.run(
|
|
"server:app",
|
|
host=server_host,
|
|
port=server_port,
|
|
log_level="info",
|
|
workers=1,
|
|
reload=False,
|
|
)
|