Modify lora inference api
This commit is contained in:
@@ -52,6 +52,22 @@ def load_model(args) -> VoxCPM:
|
||||
"ZIPENHANCER_MODEL_PATH", None
|
||||
)
|
||||
|
||||
# Build LoRA config if lora_path is provided
|
||||
lora_config = None
|
||||
lora_weights_path = getattr(args, "lora_path", None)
|
||||
if lora_weights_path:
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
lora_config = LoRAConfig(
|
||||
enable_lm=getattr(args, "lora_enable_lm", True),
|
||||
enable_dit=getattr(args, "lora_enable_dit", True),
|
||||
enable_proj=getattr(args, "lora_enable_proj", False),
|
||||
r=getattr(args, "lora_r", 32),
|
||||
alpha=getattr(args, "lora_alpha", 16),
|
||||
dropout=getattr(args, "lora_dropout", 0.0),
|
||||
)
|
||||
print(f"LoRA config: r={lora_config.r}, alpha={lora_config.alpha}, "
|
||||
f"lm={lora_config.enable_lm}, dit={lora_config.enable_dit}, proj={lora_config.enable_proj}")
|
||||
|
||||
# Load from local path if provided
|
||||
if getattr(args, "model_path", None):
|
||||
try:
|
||||
@@ -59,6 +75,8 @@ def load_model(args) -> VoxCPM:
|
||||
voxcpm_model_path=args.model_path,
|
||||
zipenhancer_model_path=zipenhancer_path,
|
||||
enable_denoiser=not getattr(args, "no_denoiser", False),
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
print("Model loaded (local).")
|
||||
return model
|
||||
@@ -74,6 +92,8 @@ def load_model(args) -> VoxCPM:
|
||||
zipenhancer_model_id=zipenhancer_path,
|
||||
cache_dir=getattr(args, "cache_dir", None),
|
||||
local_files_only=getattr(args, "local_files_only", False),
|
||||
lora_config=lora_config,
|
||||
lora_weights_path=lora_weights_path,
|
||||
)
|
||||
print("Model loaded (from_pretrained).")
|
||||
return model
|
||||
@@ -256,6 +276,15 @@ Examples:
|
||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||
parser.add_argument("--zipenhancer-path", type=str, default="iic/speech_zipenhancer_ans_multiloss_16k_base", help="ZipEnhancer model id or local path (default reads from env)")
|
||||
|
||||
# LoRA parameters
|
||||
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights (.pth file or directory containing lora_weights.ckpt)")
|
||||
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (default: 32)")
|
||||
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha scaling factor (default: 16)")
|
||||
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (default: 0.0)")
|
||||
parser.add_argument("--lora-enable-lm", action="store_true", default=True, help="Apply LoRA to LM layers (default: True)")
|
||||
parser.add_argument("--lora-enable-dit", action="store_true", default=True, help="Apply LoRA to DiT layers (default: True)")
|
||||
parser.add_argument("--lora-enable-proj", action="store_true", default=False, help="Apply LoRA to projection layers (default: False)")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user