update finetuning pipeline and runtime device handling

Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits.

Made-with: Cursor
This commit is contained in:
刘鑫
2026-04-11 11:08:50 +08:00
parent abf01b9bf3
commit e4e049624c
10 changed files with 379 additions and 47 deletions
+21 -2
View File
@@ -17,6 +17,7 @@ class VoxCPM:
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
enable_denoiser: bool = True,
optimize: bool = True,
device: str | None = None,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
):
@@ -30,6 +31,9 @@ class VoxCPM:
id or local path. If None, denoiser will not be initialized.
enable_denoiser: Whether to initialize the denoiser pipeline.
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
device: Runtime device. If set to ``None`` or ``"auto"``, VoxCPM
will choose automatically (preferring CUDA, then MPS, then CPU).
If set explicitly, that device is used or a clear error is raised.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created.
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
@@ -56,10 +60,20 @@ class VoxCPM:
arch = config.get("architecture", "voxcpm").lower()
if arch == "voxcpm2":
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
self.tts_model = VoxCPM2Model.from_local(
voxcpm_model_path,
optimize=optimize,
device=device,
lora_config=lora_config,
)
print("Loaded VoxCPM2Model", file=sys.stderr)
elif arch == "voxcpm":
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
self.tts_model = VoxCPMModel.from_local(
voxcpm_model_path,
optimize=optimize,
device=device,
lora_config=lora_config,
)
print("Loaded VoxCPMModel", file=sys.stderr)
else:
raise ValueError(f"Unsupported architecture: {arch}")
@@ -94,6 +108,7 @@ class VoxCPM:
cache_dir: str = None,
local_files_only: bool = False,
optimize: bool = True,
device: str | None = None,
lora_config: Optional[LoRAConfig] = None,
lora_weights_path: Optional[str] = None,
**kwargs,
@@ -109,6 +124,9 @@ class VoxCPM:
cache_dir: Custom cache directory for the snapshot.
local_files_only: If True, only use local files and do not attempt
to download.
device: Runtime device. Use ``None``/``"auto"`` for automatic
fallback, or an explicit value such as ``"cpu"``, ``"mps"``,
``"cuda"``, or ``"cuda:0"``.
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
provided without lora_config, a default config will be created with
enable_lm=True and enable_dit=True.
@@ -146,6 +164,7 @@ class VoxCPM:
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
enable_denoiser=load_denoiser,
optimize=optimize,
device=device,
lora_config=lora_config,
lora_weights_path=lora_weights_path,
**kwargs,