fix VoxCPM2 training sample_rate: 48000 -> 16000 (match AudioVAE encoder)
Made-with: Cursor
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
pretrained_path: /path/to/VoxCPM2/
|
||||
train_manifest: /path/to/train.jsonl
|
||||
val_manifest: null
|
||||
sample_rate: 48000
|
||||
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
||||
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training
|
||||
batch_size: 2
|
||||
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
||||
num_workers: 8
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
pretrained_path: /path/to/VoxCPM2/
|
||||
train_manifest: /path/to/train.jsonl
|
||||
val_manifest: null
|
||||
sample_rate: 48000
|
||||
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
||||
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training
|
||||
batch_size: 2
|
||||
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
||||
num_workers: 8
|
||||
|
||||
@@ -99,6 +99,24 @@ def get_timestamp_str():
|
||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
|
||||
def detect_sample_rate(pretrained_path: str) -> Optional[int]:
|
||||
"""Read audio_vae_config.sample_rate from the model's config.json.
|
||||
|
||||
This is the AudioVAE *encoder* input rate, which is the correct rate for
|
||||
resampling training data. Returns None when detection fails.
|
||||
"""
|
||||
config_file = os.path.join(pretrained_path, "config.json")
|
||||
if not os.path.isfile(config_file):
|
||||
return None
|
||||
try:
|
||||
with open(config_file, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
return int(cfg["audio_vae_config"]["sample_rate"])
|
||||
except (KeyError, ValueError, json.JSONDecodeError) as e:
|
||||
print(f"Warning: failed to detect sample_rate from {config_file}: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def get_or_load_asr_model():
|
||||
global asr_model
|
||||
if asr_model is None:
|
||||
@@ -377,6 +395,16 @@ def start_training(
|
||||
os.makedirs(checkpoints_dir, exist_ok=True)
|
||||
os.makedirs(logs_dir, exist_ok=True)
|
||||
|
||||
# Auto-detect sample_rate from model config.json to prevent mismatch
|
||||
detected_sr = detect_sample_rate(pretrained_path)
|
||||
if detected_sr is not None:
|
||||
if int(sample_rate) != detected_sr:
|
||||
training_log += (
|
||||
f"[Auto-fix] sample_rate changed from {int(sample_rate)} to {detected_sr} "
|
||||
f"(read from {pretrained_path}/config.json audio_vae_config.sample_rate)\n"
|
||||
)
|
||||
sample_rate = detected_sr
|
||||
|
||||
# Create config dictionary
|
||||
# Resolve max_steps default
|
||||
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
||||
@@ -929,6 +957,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
||||
show_label=False,
|
||||
)
|
||||
|
||||
def on_pretrained_path_change(path):
|
||||
"""Auto-detect sample_rate when pretrained model path changes."""
|
||||
sr = detect_sample_rate(path)
|
||||
if sr is not None:
|
||||
return gr.update(value=sr)
|
||||
return gr.update()
|
||||
|
||||
train_pretrained_path.change(
|
||||
on_pretrained_path_change,
|
||||
inputs=[train_pretrained_path],
|
||||
outputs=[sample_rate],
|
||||
)
|
||||
|
||||
start_btn.click(
|
||||
start_training,
|
||||
inputs=[
|
||||
|
||||
@@ -46,6 +46,7 @@ def train(
|
||||
train_manifest: str,
|
||||
val_manifest: str = "",
|
||||
sample_rate: int = 16_000,
|
||||
out_sample_rate: int = 0, # accepted from YAML for documentation; not used in training
|
||||
batch_size: int = 1,
|
||||
grad_accum_steps: int = 1,
|
||||
num_workers: int = 2,
|
||||
@@ -68,6 +69,7 @@ def train(
|
||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||
):
|
||||
_ = config_path
|
||||
_ = out_sample_rate
|
||||
|
||||
# Validate distribution options
|
||||
if lora is not None and distribute and not hf_model_id:
|
||||
@@ -98,6 +100,12 @@ def train(
|
||||
)
|
||||
tokenizer = base_model.text_tokenizer
|
||||
|
||||
expected_sr = base_model.audio_vae.sample_rate
|
||||
assert sample_rate == expected_sr, (
|
||||
f"sample_rate mismatch: config says {sample_rate}, but the AudioVAE encoder expects {expected_sr}. "
|
||||
f"Please set sample_rate: {expected_sr} in your training config. "
|
||||
)
|
||||
|
||||
train_ds, val_ds = load_audio_text_datasets(
|
||||
train_manifest=train_manifest,
|
||||
val_manifest=val_manifest,
|
||||
|
||||
Reference in New Issue
Block a user