From ec2acec8a19457e82125b462e3df06520f3dd372 Mon Sep 17 00:00:00 2001 From: JunghwanNA <70629228+shaun0927@users.noreply.github.com> Date: Sat, 18 Apr 2026 00:31:28 +0900 Subject: [PATCH] Harden LoRA checkpoint loading against untrusted pickle payloads LoRA is a first-class workflow in VoxCPM, and the project already prefers safetensors plus weights-only fallback loading for base model artifacts. The legacy LoRA .ckpt/.pth path was the remaining place that still deserialized arbitrary pickle objects, so this switches it to weights_only=True and adds focused regression coverage for both model loaders. Constraint: Must preserve compatibility with tensor-only legacy LoRA checkpoints Rejected: Remove .ckpt/.pth support entirely | too disruptive for existing users Confidence: high Scope-risk: narrow Reversibility: clean Directive: Keep LoRA artifact handling aligned with the existing safetensors-first, weights-only loading pattern Tested: python3 -m pytest -q tests/test_lora_checkpoint_loading.py tests/test_model_utils.py -q Not-tested: Full end-to-end LoRA hot-load with heavyweight model assets --- src/voxcpm/model/voxcpm.py | 2 +- src/voxcpm/model/voxcpm2.py | 2 +- tests/test_lora_checkpoint_loading.py | 150 ++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 2 deletions(-) create mode 100644 tests/test_lora_checkpoint_loading.py diff --git a/src/voxcpm/model/voxcpm.py b/src/voxcpm/model/voxcpm.py index e1c50e9..c08e1e8 100644 --- a/src/voxcpm/model/voxcpm.py +++ b/src/voxcpm/model/voxcpm.py @@ -957,7 +957,7 @@ class VoxCPMModel(nn.Module): if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE: state_dict = load_file(str(safetensors_file), device=device) elif ckpt_file and ckpt_file.exists(): - ckpt = torch.load(ckpt_file, map_location=device, weights_only=False) + ckpt = torch.load(ckpt_file, map_location=device, weights_only=True) state_dict = ckpt.get("state_dict", ckpt) else: raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}") diff --git a/src/voxcpm/model/voxcpm2.py b/src/voxcpm/model/voxcpm2.py index 7909647..8692ba0 100644 --- a/src/voxcpm/model/voxcpm2.py +++ b/src/voxcpm/model/voxcpm2.py @@ -1208,7 +1208,7 @@ class VoxCPM2Model(nn.Module): if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE: state_dict = load_file(str(safetensors_file), device=device) elif ckpt_file and ckpt_file.exists(): - ckpt = torch.load(ckpt_file, map_location=device, weights_only=False) + ckpt = torch.load(ckpt_file, map_location=device, weights_only=True) state_dict = ckpt.get("state_dict", ckpt) else: raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}") diff --git a/tests/test_lora_checkpoint_loading.py b/tests/test_lora_checkpoint_loading.py new file mode 100644 index 0000000..3a1a8a1 --- /dev/null +++ b/tests/test_lora_checkpoint_loading.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import importlib.util +import sys +import types +from pathlib import Path + +import pytest +import torch + +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "src" + + +def _load_module(name: str, path: Path): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + assert spec.loader is not None + sys.modules[name] = module + spec.loader.exec_module(module) + return module + + +def bootstrap_repo_modules(monkeypatch): + for name, path in [ + ("voxcpm", SRC / "voxcpm"), + ("voxcpm.model", SRC / "voxcpm" / "model"), + ("voxcpm.modules", SRC / "voxcpm" / "modules"), + ]: + pkg = types.ModuleType(name) + pkg.__path__ = [str(path)] + monkeypatch.setitem(sys.modules, name, pkg) + + hh = types.ModuleType("huggingface_hub") + hh.snapshot_download = lambda *a, **k: "/tmp/fake" + monkeypatch.setitem(sys.modules, "huggingface_hub", hh) + + pydantic = types.ModuleType("pydantic") + + class BaseModel: + @classmethod + def model_rebuild(cls): + return None + + @classmethod + def model_validate_json(cls, s): + return cls() + + def model_dump(self): + return {} + + pydantic.BaseModel = BaseModel + monkeypatch.setitem(sys.modules, "pydantic", pydantic) + + torchaudio = types.ModuleType("torchaudio") + monkeypatch.setitem(sys.modules, "torchaudio", torchaudio) + + librosa = types.ModuleType("librosa") + librosa.effects = types.SimpleNamespace(trim=lambda *a, **k: (None, (0, 0))) + monkeypatch.setitem(sys.modules, "librosa", librosa) + + einops = types.ModuleType("einops") + einops.rearrange = lambda x, *a, **k: x + monkeypatch.setitem(sys.modules, "einops", einops) + + tqdm_pkg = types.ModuleType("tqdm") + tqdm_pkg.__path__ = ["/nonexistent"] + tqdm_pkg.tqdm = lambda x, *a, **k: x + monkeypatch.setitem(sys.modules, "tqdm", tqdm_pkg) + + tqdm_auto = types.ModuleType("tqdm.auto") + tqdm_auto.tqdm = lambda x, *a, **k: x + monkeypatch.setitem(sys.modules, "tqdm.auto", tqdm_auto) + + transformers = types.ModuleType("transformers") + + class LlamaTokenizerFast: + pass + + class PreTrainedTokenizer: + pass + + transformers.LlamaTokenizerFast = LlamaTokenizerFast + transformers.PreTrainedTokenizer = PreTrainedTokenizer + monkeypatch.setitem(sys.modules, "transformers", transformers) + + internal_mods = { + "voxcpm.modules.audiovae": ["AudioVAE", "AudioVAEConfig", "AudioVAEV2", "AudioVAEConfigV2"], + "voxcpm.modules.layers": ["ScalarQuantizationLayer"], + "voxcpm.modules.locdit": ["CfmConfig", "UnifiedCFM", "VoxCPMLocDiT", "VoxCPMLocDiTV2"], + "voxcpm.modules.locenc": ["VoxCPMLocEnc"], + "voxcpm.modules.minicpm4": ["MiniCPM4Config", "MiniCPMModel"], + "voxcpm.modules.layers.lora": ["apply_lora_to_named_linear_modules", "LoRALinear"], + } + for modname, names in internal_mods.items(): + module = types.ModuleType(modname) + for name in names: + if name == "apply_lora_to_named_linear_modules": + setattr(module, name, lambda *a, **k: None) + else: + setattr(module, name, type(name, (), {})) + monkeypatch.setitem(sys.modules, modname, module) + + _load_module("voxcpm.model.utils", SRC / "voxcpm" / "model" / "utils.py") + voxcpm = _load_module("voxcpm.model.voxcpm", SRC / "voxcpm" / "model" / "voxcpm.py") + voxcpm2 = _load_module("voxcpm.model.voxcpm2", SRC / "voxcpm" / "model" / "voxcpm2.py") + return voxcpm.VoxCPMModel, voxcpm2.VoxCPM2Model + + +class DummyModel: + device = "cpu" + + def named_parameters(self): + return [] + + +@pytest.mark.parametrize("module_name", ["v1", "v2"]) +def test_load_lora_weights_accepts_tensor_only_legacy_checkpoints(monkeypatch, tmp_path, module_name): + VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch) + cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model + + ckpt_path = tmp_path / "lora_weights.ckpt" + torch.save({"state_dict": {"fake": torch.zeros(1)}}, ckpt_path) + + loaded, skipped = cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu") + + assert loaded == [] + assert skipped == ["fake"] + + +@pytest.mark.parametrize("module_name", ["v1", "v2"]) +def test_load_lora_weights_rejects_malicious_pickle_payloads(monkeypatch, tmp_path, module_name): + VoxCPMModel, VoxCPM2Model = bootstrap_repo_modules(monkeypatch) + cls = VoxCPMModel if module_name == "v1" else VoxCPM2Model + + ckpt_path = tmp_path / "lora_weights.ckpt" + marker_path = tmp_path / f"{module_name}-marker.txt" + + class Exploit: + def __reduce__(self): + import pathlib + + return (pathlib.Path.write_text, (marker_path, f"{module_name} executed\n")) + + torch.save({"state_dict": {"fake": torch.zeros(1)}, "boom": Exploit()}, ckpt_path) + + with pytest.raises(Exception, match="Weights only load failed"): + cls.load_lora_weights(DummyModel(), str(ckpt_path), device="cpu") + + assert not marker_path.exists()