Files
VoxCPM/src/voxcpm/core.py
T

334 lines
14 KiB
Python
Raw Normal View History

2025-09-16 11:46:47 +08:00
import os
import sys
2025-09-18 19:23:13 +08:00
import re
2026-03-31 11:50:37 +08:00
import json
2025-09-16 11:46:47 +08:00
import tempfile
2025-09-19 16:56:11 -04:00
import numpy as np
2025-12-05 22:22:13 +08:00
from typing import Generator, Optional
2025-09-16 11:46:47 +08:00
from huggingface_hub import snapshot_download
2025-12-05 22:22:13 +08:00
from .model.voxcpm import VoxCPMModel, LoRAConfig
2026-03-31 11:50:37 +08:00
from .model.voxcpm2 import VoxCPM2Model
2025-09-16 11:46:47 +08:00
class VoxCPM:
2026-03-31 11:50:37 +08:00
def __init__(
self,
voxcpm_model_path: str,
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser: bool = True,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
2025-09-16 11:46:47 +08:00
"""Initialize VoxCPM TTS pipeline.
Args:
voxcpm_model_path: Local filesystem path to the VoxCPM model assets
(weights, configs, etc.). Typically the directory returned by
a prior download step.
zipenhancer_model_path: ModelScope acoustic noise suppression model
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
2025-09-19 16:56:11 -04:00
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
2026-03-31 11:50:37 +08:00
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
2025-12-05 22:22:13 +08:00
provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
2025-09-16 11:46:47 +08:00
"""
2026-03-31 11:50:37 +08:00
print(
f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
file=sys.stderr,
)
2025-12-05 22:22:13 +08:00
# If lora_weights_path is provided but no lora_config, create a default one
if lora_weights_path is not None and lora_config is None:
lora_config = LoRAConfig(
enable_lm=True,
enable_dit=True,
enable_proj=False,
)
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
2026-03-31 11:50:37 +08:00
# Determine model type from config.json architecture field
config_path = os.path.join(voxcpm_model_path, "config.json")
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
arch = config.get("architecture", "voxcpm").lower()
if arch == "voxcpm2":
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
print("Loaded VoxCPM2Model", file=sys.stderr)
elif arch == "voxcpm":
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
print("Loaded VoxCPMModel", file=sys.stderr)
else:
raise ValueError(f"Unsupported architecture: {arch}")
2025-12-05 22:22:13 +08:00
# Load LoRA weights if path is provided
if lora_weights_path is not None:
print(f"Loading LoRA weights from: {lora_weights_path}", file=sys.stderr)
2025-12-05 22:22:13 +08:00
loaded_keys, skipped_keys = self.tts_model.load_lora_weights(lora_weights_path)
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
2026-03-31 11:50:37 +08:00
2025-09-16 22:16:40 +08:00
self.text_normalizer = None
2026-03-31 11:50:37 +08:00
self.denoiser = None
2025-09-16 11:46:47 +08:00
if enable_denoiser and zipenhancer_model_path is not None:
2025-09-16 16:46:44 +08:00
from .zipenhancer import ZipEnhancer
2026-03-31 11:50:37 +08:00
2025-09-16 16:46:44 +08:00
self.denoiser = ZipEnhancer(zipenhancer_model_path)
2025-09-16 11:46:47 +08:00
else:
self.denoiser = None
if optimize:
print("Warm up VoxCPMModel...", file=sys.stderr)
self.tts_model.generate(
target_text="Hello, this is the first test sentence.",
max_len=10,
)
2025-09-16 11:46:47 +08:00
@classmethod
2026-03-31 11:50:37 +08:00
def from_pretrained(
cls,
hf_model_id: str = "openbmb/VoxCPM2",
load_denoiser: bool = True,
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
):
2025-09-16 11:46:47 +08:00
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
Args:
2025-09-16 16:46:44 +08:00
hf_model_id: Explicit Hugging Face repository id (e.g. "org/repo") or local path.
2025-09-16 11:46:47 +08:00
load_denoiser: Whether to initialize the denoiser pipeline.
2025-12-05 21:00:01 +08:00
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
2025-09-16 11:46:47 +08:00
zipenhancer_model_id: Denoiser model id or path for ModelScope
acoustic noise suppression.
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
2026-03-31 11:50:37 +08:00
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
2025-12-05 22:22:13 +08:00
provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
containing lora_weights.ckpt). If provided, LoRA weights will be loaded
after model initialization.
2025-09-19 16:56:11 -04:00
Kwargs:
Additional keyword arguments passed to the ``VoxCPM`` constructor.
2025-09-16 11:46:47 +08:00
Returns:
VoxCPM: Initialized instance whose ``voxcpm_model_path`` points to
the downloaded snapshot directory.
Raises:
ValueError: If neither a valid ``hf_model_id`` nor a resolvable
``hf_model_id`` is provided.
"""
repo_id = hf_model_id
2025-09-16 16:46:44 +08:00
if not repo_id:
raise ValueError("You must provide hf_model_id")
2026-03-31 11:50:37 +08:00
2025-09-16 16:46:44 +08:00
# Load from local path if provided
if os.path.isdir(repo_id):
local_path = repo_id
else:
# Otherwise, try from_pretrained (Hub); exit on failure
local_path = snapshot_download(
repo_id=repo_id,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
2025-09-16 11:46:47 +08:00
return cls(
voxcpm_model_path=local_path,
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser,
2025-12-05 21:00:01 +08:00
optimize=optimize,
2025-12-05 22:22:13 +08:00
lora_config=lora_config,
lora_weights_path=lora_weights_path,
2025-09-19 16:56:11 -04:00
**kwargs,
2025-09-16 11:46:47 +08:00
)
2025-09-19 16:56:11 -04:00
def generate(self, *args, **kwargs) -> np.ndarray:
return next(self._generate(*args, streaming=False, **kwargs))
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
return self._generate(*args, streaming=True, **kwargs)
2026-03-31 11:50:37 +08:00
def _generate(
self,
text: str,
prompt_wav_path: str = None,
prompt_text: str = None,
reference_wav_path: str = None,
cfg_value: float = 2.0,
inference_timesteps: int = 10,
min_len: int = 2,
max_len: int = 4096,
normalize: bool = False,
denoise: bool = False,
retry_badcase: bool = True,
retry_badcase_max_times: int = 3,
retry_badcase_ratio_threshold: float = 6.0,
streaming: bool = False,
) -> Generator[np.ndarray, None, None]:
2025-09-16 11:46:47 +08:00
"""Synthesize speech for the given text and return a single waveform.
Args:
2026-03-31 11:50:37 +08:00
text: Input text to synthesize.
prompt_wav_path: Path to prompt audio for continuation mode.
Must be paired with ``prompt_text``.
2025-09-16 11:46:47 +08:00
prompt_text: Text content corresponding to the prompt audio.
2026-03-31 11:50:37 +08:00
reference_wav_path: Path to reference audio for voice cloning
(structurally isolated via ref_audio tokens). Can be used
alone or combined with ``prompt_wav_path`` + ``prompt_text``.
2025-09-16 11:46:47 +08:00
cfg_value: Guidance scale for the generation model.
inference_timesteps: Number of inference steps.
2026-03-31 11:50:37 +08:00
min_len: Minimum audio length.
2025-12-05 22:06:15 +08:00
max_len: Maximum token length during generation.
2025-09-16 11:46:47 +08:00
normalize: Whether to run text normalization before generation.
2026-03-31 11:50:37 +08:00
denoise: Whether to denoise the prompt/reference audio if a
denoiser is available.
2025-09-16 11:46:47 +08:00
retry_badcase: Whether to retry badcase.
retry_badcase_max_times: Maximum number of times to retry badcase.
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
2025-09-19 16:56:11 -04:00
streaming: Whether to return a generator of audio chunks.
2025-09-16 11:46:47 +08:00
Returns:
2026-03-31 11:50:37 +08:00
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
Yields audio chunks for each generation step if ``streaming=True``,
2025-09-19 16:56:11 -04:00
otherwise yields a single array containing the final audio.
2025-09-16 11:46:47 +08:00
"""
2025-09-18 12:01:26 +08:00
if not text.strip() or not isinstance(text, str):
raise ValueError("target text must be a non-empty string")
2026-03-31 11:50:37 +08:00
2025-09-18 12:01:26 +08:00
if prompt_wav_path is not None:
if not os.path.exists(prompt_wav_path):
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
2026-03-31 11:50:37 +08:00
if reference_wav_path is not None:
if not os.path.exists(reference_wav_path):
raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
2025-09-18 12:01:26 +08:00
if (prompt_wav_path is None) != (prompt_text is None):
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
2026-03-31 11:50:37 +08:00
is_v2 = isinstance(self.tts_model, VoxCPM2Model)
if reference_wav_path is not None and not is_v2:
raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
2025-09-18 14:52:22 +08:00
text = text.replace("\n", " ")
2026-03-31 11:50:37 +08:00
text = re.sub(r"\s+", " ", text)
temp_files = []
2025-09-16 11:46:47 +08:00
try:
2026-03-31 11:50:37 +08:00
actual_prompt_path = prompt_wav_path
actual_ref_path = reference_wav_path
if denoise and self.denoiser is not None:
if prompt_wav_path is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
temp_files.append(tmp.name)
self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
actual_prompt_path = temp_files[-1]
if reference_wav_path is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
temp_files.append(tmp.name)
self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
actual_ref_path = temp_files[-1]
if actual_prompt_path is not None or actual_ref_path is not None:
if is_v2:
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_text=prompt_text,
prompt_wav_path=actual_prompt_path,
reference_wav_path=actual_ref_path,
)
else:
fixed_prompt_cache = self.tts_model.build_prompt_cache(
prompt_text=prompt_text,
prompt_wav_path=actual_prompt_path,
)
2025-09-16 11:46:47 +08:00
else:
2026-03-31 11:50:37 +08:00
fixed_prompt_cache = None
2025-09-18 12:01:26 +08:00
if normalize:
if self.text_normalizer is None:
from .utils.text_normalize import TextNormalizer
2026-03-31 11:50:37 +08:00
2025-09-18 12:01:26 +08:00
self.text_normalizer = TextNormalizer()
text = self.text_normalizer.normalize(text)
2026-03-31 11:50:37 +08:00
2025-09-19 16:56:11 -04:00
generate_result = self.tts_model._generate_with_prompt_cache(
2026-03-31 11:50:37 +08:00
target_text=text,
prompt_cache=fixed_prompt_cache,
min_len=min_len,
max_len=max_len,
inference_timesteps=inference_timesteps,
cfg_value=cfg_value,
retry_badcase=retry_badcase,
retry_badcase_max_times=retry_badcase_max_times,
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
streaming=streaming,
)
2025-09-19 16:56:11 -04:00
for wav, _, _ in generate_result:
yield wav.squeeze(0).cpu().numpy()
2026-03-31 11:50:37 +08:00
2025-09-16 11:46:47 +08:00
finally:
2026-03-31 11:50:37 +08:00
for tmp_path in temp_files:
if tmp_path and os.path.exists(tmp_path):
try:
os.unlink(tmp_path)
except OSError:
pass
2025-12-05 22:22:13 +08:00
# ------------------------------------------------------------------ #
# LoRA Interface (delegated to VoxCPMModel)
# ------------------------------------------------------------------ #
def load_lora(self, lora_weights_path: str) -> tuple:
"""Load LoRA weights from a checkpoint file.
2026-03-31 11:50:37 +08:00
2025-12-05 22:22:13 +08:00
Args:
lora_weights_path: Path to LoRA weights (.pth file or directory
containing lora_weights.ckpt).
2026-03-31 11:50:37 +08:00
2025-12-05 22:22:13 +08:00
Returns:
tuple: (loaded_keys, skipped_keys) - lists of loaded and skipped parameter names.
2026-03-31 11:50:37 +08:00
2025-12-05 22:22:13 +08:00
Raises:
RuntimeError: If model was not initialized with LoRA config.
"""
if self.tts_model.lora_config is None:
raise RuntimeError(
"Cannot load LoRA weights: model was not initialized with LoRA config. "
"Please reinitialize with lora_config or lora_weights_path parameter."
)
return self.tts_model.load_lora_weights(lora_weights_path)
def unload_lora(self):
"""Unload LoRA by resetting all LoRA weights to initial state (effectively disabling LoRA)."""
self.tts_model.reset_lora_weights()
2026-03-31 11:50:37 +08:00
2025-12-05 22:22:13 +08:00
def set_lora_enabled(self, enabled: bool):
"""Enable or disable LoRA layers without unloading weights.
2026-03-31 11:50:37 +08:00
2025-12-05 22:22:13 +08:00
Args:
enabled: If True, LoRA layers are active; if False, only base model is used.
"""
self.tts_model.set_lora_enabled(enabled)
2026-03-31 11:50:37 +08:00
2025-12-05 22:22:13 +08:00
def get_lora_state_dict(self) -> dict:
"""Get current LoRA parameters state dict.
2026-03-31 11:50:37 +08:00
2025-12-05 22:22:13 +08:00
Returns:
dict: State dict containing all LoRA parameters (lora_A, lora_B).
"""
return self.tts_model.get_lora_state_dict()
2026-03-31 11:50:37 +08:00
2025-12-05 22:22:13 +08:00
@property
def lora_enabled(self) -> bool:
"""Check if LoRA is currently configured."""
return self.tts_model.lora_config is not None