update finetuning pipeline and runtime device handling

Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits.

Made-with: Cursor
This commit is contained in:
刘鑫
2026-04-11 11:08:50 +08:00
parent abf01b9bf3
commit e4e049624c
10 changed files with 379 additions and 47 deletions
+13 -9
View File
@@ -44,7 +44,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
from ..modules.locenc import VoxCPMLocEnc
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
from .utils import get_dtype, mask_multichar_chinese_tokens
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
class VoxCPMEncoderConfig(BaseModel):
@@ -109,18 +109,15 @@ class VoxCPMModel(nn.Module):
tokenizer: LlamaTokenizerFast,
audio_vae: AudioVAE,
lora_config: LoRAConfig = None,
device: str | None = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.feat_dim = config.feat_dim
self.patch_size = config.patch_size
self.device = config.device
if not torch.cuda.is_available():
if torch.backends.mps.is_available():
self.device = "mps"
else:
self.device = "cpu"
self.device = resolve_runtime_device(device, config.device)
self.config.device = self.device
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
# Text-Semantic LM
@@ -847,7 +844,14 @@ class VoxCPMModel(nn.Module):
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
@classmethod
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
def from_local(
cls,
path: str,
optimize: bool = True,
training: bool = False,
device: str | None = None,
lora_config: LoRAConfig = None,
):
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
tokenizer = LlamaTokenizerFast.from_pretrained(path)
audio_vae_config = getattr(config, "audio_vae_config", None)
@@ -870,7 +874,7 @@ class VoxCPMModel(nn.Module):
raise FileNotFoundError(
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
)
model = cls(config, tokenizer, audio_vae, lora_config)
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
if not training:
lm_dtype = get_dtype(model.config.dtype)
model = model.to(lm_dtype)