fix: handle LoRA rank mismatch during inference in lora_ft_webui
Pass the selected LoRA checkpoint to load_model() on first load so the model initializes with the correct rank from lora_config.json instead of always defaulting to r=32. On subsequent LoRA hot-swaps, detect rank incompatibility and automatically reload the model with the new checkpoint's config, preventing tensor shape mismatch errors (fixes #283). Made-with: Cursor
This commit is contained in:
+31
-10
@@ -281,27 +281,48 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
||||
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
||||
|
||||
# 加载模型
|
||||
lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None
|
||||
try:
|
||||
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||
load_model(base_model_path)
|
||||
if lora_selection and lora_selection != "None":
|
||||
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
||||
load_model(base_model_path, lora_to_load)
|
||||
if lora_to_load:
|
||||
print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
||||
print(error_msg, file=sys.stderr)
|
||||
return None, error_msg
|
||||
lora_just_loaded = lora_to_load
|
||||
else:
|
||||
lora_just_loaded = None
|
||||
|
||||
# Handle LoRA hot-swapping
|
||||
assert current_model is not None, "Model must be loaded before inference"
|
||||
if lora_selection and lora_selection != "None":
|
||||
full_lora_path = os.path.join("lora", lora_selection)
|
||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||
try:
|
||||
current_model.load_lora(full_lora_path)
|
||||
current_model.set_lora_enabled(True)
|
||||
except Exception as e:
|
||||
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
||||
return None, f"Error loading LoRA: {e}"
|
||||
|
||||
if lora_just_loaded != lora_selection:
|
||||
new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path)
|
||||
current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None
|
||||
new_r = new_lora_config.r if new_lora_config else None
|
||||
|
||||
if new_r is not None and current_r is not None and new_r != current_r:
|
||||
print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr)
|
||||
reload_base = (
|
||||
new_base_model if new_base_model and os.path.exists(new_base_model)
|
||||
else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path)
|
||||
)
|
||||
try:
|
||||
load_model(reload_base, lora_selection)
|
||||
except Exception as e:
|
||||
return None, f"Failed to reload model for LoRA rank change: {e}"
|
||||
else:
|
||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||
try:
|
||||
current_model.load_lora(full_lora_path)
|
||||
except Exception as e:
|
||||
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
||||
return None, f"Error loading LoRA: {e}"
|
||||
current_model.set_lora_enabled(True)
|
||||
else:
|
||||
print("Disabling LoRA", file=sys.stderr)
|
||||
current_model.set_lora_enabled(False)
|
||||
|
||||
Reference in New Issue
Block a user