diff --git a/README.md b/README.md index 4486e4c..9f7ec49 100644 --- a/README.md +++ b/README.md @@ -126,47 +126,72 @@ print("saved: output_streaming.wav") After installation, the entry point is `voxcpm` (or use `python -m voxcpm.cli`). ```bash -# 1) Direct synthesis (single text) -voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." --output out.wav +# 1) Voice design (VoxCPM2-first) +voxcpm design \ + --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \ + --output out.wav -# 2) Voice cloning (reference audio + transcript) -voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \ +# 2) Voice design with control instruction +voxcpm design \ + --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \ + --control "Young female voice, warm and gentle, slightly smiling" \ + --output out.wav + +# 3) Voice cloning (reference audio only, VoxCPM2) +voxcpm clone \ + --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \ + --reference-audio path/to/voice.wav \ + --output out.wav + +# 4) Hi-Fi / advanced cloning (prompt audio + transcript) +voxcpm clone \ + --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \ --prompt-audio path/to/voice.wav \ --prompt-text "reference transcript" \ - --output out.wav \ - # --denoise + --output out.wav -# (Optinal) Voice cloning (reference audio + transcript file) -voxcpm --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \ +# 5) Prompt transcript from file +voxcpm clone \ + --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \ --prompt-audio path/to/voice.wav \ --prompt-file "/path/to/text-file" \ - --output out.wav \ - # --denoise + --output out.wav -# 3) Batch processing (one text per line) -voxcpm --input examples/input.txt --output-dir outs -# (optional) Batch + cloning -voxcpm --input examples/input.txt --output-dir outs \ +# 6) Advanced cloning: prompt + reference together +voxcpm clone \ + --text "VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech." \ --prompt-audio path/to/voice.wav \ --prompt-text "reference transcript" \ - # --denoise + --reference-audio path/to/voice.wav \ + --output out.wav \ + --denoise -# 4) Inference parameters (quality/speed) -voxcpm --text "..." --output out.wav \ +# 7) Batch processing (one text per line) +voxcpm batch --input examples/input.txt --output-dir outs + +# 8) Batch + cloning +voxcpm batch --input examples/input.txt --output-dir outs \ + --reference-audio path/to/voice.wav + +# 9) Inference parameters (quality/speed) +voxcpm design --text "..." --output out.wav \ --cfg-value 2.0 --inference-timesteps 10 --normalize -# 5) Model loading +# 10) Model loading # Prefer local path -voxcpm --text "..." --output out.wav --model-path /path/to/VoxCPM_model_dir +voxcpm design --text "..." --output out.wav --model-path /path/to/VoxCPM_model_dir # Or from Hugging Face (auto download/cache) -voxcpm --text "..." --output out.wav \ - --hf-model-id openbmb/VoxCPM1.5 --cache-dir ~/.cache/huggingface --local-files-only +voxcpm design --text "..." --output out.wav \ + --hf-model-id openbmb/VoxCPM2 --cache-dir ~/.cache/huggingface --local-files-only -# 6) Denoiser control -voxcpm --text "..." --output out.wav \ +# 11) Denoiser control +voxcpm clone --text "..." --output out.wav --reference-audio path/to/voice.wav \ --no-denoiser --zipenhancer-path iic/speech_zipenhancer_ans_multiloss_16k_base -# 7) Help +# 12) Legacy root arguments still work but are deprecated +voxcpm --text "..." --output out.wav + +# 13) Help voxcpm --help python -m voxcpm.cli --help ``` diff --git a/app.py b/app.py index 2ef1022..598e797 100644 --- a/app.py +++ b/app.py @@ -1,9 +1,9 @@ import os import sys +import logging import numpy as np import torch import gradio as gr -import spaces # noqa: F401 from typing import Optional, Tuple from funasr import AutoModel from pathlib import Path @@ -14,130 +14,150 @@ if os.environ.get("HF_REPO_ID", "").strip() == "": import voxcpm +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger(__name__) -class VoxCPMDemo: - def __init__(self) -> None: - self.device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"🚀 Running on device: {self.device}", file=sys.stderr) +# ---------- Inline i18n (en + zh-CN only) ---------- - # ASR model for prompt text recognition - self.asr_model_id = "iic/SenseVoiceSmall" - self.asr_model: Optional[AutoModel] = AutoModel( - model=self.asr_model_id, - disable_update=True, - log_level="DEBUG", - device="cuda:0" if self.device == "cuda" else "cpu", - ) - - # TTS model (lazy init) - self.voxcpm_model: Optional[voxcpm.VoxCPM] = None - self.default_local_model_dir = "/Users/xinliu/Downloads/VoxCPM2-0.5B-newaudiovae-6hz-0316" - - # ---------- Model helpers ---------- - def _resolve_model_dir(self) -> str: - """ - Resolve model directory: - 1) Use local checkpoint directory if exists - 2) If HF_REPO_ID env is set, download into models/{repo} - 3) Fallback to 'models' - """ - if os.path.isdir(self.default_local_model_dir): - return self.default_local_model_dir - - repo_id = os.environ.get("HF_REPO_ID", "").strip() - if len(repo_id) > 0: - target_dir = os.path.join("models", repo_id.replace("/", "__")) - if not os.path.isdir(target_dir): - try: - from huggingface_hub import snapshot_download # type: ignore - - os.makedirs(target_dir, exist_ok=True) - print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr) - snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False) - except Exception as e: - print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr) - return "models" - return target_dir - return "models" - - def get_or_load_voxcpm(self) -> voxcpm.VoxCPM: - if self.voxcpm_model is not None: - return self.voxcpm_model - print("Model not loaded, initializing...", file=sys.stderr) - model_dir = self._resolve_model_dir() - print(f"Using model dir: {model_dir}", file=sys.stderr) - self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=False) - print("Model loaded successfully.", file=sys.stderr) - return self.voxcpm_model - - # ---------- Functional endpoints ---------- - def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str: - if prompt_wav is None: - return "" - res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True) - text = res[0]["text"].split("|>")[-1] - return text - - def generate_tts_audio( - self, - text_input: str, - control_instruction: str = "", - reference_wav_path_input: Optional[str] = None, - cfg_value_input: float = 2.0, - inference_timesteps_input: int = 10, - do_normalize: bool = True, - denoise: bool = True, - ) -> Tuple[int, np.ndarray]: - """ - Generate speech from text using VoxCPM. - - If reference_wav provided: Prompt isolation mode (voice cloning) - - If no reference_wav: Voice design mode (use control_instruction to describe voice) - - Returns (sample_rate, waveform_numpy) - """ - current_model = self.get_or_load_voxcpm() - - text = (text_input or "").strip() - if len(text) == 0: - raise ValueError("Please input text to synthesize.") - - # 处理 control instruction - control = (control_instruction or "").strip() - if control: - final_text = f"({control}){text}" - else: - final_text = text - - reference_wav_path = reference_wav_path_input if reference_wav_path_input else None - - # 判断模式 - if reference_wav_path: - print(f"[Prompt Isolation Mode] reference_wav: {reference_wav_path}", file=sys.stderr) - else: - print(f"[Voice Design Mode] control: {control[:50] if control else 'None'}...", file=sys.stderr) - - print(f"Generating audio for text: '{final_text[:80]}...'", file=sys.stderr) - wav = current_model.generate( - text=final_text, - reference_wav_path=reference_wav_path, - cfg_value=float(cfg_value_input), - inference_timesteps=int(inference_timesteps_input), - normalize=do_normalize, - denoise=denoise, - ) - return (current_model.tts_model.sample_rate, wav) - - -# ---------- UI Builders ---------- - -THEME = gr.themes.Soft( - primary_hue="blue", - secondary_hue="gray", - neutral_hue="slate", - font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"], +_USAGE_INSTRUCTIONS_EN = ( + "**Usage Instructions:**\n\n" + "🎨 **Voice Design** — Create a voice from scratch \n" + "No reference audio needed. Simply describe the desired gender, tone, and emotion " + "in Control Instruction, and VoxCPM will generate a unique voice for you.\n\n" + "🎛️ **Controllable Voice Cloning** — Clone with style control \n" + "Upload reference audio and use Control Instruction to guide speed, emotion, style, and more.\n\n" + "🎙️ **Hi-Fi Cloning** — Maximum voice similarity \n" + "For the best cloning quality, enable and provide the reference audio transcript " + "to reproduce the original voice as closely as possible." ) -CSS = """ +_EXAMPLES_FOOTER_EN = ( + "---\n" + "**Voice Description Examples:** \n" + "You can describe it like this: \n" + "【Example 1: Melancholic/Tsundere Female】 \n" + 'Control Instruction: "A young beautiful girl with a sweet voice, ' + 'tsundere tone, slow speaking pace, and a touch of sadness." \n' + 'Target Text: "I never asked you to stay... It\'s not like I care or anything. ' + 'But... why does it still hurt so much now that you\'re gone?" \n\n' + "【Example 2: Lazy/Casual Male】 \n" + 'Control Instruction: "Lazy and drawling male voice, nasal, ' + 'very relaxed and casual." \n' + 'Target Text: "Dude, did you see that set? The waves out there are totally gnarly today, bro. ' + "Just catching barrels all morning. It's like, totally righteous, you know what I mean?\"" +) + +_USAGE_INSTRUCTIONS_ZH = ( + "**使用说明:**\n\n" + "🎨 **Voice Design — 声音定制** \n" + "无需上传参考音频,只需在 Control Instruction 中描述你想要的性别、音色和情绪," + "VoxCPM 即可凭空为你生成专属音色。\n\n" + "🎛️ **Controllable Voice Cloning — 可控音色克隆** \n" + "支持上传参考音频,并可以给instruction文本来指导控制语速、情绪、风格等表现。\n\n" + "🎙️ **Hi-Fi Cloning — 高保真克隆** \n" + "追求最佳克隆效果,启用并上传参考音频文本来最大程度克隆原始音色。\n\n" +) + +_EXAMPLES_FOOTER_ZH = ( + "---\n" + "**声音描述示例:** \n" + "你可以这样输入(中英文均可): \n" + "【示例1:深宫太后】 \n" + '`Control Instruction`: `"中老年女性,声音低沉阴冷,语速慢而有力,' + '每个字都像是深思熟虑后说出,带有深不可测的城府和威胁感。"` \n' + '`Target Text`: `"哀家在这深宫待了四十年,什么风浪没见过?你以为瞒得过哀家?"` \n\n' + "【示例2:暴躁男声】 \n" + '`Control Instruction`: `"暴躁的中年男声,语速较快,充满无奈和愤怒"` \n' + '`Target Text`: `"踩离合!踩刹车啊!你往哪儿开呢?前面是树你看不见吗?' + '我教了你八百遍了,打死方向盘!你是不是想把车给我开到沟里去?"`\n\n' + "💡 **方言生成特别说明:** \n" + '当前版本若要生成纯正的方言,请务必在"Target Text"中直接输入方言专属的词汇和表达,' + "并配合方言的音色描述。 \n\n" + "【示例一:广东话】 \n" + '`Control Instruction`: `"广东话,中年男性,语气平淡"` \n' + "✅ 正确的 `Target Text`(使用粤语表达):" + '`"伙計,唔該一個A餐,凍奶茶少甜!"` \n' + "❌ 错误的 `Target Text`(使用普通话):" + '`"伙计,麻烦来一个A餐,冻奶茶少甜!"` \n\n' + "【示例二:河南话】 \n" + '`Control Instruction`: `"河南话,接地气的大叔"` \n' + "✅ 正确的 `Target Text`(使用河南话表达):" + '`"恁这是弄啥嘞?晌午吃啥饭?"` \n' + "❌ 错误的 `Target Text`(使用普通话):" + '`"你这是在干什么呢?中午吃什么饭?"` \n\n' + "🤖 **实用小技巧:不知道怎么写地道的方言?** \n" + "您可以先在 豆包、DeepSeek、Kimi 等 AI 助手中输入普通话," + "让它们帮你翻译成方言文本,然后再复制粘贴到 `Target Text` 中直接使用! \n\n" + "📢 **研发小贴士:** \n" + '我们正在努力优化 AI!后续版本将支持"输入普通话文本,一键生成方言口音"的功能,敬请期待!' +) + +_I18N_TRANSLATIONS = { + "en": { + "reference_audio_label": "Reference Audio (optional — for cloning)", + "show_prompt_text_label": "Enable Prompt Text (improves voice similarity)", + "show_prompt_text_info": "Uses the ASR transcript of reference audio for higher cloning fidelity. Control Instruction will be disabled.", + "prompt_text_label": "Prompt Text (auto-filled by ASR, editable)", + "prompt_text_placeholder": "The transcript of your reference audio will appear here...", + "control_label": "Control Instruction (optional, only support English and Chinese)", + "control_placeholder": "e.g. 年轻女性,温柔甜美 / sadly / an excited young man", + "target_text_label": "Target Text", + "generate_btn": "Generate Speech", + "generated_audio_label": "Generated Audio", + "advanced_settings_title": "Advanced Settings", + "ref_denoise_label": "Reference audio enhancement", + "ref_denoise_info": "Denoise reference audio with ZipEnhancer", + "normalize_label": "Text normalization", + "normalize_info": "Normalize input text with wetext", + "cfg_label": "CFG (guidance scale)", + "cfg_info": "Higher = stronger prompt adherence; lower = more variation", + "usage_instructions": _USAGE_INSTRUCTIONS_EN, + "examples_footer": _EXAMPLES_FOOTER_EN, + }, + "zh-CN": { + "reference_audio_label": "参考音频(可选 - 用于克隆)", + "show_prompt_text_label": "启用 Prompt Text(提升音色还原度)", + "show_prompt_text_info": "使用参考音频的文本内容提升克隆相似度,开启后 Control Instruction 将被禁用", + "prompt_text_label": "Prompt Text(ASR 自动填充,可编辑)", + "prompt_text_placeholder": "参考音频的文本内容将自动识别到这里...", + "control_label": "Control Instruction(可选,仅支持中文和英文)", + "control_placeholder": "如:年轻女性,温柔甜美 / sadly / an excited young man", + "target_text_label": "Target Text(要合成的文本)", + "generate_btn": "开始生成", + "generated_audio_label": "生成音频", + "advanced_settings_title": "高级设置", + "ref_denoise_label": "参考音频降噪增强", + "ref_denoise_info": "使用 ZipEnhancer 对参考音频进行降噪", + "normalize_label": "文本规范化", + "normalize_info": "使用 wetext 对输入文本进行规范化处理", + "cfg_label": "CFG Value(引导强度)", + "cfg_info": "数值越高,越贴合提示要求;数值越低,变化空间越大", + "usage_instructions": _USAGE_INSTRUCTIONS_ZH, + "examples_footer": _EXAMPLES_FOOTER_ZH, + }, + "zh-Hans": None, # alias, filled below + "zh": None, # alias, filled below +} +_I18N_TRANSLATIONS["zh-Hans"] = _I18N_TRANSLATIONS["zh-CN"] +_I18N_TRANSLATIONS["zh"] = _I18N_TRANSLATIONS["zh-CN"] + +for _d in _I18N_TRANSLATIONS.values(): + if _d is not None: + for _k, _v in _I18N_TRANSLATIONS["en"].items(): + _d.setdefault(_k, _v) + +I18N = gr.I18n(**_I18N_TRANSLATIONS) + +DEFAULT_TARGET_TEXT = ( + "VoxCPM is an innovative end-to-end TTS model from ModelBest, " + "designed to generate highly realistic speech." +) + +_CUSTOM_CSS = """ .logo-container { text-align: center; margin: 0.5rem 0 1rem 0; @@ -148,165 +168,314 @@ CSS = """ max-width: 200px; display: inline-block; } -/* Bold accordion labels */ -#acc_quick > .label-wrap, -#acc_tips > .label-wrap, -#acc_quick > .label-wrap > span, -#acc_tips > .label-wrap > span, -#acc_quick summary, -#acc_tips summary { - font-weight: 600 !important; - font-size: 1.1em !important; + +/* Toggle switch style */ +.switch-toggle { + padding: 8px 12px; + border-radius: 8px; + background: var(--block-background-fill); } -/* Bold labels for specific checkboxes */ -#chk_denoise label, -#chk_denoise span, -#chk_normalize label, -#chk_normalize span { - font-weight: 600; +.switch-toggle input[type="checkbox"] { + appearance: none; + -webkit-appearance: none; + width: 44px; + height: 24px; + background: #ccc; + border-radius: 12px; + position: relative; + cursor: pointer; + transition: background 0.3s ease; + flex-shrink: 0; +} +.switch-toggle input[type="checkbox"]::after { + content: ""; + position: absolute; + top: 2px; + left: 2px; + width: 20px; + height: 20px; + background: white; + border-radius: 50%; + transition: transform 0.3s ease; + box-shadow: 0 1px 3px rgba(0,0,0,0.2); +} +.switch-toggle input[type="checkbox"]:checked { + background: var(--color-accent); +} +.switch-toggle input[type="checkbox"]:checked::after { + transform: translateX(20px); } """ +_APP_THEME = gr.themes.Soft( + primary_hue="blue", + secondary_hue="gray", + neutral_hue="slate", + font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"], +) + + +# ---------- Model ---------- + +class VoxCPMDemo: + def __init__(self, model_dir: Optional[str] = None) -> None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + logger.info(f"Running on device: {self.device}") + + self.asr_model_id = "iic/SenseVoiceSmall" + self.asr_model: Optional[AutoModel] = AutoModel( + model=self.asr_model_id, + disable_update=True, + log_level="DEBUG", + device="cuda:0" if self.device == "cuda" else "cpu", + ) + + self.voxcpm_model: Optional[voxcpm.VoxCPM] = None + self.explicit_model_dir = model_dir + + def _resolve_model_dir(self) -> str: + if self.explicit_model_dir and os.path.isdir(self.explicit_model_dir): + return self.explicit_model_dir + env_model_dir = os.environ.get("VOXCPM_MODEL_DIR", "").strip() + if env_model_dir and os.path.isdir(env_model_dir): + return env_model_dir + repo_id = os.environ.get("HF_REPO_ID", "").strip() + if len(repo_id) > 0: + target_dir = os.path.join("models", repo_id.replace("/", "__")) + if not os.path.isdir(target_dir): + try: + from huggingface_hub import snapshot_download + os.makedirs(target_dir, exist_ok=True) + logger.info(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...") + snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False) + except Exception as e: + logger.warning(f"HF download failed: {e}. Falling back to 'models'.") + return "models" + return target_dir + return "models" + + def get_or_load_voxcpm(self) -> voxcpm.VoxCPM: + if self.voxcpm_model is not None: + return self.voxcpm_model + logger.info("Model not loaded, initializing...") + model_dir = self._resolve_model_dir() + logger.info(f"Using model dir: {model_dir}") + self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=True) + logger.info("Model loaded successfully.") + return self.voxcpm_model + + def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str: + if prompt_wav is None: + return "" + res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True) + return res[0]["text"].split("|>")[-1] + + def _build_generate_kwargs( + self, + *, + final_text: str, + audio_path: Optional[str], + prompt_text_clean: Optional[str], + cfg_value_input: float, + do_normalize: bool, + denoise: bool, + ) -> dict: + generate_kwargs = dict( + text=final_text, + reference_wav_path=audio_path, + cfg_value=float(cfg_value_input), + inference_timesteps=10, + normalize=do_normalize, + denoise=denoise, + ) + if prompt_text_clean and audio_path: + generate_kwargs["prompt_wav_path"] = audio_path + generate_kwargs["prompt_text"] = prompt_text_clean + return generate_kwargs + + def generate_tts_audio( + self, + text_input: str, + control_instruction: str = "", + reference_wav_path_input: Optional[str] = None, + prompt_text: str = "", + cfg_value_input: float = 2.0, + do_normalize: bool = True, + denoise: bool = True, + ) -> Tuple[int, np.ndarray]: + current_model = self.get_or_load_voxcpm() + + text = (text_input or "").strip() + if len(text) == 0: + raise ValueError("Please input text to synthesize.") + + control = (control_instruction or "").strip() + final_text = f"({control}){text}" if control else text + + audio_path = reference_wav_path_input if reference_wav_path_input else None + prompt_text_clean = (prompt_text or "").strip() or None + + if audio_path and prompt_text_clean: + logger.info(f"[Voice Cloning] prompt_wav + prompt_text + reference_wav") + elif audio_path: + logger.info(f"[Voice Control] reference_wav only") + else: + logger.info(f"[Voice Design] control: {control[:50] if control else 'None'}...") + + logger.info(f"Generating audio for text: '{final_text[:80]}...'") + generate_kwargs = self._build_generate_kwargs( + final_text=final_text, + audio_path=audio_path, + prompt_text_clean=prompt_text_clean, + cfg_value_input=cfg_value_input, + do_normalize=do_normalize, + denoise=denoise, + ) + wav = current_model.generate(**generate_kwargs) + return (current_model.tts_model.sample_rate, wav) + + +# ---------- UI ---------- def create_demo_interface(demo: VoxCPMDemo): - """Build the Gradio UI for VoxCPM demo.""" gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"]) + def _generate( + text: str, + control_instruction: str, + ref_wav: Optional[str], + use_prompt_text: bool, + prompt_text_value: str, + cfg_value: float, + do_normalize: bool, + denoise: bool, + ): + actual_prompt_text = prompt_text_value.strip() if use_prompt_text else "" + actual_control = "" if use_prompt_text else control_instruction + sr, wav_np = demo.generate_tts_audio( + text_input=text, + control_instruction=actual_control, + reference_wav_path_input=ref_wav, + prompt_text=actual_prompt_text, + cfg_value_input=cfg_value, + do_normalize=do_normalize, + denoise=denoise, + ) + return (sr, wav_np) + + def _on_toggle_instant(checked): + """Instant UI toggle — no ASR, no blocking.""" + if checked: + return ( + gr.update(visible=True, value="", placeholder="Recognizing reference audio..."), + gr.update(visible=False), + ) + return ( + gr.update(visible=False), + gr.update(visible=True, interactive=True), + ) + + def _run_asr_if_needed(checked, audio_path): + """Run ASR after the UI has updated. Only when toggled ON.""" + if not checked or not audio_path: + return gr.update() + try: + logger.info("Running ASR on reference audio...") + asr_text = demo.prompt_wav_recognition(audio_path) + logger.info(f"ASR result: {asr_text[:60]}...") + return gr.update(value=asr_text) + except Exception as e: + logger.warning(f"ASR recognition failed: {e}") + return gr.update(value="") + with gr.Blocks() as interface: gr.HTML( - '
VoxCPM Logo
', - padding=True, + '
' + 'VoxCPM Logo' + "
" ) - # Quick Start - with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"): - gr.Markdown(""" - ### How to Use |使用说明 - 1. **(Optional) Provide a Voice Prompt** - Upload or record an audio clip to provide the desired voice characteristics for synthesis. - **(可选)提供参考声音** - 上传或录制一段音频,为声音合成提供音色、语调和情感等个性化特征 - 2. **(Optional) Enter prompt text** - If you provided a voice prompt, enter the corresponding transcript here (auto-recognition available). - **(可选项)输入参考文本** - 如果提供了参考语音,请输入其对应的文本内容(支持自动识别)。 - 3. **Enter target text** - Type the text you want the model to speak. - **输入目标文本** - 输入您希望模型朗读的文字内容。 - 4. **Generate Speech** - Click the "Generate" button to create your audio. - **生成语音** - 点击"生成"按钮,即可为您创造出音频。 - """) + gr.Markdown(I18N("usage_instructions")) - # Pro Tips - with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"): - gr.Markdown(""" - ### Prompt Speech Enhancement|参考语音降噪 - - **Enable** to remove background noise for a clean voice, with an external ZipEnhancer component. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling. - **启用**:通过 ZipEnhancer 组件消除背景噪音,但会将音频采样率限制在16kHz,限制克隆上限。 - - **Disable** to preserve the original audio's all information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate. - **禁用**:保留原始音频的全部信息,包括背景环境声,最高支持44.1kHz的音频复刻。 - - ### Text Normalization|文本正则化 - - **Enable** to process general text with an external WeTextProcessing component. - **启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理。 - - **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input (For Chinese, phonemes are converted using pinyin, {ni3}{hao3}; For English, phonemes are converted using CMUDict, {HH AH0 L OW1}), try it! - **禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3};英文转CMUDict:{HH AH0 L OW1})和公式符号合成,尝试一下! - - ### CFG Value|CFG 值 - - **Lower CFG** if the voice prompt sounds strained or expressive, or instability occurs with long text input. - **调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题。 - - **Higher CFG** for better adherence to the prompt speech style or input text, or instability occurs with too short text input. - **调高**:为更好地贴合提示音频的风格或输入文本, 或者极短文本输入出现稳定性问题。 - - ### Inference Timesteps|推理时间步 - - **Lower** for faster synthesis speed. - **调低**:合成速度更快。 - - **Higher** for better synthesis quality. - **调高**:合成质量更佳。 - """) - - # Main controls with gr.Row(): with gr.Column(): - # 1. Reference Audio - # gr.Markdown("### 🎤 Reference Audio (Optional)") - # gr.Markdown("*提供参考音频进行音色克隆;不提供则使用 Voice Design 模式*") reference_wav = gr.Audio( sources=["upload", "microphone"], type="filepath", - label="Reference Audio (Optional)", + label=I18N("reference_audio_label"), ) - DoDenoisePromptAudio = gr.Checkbox( + show_prompt_text = gr.Checkbox( value=False, - label="Reference Audio Enhancement", - elem_id="chk_denoise", - info="Use ZipEnhancer to denoise the reference audio", + label=I18N("show_prompt_text_label"), + info=I18N("show_prompt_text_info"), + elem_classes=["switch-toggle"], + ) + prompt_text = gr.Textbox( + value="", + label=I18N("prompt_text_label"), + placeholder=I18N("prompt_text_placeholder"), + lines=2, + visible=False, ) - - # 2. Control Instruction - # gr.Markdown("### 🎛️ Control Instruction (Optional)") - # gr.Markdown("*描述声音风格、情感等,格式:`(instruction) text`*") control_instruction = gr.Textbox( value="", - label="Control Instruction", - placeholder="*描述声音风格、情感等,格式:`(instruction) text`,例如:年轻女性,温柔甜美 / 悲伤地说 / an excited young man*", + label=I18N("control_label"), + placeholder=I18N("control_placeholder"), lines=2, ) - - # 3. Target Text - # gr.Markdown("### 📝 Target Text") text = gr.Textbox( - value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.", - label="Target Text", + value=DEFAULT_TARGET_TEXT, + label=I18N("target_text_label"), lines=3, ) - DoNormalizeText = gr.Checkbox( - value=False, - label="Text Normalization", - elem_id="chk_normalize", - info="Use wetext library to normalize the input text", - ) - run_btn = gr.Button("🔊 Generate Speech", variant="primary", size="lg") + with gr.Accordion(I18N("advanced_settings_title"), open=False): + DoDenoisePromptAudio = gr.Checkbox( + value=False, + label=I18N("ref_denoise_label"), + elem_classes=["switch-toggle"], + info=I18N("ref_denoise_info"), + ) + DoNormalizeText = gr.Checkbox( + value=False, + label=I18N("normalize_label"), + elem_classes=["switch-toggle"], + info=I18N("normalize_info"), + ) + cfg_value = gr.Slider( + minimum=1.0, + maximum=3.0, + value=2.0, + step=0.1, + label=I18N("cfg_label"), + info=I18N("cfg_info"), + ) + + run_btn = gr.Button(I18N("generate_btn"), variant="primary", size="lg") with gr.Column(): - gr.Markdown("### ⚙️ Generation Settings") - cfg_value = gr.Slider( - minimum=1.0, - maximum=3.0, - value=2.0, - step=0.1, - label="CFG Value (Guidance Scale)", - info="Higher = more adherence to prompt; Lower = more creativity", - ) - inference_timesteps = gr.Slider( - minimum=4, - maximum=30, - value=10, - step=1, - label="Inference Timesteps", - info="Higher = better quality but slower", - ) + audio_output = gr.Audio(label=I18N("generated_audio_label")) + gr.Markdown(I18N("examples_footer")) - gr.Markdown("### 🔈 Output") - audio_output = gr.Audio(label="Generated Audio") + show_prompt_text.change( + fn=_on_toggle_instant, + inputs=[show_prompt_text], + outputs=[prompt_text, control_instruction], + ).then( + fn=_run_asr_if_needed, + inputs=[show_prompt_text, reference_wav], + outputs=[prompt_text], + ) - gr.Markdown(""" - --- - **模式说明 / Mode Info:** - - **有 Reference Audio** → Prompt 隔离模式(音色克隆) - - **无 Reference Audio** → Voice Design 模式(用 Control Instruction 描述声音) - - **Control Instruction 示例:** - - `年轻女性,温柔甜美` - - `悲伤地说` - - `an excited young man` - """) - - # Wiring run_btn.click( - fn=demo.generate_tts_audio, + fn=_generate, inputs=[ text, control_instruction, reference_wav, + show_prompt_text, + prompt_text, cfg_value, - inference_timesteps, DoNormalizeText, DoDenoisePromptAudio, ], @@ -317,18 +486,28 @@ def create_demo_interface(demo: VoxCPMDemo): return interface - -def run_demo(server_name: str = "0.0.0.0", server_port: int = 7869, show_error: bool = True): - demo = VoxCPMDemo() +def run_demo( + server_name: str = "0.0.0.0", + server_port: int = 8808, + show_error: bool = True, + model_dir: Optional[str] = None, +): + demo = VoxCPMDemo(model_dir=model_dir) interface = create_demo_interface(demo) interface.queue(max_size=10, default_concurrency_limit=1).launch( server_name=server_name, server_port=server_port, show_error=show_error, - theme=THEME, - css=CSS, + i18n=I18N, + theme=_APP_THEME, + css=_CUSTOM_CSS, ) if __name__ == "__main__": - run_demo() + import argparse + parser = argparse.ArgumentParser() + parser.add_argument("--model-dir", type=str, default=None, help="Path to VoxCPM2 checkpoint directory") + parser.add_argument("--port", type=int, default=8808, help="Server port") + args = parser.parse_args() + run_demo(model_dir=args.model_dir, server_port=args.port) diff --git a/app_old.py b/app_old.py new file mode 100644 index 0000000..d46c2e1 --- /dev/null +++ b/app_old.py @@ -0,0 +1,280 @@ +import os +import sys +import numpy as np +import torch +import gradio as gr +from typing import Optional, Tuple +from funasr import AutoModel +from pathlib import Path +os.environ["TOKENIZERS_PARALLELISM"] = "false" +if os.environ.get("HF_REPO_ID", "").strip() == "": + os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5" + +import voxcpm + + +class VoxCPMDemo: + def __init__(self) -> None: + self.device = "cuda" if torch.cuda.is_available() else "cpu" + print(f"🚀 Running on device: {self.device}", file=sys.stderr) + + # ASR model for prompt text recognition + self.asr_model_id = "iic/SenseVoiceSmall" + self.asr_model: Optional[AutoModel] = AutoModel( + model=self.asr_model_id, + disable_update=True, + log_level='DEBUG', + device="cuda:0" if self.device == "cuda" else "cpu", + ) + + # TTS model (lazy init) + self.voxcpm_model: Optional[voxcpm.VoxCPM] = None + self.default_local_model_dir = "./models/VoxCPM1.5" + + # ---------- Model helpers ---------- + def _resolve_model_dir(self) -> str: + """ + Resolve model directory: + 1) Use local checkpoint directory if exists + 2) If HF_REPO_ID env is set, download into models/{repo} + 3) Fallback to 'models' + """ + if os.path.isdir(self.default_local_model_dir): + return self.default_local_model_dir + + repo_id = os.environ.get("HF_REPO_ID", "").strip() + if len(repo_id) > 0: + target_dir = os.path.join("models", repo_id.replace("/", "__")) + if not os.path.isdir(target_dir): + try: + from huggingface_hub import snapshot_download # type: ignore + os.makedirs(target_dir, exist_ok=True) + print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr) + snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False) + except Exception as e: + print(f"Warning: HF download failed: {e}. Falling back to 'data'.", file=sys.stderr) + return "models" + return target_dir + return "models" + + def get_or_load_voxcpm(self) -> voxcpm.VoxCPM: + if self.voxcpm_model is not None: + return self.voxcpm_model + print("Model not loaded, initializing...", file=sys.stderr) + model_dir = self._resolve_model_dir() + print(f"Using model dir: {model_dir}", file=sys.stderr) + self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir) + print("Model loaded successfully.", file=sys.stderr) + return self.voxcpm_model + + # ---------- Functional endpoints ---------- + def prompt_wav_recognition(self, prompt_wav: Optional[str]) -> str: + if prompt_wav is None: + return "" + res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True) + text = res[0]["text"].split('|>')[-1] + return text + + def generate_tts_audio( + self, + text_input: str, + prompt_wav_path_input: Optional[str] = None, + prompt_text_input: Optional[str] = None, + cfg_value_input: float = 2.0, + inference_timesteps_input: int = 10, + do_normalize: bool = True, + denoise: bool = True, + ) -> Tuple[int, np.ndarray]: + """ + Generate speech from text using VoxCPM; optional reference audio for voice style guidance. + Returns (sample_rate, waveform_numpy) + """ + current_model = self.get_or_load_voxcpm() + + text = (text_input or "").strip() + if len(text) == 0: + raise ValueError("Please input text to synthesize.") + + prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None + prompt_text = prompt_text_input if prompt_text_input else None + + print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr) + wav = current_model.generate( + text=text, + prompt_text=prompt_text, + prompt_wav_path=prompt_wav_path, + cfg_value=float(cfg_value_input), + inference_timesteps=int(inference_timesteps_input), + normalize=do_normalize, + denoise=denoise, + ) + return (current_model.tts_model.sample_rate, wav) + + +# ---------- UI Builders ---------- + +_APP_THEME = gr.themes.Soft( + primary_hue="blue", + secondary_hue="gray", + neutral_hue="slate", + font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"], +) + +_CUSTOM_CSS = """ +.logo-container { + text-align: center; + margin: 0.5rem 0 1rem 0; +} +.logo-container img { + height: 80px; + width: auto; + max-width: 200px; + display: inline-block; +} +/* Bold accordion labels */ +#acc_quick details > summary, +#acc_tips details > summary { + font-weight: 600 !important; + font-size: 1.1em !important; +} +/* Bold labels for specific checkboxes */ +#chk_denoise label, +#chk_denoise span, +#chk_normalize label, +#chk_normalize span { + font-weight: 600; +} +""" + + +def create_demo_interface(demo: VoxCPMDemo): + """Build the Gradio UI for VoxCPM demo.""" + gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"]) + + with gr.Blocks() as interface: + # Header logo + gr.HTML('
VoxCPM Logo
') + + # Quick Start + with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"): + gr.Markdown(""" + ### How to Use |使用说明 + 1. **(Optional) Provide a Voice Prompt** - Upload or record an audio clip to provide the desired voice characteristics for synthesis. + **(可选)提供参考声音** - 上传或录制一段音频,为声音合成提供音色、语调和情感等个性化特征 + 2. **(Optional) Enter prompt text** - If you provided a voice prompt, enter the corresponding transcript here (auto-recognition available). + **(可选项)输入参考文本** - 如果提供了参考语音,请输入其对应的文本内容(支持自动识别)。 + 3. **Enter target text** - Type the text you want the model to speak. + **输入目标文本** - 输入您希望模型朗读的文字内容。 + 4. **Generate Speech** - Click the "Generate" button to create your audio. + **生成语音** - 点击"生成"按钮,即可为您创造出音频。 + """) + + # Pro Tips + with gr.Accordion("💡 Pro Tips |使用建议", open=False, elem_id="acc_tips"): + gr.Markdown(""" + ### Prompt Speech Enhancement|参考语音降噪 + - **Enable** to remove background noise for a clean voice, with an external ZipEnhancer component. However, this will limit the audio sampling rate to 16kHz, restricting the cloning quality ceiling. + **启用**:通过 ZipEnhancer 组件消除背景噪音,但会将音频采样率限制在16kHz,限制克隆上限。 + - **Disable** to preserve the original audio's all information, including background atmosphere, and support audio cloning up to 44.1kHz sampling rate. + **禁用**:保留原始音频的全部信息,包括背景环境声,最高支持44.1kHz的音频复刻。 + + ### Text Normalization|文本正则化 + - **Enable** to process general text with an external WeTextProcessing component. + **启用**:使用 WeTextProcessing 组件,可支持常见文本的正则化处理。 + - **Disable** to use VoxCPM's native text understanding ability. For example, it supports phonemes input (For Chinese, phonemes are converted using pinyin, {ni3}{hao3}; For English, phonemes are converted using CMUDict, {HH AH0 L OW1}), try it! + **禁用**:将使用 VoxCPM 内置的文本理解能力。如,支持音素输入(如中文转拼音:{ni3}{hao3};英文转CMUDict:{HH AH0 L OW1})和公式符号合成,尝试一下! + + ### CFG Value|CFG 值 + - **Lower CFG** if the voice prompt sounds strained or expressive, or instability occurs with long text input. + **调低**:如果提示语音听起来不自然或过于夸张,或者长文本输入出现稳定性问题。 + - **Higher CFG** for better adherence to the prompt speech style or input text, or instability occurs with too short text input. + **调高**:为更好地贴合提示音频的风格或输入文本, 或者极短文本输入出现稳定性问题。 + + ### Inference Timesteps|推理时间步 + - **Lower** for faster synthesis speed. + **调低**:合成速度更快。 + - **Higher** for better synthesis quality. + **调高**:合成质量更佳。 + """) + + # Main controls + with gr.Row(): + with gr.Column(): + prompt_wav = gr.Audio( + sources=["upload", 'microphone'], + type="filepath", + label="Prompt Speech (Optional, or let VoxCPM improvise)", + value="./examples/example.wav", + ) + DoDenoisePromptAudio = gr.Checkbox( + value=False, + label="Prompt Speech Enhancement", + elem_id="chk_denoise", + info="We use ZipEnhancer model to denoise the prompt audio." + ) + with gr.Row(): + prompt_text = gr.Textbox( + value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.", + label="Prompt Text", + placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..." + ) + run_btn = gr.Button("Generate Speech", variant="primary") + + with gr.Column(): + cfg_value = gr.Slider( + minimum=1.0, + maximum=3.0, + value=2.0, + step=0.1, + label="CFG Value (Guidance Scale)", + info="Higher values increase adherence to prompt, lower values allow more creativity" + ) + inference_timesteps = gr.Slider( + minimum=4, + maximum=30, + value=10, + step=1, + label="Inference Timesteps", + info="Number of inference timesteps for generation (higher values may improve quality but slower)" + ) + with gr.Row(): + text = gr.Textbox( + value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.", + label="Target Text", + ) + with gr.Row(): + DoNormalizeText = gr.Checkbox( + value=False, + label="Text Normalization", + elem_id="chk_normalize", + info="We use wetext library to normalize the input text." + ) + audio_output = gr.Audio(label="Output Audio") + + # Wiring + run_btn.click( + fn=demo.generate_tts_audio, + inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio], + outputs=[audio_output], + show_progress=True, + api_name="generate", + ) + prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text]) + + return interface + + +def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True): + demo = VoxCPMDemo() + interface = create_demo_interface(demo) + interface.queue(max_size=10, default_concurrency_limit=1).launch( + server_name=server_name, + server_port=server_port, + show_error=show_error, + theme=_APP_THEME, + css=_CUSTOM_CSS, + ) + + +if __name__ == "__main__": + run_demo() \ No newline at end of file diff --git a/smoke_outputs/v15_prompt_clone.wav b/smoke_outputs/v15_prompt_clone.wav new file mode 100644 index 0000000..38be901 Binary files /dev/null and b/smoke_outputs/v15_prompt_clone.wav differ diff --git a/smoke_outputs/v2_design.wav b/smoke_outputs/v2_design.wav new file mode 100644 index 0000000..6154ac2 Binary files /dev/null and b/smoke_outputs/v2_design.wav differ diff --git a/smoke_outputs/v2_reference_clone.wav b/smoke_outputs/v2_reference_clone.wav new file mode 100644 index 0000000..fd5e8a9 Binary files /dev/null and b/smoke_outputs/v2_reference_clone.wav differ diff --git a/src/voxcpm/cli.py b/src/voxcpm/cli.py index 8d1698c..dd9feae 100644 --- a/src/voxcpm/cli.py +++ b/src/voxcpm/cli.py @@ -2,17 +2,22 @@ """ VoxCPM Command Line Interface -Unified CLI for voice cloning, direct TTS synthesis, and batch processing. +VoxCPM2-first CLI for voice design, cloning, and batch processing. """ import argparse +import json import os import sys from pathlib import Path + import soundfile as sf from voxcpm.core import VoxCPM + +DEFAULT_HF_MODEL_ID = "openbmb/VoxCPM2" + # ----------------------------- # Validators # ----------------------------- @@ -25,6 +30,13 @@ def validate_file_exists(file_path: str, file_type: str = "file") -> Path: return path +def require_file_exists(file_path: str, parser, file_type: str = "file") -> Path: + try: + return validate_file_exists(file_path, file_type) + except FileNotFoundError as exc: + parser.error(str(exc)) + + def validate_output_path(output_path: str) -> Path: path = Path(output_path) path.parent.mkdir(parents=True, exist_ok=True) @@ -49,6 +61,113 @@ def validate_ranges(args, parser): parser.error("--lora-dropout must be between 0.0 and 1.0") +def warn_legacy_mode(): + print( + "Warning: legacy root CLI arguments are deprecated. Prefer `voxcpm design|clone|batch ...`.", + file=sys.stderr, + ) + + +def build_final_text(text: str, control: str | None) -> str: + control = (control or "").strip() + return f"({control}){text}" if control else text + + +def resolve_prompt_text(args, parser) -> str | None: + prompt_text = getattr(args, "prompt_text", None) + prompt_file = getattr(args, "prompt_file", None) + + if prompt_text and prompt_file: + parser.error("Use either --prompt-text or --prompt-file, not both.") + + if prompt_file: + prompt_path = require_file_exists(prompt_file, parser, "prompt text file") + return prompt_path.read_text(encoding="utf-8").strip() + + if prompt_text: + return prompt_text.strip() + + return None + + +def detect_model_architecture(args) -> str | None: + model_location = getattr(args, "model_path", None) or getattr( + args, "hf_model_id", None + ) + if not model_location: + return None + + if os.path.isdir(model_location): + config_path = Path(model_location) / "config.json" + if not config_path.exists(): + return None + + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f).get("architecture", "voxcpm").lower() + + model_hint = str(model_location).lower() + if "voxcpm2" in model_hint: + return "voxcpm2" + if ( + "voxcpm1.5" in model_hint + or "voxcpm-1.5" in model_hint + or "voxcpm_1.5" in model_hint + ): + return "voxcpm" + + return None + + +def validate_prompt_related_args(args, parser, prompt_text: str | None): + if prompt_text and not args.prompt_audio: + parser.error("--prompt-text/--prompt-file requires --prompt-audio.") + + if args.prompt_audio and not prompt_text: + parser.error("--prompt-audio requires --prompt-text or --prompt-file.") + + if args.control and prompt_text: + parser.error( + "--control cannot be used together with --prompt-text or --prompt-file." + ) + + +def validate_reference_support(args, parser): + if not getattr(args, "reference_audio", None): + return + + arch = detect_model_architecture(args) + if arch == "voxcpm": + parser.error("--reference-audio is only supported with VoxCPM2 models.") + + +def validate_design_args(args, parser): + prompt_text = resolve_prompt_text(args, parser) + if args.prompt_audio or args.reference_audio or prompt_text: + parser.error( + "`design` does not accept prompt/reference audio. Use `clone` instead." + ) + + +def validate_clone_args(args, parser): + prompt_text = resolve_prompt_text(args, parser) + validate_prompt_related_args(args, parser, prompt_text) + validate_reference_support(args, parser) + + if not args.prompt_audio and not args.reference_audio: + parser.error( + "`clone` requires --reference-audio, or --prompt-audio with --prompt-text/--prompt-file." + ) + + return prompt_text + + +def validate_batch_args(args, parser): + prompt_text = resolve_prompt_text(args, parser) + validate_prompt_related_args(args, parser, prompt_text) + validate_reference_support(args, parser) + return prompt_text + + # ----------------------------- # Model loading # ----------------------------- @@ -57,7 +176,9 @@ def validate_ranges(args, parser): def load_model(args) -> VoxCPM: print("Loading VoxCPM model...", file=sys.stderr) - zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get("ZIPENHANCER_MODEL_PATH", None) + zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get( + "ZIPENHANCER_MODEL_PATH", None + ) # Build LoRA config if provided lora_config = None @@ -87,6 +208,7 @@ def load_model(args) -> VoxCPM: voxcpm_model_path=args.model_path, zipenhancer_model_path=zipenhancer_path, enable_denoiser=not args.no_denoiser, + optimize=not args.no_optimize, lora_config=lora_config, lora_weights_path=lora_weights_path, ) @@ -104,6 +226,7 @@ def load_model(args) -> VoxCPM: zipenhancer_model_id=zipenhancer_path, cache_dir=args.cache_dir, local_files_only=args.local_files_only, + optimize=not args.no_optimize, lora_config=lora_config, lora_weights_path=lora_weights_path, ) @@ -119,32 +242,26 @@ def load_model(args) -> VoxCPM: # ----------------------------- -def cmd_clone(args): - if not args.text: - sys.exit("Error: Please provide --text for synthesis") - - has_prompt = args.prompt_audio and args.prompt_text - has_ref = args.reference_audio is not None - if not has_prompt and not has_ref: - sys.exit("Error: Voice cloning requires --prompt-audio + --prompt-text, or --reference-audio, or both") +def _run_single(args, parser, *, text: str, output: str, prompt_text: str | None): + output_path = validate_output_path(output) if args.prompt_audio: - validate_file_exists(args.prompt_audio, "prompt audio file") + require_file_exists(args.prompt_audio, parser, "prompt audio file") if args.reference_audio: - validate_file_exists(args.reference_audio, "reference audio file") - output_path = validate_output_path(args.output) + require_file_exists(args.reference_audio, parser, "reference audio file") model = load_model(args) audio_array = model.generate( - text=args.text, - prompt_wav_path=args.prompt_audio if has_prompt else None, - prompt_text=args.prompt_text if has_prompt else None, + text=text, + prompt_wav_path=args.prompt_audio, + prompt_text=prompt_text, reference_wav_path=args.reference_audio, cfg_value=args.cfg_value, inference_timesteps=args.inference_timesteps, normalize=args.normalize, - denoise=args.denoise, + denoise=args.denoise + and (args.prompt_audio is not None or args.reference_audio is not None), ) sf.write(str(output_path), audio_array, model.tts_model.sample_rate) @@ -153,31 +270,24 @@ def cmd_clone(args): print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr) -def cmd_synthesize(args): - if not args.text: - sys.exit("Error: Please provide --text for synthesis") - - output_path = validate_output_path(args.output) - model = load_model(args) - - audio_array = model.generate( - text=args.text, - prompt_wav_path=None, - prompt_text=None, - cfg_value=args.cfg_value, - inference_timesteps=args.inference_timesteps, - normalize=args.normalize, - denoise=False, +def cmd_design(args, parser): + validate_design_args(args, parser) + final_text = build_final_text(args.text, args.control) + return _run_single( + args, parser, text=final_text, output=args.output, prompt_text=None ) - sf.write(str(output_path), audio_array, model.tts_model.sample_rate) - duration = len(audio_array) / model.tts_model.sample_rate - print(f"Saved audio to: {output_path} ({duration:.2f}s)", file=sys.stderr) +def cmd_clone(args, parser): + prompt_text = validate_clone_args(args, parser) + final_text = build_final_text(args.text, args.control) + return _run_single( + args, parser, text=final_text, output=args.output, prompt_text=prompt_text + ) -def cmd_batch(args): - input_file = validate_file_exists(args.input, "input file") +def cmd_batch(args, parser): + input_file = require_file_exists(args.input, parser, "input file") output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -187,29 +297,36 @@ def cmd_batch(args): if not texts: sys.exit("Error: Input file is empty") + prompt_text = validate_batch_args(args, parser) model = load_model(args) prompt_audio_path = None if args.prompt_audio: - prompt_audio_path = str(validate_file_exists(args.prompt_audio, "prompt audio file")) + prompt_audio_path = str( + require_file_exists(args.prompt_audio, parser, "prompt audio file") + ) reference_audio_path = None if args.reference_audio: - reference_audio_path = str(validate_file_exists(args.reference_audio, "reference audio file")) + reference_audio_path = str( + require_file_exists(args.reference_audio, parser, "reference audio file") + ) success_count = 0 for i, text in enumerate(texts, 1): try: + final_text = build_final_text(text, args.control) audio_array = model.generate( - text=text, + text=final_text, prompt_wav_path=prompt_audio_path, - prompt_text=args.prompt_text, + prompt_text=prompt_text, reference_wav_path=reference_audio_path, cfg_value=args.cfg_value, inference_timesteps=args.inference_timesteps, normalize=args.normalize, - denoise=args.denoise and (prompt_audio_path is not None or reference_audio_path is not None), + denoise=args.denoise + and (prompt_audio_path is not None or reference_audio_path is not None), ) output_file = output_dir / f"output_{i:03d}.wav" @@ -230,97 +347,251 @@ def cmd_batch(args): # ----------------------------- -def _build_unified_parser(): +def _add_common_generation_args(parser): + parser.add_argument("--text", "-t", help="Text to synthesize") + parser.add_argument( + "--control", + type=str, + help="Control instruction for VoxCPM2 voice design/cloning", + ) + parser.add_argument( + "--cfg-value", + type=float, + default=2.0, + help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)", + ) + parser.add_argument( + "--inference-timesteps", + type=int, + default=10, + help="Inference steps (int, 1–100, default: 10)", + ) + parser.add_argument( + "--normalize", action="store_true", help="Enable text normalization" + ) + + +def _add_prompt_reference_args(parser): + parser.add_argument( + "--prompt-audio", + "-pa", + help="Prompt audio file path (continuation mode, requires --prompt-text or --prompt-file)", + ) + parser.add_argument( + "--prompt-text", "-pt", help="Text corresponding to the prompt audio" + ) + parser.add_argument( + "--prompt-file", type=str, help="Text file corresponding to the prompt audio" + ) + parser.add_argument( + "--reference-audio", + "-ra", + help="Reference audio for voice cloning (VoxCPM2 only)", + ) + parser.add_argument( + "--denoise", + action="store_true", + help="Enable prompt/reference speech enhancement", + ) + + +def _add_model_args(parser): + parser.add_argument("--model-path", type=str, help="Local VoxCPM model path") + parser.add_argument( + "--hf-model-id", + type=str, + default=DEFAULT_HF_MODEL_ID, + help=f"Hugging Face repo id (default: {DEFAULT_HF_MODEL_ID})", + ) + parser.add_argument( + "--cache-dir", type=str, help="Cache directory for Hub downloads" + ) + parser.add_argument( + "--local-files-only", action="store_true", help="Disable network access" + ) + parser.add_argument( + "--no-denoiser", action="store_true", help="Disable denoiser model loading" + ) + parser.add_argument( + "--no-optimize", + action="store_true", + help="Disable model optimization during loading", + ) + parser.add_argument( + "--zipenhancer-path", + type=str, + help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)", + ) + + +def _add_lora_args(parser): + parser.add_argument("--lora-path", type=str, help="Path to LoRA weights") + parser.add_argument( + "--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)" + ) + parser.add_argument( + "--lora-alpha", + type=int, + default=16, + help="LoRA alpha (positive int, default: 16)", + ) + parser.add_argument( + "--lora-dropout", + type=float, + default=0.0, + help="LoRA dropout rate (0.0–1.0, default: 0.0)", + ) + parser.add_argument( + "--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers" + ) + parser.add_argument( + "--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers" + ) + parser.add_argument( + "--lora-enable-proj", + action="store_true", + help="Enable LoRA on projection layers", + ) + + +def _build_parser(): parser = argparse.ArgumentParser( - description="VoxCPM CLI - voice cloning, direct TTS, and batch processing", + description="VoxCPM CLI - VoxCPM2-first voice design, cloning, and batch processing", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: - voxcpm --text "Hello world" --output out.wav - voxcpm --text "Hello" --prompt-audio ref.wav --prompt-text "hi" --output out.wav --denoise - voxcpm --input texts.txt --output-dir ./outs + voxcpm design --text "Hello world" --output out.wav + voxcpm design --text "Hello world" --control "warm female voice" --output out.wav + voxcpm clone --text "Hello" --reference-audio ref.wav --output out.wav + voxcpm batch --input texts.txt --output-dir ./outs --reference-audio ref.wav """, ) - # Mode selection + subparsers = parser.add_subparsers(dest="command") + + design_parser = subparsers.add_parser( + "design", help="Generate speech with VoxCPM2-first voice design" + ) + _add_common_generation_args(design_parser) + _add_prompt_reference_args(design_parser) + _add_model_args(design_parser) + _add_lora_args(design_parser) + design_parser.add_argument( + "--output", "-o", required=True, help="Output audio file path" + ) + + clone_parser = subparsers.add_parser( + "clone", help="Clone a voice with reference/prompt audio" + ) + _add_common_generation_args(clone_parser) + _add_prompt_reference_args(clone_parser) + _add_model_args(clone_parser) + _add_lora_args(clone_parser) + clone_parser.add_argument( + "--output", "-o", required=True, help="Output audio file path" + ) + + batch_parser = subparsers.add_parser( + "batch", help="Batch-generate one line per output file" + ) + batch_parser.add_argument( + "--input", "-i", required=True, help="Input text file (one text per line)" + ) + batch_parser.add_argument( + "--output-dir", "-od", required=True, help="Output directory" + ) + batch_parser.add_argument( + "--control", + type=str, + help="Control instruction for VoxCPM2 voice design/cloning", + ) + _add_prompt_reference_args(batch_parser) + batch_parser.add_argument( + "--cfg-value", + type=float, + default=2.0, + help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)", + ) + batch_parser.add_argument( + "--inference-timesteps", + type=int, + default=10, + help="Inference steps (int, 1–100, default: 10)", + ) + batch_parser.add_argument( + "--normalize", action="store_true", help="Enable text normalization" + ) + _add_model_args(batch_parser) + _add_lora_args(batch_parser) + + # Legacy root arguments parser.add_argument("--input", "-i", help="Input text file (batch mode only)") - parser.add_argument("--output-dir", "-od", help="Output directory (batch mode only)") - parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)") - parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)") - - # Prompt / Reference parser.add_argument( - "--prompt-audio", "-pa", help="Prompt audio file path (continuation mode, requires --prompt-text)" + "--output-dir", "-od", help="Output directory (batch mode only)" ) - parser.add_argument("--prompt-text", "-pt", help="Text corresponding to the prompt audio") + _add_common_generation_args(parser) parser.add_argument( - "--reference-audio", "-ra", help="Reference audio for voice cloning (isolated mode, VoxCPM2 only)" + "--output", "-o", help="Output audio file path (single or clone mode)" ) - parser.add_argument("--denoise", action="store_true", help="Enable prompt/reference speech enhancement") - - # Generation parameters - parser.add_argument( - "--cfg-value", type=float, default=2.0, help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)" - ) - parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (int, 1–100, default: 10)") - parser.add_argument("--normalize", action="store_true", help="Enable text normalization") - - # Model loading - parser.add_argument("--model-path", type=str, help="Local VoxCPM model path") - parser.add_argument( - "--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (default: openbmb/VoxCPM1.5)" - ) - parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads") - parser.add_argument("--local-files-only", action="store_true", help="Disable network access") - parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading") - parser.add_argument( - "--zipenhancer-path", type=str, help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)" - ) - - # LoRA - parser.add_argument("--lora-path", type=str, help="Path to LoRA weights") - parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)") - parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)") - parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (0.0–1.0, default: 0.0)") - parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers") - parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers") - parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers") + _add_prompt_reference_args(parser) + _add_model_args(parser) + _add_lora_args(parser) return parser +def _dispatch_legacy(args, parser): + warn_legacy_mode() + + if args.input and args.text: + parser.error( + "Use either batch mode (--input) or single mode (--text), not both." + ) + + if args.input: + if not args.output_dir: + parser.error("Batch mode requires --output-dir") + return cmd_batch(args, parser) + + if not args.text or not args.output: + parser.error("Single-sample legacy mode requires --text and --output") + + if ( + args.prompt_audio + or args.prompt_text + or args.prompt_file + or args.reference_audio + ): + return cmd_clone(args, parser) + + return cmd_design(args, parser) + + # ----------------------------- # Entrypoint # ----------------------------- def main(): - parser = _build_unified_parser() + parser = _build_parser() args = parser.parse_args() - # Validate ranges validate_ranges(args, parser) - # Mode conflict checks - if args.input and args.text: - parser.error("Use either batch mode (--input) or single mode (--text), not both.") + if args.command == "design": + if not args.text: + parser.error("`design` requires --text") + return cmd_design(args, parser) - # Batch mode - if args.input: - if not args.output_dir: - parser.error("Batch mode requires --output-dir") - return cmd_batch(args) + if args.command == "clone": + if not args.text or not args.output: + parser.error("`clone` requires --text and --output") + return cmd_clone(args, parser) - # Single mode - if not args.text or not args.output: - parser.error("Single-sample mode requires --text and --output") + if args.command == "batch": + return cmd_batch(args, parser) - # Clone mode (prompt continuation, reference isolation, or both) - if args.prompt_audio or args.prompt_text or args.reference_audio: - return cmd_clone(args) - - # Direct synthesis - return cmd_synthesize(args) + return _dispatch_legacy(args, parser) if __name__ == "__main__": diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 0000000..509ef6d --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,512 @@ +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path + +import numpy as np +import pytest + +ROOT = Path(__file__).resolve().parents[1] +CLI_PATH = ROOT / "src" / "voxcpm" / "cli.py" +V1_MODEL_PATH = ROOT / "models" / "openbmb__VoxCPM1.5" +V2_MODEL_PATH = ROOT / "models" / "VoxCPM2-1B-newaudiovae-6hz-nope-sft" + + +pkg = types.ModuleType("voxcpm") +pkg.__path__ = [str(ROOT / "src" / "voxcpm")] +sys.modules.setdefault("voxcpm", pkg) + +core_stub = types.ModuleType("voxcpm.core") + + +class StubVoxCPM: + pass + + +core_stub.VoxCPM = StubVoxCPM +sys.modules["voxcpm.core"] = core_stub + +spec = importlib.util.spec_from_file_location("voxcpm.cli", CLI_PATH) +cli = importlib.util.module_from_spec(spec) +sys.modules["voxcpm.cli"] = cli +assert spec.loader is not None +spec.loader.exec_module(cli) + + +class DummyTTSModel: + sample_rate = 16000 + + +class DummyModel: + def __init__(self): + self.tts_model = DummyTTSModel() + self.calls = [] + + def generate(self, **kwargs): + self.calls.append(kwargs) + return np.zeros(160, dtype=np.float32) + + +def run_main(monkeypatch, argv): + monkeypatch.setattr(sys, "argv", ["voxcpm", *argv]) + cli.main() + + +def test_parser_defaults_to_voxcpm2(): + parser = cli._build_parser() + args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"]) + assert args.hf_model_id == "openbmb/VoxCPM2" + assert args.no_optimize is False + + +def test_load_model_respects_no_optimize_for_local_model(monkeypatch): + calls = {} + + class FakeVoxCPM: + def __init__(self, **kwargs): + calls["kwargs"] = kwargs + self.tts_model = DummyTTSModel() + + monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM) + args = cli._build_parser().parse_args( + [ + "design", + "--text", + "hello", + "--output", + "out.wav", + "--model-path", + str(V2_MODEL_PATH), + "--no-optimize", + ] + ) + + cli.load_model(args) + + assert calls["kwargs"]["optimize"] is False + + +def test_load_model_defaults_optimize_for_hf(monkeypatch): + calls = {} + + class FakeVoxCPM: + @classmethod + def from_pretrained(cls, **kwargs): + calls["kwargs"] = kwargs + return DummyModel() + + monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM) + args = cli._build_parser().parse_args( + [ + "design", + "--text", + "hello", + "--output", + "out.wav", + ] + ) + + cli.load_model(args) + + assert calls["kwargs"]["optimize"] is True + + +def test_load_model_respects_no_optimize_for_hf(monkeypatch): + calls = {} + + class FakeVoxCPM: + @classmethod + def from_pretrained(cls, **kwargs): + calls["kwargs"] = kwargs + return DummyModel() + + monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM) + args = cli._build_parser().parse_args( + [ + "design", + "--text", + "hello", + "--output", + "out.wav", + "--no-optimize", + ] + ) + + cli.load_model(args) + + assert calls["kwargs"]["optimize"] is False + + +def test_design_subcommand_applies_control(monkeypatch, tmp_path): + dummy_model = DummyModel() + monkeypatch.setattr(cli, "load_model", lambda args: dummy_model) + monkeypatch.setattr(cli.sf, "write", lambda *args, **kwargs: None) + + run_main( + monkeypatch, + [ + "design", + "--text", + "hello", + "--control", + "warm female voice", + "--output", + str(tmp_path / "out.wav"), + ], + ) + + assert dummy_model.calls[0]["text"] == "(warm female voice)hello" + assert dummy_model.calls[0]["prompt_wav_path"] is None + assert dummy_model.calls[0]["reference_wav_path"] is None + + +def test_clone_subcommand_reads_prompt_file(monkeypatch, tmp_path): + dummy_model = DummyModel() + prompt_audio = tmp_path / "prompt.wav" + prompt_audio.write_bytes(b"RIFF") + prompt_file = tmp_path / "prompt.txt" + prompt_file.write_text("prompt transcript\n", encoding="utf-8") + + monkeypatch.setattr(cli, "load_model", lambda args: dummy_model) + monkeypatch.setattr(cli.sf, "write", lambda *args, **kwargs: None) + + run_main( + monkeypatch, + [ + "clone", + "--text", + "hello", + "--prompt-audio", + str(prompt_audio), + "--prompt-file", + str(prompt_file), + "--output", + str(tmp_path / "out.wav"), + ], + ) + + assert dummy_model.calls[0]["prompt_wav_path"] == str(prompt_audio) + assert dummy_model.calls[0]["prompt_text"] == "prompt transcript" + + +def test_clone_rejects_reference_audio_for_v1_local_model(monkeypatch, tmp_path): + reference_audio = tmp_path / "ref.wav" + reference_audio.write_bytes(b"RIFF") + monkeypatch.setattr( + sys, + "argv", + [ + "voxcpm", + "clone", + "--text", + "hello", + "--reference-audio", + str(reference_audio), + "--model-path", + str(V1_MODEL_PATH), + "--output", + str(tmp_path / "out.wav"), + ], + ) + + with pytest.raises(SystemExit): + cli.main() + + +def test_clone_rejects_reference_audio_for_v1_hf_model_id(monkeypatch, tmp_path): + reference_audio = tmp_path / "ref.wav" + reference_audio.write_bytes(b"RIFF") + monkeypatch.setattr( + sys, + "argv", + [ + "voxcpm", + "clone", + "--text", + "hello", + "--reference-audio", + str(reference_audio), + "--hf-model-id", + "openbmb/VoxCPM1.5", + "--output", + str(tmp_path / "out.wav"), + ], + ) + + with pytest.raises(SystemExit): + cli.main() + + +def test_legacy_root_args_still_work_and_warn(monkeypatch, tmp_path, capsys): + dummy_model = DummyModel() + monkeypatch.setattr(cli, "load_model", lambda args: dummy_model) + monkeypatch.setattr(cli.sf, "write", lambda *args, **kwargs: None) + + run_main( + monkeypatch, + [ + "--text", + "hello", + "--output", + str(tmp_path / "out.wav"), + ], + ) + + captured = capsys.readouterr() + assert "deprecated" in captured.err + assert dummy_model.calls[0]["text"] == "hello" + + +def test_batch_subcommand_applies_control(monkeypatch, tmp_path): + dummy_model = DummyModel() + input_file = tmp_path / "texts.txt" + input_file.write_text("hello\nworld\n", encoding="utf-8") + + monkeypatch.setattr(cli, "load_model", lambda args: dummy_model) + monkeypatch.setattr(cli.sf, "write", lambda *args, **kwargs: None) + + run_main( + monkeypatch, + [ + "batch", + "--input", + str(input_file), + "--output-dir", + str(tmp_path / "outs"), + "--control", + "calm narrator", + ], + ) + + assert [call["text"] for call in dummy_model.calls] == [ + "(calm narrator)hello", + "(calm narrator)world", + ] + + +def test_legacy_clone_with_prompt_file_still_works(monkeypatch, tmp_path, capsys): + dummy_model = DummyModel() + prompt_audio = tmp_path / "prompt.wav" + prompt_audio.write_bytes(b"RIFF") + prompt_file = tmp_path / "prompt.txt" + prompt_file.write_text("legacy transcript", encoding="utf-8") + + monkeypatch.setattr(cli, "load_model", lambda args: dummy_model) + monkeypatch.setattr(cli.sf, "write", lambda *args, **kwargs: None) + + run_main( + monkeypatch, + [ + "--text", + "hello", + "--prompt-audio", + str(prompt_audio), + "--prompt-file", + str(prompt_file), + "--output", + str(tmp_path / "out.wav"), + ], + ) + + captured = capsys.readouterr() + assert "deprecated" in captured.err + assert dummy_model.calls[0]["prompt_text"] == "legacy transcript" + + +def test_invalid_prompt_text_and_prompt_file_combination(monkeypatch, tmp_path, capsys): + prompt_audio = tmp_path / "prompt.wav" + prompt_audio.write_bytes(b"RIFF") + prompt_file = tmp_path / "prompt.txt" + prompt_file.write_text("transcript", encoding="utf-8") + + monkeypatch.setattr( + sys, + "argv", + [ + "voxcpm", + "clone", + "--text", + "hello", + "--prompt-audio", + str(prompt_audio), + "--prompt-text", + "inline transcript", + "--prompt-file", + str(prompt_file), + "--output", + str(tmp_path / "out.wav"), + ], + ) + + with pytest.raises(SystemExit): + cli.main() + + assert "Use either --prompt-text or --prompt-file" in capsys.readouterr().err + + +def test_missing_prompt_file_reports_parser_error(monkeypatch, tmp_path, capsys): + prompt_audio = tmp_path / "prompt.wav" + prompt_audio.write_bytes(b"RIFF") + monkeypatch.setattr( + sys, + "argv", + [ + "voxcpm", + "clone", + "--text", + "hello", + "--prompt-audio", + str(prompt_audio), + "--prompt-file", + str(tmp_path / "missing.txt"), + "--output", + str(tmp_path / "out.wav"), + ], + ) + + with pytest.raises(SystemExit): + cli.main() + + assert "prompt text file" in capsys.readouterr().err + + +def test_design_rejects_prompt_audio_args(monkeypatch, tmp_path, capsys): + prompt_audio = tmp_path / "prompt.wav" + prompt_audio.write_bytes(b"RIFF") + monkeypatch.setattr( + sys, + "argv", + [ + "voxcpm", + "design", + "--text", + "hello", + "--prompt-audio", + str(prompt_audio), + "--prompt-text", + "transcript", + "--output", + str(tmp_path / "out.wav"), + ], + ) + + with pytest.raises(SystemExit): + cli.main() + + assert "does not accept prompt/reference audio" in capsys.readouterr().err + + +def test_clone_rejects_prompt_audio_without_transcript(monkeypatch, tmp_path, capsys): + prompt_audio = tmp_path / "prompt.wav" + prompt_audio.write_bytes(b"RIFF") + monkeypatch.setattr( + sys, + "argv", + [ + "voxcpm", + "clone", + "--text", + "hello", + "--prompt-audio", + str(prompt_audio), + "--output", + str(tmp_path / "out.wav"), + ], + ) + + with pytest.raises(SystemExit): + cli.main() + + assert ( + "--prompt-audio requires --prompt-text or --prompt-file" + in capsys.readouterr().err + ) + + +def test_clone_rejects_transcript_without_prompt_audio(monkeypatch, tmp_path, capsys): + monkeypatch.setattr( + sys, + "argv", + [ + "voxcpm", + "clone", + "--text", + "hello", + "--prompt-text", + "transcript", + "--output", + str(tmp_path / "out.wav"), + ], + ) + + with pytest.raises(SystemExit): + cli.main() + + assert ( + "--prompt-text/--prompt-file requires --prompt-audio" in capsys.readouterr().err + ) + + +def test_batch_rejects_control_with_prompt_transcript(monkeypatch, tmp_path, capsys): + input_file = tmp_path / "texts.txt" + input_file.write_text("hello\n", encoding="utf-8") + prompt_audio = tmp_path / "prompt.wav" + prompt_audio.write_bytes(b"RIFF") + monkeypatch.setattr( + sys, + "argv", + [ + "voxcpm", + "batch", + "--input", + str(input_file), + "--output-dir", + str(tmp_path / "outs"), + "--control", + "calm narrator", + "--prompt-audio", + str(prompt_audio), + "--prompt-text", + "transcript", + ], + ) + + with pytest.raises(SystemExit): + cli.main() + + assert "--control cannot be used together" in capsys.readouterr().err + + +def test_detect_model_architecture_uses_local_configs(): + parser = cli._build_parser() + v1_args = parser.parse_args( + [ + "clone", + "--text", + "hello", + "--reference-audio", + "ref.wav", + "--model-path", + str(V1_MODEL_PATH), + "--output", + "out.wav", + ] + ) + v2_args = parser.parse_args( + [ + "clone", + "--text", + "hello", + "--reference-audio", + "ref.wav", + "--model-path", + str(V2_MODEL_PATH), + "--output", + "out.wav", + ] + ) + + assert cli.detect_model_architecture(v1_args) == "voxcpm" + assert cli.detect_model_architecture(v2_args) == "voxcpm2"