diff --git a/lora_ft_webui.py b/lora_ft_webui.py index 439261b..e4a6822 100644 --- a/lora_ft_webui.py +++ b/lora_ft_webui.py @@ -281,27 +281,48 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step print(f"Warning: Failed to read base_model from LoRA config: {e}", file=sys.stderr) # 加载模型 + lora_to_load = lora_selection if lora_selection and lora_selection != "None" else None try: print(f"Loading base model: {base_model_path}", file=sys.stderr) - load_model(base_model_path) - if lora_selection and lora_selection != "None": - print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr) + load_model(base_model_path, lora_to_load) + if lora_to_load: + print(f"Model loaded with LoRA: {lora_selection}", file=sys.stderr) except Exception as e: error_msg = f"Failed to load model from {base_model_path}: {str(e)}" print(error_msg, file=sys.stderr) return None, error_msg + lora_just_loaded = lora_to_load + else: + lora_just_loaded = None # Handle LoRA hot-swapping assert current_model is not None, "Model must be loaded before inference" if lora_selection and lora_selection != "None": full_lora_path = os.path.join("lora", lora_selection) - print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr) - try: - current_model.load_lora(full_lora_path) - current_model.set_lora_enabled(True) - except Exception as e: - print(f"Error loading LoRA: {e}", file=sys.stderr) - return None, f"Error loading LoRA: {e}" + + if lora_just_loaded != lora_selection: + new_lora_config, new_base_model = load_lora_config_from_checkpoint(full_lora_path) + current_r = current_model.tts_model.lora_config.r if current_model.tts_model.lora_config else None + new_r = new_lora_config.r if new_lora_config else None + + if new_r is not None and current_r is not None and new_r != current_r: + print(f"LoRA rank mismatch (model r={current_r}, checkpoint r={new_r}), reloading...", file=sys.stderr) + reload_base = ( + new_base_model if new_base_model and os.path.exists(new_base_model) + else (pretrained_path if pretrained_path and pretrained_path.strip() else default_pretrained_path) + ) + try: + load_model(reload_base, lora_selection) + except Exception as e: + return None, f"Failed to reload model for LoRA rank change: {e}" + else: + print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr) + try: + current_model.load_lora(full_lora_path) + except Exception as e: + print(f"Error loading LoRA: {e}", file=sys.stderr) + return None, f"Error loading LoRA: {e}" + current_model.set_lora_enabled(True) else: print("Disabling LoRA", file=sys.stderr) current_model.set_lora_enabled(False)