update finetuning pipeline and runtime device handling
Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits. Made-with: Cursor
This commit is contained in:
@@ -44,7 +44,7 @@ from ..modules.layers.lora import apply_lora_to_named_linear_modules
|
||||
from ..modules.locdit import CfmConfig, UnifiedCFM, VoxCPMLocDiT
|
||||
from ..modules.locenc import VoxCPMLocEnc
|
||||
from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||
from .utils import get_dtype, mask_multichar_chinese_tokens, resolve_runtime_device
|
||||
|
||||
|
||||
class VoxCPMEncoderConfig(BaseModel):
|
||||
@@ -109,18 +109,15 @@ class VoxCPMModel(nn.Module):
|
||||
tokenizer: LlamaTokenizerFast,
|
||||
audio_vae: AudioVAE,
|
||||
lora_config: LoRAConfig = None,
|
||||
device: str | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
self.feat_dim = config.feat_dim
|
||||
self.patch_size = config.patch_size
|
||||
self.device = config.device
|
||||
if not torch.cuda.is_available():
|
||||
if torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
self.device = resolve_runtime_device(device, config.device)
|
||||
self.config.device = self.device
|
||||
print(f"Running on device: {self.device}, dtype: {self.config.dtype}", file=sys.stderr)
|
||||
|
||||
# Text-Semantic LM
|
||||
@@ -847,7 +844,14 @@ class VoxCPMModel(nn.Module):
|
||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||
|
||||
@classmethod
|
||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||
def from_local(
|
||||
cls,
|
||||
path: str,
|
||||
optimize: bool = True,
|
||||
training: bool = False,
|
||||
device: str | None = None,
|
||||
lora_config: LoRAConfig = None,
|
||||
):
|
||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||
@@ -870,7 +874,7 @@ class VoxCPMModel(nn.Module):
|
||||
raise FileNotFoundError(
|
||||
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||
)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||
model = cls(config, tokenizer, audio_vae, lora_config, device=device)
|
||||
if not training:
|
||||
lm_dtype = get_dtype(model.config.dtype)
|
||||
model = model.to(lm_dtype)
|
||||
|
||||
Reference in New Issue
Block a user