From 5611bd08a03503319be3ce44b3a8c4a6ffb2b53f Mon Sep 17 00:00:00 2001 From: Labmem-Zhouyx <913703649@qq.com> Date: Thu, 9 Apr 2026 00:30:19 +0800 Subject: [PATCH] optim app.py --- README.md | 4 ++-- README_zh.md | 2 +- app.py | 44 +++++++++++--------------------------------- 3 files changed, 14 insertions(+), 36 deletions(-) diff --git a/README.md b/README.md index 1bebc7e..2eb8224 100644 --- a/README.md +++ b/README.md @@ -238,8 +238,8 @@ voxcpm --help ### Web Demo -```bash -python app.py --model-dir /path/to/VoxCPM2 --port 8808 # use a local model directory, open http://localhost:8808 +```bash +python app.py --port 8808 # then open in browser: http://localhost:8808 ``` ### 🚢 Production Deployment (Nano-vLLM) diff --git a/README_zh.md b/README_zh.md index 907812d..b990bb5 100644 --- a/README_zh.md +++ b/README_zh.md @@ -238,7 +238,7 @@ voxcpm --help ### Web Demo ```bash -python app.py --model-dir /path/to/VoxCPM2 --port 8808 # 指定本地模型路径,然后打开 http://localhost:8808 +python app.py --port 8808 # 然后在浏览器打开 http://localhost:8808 ``` ### 🚢 生产部署(Nano-vLLM) diff --git a/app.py b/app.py index 1503b9d..ac008bc 100644 --- a/app.py +++ b/app.py @@ -9,8 +9,6 @@ from funasr import AutoModel from pathlib import Path os.environ["TOKENIZERS_PARALLELISM"] = "false" -if os.environ.get("HF_REPO_ID", "").strip() == "": - os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2" import voxcpm @@ -221,7 +219,7 @@ _APP_THEME = gr.themes.Soft( # ---------- Model ---------- class VoxCPMDemo: - def __init__(self, model_dir: Optional[str] = None) -> None: + def __init__(self, model_id: str = "openbmb/VoxCPM2") -> None: self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Running on device: {self.device}") @@ -234,36 +232,13 @@ class VoxCPMDemo: ) self.voxcpm_model: Optional[voxcpm.VoxCPM] = None - self.explicit_model_dir = model_dir - - def _resolve_model_dir(self) -> str: - if self.explicit_model_dir and os.path.isdir(self.explicit_model_dir): - return self.explicit_model_dir - env_model_dir = os.environ.get("VOXCPM_MODEL_DIR", "").strip() - if env_model_dir and os.path.isdir(env_model_dir): - return env_model_dir - repo_id = os.environ.get("HF_REPO_ID", "").strip() - if len(repo_id) > 0: - target_dir = os.path.join("models", repo_id.replace("/", "__")) - if not os.path.isdir(target_dir): - try: - from huggingface_hub import snapshot_download - os.makedirs(target_dir, exist_ok=True) - logger.info(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...") - snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False) - except Exception as e: - logger.warning(f"HF download failed: {e}. Falling back to 'models'.") - return "models" - return target_dir - return "models" + self._model_id = model_id def get_or_load_voxcpm(self) -> voxcpm.VoxCPM: if self.voxcpm_model is not None: return self.voxcpm_model - logger.info("Model not loaded, initializing...") - model_dir = self._resolve_model_dir() - logger.info(f"Using model dir: {model_dir}") - self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=True) + logger.info(f"Loading model: {self._model_id}") + self.voxcpm_model = voxcpm.VoxCPM.from_pretrained(self._model_id, optimize=True) logger.info("Model loaded successfully.") return self.voxcpm_model @@ -507,9 +482,9 @@ def run_demo( server_name: str = "0.0.0.0", server_port: int = 8808, show_error: bool = True, - model_dir: Optional[str] = None, + model_id: str = "openbmb/VoxCPM2", ): - demo = VoxCPMDemo(model_dir=model_dir) + demo = VoxCPMDemo(model_id=model_id) interface = create_demo_interface(demo) interface.queue(max_size=10, default_concurrency_limit=1).launch( server_name=server_name, @@ -524,7 +499,10 @@ def run_demo( if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("--model-dir", type=str, default=None, help="Path to VoxCPM2 checkpoint directory") + parser.add_argument( + "--model-id", type=str, default="openbmb/VoxCPM2", + help="Local path or HuggingFace repo ID (default: openbmb/VoxCPM2)", + ) parser.add_argument("--port", type=int, default=8808, help="Server port") args = parser.parse_args() - run_demo(model_dir=args.model_dir, server_port=args.port) + run_demo(model_id=args.model_id, server_port=args.port)