Files
VoxCPM/scripts/test_voxcpm_lora_infer.py
T

263 lines
8.6 KiB
Python
Raw Normal View History

2025-12-05 21:00:01 +08:00
#!/usr/bin/env python3
"""
LoRA inference test script.
Usage:
python scripts/test_voxcpm_lora_infer.py \
--lora_ckpt checkpoints/step_0002000 \
--text "Hello, this is LoRA finetuned result." \
--output lora_test.wav
With voice cloning:
python scripts/test_voxcpm_lora_infer.py \
--lora_ckpt checkpoints/step_0002000 \
--text "This is voice cloning result." \
--prompt_audio path/to/ref.wav \
--prompt_text "Reference audio transcript" \
--output lora_clone.wav
Note: The script reads base_model path and lora_config from lora_config.json
in the checkpoint directory (saved automatically during training).
2025-12-05 21:00:01 +08:00
"""
import argparse
import json
import sys
2025-12-05 21:00:01 +08:00
from pathlib import Path
import soundfile as sf
2025-12-05 22:22:13 +08:00
from voxcpm.core import VoxCPM
2025-12-05 21:00:01 +08:00
from voxcpm.model.voxcpm import LoRAConfig
def parse_args():
parser = argparse.ArgumentParser("VoxCPM LoRA inference test")
parser.add_argument(
"--lora_ckpt",
2025-12-05 21:00:01 +08:00
type=str,
required=True,
help="LoRA checkpoint directory (contains lora_weights.safetensors and lora_config.json)",
2025-12-05 21:00:01 +08:00
)
parser.add_argument(
"--base_model",
2025-12-05 21:00:01 +08:00
type=str,
default="",
help="Optional: override base model path (default: read from lora_config.json)",
2025-12-05 21:00:01 +08:00
)
parser.add_argument(
"--text",
type=str,
required=True,
help="Target text to synthesize",
)
parser.add_argument(
"--prompt_audio",
type=str,
default="",
help="Optional: reference audio path for voice cloning",
)
parser.add_argument(
"--prompt_text",
type=str,
default="",
help="Optional: transcript of reference audio",
)
parser.add_argument(
"--output",
type=str,
default="lora_test.wav",
help="Output wav file path",
)
parser.add_argument(
"--cfg_value",
type=float,
default=2.0,
help="CFG scale (default: 2.0)",
)
parser.add_argument(
"--inference_timesteps",
type=int,
default=10,
help="Diffusion inference steps (default: 10)",
)
parser.add_argument(
"--max_len",
type=int,
default=600,
help="Max generation steps",
)
2025-12-05 22:22:13 +08:00
parser.add_argument(
"--normalize",
action="store_true",
help="Enable text normalization",
)
2025-12-05 21:00:01 +08:00
return parser.parse_args()
def main():
args = parse_args()
# 1. Check LoRA checkpoint directory
ckpt_dir = Path(args.lora_ckpt)
if not ckpt_dir.exists():
2025-12-05 22:22:13 +08:00
raise FileNotFoundError(f"LoRA checkpoint not found: {ckpt_dir}")
# 2. Load lora_config.json from checkpoint
lora_config_path = ckpt_dir / "lora_config.json"
if not lora_config_path.exists():
raise FileNotFoundError(
f"lora_config.json not found in {ckpt_dir}. "
"Make sure the checkpoint was saved with the updated training script."
)
2026-03-31 11:50:37 +08:00
with open(lora_config_path, "r", encoding="utf-8") as f:
lora_info = json.load(f)
2026-03-31 11:50:37 +08:00
# Get base model path (command line arg overrides config)
pretrained_path = args.base_model if args.base_model else lora_info.get("base_model")
if not pretrained_path:
raise ValueError("base_model not found in lora_config.json and --base_model not provided")
2026-03-31 11:50:37 +08:00
# Get LoRA config
lora_cfg_dict = lora_info.get("lora_config", {})
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
2026-03-31 11:50:37 +08:00
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
print(f" Base model: {pretrained_path}", file=sys.stderr)
2026-03-31 11:50:37 +08:00
print(
f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr
)
2025-12-05 22:22:13 +08:00
# 3. Load model with LoRA (no denoiser)
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
print(f" LoRA weights: {ckpt_dir}", file=sys.stderr)
2025-12-05 22:22:13 +08:00
model = VoxCPM.from_pretrained(
hf_model_id=pretrained_path,
load_denoiser=False,
optimize=True,
2025-12-05 21:00:01 +08:00
lora_config=lora_cfg,
lora_weights_path=str(ckpt_dir),
2025-12-05 21:00:01 +08:00
)
# 4. Synthesize audio
2025-12-05 22:22:13 +08:00
prompt_wav_path = args.prompt_audio if args.prompt_audio else None
prompt_text = args.prompt_text if args.prompt_text else None
2025-12-05 21:00:01 +08:00
out_path = Path(args.output)
out_path.parent.mkdir(parents=True, exist_ok=True)
2026-03-31 11:50:37 +08:00
print("\n[2/2] Starting synthesis tests...", file=sys.stderr)
2025-12-05 21:00:01 +08:00
# === Test 1: With LoRA ===
2026-03-31 11:50:37 +08:00
print("\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
2025-12-05 22:22:13 +08:00
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
2025-12-05 22:22:13 +08:00
normalize=args.normalize,
denoise=False,
)
2025-12-05 21:00:01 +08:00
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
2025-12-05 22:22:13 +08:00
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
2026-03-31 11:50:37 +08:00
print(
f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
2025-12-05 21:00:01 +08:00
# === Test 2: Disable LoRA (via set_lora_enabled) ===
2026-03-31 11:50:37 +08:00
print("\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
2025-12-05 21:00:01 +08:00
model.set_lora_enabled(False)
2025-12-05 22:22:13 +08:00
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
2025-12-05 22:22:13 +08:00
normalize=args.normalize,
denoise=False,
)
2025-12-05 21:00:01 +08:00
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
2025-12-05 22:22:13 +08:00
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
2026-03-31 11:50:37 +08:00
print(
f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
2025-12-05 21:00:01 +08:00
# === Test 3: Re-enable LoRA ===
2026-03-31 11:50:37 +08:00
print("\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
2025-12-05 21:00:01 +08:00
model.set_lora_enabled(True)
2025-12-05 22:22:13 +08:00
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
2025-12-05 22:22:13 +08:00
normalize=args.normalize,
denoise=False,
)
2025-12-05 21:00:01 +08:00
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
2025-12-05 22:22:13 +08:00
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
2026-03-31 11:50:37 +08:00
print(
f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
2025-12-05 21:00:01 +08:00
# === Test 4: Unload LoRA (reset_lora_weights) ===
2026-03-31 11:50:37 +08:00
print("\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
2025-12-05 22:22:13 +08:00
model.unload_lora()
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
2025-12-05 22:22:13 +08:00
normalize=args.normalize,
denoise=False,
)
2025-12-05 21:00:01 +08:00
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
2025-12-05 22:22:13 +08:00
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
2026-03-31 11:50:37 +08:00
print(
f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
2025-12-05 21:00:01 +08:00
2025-12-05 22:22:13 +08:00
# === Test 5: Hot-reload LoRA (load_lora) ===
2026-03-31 11:50:37 +08:00
print("\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
loaded, skipped = model.load_lora(ckpt_dir)
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
2025-12-05 22:22:13 +08:00
audio_np = model.generate(
text=args.text,
prompt_wav_path=prompt_wav_path,
prompt_text=prompt_text,
cfg_value=args.cfg_value,
inference_timesteps=args.inference_timesteps,
max_len=args.max_len,
2025-12-05 22:22:13 +08:00
normalize=args.normalize,
denoise=False,
)
2025-12-05 21:00:01 +08:00
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
2025-12-05 22:22:13 +08:00
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
2026-03-31 11:50:37 +08:00
print(
f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
file=sys.stderr,
)
2026-03-31 11:50:37 +08:00
print("\n[Done] All tests completed!", file=sys.stderr)
print(f" - with_lora: {lora_output}", file=sys.stderr)
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
print(f" - lora_reset: {reset_output}", file=sys.stderr)
print(f" - lora_reloaded: {reload_output}", file=sys.stderr)
2025-12-05 21:00:01 +08:00
if __name__ == "__main__":
main()