fix: handle LoRA rank mismatch during inference in lora_ft_webui
Pass the selected LoRA checkpoint to load_model() on first load so the model initializes with the correct rank from lora_config.json instead of always defaulting to r=32. On subsequent LoRA hot-swaps, detect rank incompatibility and automatically reload the model with the new checkpoint's config, preventing tensor shape mismatch errors (fixes #283). Made-with: Cursor
This commit is contained in:
+25
-4
@@ -281,27 +281,48 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
|||||||
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr)
|
||||||
|
|
||||||
# 加载模型
|
# 加载模型
|
||||||
|
lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None
|
||||||
try:
|
try:
|
||||||
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||||
load_model(base_model_path)
|
load_model(base_model_path, lora_to_load)
|
||||||
if lora_selection and lora_selection != "None":
|
if lora_to_load:
|
||||||
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
error_msg = f"Failed to load model from {base_model_path}: {str(e)}"
|
||||||
print(error_msg, file=sys.stderr)
|
print(error_msg, file=sys.stderr)
|
||||||
return None, error_msg
|
return None, error_msg
|
||||||
|
lora_just_loaded = lora_to_load
|
||||||
|
else:
|
||||||
|
lora_just_loaded = None
|
||||||
|
|
||||||
# Handle LoRA hot-swapping
|
# Handle LoRA hot-swapping
|
||||||
assert current_model is not None, "Model must be loaded before inference"
|
assert current_model is not None, "Model must be loaded before inference"
|
||||||
if lora_selection and lora_selection != "None":
|
if lora_selection and lora_selection != "None":
|
||||||
full_lora_path = os.path.join("lora", lora_selection)
|
full_lora_path = os.path.join("lora", lora_selection)
|
||||||
|
|
||||||
|
if lora_just_loaded != lora_selection:
|
||||||
|
new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path)
|
||||||
|
current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None
|
||||||
|
new_r = new_lora_config.r if new_lora_config else None
|
||||||
|
|
||||||
|
if new_r is not None and current_r is not None and new_r != current_r:
|
||||||
|
print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr)
|
||||||
|
reload_base = (
|
||||||
|
new_base_model if new_base_model and os.path.exists(new_base_model)
|
||||||
|
else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
load_model(reload_base, lora_selection)
|
||||||
|
except Exception as e:
|
||||||
|
return None, f"Failed to reload model for LoRA rank change: {e}"
|
||||||
|
else:
|
||||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
current_model.load_lora(full_lora_path)
|
current_model.load_lora(full_lora_path)
|
||||||
current_model.set_lora_enabled(True)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
print(f"Error loading LoRA: {e}", file=sys.stderr)
|
||||||
return None, f"Error loading LoRA: {e}"
|
return None, f"Error loading LoRA: {e}"
|
||||||
|
current_model.set_lora_enabled(True)
|
||||||
else:
|
else:
|
||||||
print("Disabling LoRA", file=sys.stderr)
|
print("Disabling LoRA", file=sys.stderr)
|
||||||
current_model.set_lora_enabled(False)
|
current_model.set_lora_enabled(False)
|
||||||
|
|||||||
Reference in New Issue
Block a user