update voxcpm2
This commit is contained in:
@@ -112,22 +112,24 @@ def main():
|
||||
f"lora_config.json not found in {ckpt_dir}. "
|
||||
"Make sure the checkpoint was saved with the updated training script."
|
||||
)
|
||||
|
||||
|
||||
with open(lora_config_path, "r", encoding="utf-8") as f:
|
||||
lora_info = json.load(f)
|
||||
|
||||
|
||||
# Get base model path (command line arg overrides config)
|
||||
pretrained_path = args.base_model if args.base_model else lora_info.get("base_model")
|
||||
if not pretrained_path:
|
||||
raise ValueError("base_model not found in lora_config.json and --base_model not provided")
|
||||
|
||||
|
||||
# Get LoRA config
|
||||
lora_cfg_dict = lora_info.get("lora_config", {})
|
||||
lora_cfg = LoRAConfig(**lora_cfg_dict) if lora_cfg_dict else None
|
||||
|
||||
|
||||
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
|
||||
print(f" Base model: {pretrained_path}", file=sys.stderr)
|
||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
|
||||
print(
|
||||
f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr
|
||||
)
|
||||
|
||||
# 3. Load model with LoRA (no denoiser)
|
||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
|
||||
@@ -146,10 +148,10 @@ def main():
|
||||
out_path = Path(args.output)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||
|
||||
print("\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||
|
||||
# === Test 1: With LoRA ===
|
||||
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||
print("\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
prompt_wav_path=prompt_wav_path,
|
||||
@@ -162,10 +164,13 @@ def main():
|
||||
)
|
||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||
print("\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||
model.set_lora_enabled(False)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@@ -179,10 +184,13 @@ def main():
|
||||
)
|
||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 3: Re-enable LoRA ===
|
||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||
print("\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||
model.set_lora_enabled(True)
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@@ -196,10 +204,13 @@ def main():
|
||||
)
|
||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||
print("\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||
model.unload_lora()
|
||||
audio_np = model.generate(
|
||||
text=args.text,
|
||||
@@ -213,10 +224,13 @@ def main():
|
||||
)
|
||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||
print("\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||
loaded, skipped = model.load_lora(ckpt_dir)
|
||||
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
|
||||
audio_np = model.generate(
|
||||
@@ -231,9 +245,12 @@ def main():
|
||||
)
|
||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
||||
print(
|
||||
f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
print(f"\n[Done] All tests completed!", file=sys.stderr)
|
||||
print("\n[Done] All tests completed!", file=sys.stderr)
|
||||
print(f" - with_lora: {lora_output}", file=sys.stderr)
|
||||
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
|
||||
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
|
||||
|
||||
Reference in New Issue
Block a user