fix VoxCPM2 training sample_rate: 48000 -> 16000 (match AudioVAE encoder)

Made-with: Cursor
This commit is contained in:
刘鑫
2026-04-07 22:57:42 +08:00
parent da700f264e
commit 46cfce0c97
4 changed files with 53 additions and 2 deletions
+2 -1
View File
@@ -1,7 +1,8 @@
pretrained_path: /path/to/VoxCPM2/ pretrained_path: /path/to/VoxCPM2/
train_manifest: /path/to/train.jsonl train_manifest: /path/to/train.jsonl
val_manifest: null val_manifest: null
sample_rate: 48000 sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training
batch_size: 2 batch_size: 2
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16 grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
num_workers: 8 num_workers: 8
+2 -1
View File
@@ -1,7 +1,8 @@
pretrained_path: /path/to/VoxCPM2/ pretrained_path: /path/to/VoxCPM2/
train_manifest: /path/to/train.jsonl train_manifest: /path/to/train.jsonl
val_manifest: null val_manifest: null
sample_rate: 48000 sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training
batch_size: 2 batch_size: 2
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16 grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
num_workers: 8 num_workers: 8
+41
View File
@@ -99,6 +99,24 @@ def get_timestamp_str():
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S") return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
def detect_sample_rate(pretrained_path: str) -> Optional[int]:
"""Read audio_vae_config.sample_rate from the model's config.json.
This is the AudioVAE *encoder* input rate, which is the correct rate for
resampling training data. Returns None when detection fails.
"""
config_file = os.path.join(pretrained_path, "config.json")
if not os.path.isfile(config_file):
return None
try:
with open(config_file, "r", encoding="utf-8") as f:
cfg = json.load(f)
return int(cfg["audio_vae_config"]["sample_rate"])
except (KeyError, ValueError, json.JSONDecodeError) as e:
print(f"Warning: failed to detect sample_rate from {config_file}: {e}", file=sys.stderr)
return None
def get_or_load_asr_model(): def get_or_load_asr_model():
global asr_model global asr_model
if asr_model is None: if asr_model is None:
@@ -377,6 +395,16 @@ def start_training(
os.makedirs(checkpoints_dir, exist_ok=True) os.makedirs(checkpoints_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True) os.makedirs(logs_dir, exist_ok=True)
# Auto-detect sample_rate from model config.json to prevent mismatch
detected_sr = detect_sample_rate(pretrained_path)
if detected_sr is not None:
if int(sample_rate) != detected_sr:
training_log += (
f"[Auto-fix] sample_rate changed from {int(sample_rate)} to {detected_sr} "
f"(read from {pretrained_path}/config.json audio_vae_config.sample_rate)\n"
)
sample_rate = detected_sr
# Create config dictionary # Create config dictionary
# Resolve max_steps default # Resolve max_steps default
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters) resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
@@ -929,6 +957,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
show_label=False, show_label=False,
) )
def on_pretrained_path_change(path):
"""Auto-detect sample_rate when pretrained model path changes."""
sr = detect_sample_rate(path)
if sr is not None:
return gr.update(value=sr)
return gr.update()
train_pretrained_path.change(
on_pretrained_path_change,
inputs=[train_pretrained_path],
outputs=[sample_rate],
)
start_btn.click( start_btn.click(
start_training, start_training,
inputs=[ inputs=[
+8
View File
@@ -46,6 +46,7 @@ def train(
train_manifest: str, train_manifest: str,
val_manifest: str = "", val_manifest: str = "",
sample_rate: int = 16_000, sample_rate: int = 16_000,
out_sample_rate: int = 0, # accepted from YAML for documentation; not used in training
batch_size: int = 1, batch_size: int = 1,
grad_accum_steps: int = 1, grad_accum_steps: int = 1,
num_workers: int = 2, num_workers: int = 2,
@@ -68,6 +69,7 @@ def train(
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
): ):
_ = config_path _ = config_path
_ = out_sample_rate
# Validate distribution options # Validate distribution options
if lora is not None and distribute and not hf_model_id: if lora is not None and distribute and not hf_model_id:
@@ -98,6 +100,12 @@ def train(
) )
tokenizer = base_model.text_tokenizer tokenizer = base_model.text_tokenizer
expected_sr = base_model.audio_vae.sample_rate
assert sample_rate == expected_sr, (
f"sample_rate mismatch: config says {sample_rate}, but the AudioVAE encoder expects {expected_sr}. "
f"Please set sample_rate: {expected_sr} in your training config. "
)
train_ds, val_ds = load_audio_text_datasets( train_ds, val_ds = load_audio_text_datasets(
train_manifest=train_manifest, train_manifest=train_manifest,
val_manifest=val_manifest, val_manifest=val_manifest,