update finetuning pipeline and runtime device handling
Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits. Made-with: Cursor
This commit is contained in:
+21
-2
@@ -17,6 +17,7 @@ class VoxCPM:
|
||||
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||
enable_denoiser: bool = True,
|
||||
optimize: bool = True,
|
||||
device: str | None = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
):
|
||||
@@ -30,6 +31,9 @@ class VoxCPM:
|
||||
id or local path. If None, denoiser will not be initialized.
|
||||
enable_denoiser: Whether to initialize the denoiser pipeline.
|
||||
optimize: Whether to optimize the model with torch.compile. True by default, but can be disabled for debugging.
|
||||
device: Runtime device. If set to ``None`` or ``"auto"``, VoxCPM
|
||||
will choose automatically (preferring CUDA, then MPS, then CPU).
|
||||
If set explicitly, that device is used or a clear error is raised.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created.
|
||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||
@@ -56,10 +60,20 @@ class VoxCPM:
|
||||
arch = config.get("architecture", "voxcpm").lower()
|
||||
|
||||
if arch == "voxcpm2":
|
||||
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
self.tts_model = VoxCPM2Model.from_local(
|
||||
voxcpm_model_path,
|
||||
optimize=optimize,
|
||||
device=device,
|
||||
lora_config=lora_config,
|
||||
)
|
||||
print("Loaded VoxCPM2Model", file=sys.stderr)
|
||||
elif arch == "voxcpm":
|
||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||
self.tts_model = VoxCPMModel.from_local(
|
||||
voxcpm_model_path,
|
||||
optimize=optimize,
|
||||
device=device,
|
||||
lora_config=lora_config,
|
||||
)
|
||||
print("Loaded VoxCPMModel", file=sys.stderr)
|
||||
else:
|
||||
raise ValueError(f"Unsupported architecture: {arch}")
|
||||
@@ -94,6 +108,7 @@ class VoxCPM:
|
||||
cache_dir: str = None,
|
||||
local_files_only: bool = False,
|
||||
optimize: bool = True,
|
||||
device: str | None = None,
|
||||
lora_config: Optional[LoRAConfig] = None,
|
||||
lora_weights_path: Optional[str] = None,
|
||||
**kwargs,
|
||||
@@ -109,6 +124,9 @@ class VoxCPM:
|
||||
cache_dir: Custom cache directory for the snapshot.
|
||||
local_files_only: If True, only use local files and do not attempt
|
||||
to download.
|
||||
device: Runtime device. Use ``None``/``"auto"`` for automatic
|
||||
fallback, or an explicit value such as ``"cpu"``, ``"mps"``,
|
||||
``"cuda"``, or ``"cuda:0"``.
|
||||
lora_config: LoRA configuration for fine-tuning. If lora_weights_path is
|
||||
provided without lora_config, a default config will be created with
|
||||
enable_lm=True and enable_dit=True.
|
||||
@@ -146,6 +164,7 @@ class VoxCPM:
|
||||
zipenhancer_model_path=zipenhancer_model_id if load_denoiser else None,
|
||||
enable_denoiser=load_denoiser,
|
||||
optimize=optimize,
|
||||
device=device,
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
**kwargs,
|
||||
|
||||
Reference in New Issue
Block a user