fix: handle LoRA rank mismatch during inference in lora_ft_webui

Pass the selected LoRA checkpoint to load_model() on first load so the
model initializes with the correct rank from lora_config.json instead of
always defaulting to r=32.

On subsequent LoRA hot-swaps, detect rank incompatibility and
automatically reload the model with the new checkpoint's config,
preventing tensor shape mismatch errors (fixes #283).

Made-with: Cursor
This commit is contained in:
liuxin
2026-04-28 10:52:57 +08:00
parent 86bff0fc82
commit 19b6bf7590
+25 -4
View File
@@ -281,27 +281,48 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr) print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
# 加载模型 # 加载模型
lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None
try: try:
print(f"Loading base model: {base_model_path}", file=sys.stderr) print(f"Loading base model: {base_model_path}", file=sys.stderr)
load_model(base_model_path) load_model(base_model_path, lora_to_load)
if lora_selection and lora_selection != "None": if lora_to_load:
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr) print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr)
except Exception as e: except Exception as e:
error_msg = f"Failed to load model from {base_model_path}: {str(e)}" error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
print(error_msg, file=sys.stderr) print(error_msg, file=sys.stderr)
return None, error_msg return None, error_msg
lora_just_loaded = lora_to_load
else:
lora_just_loaded = None
# Handle LoRA hot-swapping # Handle LoRA hot-swapping
assert current_model is not None, "Model must be loaded before inference" assert current_model is not None, "Model must be loaded before inference"
if lora_selection and lora_selection != "None": if lora_selection and lora_selection != "None":
full_lora_path = os.path.join("lora", lora_selection) full_lora_path = os.path.join("lora", lora_selection)
if lora_just_loaded != lora_selection:
new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path)
current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None
new_r = new_lora_config.r if new_lora_config else None
if new_r is not None and current_r is not None and new_r != current_r:
print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr)
reload_base = (
new_base_model if new_base_model and os.path.exists(new_base_model)
else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path)
)
try:
load_model(reload_base, lora_selection)
except Exception as e:
return None, f"Failed to reload model for LoRA rank change: {e}"
else:
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr) print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
try: try:
current_model.load_lora(full_lora_path) current_model.load_lora(full_lora_path)
current_model.set_lora_enabled(True)
except Exception as e: except Exception as e:
print(f"Error loading LoRA: {e}", file=sys.stderr) print(f"Error loading LoRA: {e}", file=sys.stderr)
return None, f"Error loading LoRA: {e}" return None, f"Error loading LoRA: {e}"
current_model.set_lora_enabled(True)
else: else:
print("Disabling LoRA", file=sys.stderr) print("Disabling LoRA", file=sys.stderr)
current_model.set_lora_enabled(False) current_model.set_lora_enabled(False)