Harden LoRA checkpoint loading against untrusted pickle payloads
LoRA is a first-class workflow in VoxCPM, and the project already prefers safetensors plus weights-only fallback loading for base model artifacts. The legacy LoRA .ckpt/.pth path was the remaining place that still deserialized arbitrary pickle objects, so this switches it to weights_only=True and adds focused regression coverage for both model loaders. Constraint: Must preserve compatibility with tensor-only legacy LoRA checkpoints Rejected: Remove .ckpt/.pth support entirely | too disruptive for existing users Confidence: high Scope-risk: narrow Reversibility: clean Directive: Keep LoRA artifact handling aligned with the existing safetensors-first, weights-only loading pattern Tested: python3 -m pytest -q tests/test_lora_checkpoint_loading.py tests/test_model_utils.py -q Not-tested: Full end-to-end LoRA hot-load with heavyweight model assets
This commit is contained in:
@@ -957,7 +957,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||||
state_dict = load_file(str(safetensors_file), device=device)
|
state_dict = load_file(str(safetensors_file), device=device)
|
||||||
elif ckpt_file and ckpt_file.exists():
|
elif ckpt_file and ckpt_file.exists():
|
||||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
|
||||||
state_dict = ckpt.get("state_dict", ckpt)
|
state_dict = ckpt.get("state_dict", ckpt)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||||
|
|||||||
@@ -1208,7 +1208,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||||
state_dict = load_file(str(safetensors_file), device=device)
|
state_dict = load_file(str(safetensors_file), device=device)
|
||||||
elif ckpt_file and ckpt_file.exists():
|
elif ckpt_file and ckpt_file.exists():
|
||||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
ckpt = torch.load(ckpt_file, map_location=device, weights_only=True)
|
||||||
state_dict = ckpt.get("state_dict", ckpt)
|
state_dict = ckpt.get("state_dict", ckpt)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||||
|
|||||||
@@ -0,0 +1,150 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
SRC = ROOT / "src"
|
||||||
|
|
||||||
|
|
||||||
|
def _load_module(name: str, path: Path):
|
||||||
|
spec = importlib.util.spec_from_file_location(name, path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec.loader is not None
|
||||||
|
sys.modules[name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def bootstrap_repo_modules(monkeypatch):
|
||||||
|
for name, path in [
|
||||||
|
("voxcpm", SRC / "voxcpm"),
|
||||||
|
("voxcpm.model", SRC / "voxcpm" / "model"),
|
||||||
|
("voxcpm.modules", SRC / "voxcpm" / "modules"),
|
||||||
|
]:
|
||||||
|
pkg = types.ModuleType(name)
|
||||||
|
pkg.__path__ = [str(path)]
|
||||||
|
monkeypatch.setitem(sys.modules, name, pkg)
|
||||||
|
|
||||||
|
hh = types.ModuleType("huggingface_hub")
|
||||||
|
hh.snapshot_download = lambda *a, **k: "/tmp/fake"
|
||||||
|
monkeypatch.setitem(sys.modules, "huggingface_hub", hh)
|
||||||
|
|
||||||
|
pydantic = types.ModuleType("pydantic")
|
||||||
|
|
||||||
|
class BaseModel:
|
||||||
|
@classmethod
|
||||||
|
def model_rebuild(cls):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def model_validate_json(cls, s):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def model_dump(self):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
pydantic.BaseModel = BaseModel
|
||||||
|
monkeypatch.setitem(sys.modules, "pydantic", pydantic)
|
||||||
|
|
||||||
|
torchaudio = types.ModuleType("torchaudio")
|
||||||
|
monkeypatch.setitem(sys.modules, "torchaudio", torchaudio)
|
||||||
|
|
||||||
|
librosa = types.ModuleType("librosa")
|
||||||
|
librosa.effects = types.SimpleNamespace(trim=lambda *a, **k: (None, (0, 0)))
|
||||||
|
monkeypatch.setitem(sys.modules, "librosa", librosa)
|
||||||
|
|
||||||
|
einops = types.ModuleType("einops")
|
||||||
|
einops.rearrange = lambda x, *a, **k: x
|
||||||
|
monkeypatch.setitem(sys.modules, "einops", einops)
|
||||||
|
|
||||||
|
tqdm_pkg = types.ModuleType("tqdm")
|
||||||
|
tqdm_pkg.__path__ = ["/nonexistent"]
|
||||||
|
tqdm_pkg.tqdm = lambda x, *a, **k: x
|
||||||
|
monkeypatch.setitem(sys.modules, "tqdm", tqdm_pkg)
|
||||||
|
|
||||||
|
tqdm_auto = types.ModuleType("tqdm.auto")
|
||||||
|
tqdm_auto.tqdm = lambda x, *a, **k: x
|
||||||
|
monkeypatch.setitem(sys.modules, "tqdm.auto", tqdm_auto)
|
||||||
|
|
||||||
|
transformers = types.ModuleType("transformers")
|
||||||
|
|
||||||
|
class LlamaTokenizerFast:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class PreTrainedTokenizer:
|
||||||
|
pass
|
||||||
|
|
||||||
|
transformers.LlamaTokenizerFast = LlamaTokenizerFast
|
||||||
|
transformers.PreTrainedTokenizer = PreTrainedTokenizer
|
||||||
|
monkeypatch.setitem(sys.modules, "transformers", transformers)
|
||||||
|
|
||||||
|
internal_mods = {
|
||||||
|
"voxcpm.modules.audiovae": ["AudioVAE", "AudioVAEConfig", "AudioVAEV2", "AudioVAEConfigV2"],
|
||||||
|
"voxcpm.modules.layers": ["ScalarQuantizationLayer"],
|
||||||
|
"voxcpm.modules.locdit": ["CfmConfig", "UnifiedCFM", "VoxCPMLocDiT", "VoxCPMLocDiTV2"],
|
||||||
|
"voxcpm.modules.locenc": ["VoxCPMLocEnc"],
|
||||||
|
"voxcpm.modules.minicpm4": ["MiniCPM4Config", "MiniCPMModel"],
|
||||||
|
"voxcpm.modules.layers.lora": ["apply_lora_to_named_linear_modules", "LoRALinear"],
|
||||||
|
}
|
||||||
|
for modname, names in internal_mods.items():
|
||||||
|
module = types.ModuleType(modname)
|
||||||
|
for name in names:
|
||||||
|
if name == "apply_lora_to_named_linear_modules":
|
||||||
|
setattr(module, name, lambda *a, **k: None)
|
||||||
|
else:
|
||||||
|
setattr(module, name, type(name, (), {}))
|
||||||
|
monkeypatch.setitem(sys.modules, modname, module)
|
||||||
|
|
||||||
|
_load_module("voxcpm.model.utils", SRC / "voxcpm" / "model" / "utils.py")
|
||||||
|
voxcpm = _load_module("voxcpm.model.voxcpm", SRC / "voxcpm" / "model" / "voxcpm.py")
|
||||||
|
voxcpm2 = _load_module("voxcpm.model.voxcpm2", SRC / "voxcpm" / "model" / "voxcpm2.py")
|
||||||
|
return voxcpm.VoxCPMModel, voxcpm2.VoxCPM2Model
|
||||||
|
|
||||||
|
|
||||||
|
class DummyModel:
|
||||||
|
device = "cpu"
|
||||||
|
|
||||||
|
def named_parameters(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("module_name", ["v1", "v2"])
|
||||||
|
def test_load_lora_weights_accepts_tensor_only_legacy_checkpoints(monkeypatch, tmp_path, module_name):
|
||||||
|
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
|
||||||
|
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
|
||||||
|
|
||||||
|
ckpt_path = tmp_path / "lora_weights.ckpt"
|
||||||
|
torch.save({"state_dict": {"fake": torch.zeros(1)}}, ckpt_path)
|
||||||
|
|
||||||
|
loaded, skipped = cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
|
||||||
|
|
||||||
|
assert loaded == []
|
||||||
|
assert skipped == ["fake"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("module_name", ["v1", "v2"])
|
||||||
|
def test_load_lora_weights_rejects_malicious_pickle_payloads(monkeypatch, tmp_path, module_name):
|
||||||
|
VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch)
|
||||||
|
cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model
|
||||||
|
|
||||||
|
ckpt_path = tmp_path / "lora_weights.ckpt"
|
||||||
|
marker_path = tmp_path / f"{module_name}-marker.txt"
|
||||||
|
|
||||||
|
class Exploit:
|
||||||
|
def __reduce__(self):
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
return (pathlib.Path.write_text, (marker_path, f"{module_name} executed\n"))
|
||||||
|
|
||||||
|
torch.save({"state_dict": {"fake": torch.zeros(1)}, "boom": Exploit()}, ckpt_path)
|
||||||
|
|
||||||
|
with pytest.raises(Exception, match="Weights only load failed"):
|
||||||
|
cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu")
|
||||||
|
|
||||||
|
assert not marker_path.exists()
|
||||||
Reference in New Issue
Block a user