fix VoxCPM2 training sample_rate: 48000 -> 16000 (match AudioVAE encoder)
Made-with: Cursor
This commit is contained in:
@@ -1,7 +1,8 @@
|
|||||||
pretrained_path: /path/to/VoxCPM2/
|
pretrained_path: /path/to/VoxCPM2/
|
||||||
train_manifest: /path/to/train.jsonl
|
train_manifest: /path/to/train.jsonl
|
||||||
val_manifest: null
|
val_manifest: null
|
||||||
sample_rate: 48000
|
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
||||||
|
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
pretrained_path: /path/to/VoxCPM2/
|
pretrained_path: /path/to/VoxCPM2/
|
||||||
train_manifest: /path/to/train.jsonl
|
train_manifest: /path/to/train.jsonl
|
||||||
val_manifest: null
|
val_manifest: null
|
||||||
sample_rate: 48000
|
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
||||||
|
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
|
|||||||
@@ -99,6 +99,24 @@ def get_timestamp_str():
|
|||||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
|
||||||
|
def detect_sample_rate(pretrained_path: str) -> Optional[int]:
|
||||||
|
"""Read audio_vae_config.sample_rate from the model's config.json.
|
||||||
|
|
||||||
|
This is the AudioVAE *encoder* input rate, which is the correct rate for
|
||||||
|
resampling training data. Returns None when detection fails.
|
||||||
|
"""
|
||||||
|
config_file = os.path.join(pretrained_path, "config.json")
|
||||||
|
if not os.path.isfile(config_file):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
return int(cfg["audio_vae_config"]["sample_rate"])
|
||||||
|
except (KeyError, ValueError, json.JSONDecodeError) as e:
|
||||||
|
print(f"Warning: failed to detect sample_rate from {config_file}: {e}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_or_load_asr_model():
|
def get_or_load_asr_model():
|
||||||
global asr_model
|
global asr_model
|
||||||
if asr_model is None:
|
if asr_model is None:
|
||||||
@@ -377,6 +395,16 @@ def start_training(
|
|||||||
os.makedirs(checkpoints_dir, exist_ok=True)
|
os.makedirs(checkpoints_dir, exist_ok=True)
|
||||||
os.makedirs(logs_dir, exist_ok=True)
|
os.makedirs(logs_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Auto-detect sample_rate from model config.json to prevent mismatch
|
||||||
|
detected_sr = detect_sample_rate(pretrained_path)
|
||||||
|
if detected_sr is not None:
|
||||||
|
if int(sample_rate) != detected_sr:
|
||||||
|
training_log += (
|
||||||
|
f"[Auto-fix] sample_rate changed from {int(sample_rate)} to {detected_sr} "
|
||||||
|
f"(read from {pretrained_path}/config.json audio_vae_config.sample_rate)\n"
|
||||||
|
)
|
||||||
|
sample_rate = detected_sr
|
||||||
|
|
||||||
# Create config dictionary
|
# Create config dictionary
|
||||||
# Resolve max_steps default
|
# Resolve max_steps default
|
||||||
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
||||||
@@ -929,6 +957,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
show_label=False,
|
show_label=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def on_pretrained_path_change(path):
|
||||||
|
"""Auto-detect sample_rate when pretrained model path changes."""
|
||||||
|
sr = detect_sample_rate(path)
|
||||||
|
if sr is not None:
|
||||||
|
return gr.update(value=sr)
|
||||||
|
return gr.update()
|
||||||
|
|
||||||
|
train_pretrained_path.change(
|
||||||
|
on_pretrained_path_change,
|
||||||
|
inputs=[train_pretrained_path],
|
||||||
|
outputs=[sample_rate],
|
||||||
|
)
|
||||||
|
|
||||||
start_btn.click(
|
start_btn.click(
|
||||||
start_training,
|
start_training,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ def train(
|
|||||||
train_manifest: str,
|
train_manifest: str,
|
||||||
val_manifest: str = "",
|
val_manifest: str = "",
|
||||||
sample_rate: int = 16_000,
|
sample_rate: int = 16_000,
|
||||||
|
out_sample_rate: int = 0, # accepted from YAML for documentation; not used in training
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
grad_accum_steps: int = 1,
|
grad_accum_steps: int = 1,
|
||||||
num_workers: int = 2,
|
num_workers: int = 2,
|
||||||
@@ -68,6 +69,7 @@ def train(
|
|||||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||||
):
|
):
|
||||||
_ = config_path
|
_ = config_path
|
||||||
|
_ = out_sample_rate
|
||||||
|
|
||||||
# Validate distribution options
|
# Validate distribution options
|
||||||
if lora is not None and distribute and not hf_model_id:
|
if lora is not None and distribute and not hf_model_id:
|
||||||
@@ -98,6 +100,12 @@ def train(
|
|||||||
)
|
)
|
||||||
tokenizer = base_model.text_tokenizer
|
tokenizer = base_model.text_tokenizer
|
||||||
|
|
||||||
|
expected_sr = base_model.audio_vae.sample_rate
|
||||||
|
assert sample_rate == expected_sr, (
|
||||||
|
f"sample_rate mismatch: config says {sample_rate}, but the AudioVAE encoder expects {expected_sr}. "
|
||||||
|
f"Please set sample_rate: {expected_sr} in your training config. "
|
||||||
|
)
|
||||||
|
|
||||||
train_ds, val_ds = load_audio_text_datasets(
|
train_ds, val_ds = load_audio_text_datasets(
|
||||||
train_manifest=train_manifest,
|
train_manifest=train_manifest,
|
||||||
val_manifest=val_manifest,
|
val_manifest=val_manifest,
|
||||||
|
|||||||
Reference in New Issue
Block a user