fix: ft log and setting
This commit is contained in:
@@ -30,7 +30,8 @@ except ImportError:
|
||||
import json
|
||||
|
||||
from voxcpm.model import VoxCPMModel, VoxCPM2Model
|
||||
from voxcpm.model.voxcpm import LoRAConfig
|
||||
from voxcpm.model.voxcpm import LoRAConfig as LoRAConfigV1
|
||||
from voxcpm.model.voxcpm2 import LoRAConfig as LoRAConfigV2
|
||||
from voxcpm.training import (
|
||||
Accelerator,
|
||||
BatchProcessor,
|
||||
@@ -46,7 +47,7 @@ def train(
|
||||
train_manifest: str,
|
||||
val_manifest: str = "",
|
||||
sample_rate: int = 16_000,
|
||||
out_sample_rate: int = 0, # accepted from YAML for documentation; not used in training
|
||||
out_sample_rate: int = 0, # AudioVAE decoder output rate; used for TensorBoard audio logging
|
||||
batch_size: int = 1,
|
||||
grad_accum_steps: int = 1,
|
||||
num_workers: int = 2,
|
||||
@@ -64,12 +65,12 @@ def train(
|
||||
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||
lora: dict = None,
|
||||
config_path: str = "",
|
||||
max_grad_norm: float = 0.0, # gradient clipping; 0 = disabled (backward compat)
|
||||
# Distribution options (for LoRA checkpoints)
|
||||
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||
):
|
||||
_ = config_path
|
||||
_ = out_sample_rate
|
||||
|
||||
# Validate distribution options
|
||||
if lora is not None and distribute and not hf_model_id:
|
||||
@@ -93,6 +94,7 @@ def train(
|
||||
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
|
||||
_arch = json.load(_f).get("architecture", "voxcpm").lower()
|
||||
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
|
||||
LoRAConfig = LoRAConfigV2 if _arch == "voxcpm2" else LoRAConfigV1
|
||||
if accelerator.rank == 0:
|
||||
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
|
||||
base_model = _model_cls.from_local(
|
||||
@@ -178,8 +180,12 @@ def train(
|
||||
dataset_cnt=dataset_cnt,
|
||||
device=accelerator.device,
|
||||
)
|
||||
# Save audio_vae for audio generation
|
||||
# Save audio_vae and output sample rate for audio generation.
|
||||
# Prefer model's actual output rate; fall back to YAML out_sample_rate or encode rate.
|
||||
audio_vae_for_gen = base_model.audio_vae
|
||||
out_sr = base_model.sample_rate # decoder output rate (e.g. 48000 for V2)
|
||||
if out_sr == 0 and out_sample_rate > 0:
|
||||
out_sr = out_sample_rate
|
||||
del base_model.audio_vae
|
||||
model = accelerator.prepare_model(base_model)
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
@@ -312,8 +318,8 @@ def train(
|
||||
scaler = getattr(accelerator, "scaler", None)
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
# Use large max_norm to only compute grad_norm without actual clipping
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9)
|
||||
effective_max_norm = max_grad_norm if max_grad_norm > 0 else 1e9
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=effective_max_norm)
|
||||
|
||||
accelerator.step(optimizer)
|
||||
accelerator.update()
|
||||
@@ -341,6 +347,7 @@ def train(
|
||||
val_ds=val_ds,
|
||||
audio_vae=audio_vae_for_gen,
|
||||
sample_rate=sample_rate,
|
||||
out_sample_rate=out_sr,
|
||||
val_texts=val_texts,
|
||||
tokenizer=tokenizer,
|
||||
valid_interval=valid_interval,
|
||||
@@ -367,6 +374,7 @@ def validate(
|
||||
val_ds=None,
|
||||
audio_vae=None,
|
||||
sample_rate=22050,
|
||||
out_sample_rate=0,
|
||||
val_texts=None,
|
||||
tokenizer=None,
|
||||
valid_interval=1000,
|
||||
@@ -432,6 +440,7 @@ def validate(
|
||||
step,
|
||||
accelerator,
|
||||
sample_rate,
|
||||
out_sample_rate=out_sample_rate,
|
||||
val_texts=val_texts,
|
||||
tokenizer=tokenizer,
|
||||
valid_interval=valid_interval,
|
||||
@@ -534,6 +543,7 @@ def generate_sample_audio(
|
||||
step,
|
||||
accelerator,
|
||||
sample_rate=22050,
|
||||
out_sample_rate=0,
|
||||
val_texts=None,
|
||||
tokenizer=None,
|
||||
pretrained_path=None,
|
||||
@@ -548,6 +558,10 @@ def generate_sample_audio(
|
||||
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
|
||||
|
||||
unwrapped_model = accelerator.unwrap(model)
|
||||
# Determine the correct output sample rate for generated audio.
|
||||
# out_sample_rate is the decoder output rate (e.g. 48kHz for V2);
|
||||
# sample_rate is the encoder input rate (e.g. 16kHz for V2).
|
||||
gen_sr = out_sample_rate if out_sample_rate > 0 else sample_rate
|
||||
|
||||
for i in range(num_samples):
|
||||
sample = val_ds[i]
|
||||
@@ -604,10 +618,10 @@ def generate_sample_audio(
|
||||
gen_audio_np = normalize_audio(gen_audio_np)
|
||||
|
||||
tag = f"val_sample_{i}"
|
||||
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
|
||||
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
|
||||
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=gen_sr)
|
||||
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/gen_sr:.2f}s")
|
||||
|
||||
# Log reference audio
|
||||
# Log reference audio (at encoder input rate, which is what val_ds provides)
|
||||
if ref_audio_np is not None:
|
||||
writer.add_audio(
|
||||
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
|
||||
@@ -615,9 +629,9 @@ def generate_sample_audio(
|
||||
|
||||
# Generate mel spectrogram figure
|
||||
try:
|
||||
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
|
||||
mel_gen = compute_mel_spectrogram(gen_audio_np, gen_sr)
|
||||
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
|
||||
fig = create_mel_figure(gen_audio_np, mel_gen, sample_rate, step, ref_audio_np, mel_ref)
|
||||
fig = create_mel_figure(gen_audio_np, mel_gen, gen_sr, step, ref_audio_np, mel_ref)
|
||||
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
|
||||
log(f"[Audio] Created mel spectrogram figure for sample {i}")
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user