update voxcpm2
This commit is contained in:
@@ -2,14 +2,15 @@ import os
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import gradio as gr
|
||||
import spaces
|
||||
import gradio as gr
|
||||
import spaces # noqa: F401
|
||||
from typing import Optional, Tuple
|
||||
from funasr import AutoModel
|
||||
from pathlib import Path
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
|
||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
|
||||
|
||||
import voxcpm
|
||||
|
||||
@@ -24,13 +25,13 @@ class VoxCPMDemo:
|
||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
||||
model=self.asr_model_id,
|
||||
disable_update=True,
|
||||
log_level='DEBUG',
|
||||
log_level="DEBUG",
|
||||
device="cuda:0" if self.device == "cuda" else "cpu",
|
||||
)
|
||||
|
||||
# TTS model (lazy init)
|
||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||
self.default_local_model_dir = "./models/VoxCPM1.5"
|
||||
self.default_local_model_dir = "/Users/xinliu/Downloads/VoxCPM2-0.5B-newaudiovae-6hz-0316"
|
||||
|
||||
# ---------- Model helpers ----------
|
||||
def _resolve_model_dir(self) -> str:
|
||||
@@ -49,6 +50,7 @@ class VoxCPMDemo:
|
||||
if not os.path.isdir(target_dir):
|
||||
try:
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
|
||||
os.makedirs(target_dir, exist_ok=True)
|
||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||
@@ -64,7 +66,7 @@ class VoxCPMDemo:
|
||||
print("Model not loaded, initializing...", file=sys.stderr)
|
||||
model_dir = self._resolve_model_dir()
|
||||
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=False)
|
||||
print("Model loaded successfully.", file=sys.stderr)
|
||||
return self.voxcpm_model
|
||||
|
||||
@@ -73,21 +75,24 @@ class VoxCPMDemo:
|
||||
if prompt_wav is None:
|
||||
return ""
|
||||
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
||||
text = res[0]["text"].split('|>')[-1]
|
||||
text = res[0]["text"].split("|>")[-1]
|
||||
return text
|
||||
|
||||
def generate_tts_audio(
|
||||
self,
|
||||
text_input: str,
|
||||
prompt_wav_path_input: Optional[str] = None,
|
||||
prompt_text_input: Optional[str] = None,
|
||||
control_instruction: str = "",
|
||||
reference_wav_path_input: Optional[str] = None,
|
||||
cfg_value_input: float = 2.0,
|
||||
inference_timesteps_input: int = 10,
|
||||
do_normalize: bool = True,
|
||||
denoise: bool = True,
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
"""
|
||||
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
|
||||
Generate speech from text using VoxCPM.
|
||||
- If reference_wav provided: Prompt isolation mode (voice cloning)
|
||||
- If no reference_wav: Voice design mode (use control_instruction to describe voice)
|
||||
|
||||
Returns (sample_rate, waveform_numpy)
|
||||
"""
|
||||
current_model = self.get_or_load_voxcpm()
|
||||
@@ -96,14 +101,25 @@ class VoxCPMDemo:
|
||||
if len(text) == 0:
|
||||
raise ValueError("Please input text to synthesize.")
|
||||
|
||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
||||
prompt_text = prompt_text_input if prompt_text_input else None
|
||||
# 处理 control instruction
|
||||
control = (control_instruction or "").strip()
|
||||
if control:
|
||||
final_text = f"({control}){text}"
|
||||
else:
|
||||
final_text = text
|
||||
|
||||
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
|
||||
reference_wav_path = reference_wav_path_input if reference_wav_path_input else None
|
||||
|
||||
# 判断模式
|
||||
if reference_wav_path:
|
||||
print(f"[Prompt Isolation Mode] reference_wav: {reference_wav_path}", file=sys.stderr)
|
||||
else:
|
||||
print(f"[Voice Design Mode] control: {control[:50] if control else 'None'}...", file=sys.stderr)
|
||||
|
||||
print(f"Generating audio for text: '{final_text[:80]}...'", file=sys.stderr)
|
||||
wav = current_model.generate(
|
||||
text=text,
|
||||
prompt_text=prompt_text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
text=final_text,
|
||||
reference_wav_path=reference_wav_path,
|
||||
cfg_value=float(cfg_value_input),
|
||||
inference_timesteps=int(inference_timesteps_input),
|
||||
normalize=do_normalize,
|
||||
@@ -114,46 +130,53 @@ class VoxCPMDemo:
|
||||
|
||||
# ---------- UI Builders ----------
|
||||
|
||||
THEME = gr.themes.Soft(
|
||||
primary_hue="blue",
|
||||
secondary_hue="gray",
|
||||
neutral_hue="slate",
|
||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
|
||||
)
|
||||
|
||||
CSS = """
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin: 0.5rem 0 1rem 0;
|
||||
}
|
||||
.logo-container img {
|
||||
height: 80px;
|
||||
width: auto;
|
||||
max-width: 200px;
|
||||
display: inline-block;
|
||||
}
|
||||
/* Bold accordion labels */
|
||||
#acc_quick > .label-wrap,
|
||||
#acc_tips > .label-wrap,
|
||||
#acc_quick > .label-wrap > span,
|
||||
#acc_tips > .label-wrap > span,
|
||||
#acc_quick summary,
|
||||
#acc_tips summary {
|
||||
font-weight: 600 !important;
|
||||
font-size: 1.1em !important;
|
||||
}
|
||||
/* Bold labels for specific checkboxes */
|
||||
#chk_denoise label,
|
||||
#chk_denoise span,
|
||||
#chk_normalize label,
|
||||
#chk_normalize span {
|
||||
font-weight: 600;
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def create_demo_interface(demo: VoxCPMDemo):
|
||||
"""Build the Gradio UI for VoxCPM demo."""
|
||||
# static assets (logo path)
|
||||
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
|
||||
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
|
||||
|
||||
with gr.Blocks(
|
||||
theme=gr.themes.Soft(
|
||||
primary_hue="blue",
|
||||
secondary_hue="gray",
|
||||
neutral_hue="slate",
|
||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
|
||||
),
|
||||
css="""
|
||||
.logo-container {
|
||||
text-align: center;
|
||||
margin: 0.5rem 0 1rem 0;
|
||||
}
|
||||
.logo-container img {
|
||||
height: 80px;
|
||||
width: auto;
|
||||
max-width: 200px;
|
||||
display: inline-block;
|
||||
}
|
||||
/* Bold accordion labels */
|
||||
#acc_quick details > summary,
|
||||
#acc_tips details > summary {
|
||||
font-weight: 600 !important;
|
||||
font-size: 1.1em !important;
|
||||
}
|
||||
/* Bold labels for specific checkboxes */
|
||||
#chk_denoise label,
|
||||
#chk_denoise span,
|
||||
#chk_normalize label,
|
||||
#chk_normalize span {
|
||||
font-weight: 600;
|
||||
}
|
||||
"""
|
||||
) as interface:
|
||||
# Header logo
|
||||
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
|
||||
with gr.Blocks() as interface:
|
||||
gr.HTML(
|
||||
'<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>',
|
||||
padding=True,
|
||||
)
|
||||
|
||||
# Quick Start
|
||||
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
|
||||
@@ -200,34 +223,56 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
# Main controls
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
prompt_wav = gr.Audio(
|
||||
sources=["upload", 'microphone'],
|
||||
# 1. Reference Audio
|
||||
# gr.Markdown("### 🎤 Reference Audio (Optional)")
|
||||
# gr.Markdown("*提供参考音频进行音色克隆;不提供则使用 Voice Design 模式*")
|
||||
reference_wav = gr.Audio(
|
||||
sources=["upload", "microphone"],
|
||||
type="filepath",
|
||||
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
||||
value="./examples/example.wav",
|
||||
label="Reference Audio (Optional)",
|
||||
)
|
||||
DoDenoisePromptAudio = gr.Checkbox(
|
||||
value=False,
|
||||
label="Prompt Speech Enhancement",
|
||||
label="Reference Audio Enhancement",
|
||||
elem_id="chk_denoise",
|
||||
info="We use ZipEnhancer model to denoise the prompt audio."
|
||||
info="Use ZipEnhancer to denoise the reference audio",
|
||||
)
|
||||
with gr.Row():
|
||||
prompt_text = gr.Textbox(
|
||||
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
|
||||
label="Prompt Text",
|
||||
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
|
||||
)
|
||||
run_btn = gr.Button("Generate Speech", variant="primary")
|
||||
|
||||
# 2. Control Instruction
|
||||
# gr.Markdown("### 🎛️ Control Instruction (Optional)")
|
||||
# gr.Markdown("*描述声音风格、情感等,格式:`(instruction) text`*")
|
||||
control_instruction = gr.Textbox(
|
||||
value="",
|
||||
label="Control Instruction",
|
||||
placeholder="*描述声音风格、情感等,格式:`(instruction) text`,例如:年轻女性,温柔甜美 / 悲伤地说 / an excited young man*",
|
||||
lines=2,
|
||||
)
|
||||
|
||||
# 3. Target Text
|
||||
# gr.Markdown("### 📝 Target Text")
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
lines=3,
|
||||
)
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="Use wetext library to normalize the input text",
|
||||
)
|
||||
|
||||
run_btn = gr.Button("🔊 Generate Speech", variant="primary", size="lg")
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown("### ⚙️ Generation Settings")
|
||||
cfg_value = gr.Slider(
|
||||
minimum=1.0,
|
||||
maximum=3.0,
|
||||
value=2.0,
|
||||
step=0.1,
|
||||
label="CFG Value (Guidance Scale)",
|
||||
info="Higher values increase adherence to prompt, lower values allow more creativity"
|
||||
info="Higher = more adherence to prompt; Lower = more creativity",
|
||||
)
|
||||
inference_timesteps = gr.Slider(
|
||||
minimum=4,
|
||||
@@ -235,41 +280,55 @@ def create_demo_interface(demo: VoxCPMDemo):
|
||||
value=10,
|
||||
step=1,
|
||||
label="Inference Timesteps",
|
||||
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
|
||||
info="Higher = better quality but slower",
|
||||
)
|
||||
with gr.Row():
|
||||
text = gr.Textbox(
|
||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||
label="Target Text",
|
||||
)
|
||||
with gr.Row():
|
||||
DoNormalizeText = gr.Checkbox(
|
||||
value=False,
|
||||
label="Text Normalization",
|
||||
elem_id="chk_normalize",
|
||||
info="We use wetext library to normalize the input text."
|
||||
)
|
||||
audio_output = gr.Audio(label="Output Audio")
|
||||
|
||||
gr.Markdown("### 🔈 Output")
|
||||
audio_output = gr.Audio(label="Generated Audio")
|
||||
|
||||
gr.Markdown("""
|
||||
---
|
||||
**模式说明 / Mode Info:**
|
||||
- **有 Reference Audio** → Prompt 隔离模式(音色克隆)
|
||||
- **无 Reference Audio** → Voice Design 模式(用 Control Instruction 描述声音)
|
||||
|
||||
**Control Instruction 示例:**
|
||||
- `年轻女性,温柔甜美`
|
||||
- `悲伤地说`
|
||||
- `an excited young man`
|
||||
""")
|
||||
|
||||
# Wiring
|
||||
run_btn.click(
|
||||
fn=demo.generate_tts_audio,
|
||||
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
|
||||
inputs=[
|
||||
text,
|
||||
control_instruction,
|
||||
reference_wav,
|
||||
cfg_value,
|
||||
inference_timesteps,
|
||||
DoNormalizeText,
|
||||
DoDenoisePromptAudio,
|
||||
],
|
||||
outputs=[audio_output],
|
||||
show_progress=True,
|
||||
api_name="generate",
|
||||
)
|
||||
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
|
||||
def run_demo(server_name: str = "0.0.0.0", server_port: int = 7869, show_error: bool = True):
|
||||
demo = VoxCPMDemo()
|
||||
interface = create_demo_interface(demo)
|
||||
# Recommended to enable queue on Spaces for better throughput
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error)
|
||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
||||
server_name=server_name,
|
||||
server_port=server_port,
|
||||
show_error=show_error,
|
||||
theme=THEME,
|
||||
css=CSS,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_demo()
|
||||
run_demo()
|
||||
|
||||
Reference in New Issue
Block a user