Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 68af4fe502 | |||
| ee3649c1b3 | |||
| 82d77d445c | |||
| 8f95d13073 |
@@ -239,7 +239,7 @@ voxcpm --help
|
|||||||
### Web Demo
|
### Web Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python app.py # then open http://localhost:7860
|
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # use a local model directory, open http://localhost:8808
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🚢 Production Deployment (Nano-vLLM)
|
### 🚢 Production Deployment (Nano-vLLM)
|
||||||
@@ -415,10 +415,54 @@ VoxCPM2 achieves state-of-the-art or comparable results on public zero-shot and
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
### Internal 30-Language ASR Benchmark
|
||||||
|
|
||||||
|
We additionally run an internal multilingual intelligibility benchmark with **30 languages × 500 samples**. ASR transcription is evaluated via **Gemini 3.1 Flash Lite API**.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>Internal 30-Language ASR Benchmark (click to expand)</b></summary>
|
||||||
|
|
||||||
|
| Language | Metric | VoxCPM2 | Fish S2-Pro |
|
||||||
|
|---|---:|---:|---:|
|
||||||
|
| ar (Arabic) | CER | 1.23% | 0.30% |
|
||||||
|
| da (Danish) | WER | 2.70% | 3.52% |
|
||||||
|
| de (German) | WER | 0.96% | 0.64% |
|
||||||
|
| el (Greek) | WER | 3.17% | 4.61% |
|
||||||
|
| en (English) | WER | 0.42% | 1.03% |
|
||||||
|
| es (Spanish) | WER | 1.33% | 0.64% |
|
||||||
|
| fi (Finnish) | WER | 2.24% | 2.80% |
|
||||||
|
| fr (French) | WER | 2.16% | 2.34% |
|
||||||
|
| he (Hebrew) | CER | 2.98% | 15.27% |
|
||||||
|
| hi (Hindi) | CER | 0.79% | 0.91% |
|
||||||
|
| id (Indonesian) | WER | 1.36% | 1.68% |
|
||||||
|
| it (Italian) | WER | 1.65% | 1.08% |
|
||||||
|
| ja (Japanese) | CER | 2.40% | 1.82% |
|
||||||
|
| km (Khmer) | CER | 2.05% | 75.15% |
|
||||||
|
| ko (Korean) | CER | 0.95% | 0.29% |
|
||||||
|
| lo (Lao) | CER | 1.90% | 87.40% |
|
||||||
|
| ms (Malay) | WER | 1.75% | 1.41% |
|
||||||
|
| my (Burmese) | CER | 1.42% | 85.27% |
|
||||||
|
| nl (Dutch) | WER | 1.25% | 1.68% |
|
||||||
|
| no (Norwegian) | WER | 2.49% | 3.76% |
|
||||||
|
| pl (Polish) | WER | 1.90% | 1.65% |
|
||||||
|
| pt (Portuguese) | WER | 1.48% | 1.49% |
|
||||||
|
| ru (Russian) | WER | 0.90% | 0.86% |
|
||||||
|
| sv (Swedish) | WER | 2.22% | 2.63% |
|
||||||
|
| sw (Swahili) | CER | 1.07% | 2.02% |
|
||||||
|
| th (Thai) | CER | 0.94% | 1.92% |
|
||||||
|
| tl (Tagalog) | WER | 2.63% | 4.00% |
|
||||||
|
| tr (Turkish) | WER | 1.65% | 1.65% |
|
||||||
|
| vi (Vietnamese) | WER | 1.56% | 5.56% |
|
||||||
|
| zh (Chinese) | CER | 0.92% | 1.02% |
|
||||||
|
| Average (30 languages) | | **1.68%** | - |
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
### InstructTTSEval
|
### InstructTTSEval
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>Instruction-Guided Voice Design Results</b></summary>
|
<summary><b>Instruction-Guided Voice Design Results (click to expand)</b></summary>
|
||||||
|
|
||||||
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|
||||||
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
|
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
|
||||||
|
|||||||
+46
-2
@@ -238,7 +238,7 @@ voxcpm --help
|
|||||||
### Web Demo
|
### Web Demo
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python app.py # 然后打开 http://localhost:7860
|
python app.py --model-dir /path/to/VoxCPM2 --port 8808 # 指定本地模型路径,然后打开 http://localhost:8808
|
||||||
```
|
```
|
||||||
|
|
||||||
### 🚢 生产部署(Nano-vLLM)
|
### 🚢 生产部署(Nano-vLLM)
|
||||||
@@ -414,10 +414,54 @@ VoxCPM2 在公开的零样本和可控 TTS 基准测试中取得了 SOTA 或可
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### Internal 30-Language ASR Benchmark
|
||||||
|
|
||||||
|
我们额外进行了内部多语言可懂度评测:**30 语种 × 500 样本**,ASR 转写评估使用 **Gemini 3.1 Flash Lite API**。
|
||||||
|
|
||||||
|
<details>
|
||||||
|
<summary><b>内部30语种评测集ASR结果(点击展开)</b></summary>
|
||||||
|
|
||||||
|
| 语言 | 指标 | VoxCPM2 | Fish S2-Pro |
|
||||||
|
|---|---:|---:|---:|
|
||||||
|
| ar (阿拉伯语) | CER | 1.23% | 0.30% |
|
||||||
|
| da (丹麦语) | WER | 2.70% | 3.52% |
|
||||||
|
| de (德语) | WER | 0.96% | 0.64% |
|
||||||
|
| el (希腊语) | WER | 3.17% | 4.61% |
|
||||||
|
| en (英语) | WER | 0.42% | 1.03% |
|
||||||
|
| es (西班牙语) | WER | 1.33% | 0.64% |
|
||||||
|
| fi (芬兰语) | WER | 2.24% | 2.80% |
|
||||||
|
| fr (法语) | WER | 2.16% | 2.34% |
|
||||||
|
| he (希伯来语) | CER | 2.98% | 15.27% |
|
||||||
|
| hi (印地语) | CER | 0.79% | 0.91% |
|
||||||
|
| id (印尼语) | WER | 1.36% | 1.68% |
|
||||||
|
| it (意大利语) | WER | 1.65% | 1.08% |
|
||||||
|
| ja (日语) | CER | 2.40% | 1.82% |
|
||||||
|
| km (高棉语) | CER | 2.05% | 75.15% |
|
||||||
|
| ko (韩语) | CER | 0.95% | 0.29% |
|
||||||
|
| lo (老挝语) | CER | 1.90% | 87.40% |
|
||||||
|
| ms (马来语) | WER | 1.75% | 1.41% |
|
||||||
|
| my (缅甸语) | CER | 1.42% | 85.27% |
|
||||||
|
| nl (荷兰语) | WER | 1.25% | 1.68% |
|
||||||
|
| no (挪威语) | WER | 2.49% | 3.76% |
|
||||||
|
| pl (波兰语) | WER | 1.90% | 1.65% |
|
||||||
|
| pt (葡萄牙语) | WER | 1.48% | 1.49% |
|
||||||
|
| ru (俄语) | WER | 0.90% | 0.86% |
|
||||||
|
| sv (瑞典语) | WER | 2.22% | 2.63% |
|
||||||
|
| sw (斯瓦希里语) | CER | 1.07% | 2.02% |
|
||||||
|
| th (泰语) | CER | 0.94% | 1.92% |
|
||||||
|
| tl (菲律宾语) | WER | 2.63% | 4.00% |
|
||||||
|
| tr (土耳其语) | WER | 1.65% | 1.65% |
|
||||||
|
| vi (越南语) | WER | 1.56% | 5.56% |
|
||||||
|
| zh (中文) | CER | 0.92% | 1.02% |
|
||||||
|
| 平均(30 语种) | | **1.68%** | - |
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
### InstructTTSEval
|
### InstructTTSEval
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary><b>指令驱动音色设计结果</b></summary>
|
<summary><b>指令驱动音色设计结果 (点击展开)</b></summary>
|
||||||
|
|
||||||
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|
| Model | InstructTTSEval-ZH | | | InstructTTSEval-EN | | |
|
||||||
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
|
|-------|:---:|:----:|:----:|:----:|:----:|:----:|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ pretrained_path: /path/to/VoxCPM2/
|
|||||||
train_manifest: /path/to/train.jsonl
|
train_manifest: /path/to/train.jsonl
|
||||||
val_manifest: null
|
val_manifest: null
|
||||||
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
||||||
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training
|
out_sample_rate: 48000 # AudioVAE decoder output rate; used for TensorBoard audio logging
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
@@ -15,6 +15,7 @@ weight_decay: 0.01
|
|||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
max_steps: 1000
|
max_steps: 1000
|
||||||
max_batch_tokens: 8192
|
max_batch_tokens: 8192
|
||||||
|
max_grad_norm: 1.0 # gradient clipping max norm; 0 = disabled
|
||||||
save_path: /path/to/checkpoints/finetune_all
|
save_path: /path/to/checkpoints/finetune_all
|
||||||
tensorboard: /path/to/logs/finetune_all
|
tensorboard: /path/to/logs/finetune_all
|
||||||
lambdas:
|
lambdas:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ pretrained_path: /path/to/VoxCPM2/
|
|||||||
train_manifest: /path/to/train.jsonl
|
train_manifest: /path/to/train.jsonl
|
||||||
val_manifest: null
|
val_manifest: null
|
||||||
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
sample_rate: 16000 # AudioVAE encoder input rate; must match audio_vae_config.sample_rate
|
||||||
out_sample_rate: 48000 # AudioVAE decoder output rate; only used at inference, not during training
|
out_sample_rate: 48000 # AudioVAE decoder output rate; used for TensorBoard audio logging
|
||||||
batch_size: 2
|
batch_size: 2
|
||||||
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
grad_accum_steps: 8 # effective batch size = batch_size × grad_accum_steps = 16
|
||||||
num_workers: 8
|
num_workers: 8
|
||||||
@@ -15,6 +15,7 @@ weight_decay: 0.01
|
|||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
max_steps: 1000
|
max_steps: 1000
|
||||||
max_batch_tokens: 8192
|
max_batch_tokens: 8192
|
||||||
|
max_grad_norm: 1.0 # gradient clipping max norm; 0 = disabled
|
||||||
save_path: /path/to/checkpoints/finetune_lora
|
save_path: /path/to/checkpoints/finetune_lora
|
||||||
tensorboard: /path/to/logs/finetune_lora
|
tensorboard: /path/to/logs/finetune_lora
|
||||||
lambdas:
|
lambdas:
|
||||||
|
|||||||
+33
-8
@@ -14,8 +14,10 @@ from typing import Optional
|
|||||||
project_root = Path(__file__).parent
|
project_root = Path(__file__).parent
|
||||||
sys.path.insert(0, str(project_root / "src"))
|
sys.path.insert(0, str(project_root / "src"))
|
||||||
|
|
||||||
# Default pretrained model path relative to this repo
|
# Default pretrained model path: prefer VoxCPM2 if it exists, fallback to VoxCPM1.5
|
||||||
default_pretrained_path = str(project_root / "models" / "openbmb__VoxCPM1.5")
|
_v2_path = project_root / "models" / "openbmb__VoxCPM2"
|
||||||
|
_v15_path = project_root / "models" / "openbmb__VoxCPM1.5"
|
||||||
|
default_pretrained_path = str(_v2_path if _v2_path.exists() else _v15_path)
|
||||||
|
|
||||||
from voxcpm.core import VoxCPM
|
from voxcpm.core import VoxCPM
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
@@ -368,6 +370,7 @@ def start_training(
|
|||||||
warmup_steps=100,
|
warmup_steps=100,
|
||||||
max_steps=None,
|
max_steps=None,
|
||||||
sample_rate=44100,
|
sample_rate=44100,
|
||||||
|
max_grad_norm=1.0,
|
||||||
# LoRA advanced
|
# LoRA advanced
|
||||||
enable_lm=True,
|
enable_lm=True,
|
||||||
enable_dit=True,
|
enable_dit=True,
|
||||||
@@ -409,11 +412,25 @@ def start_training(
|
|||||||
# Resolve max_steps default
|
# Resolve max_steps default
|
||||||
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
resolved_max_steps = int(max_steps) if max_steps not in (None, "", 0) else int(num_iters)
|
||||||
|
|
||||||
|
# Auto-detect out_sample_rate from model config
|
||||||
|
out_sample_rate = 0
|
||||||
|
config_file = os.path.join(pretrained_path, "config.json")
|
||||||
|
if os.path.isfile(config_file):
|
||||||
|
try:
|
||||||
|
with open(config_file, "r", encoding="utf-8") as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
out_sr = cfg.get("audio_vae_config", {}).get("out_sample_rate")
|
||||||
|
if out_sr:
|
||||||
|
out_sample_rate = int(out_sr)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"pretrained_path": pretrained_path,
|
"pretrained_path": pretrained_path,
|
||||||
"train_manifest": train_manifest,
|
"train_manifest": train_manifest,
|
||||||
"val_manifest": val_manifest,
|
"val_manifest": val_manifest,
|
||||||
"sample_rate": int(sample_rate),
|
"sample_rate": int(sample_rate),
|
||||||
|
"out_sample_rate": out_sample_rate,
|
||||||
"batch_size": int(batch_size),
|
"batch_size": int(batch_size),
|
||||||
"grad_accum_steps": int(grad_accum_steps),
|
"grad_accum_steps": int(grad_accum_steps),
|
||||||
"num_workers": int(num_workers),
|
"num_workers": int(num_workers),
|
||||||
@@ -425,6 +442,7 @@ def start_training(
|
|||||||
"weight_decay": float(weight_decay),
|
"weight_decay": float(weight_decay),
|
||||||
"warmup_steps": int(warmup_steps),
|
"warmup_steps": int(warmup_steps),
|
||||||
"max_steps": resolved_max_steps,
|
"max_steps": resolved_max_steps,
|
||||||
|
"max_grad_norm": float(max_grad_norm),
|
||||||
"save_path": checkpoints_dir,
|
"save_path": checkpoints_dir,
|
||||||
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
||||||
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
|
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||||
@@ -932,17 +950,19 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0)
|
max_steps = gr.Number(label="最大步数 (max_steps, 0→默认num_iters)", value=0, precision=0)
|
||||||
sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0)
|
sample_rate = gr.Number(label="采样率 (sample_rate)", value=44100, precision=0)
|
||||||
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
|
max_grad_norm = gr.Number(label="梯度裁剪 (max_grad_norm, 0=关闭)", value=1.0)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
|
tensorboard_path = gr.Textbox(label="Tensorboard 路径 (可选)", value="")
|
||||||
enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True)
|
enable_lm = gr.Checkbox(label="启用 LoRA LM (enable_lm)", value=True)
|
||||||
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
|
enable_dit = gr.Checkbox(label="启用 LoRA DIT (enable_dit)", value=True)
|
||||||
|
with gr.Row():
|
||||||
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
|
enable_proj = gr.Checkbox(label="启用投影 (enable_proj)", value=False)
|
||||||
dropout = gr.Number(label="LoRA Dropout", value=0.0)
|
dropout = gr.Number(label="LoRA Dropout", value=0.0)
|
||||||
|
|
||||||
gr.Markdown("#### 分发选项 (Distribution)")
|
gr.Markdown("#### 分发选项 (Distribution)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hf_model_id = gr.Textbox(
|
hf_model_id = gr.Textbox(
|
||||||
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
|
label="HuggingFace Model ID (e.g., openbmb/VoxCPM2)", value=""
|
||||||
)
|
)
|
||||||
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
||||||
|
|
||||||
@@ -992,6 +1012,7 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
warmup_steps,
|
warmup_steps,
|
||||||
max_steps,
|
max_steps,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
|
max_grad_norm,
|
||||||
enable_lm,
|
enable_lm,
|
||||||
enable_dit,
|
enable_dit,
|
||||||
enable_proj,
|
enable_proj,
|
||||||
@@ -1150,12 +1171,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
"warmup_steps": "warmup_steps",
|
"warmup_steps": "warmup_steps",
|
||||||
"max_steps": "最大步数 (max_steps)",
|
"max_steps": "最大步数 (max_steps)",
|
||||||
"sample_rate": "采样率 (sample_rate)",
|
"sample_rate": "采样率 (sample_rate)",
|
||||||
|
"max_grad_norm": "梯度裁剪 (max_grad_norm, 0=关闭)",
|
||||||
"enable_lm": "启用 LoRA LM (enable_lm)",
|
"enable_lm": "启用 LoRA LM (enable_lm)",
|
||||||
"enable_dit": "启用 LoRA DIT (enable_dit)",
|
"enable_dit": "启用 LoRA DIT (enable_dit)",
|
||||||
"enable_proj": "启用投影 (enable_proj)",
|
"enable_proj": "启用投影 (enable_proj)",
|
||||||
"dropout": "LoRA Dropout",
|
"dropout": "LoRA Dropout",
|
||||||
"tensorboard_path": "Tensorboard 路径 (可选)",
|
"tensorboard_path": "Tensorboard 路径 (可选)",
|
||||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
|
||||||
"distribute": "分发模式 (distribute)",
|
"distribute": "分发模式 (distribute)",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@@ -1168,12 +1190,13 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
"warmup_steps": "Warmup Steps",
|
"warmup_steps": "Warmup Steps",
|
||||||
"max_steps": "Max Steps",
|
"max_steps": "Max Steps",
|
||||||
"sample_rate": "Sample Rate",
|
"sample_rate": "Sample Rate",
|
||||||
|
"max_grad_norm": "Max Grad Norm (0=disabled)",
|
||||||
"enable_lm": "Enable LoRA LM",
|
"enable_lm": "Enable LoRA LM",
|
||||||
"enable_dit": "Enable LoRA DIT",
|
"enable_dit": "Enable LoRA DIT",
|
||||||
"enable_proj": "Enable Projection",
|
"enable_proj": "Enable Projection",
|
||||||
"dropout": "LoRA Dropout",
|
"dropout": "LoRA Dropout",
|
||||||
"tensorboard_path": "Tensorboard Path (Optional)",
|
"tensorboard_path": "Tensorboard Path (Optional)",
|
||||||
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM2)",
|
||||||
"distribute": "Distribute Mode",
|
"distribute": "Distribute Mode",
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1203,11 +1226,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
gr.update(label=adv["warmup_steps"]),
|
gr.update(label=adv["warmup_steps"]),
|
||||||
gr.update(label=adv["max_steps"]),
|
gr.update(label=adv["max_steps"]),
|
||||||
gr.update(label=adv["sample_rate"]),
|
gr.update(label=adv["sample_rate"]),
|
||||||
|
gr.update(label=adv["max_grad_norm"]),
|
||||||
|
gr.update(label=adv["tensorboard_path"]),
|
||||||
gr.update(label=adv["enable_lm"]),
|
gr.update(label=adv["enable_lm"]),
|
||||||
gr.update(label=adv["enable_dit"]),
|
gr.update(label=adv["enable_dit"]),
|
||||||
gr.update(label=adv["enable_proj"]),
|
gr.update(label=adv["enable_proj"]),
|
||||||
gr.update(label=adv["dropout"]),
|
gr.update(label=adv["dropout"]),
|
||||||
gr.update(label=adv["tensorboard_path"]),
|
|
||||||
# Distribution options
|
# Distribution options
|
||||||
gr.update(label=adv["hf_model_id"]),
|
gr.update(label=adv["hf_model_id"]),
|
||||||
gr.update(label=adv["distribute"]),
|
gr.update(label=adv["distribute"]),
|
||||||
@@ -1254,11 +1278,12 @@ with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css
|
|||||||
warmup_steps,
|
warmup_steps,
|
||||||
max_steps,
|
max_steps,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
|
max_grad_norm,
|
||||||
|
tensorboard_path,
|
||||||
enable_lm,
|
enable_lm,
|
||||||
enable_dit,
|
enable_dit,
|
||||||
enable_proj,
|
enable_proj,
|
||||||
dropout,
|
dropout,
|
||||||
tensorboard_path,
|
|
||||||
# distribution outputs
|
# distribution outputs
|
||||||
hf_model_id,
|
hf_model_id,
|
||||||
distribute,
|
distribute,
|
||||||
|
|||||||
@@ -30,7 +30,8 @@ except ImportError:
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from voxcpm.model import VoxCPMModel, VoxCPM2Model
|
from voxcpm.model import VoxCPMModel, VoxCPM2Model
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig as LoRAConfigV1
|
||||||
|
from voxcpm.model.voxcpm2 import LoRAConfig as LoRAConfigV2
|
||||||
from voxcpm.training import (
|
from voxcpm.training import (
|
||||||
Accelerator,
|
Accelerator,
|
||||||
BatchProcessor,
|
BatchProcessor,
|
||||||
@@ -46,7 +47,7 @@ def train(
|
|||||||
train_manifest: str,
|
train_manifest: str,
|
||||||
val_manifest: str = "",
|
val_manifest: str = "",
|
||||||
sample_rate: int = 16_000,
|
sample_rate: int = 16_000,
|
||||||
out_sample_rate: int = 0, # accepted from YAML for documentation; not used in training
|
out_sample_rate: int = 0, # AudioVAE decoder output rate; used for TensorBoard audio logging
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
grad_accum_steps: int = 1,
|
grad_accum_steps: int = 1,
|
||||||
num_workers: int = 2,
|
num_workers: int = 2,
|
||||||
@@ -64,12 +65,12 @@ def train(
|
|||||||
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
lambdas: Dict[str, float] = {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||||
lora: dict = None,
|
lora: dict = None,
|
||||||
config_path: str = "",
|
config_path: str = "",
|
||||||
|
max_grad_norm: float = 0.0, # gradient clipping; 0 = disabled (backward compat)
|
||||||
# Distribution options (for LoRA checkpoints)
|
# Distribution options (for LoRA checkpoints)
|
||||||
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||||
):
|
):
|
||||||
_ = config_path
|
_ = config_path
|
||||||
_ = out_sample_rate
|
|
||||||
|
|
||||||
# Validate distribution options
|
# Validate distribution options
|
||||||
if lora is not None and distribute and not hf_model_id:
|
if lora is not None and distribute and not hf_model_id:
|
||||||
@@ -93,6 +94,7 @@ def train(
|
|||||||
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
|
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
|
||||||
_arch = json.load(_f).get("architecture", "voxcpm").lower()
|
_arch = json.load(_f).get("architecture", "voxcpm").lower()
|
||||||
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
|
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
|
||||||
|
LoRAConfig = LoRAConfigV2 if _arch == "voxcpm2" else LoRAConfigV1
|
||||||
if accelerator.rank == 0:
|
if accelerator.rank == 0:
|
||||||
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
|
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
|
||||||
base_model = _model_cls.from_local(
|
base_model = _model_cls.from_local(
|
||||||
@@ -178,8 +180,12 @@ def train(
|
|||||||
dataset_cnt=dataset_cnt,
|
dataset_cnt=dataset_cnt,
|
||||||
device=accelerator.device,
|
device=accelerator.device,
|
||||||
)
|
)
|
||||||
# Save audio_vae for audio generation
|
# Save audio_vae and output sample rate for audio generation.
|
||||||
|
# Prefer model's actual output rate; fall back to YAML out_sample_rate or encode rate.
|
||||||
audio_vae_for_gen = base_model.audio_vae
|
audio_vae_for_gen = base_model.audio_vae
|
||||||
|
out_sr = base_model.sample_rate # decoder output rate (e.g. 48000 for V2)
|
||||||
|
if out_sr == 0 and out_sample_rate > 0:
|
||||||
|
out_sr = out_sample_rate
|
||||||
del base_model.audio_vae
|
del base_model.audio_vae
|
||||||
model = accelerator.prepare_model(base_model)
|
model = accelerator.prepare_model(base_model)
|
||||||
unwrapped_model = accelerator.unwrap(model)
|
unwrapped_model = accelerator.unwrap(model)
|
||||||
@@ -312,8 +318,8 @@ def train(
|
|||||||
scaler = getattr(accelerator, "scaler", None)
|
scaler = getattr(accelerator, "scaler", None)
|
||||||
if scaler is not None:
|
if scaler is not None:
|
||||||
scaler.unscale_(optimizer)
|
scaler.unscale_(optimizer)
|
||||||
# Use large max_norm to only compute grad_norm without actual clipping
|
effective_max_norm = max_grad_norm if max_grad_norm > 0 else 1e9
|
||||||
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=1e9)
|
grad_norm = torch.nn.utils.clip_grad_norm_(unwrapped_model.parameters(), max_norm=effective_max_norm)
|
||||||
|
|
||||||
accelerator.step(optimizer)
|
accelerator.step(optimizer)
|
||||||
accelerator.update()
|
accelerator.update()
|
||||||
@@ -341,6 +347,7 @@ def train(
|
|||||||
val_ds=val_ds,
|
val_ds=val_ds,
|
||||||
audio_vae=audio_vae_for_gen,
|
audio_vae=audio_vae_for_gen,
|
||||||
sample_rate=sample_rate,
|
sample_rate=sample_rate,
|
||||||
|
out_sample_rate=out_sr,
|
||||||
val_texts=val_texts,
|
val_texts=val_texts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
valid_interval=valid_interval,
|
valid_interval=valid_interval,
|
||||||
@@ -367,6 +374,7 @@ def validate(
|
|||||||
val_ds=None,
|
val_ds=None,
|
||||||
audio_vae=None,
|
audio_vae=None,
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
out_sample_rate=0,
|
||||||
val_texts=None,
|
val_texts=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
valid_interval=1000,
|
valid_interval=1000,
|
||||||
@@ -432,6 +440,7 @@ def validate(
|
|||||||
step,
|
step,
|
||||||
accelerator,
|
accelerator,
|
||||||
sample_rate,
|
sample_rate,
|
||||||
|
out_sample_rate=out_sample_rate,
|
||||||
val_texts=val_texts,
|
val_texts=val_texts,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
valid_interval=valid_interval,
|
valid_interval=valid_interval,
|
||||||
@@ -534,6 +543,7 @@ def generate_sample_audio(
|
|||||||
step,
|
step,
|
||||||
accelerator,
|
accelerator,
|
||||||
sample_rate=22050,
|
sample_rate=22050,
|
||||||
|
out_sample_rate=0,
|
||||||
val_texts=None,
|
val_texts=None,
|
||||||
tokenizer=None,
|
tokenizer=None,
|
||||||
pretrained_path=None,
|
pretrained_path=None,
|
||||||
@@ -548,6 +558,10 @@ def generate_sample_audio(
|
|||||||
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
|
log(f"[Audio] Starting audio generation for {num_samples} samples at step {step}")
|
||||||
|
|
||||||
unwrapped_model = accelerator.unwrap(model)
|
unwrapped_model = accelerator.unwrap(model)
|
||||||
|
# Determine the correct output sample rate for generated audio.
|
||||||
|
# out_sample_rate is the decoder output rate (e.g. 48kHz for V2);
|
||||||
|
# sample_rate is the encoder input rate (e.g. 16kHz for V2).
|
||||||
|
gen_sr = out_sample_rate if out_sample_rate > 0 else sample_rate
|
||||||
|
|
||||||
for i in range(num_samples):
|
for i in range(num_samples):
|
||||||
sample = val_ds[i]
|
sample = val_ds[i]
|
||||||
@@ -604,10 +618,10 @@ def generate_sample_audio(
|
|||||||
gen_audio_np = normalize_audio(gen_audio_np)
|
gen_audio_np = normalize_audio(gen_audio_np)
|
||||||
|
|
||||||
tag = f"val_sample_{i}"
|
tag = f"val_sample_{i}"
|
||||||
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=sample_rate)
|
writer.add_audio(f"{tag}/generated_audio", gen_audio_np, global_step=step, sample_rate=gen_sr)
|
||||||
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/sample_rate:.2f}s")
|
log(f"[Audio] Generated audio for sample {i}: duration={len(gen_audio_np)/gen_sr:.2f}s")
|
||||||
|
|
||||||
# Log reference audio
|
# Log reference audio (at encoder input rate, which is what val_ds provides)
|
||||||
if ref_audio_np is not None:
|
if ref_audio_np is not None:
|
||||||
writer.add_audio(
|
writer.add_audio(
|
||||||
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
|
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
|
||||||
@@ -615,9 +629,9 @@ def generate_sample_audio(
|
|||||||
|
|
||||||
# Generate mel spectrogram figure
|
# Generate mel spectrogram figure
|
||||||
try:
|
try:
|
||||||
mel_gen = compute_mel_spectrogram(gen_audio_np, sample_rate)
|
mel_gen = compute_mel_spectrogram(gen_audio_np, gen_sr)
|
||||||
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
|
mel_ref = compute_mel_spectrogram(ref_audio_np, sample_rate) if ref_audio_np is not None else None
|
||||||
fig = create_mel_figure(gen_audio_np, mel_gen, sample_rate, step, ref_audio_np, mel_ref)
|
fig = create_mel_figure(gen_audio_np, mel_gen, gen_sr, step, ref_audio_np, mel_ref)
|
||||||
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
|
writer.add_figure(f"{tag}/mel_spectrogram", fig, global_step=step)
|
||||||
log(f"[Audio] Created mel spectrogram figure for sample {i}")
|
log(f"[Audio] Created mel spectrogram figure for sample {i}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
+31
-46
@@ -48,25 +48,8 @@ from ..modules.minicpm4 import MiniCPM4Config, MiniCPMModel
|
|||||||
from .utils import get_dtype, mask_multichar_chinese_tokens
|
from .utils import get_dtype, mask_multichar_chinese_tokens
|
||||||
|
|
||||||
|
|
||||||
def _trim_audio_silence_vad(
|
# A simple function to trim audio silence using VAD, not used default
|
||||||
audio: torch.Tensor,
|
def _trim_audio_silence_vad(audio: torch.Tensor, sample_rate: int, max_silence_ms: float = 200.0, top_db: float = 35.0) -> torch.Tensor:
|
||||||
sample_rate: int,
|
|
||||||
max_silence_ms: float = 200.0,
|
|
||||||
top_db: float = 35.0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""使用能量阈值(VAD 方式)截取首尾静音及尾部长段伪静音,首尾各最多保留 max_silence_ms 毫秒静音。
|
|
||||||
|
|
||||||
会同时截掉末尾的长段伪静音(低能量但非完全静音的段落,如长时间底噪)。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
audio: (1, T) 的音频 tensor
|
|
||||||
sample_rate: 采样率
|
|
||||||
max_silence_ms: 首尾允许保留的最大静音长度(毫秒)
|
|
||||||
top_db: 低于参考电平多少 dB 视为静音
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
截取后的 (1, T') tensor
|
|
||||||
"""
|
|
||||||
if audio.numel() == 0:
|
if audio.numel() == 0:
|
||||||
return audio
|
return audio
|
||||||
y = audio.squeeze(0).numpy()
|
y = audio.squeeze(0).numpy()
|
||||||
@@ -85,7 +68,7 @@ def _trim_audio_silence_vad(
|
|||||||
except Exception:
|
except Exception:
|
||||||
start, end = 0, n
|
start, end = 0, n
|
||||||
|
|
||||||
# 用逐帧 RMS 找「最后一段有持续能量的位置」,截掉末尾长伪静音(低能量底噪等)
|
# Find the last frame with continuous energy, trim the long pseudo-silence at the end (low energy background noise, etc.)
|
||||||
n_frames = max(0, (n - frame_length) // hop_length + 1)
|
n_frames = max(0, (n - frame_length) // hop_length + 1)
|
||||||
last_voice_frame = -1
|
last_voice_frame = -1
|
||||||
for j in range(n_frames):
|
for j in range(n_frames):
|
||||||
@@ -246,6 +229,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
# Audio VAE
|
# Audio VAE
|
||||||
self.audio_vae = audio_vae
|
self.audio_vae = audio_vae
|
||||||
self.chunk_size = audio_vae.chunk_size
|
self.chunk_size = audio_vae.chunk_size
|
||||||
|
self._decode_chunk_size = getattr(audio_vae, "decode_chunk_size", audio_vae.chunk_size)
|
||||||
self._encode_sample_rate = audio_vae.sample_rate
|
self._encode_sample_rate = audio_vae.sample_rate
|
||||||
self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate)
|
self.sample_rate = getattr(audio_vae, "out_sample_rate", audio_vae.sample_rate)
|
||||||
|
|
||||||
@@ -382,11 +366,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
mu=dit_hidden,
|
mu=dit_hidden,
|
||||||
patch_size=self.patch_size,
|
patch_size=self.patch_size,
|
||||||
cond=feat_cond_for_sample,
|
cond=feat_cond_for_sample,
|
||||||
n_timesteps=(
|
n_timesteps=10,
|
||||||
self.config.dit_config.cfm_config.inference_cfg_rate
|
|
||||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
|
||||||
else 10
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||||
|
|
||||||
@@ -656,14 +636,14 @@ class VoxCPM2Model(nn.Module):
|
|||||||
streaming_prefix_len=streaming_prefix_len,
|
streaming_prefix_len=streaming_prefix_len,
|
||||||
)
|
)
|
||||||
if streaming:
|
if streaming:
|
||||||
patch_len = self.patch_size * self.chunk_size
|
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||||
for latent_pred, _ in inference_result:
|
for latent_pred, _, _ctx in inference_result:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
||||||
yield decode_audio
|
yield decode_audio
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -679,10 +659,9 @@ class VoxCPM2Model(nn.Module):
|
|||||||
|
|
||||||
if not streaming:
|
if not streaming:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
patch_len = self.patch_size * self.chunk_size
|
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||||
has_continuation = bool(prompt_wav_path)
|
if context_len > 0:
|
||||||
if has_continuation:
|
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
|
||||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
|
||||||
else:
|
else:
|
||||||
decode_audio = decode_audio.squeeze(1).cpu()
|
decode_audio = decode_audio.squeeze(1).cpu()
|
||||||
yield decode_audio
|
yield decode_audio
|
||||||
@@ -944,14 +923,14 @@ class VoxCPM2Model(nn.Module):
|
|||||||
streaming_prefix_len=streaming_prefix_len,
|
streaming_prefix_len=streaming_prefix_len,
|
||||||
)
|
)
|
||||||
if streaming:
|
if streaming:
|
||||||
patch_len = self.patch_size * self.chunk_size
|
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||||
for latent_pred, pred_audio_feat in inference_result:
|
for latent_pred, pred_audio_feat, _ctx in inference_result:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
decode_audio = decode_audio[..., -decode_patch_len:].squeeze(1).cpu()
|
||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat, context_len = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(
|
print(
|
||||||
@@ -966,18 +945,20 @@ class VoxCPM2Model(nn.Module):
|
|||||||
break
|
break
|
||||||
if not streaming:
|
if not streaming:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
patch_len = self.patch_size * self.chunk_size
|
decode_patch_len = self.patch_size * self._decode_chunk_size
|
||||||
if mode in ("continuation", "ref_continuation"):
|
if context_len > 0:
|
||||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
decode_audio = decode_audio[..., decode_patch_len * context_len:].squeeze(1).cpu()
|
||||||
else:
|
else:
|
||||||
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
decode_audio = decode_audio.squeeze(1).cpu()
|
||||||
yield (decode_audio, target_text_token, pred_audio_feat)
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
|
|
||||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
return next(self._inference(*args, streaming=False, **kwargs))
|
feat_pred, generated_feat, _ = next(self._inference(*args, streaming=False, **kwargs))
|
||||||
|
return feat_pred, generated_feat
|
||||||
|
|
||||||
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
def inference_streaming(self, *args, **kwargs) -> Generator[Tuple[torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
return self._inference(*args, streaming=True, **kwargs)
|
for feat_pred, pred_feat_seq, _ in self._inference(*args, streaming=True, **kwargs):
|
||||||
|
yield feat_pred, pred_feat_seq
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _inference(
|
def _inference(
|
||||||
@@ -1036,6 +1017,7 @@ class VoxCPM2Model(nn.Module):
|
|||||||
# trailing audio patches as initial context so the VAE can decode smoothly.
|
# trailing audio patches as initial context so the VAE can decode smoothly.
|
||||||
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
|
# - Reference-only / zero-shot (feat_mask ends with 0): start from scratch.
|
||||||
has_continuation_audio = feat_mask[0, -1].item() == 1
|
has_continuation_audio = feat_mask[0, -1].item() == 1
|
||||||
|
context_len = 0
|
||||||
if has_continuation_audio:
|
if has_continuation_audio:
|
||||||
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
|
audio_indices = feat_mask.squeeze(0).nonzero(as_tuple=True)[0]
|
||||||
context_len = min(streaming_prefix_len - 1, len(audio_indices))
|
context_len = min(streaming_prefix_len - 1, len(audio_indices))
|
||||||
@@ -1085,11 +1067,13 @@ class VoxCPM2Model(nn.Module):
|
|||||||
prefix_feat_cond = pred_feat
|
prefix_feat_cond = pred_feat
|
||||||
|
|
||||||
if streaming:
|
if streaming:
|
||||||
# return the last three predicted latent features to provide enough context for smooth decoding
|
|
||||||
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
pred_feat_chunk = torch.cat(pred_feat_seq[-streaming_prefix_len:], dim=1)
|
||||||
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(pred_feat_chunk, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
|
|
||||||
yield feat_pred, pred_feat_seq
|
yield feat_pred, pred_feat_seq, context_len
|
||||||
|
|
||||||
|
if len(pred_feat_seq) > streaming_prefix_len:
|
||||||
|
pred_feat_seq = pred_feat_seq[-streaming_prefix_len:]
|
||||||
|
|
||||||
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
stop_flag = self.stop_head(self.stop_actn(self.stop_proj(lm_hidden))).argmax(dim=-1)[0].cpu().item()
|
||||||
if i > min_len and stop_flag == 1:
|
if i > min_len and stop_flag == 1:
|
||||||
@@ -1108,7 +1092,8 @@ class VoxCPM2Model(nn.Module):
|
|||||||
if not streaming:
|
if not streaming:
|
||||||
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
pred_feat_seq = torch.cat(pred_feat_seq, dim=1) # b, t, p, d
|
||||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
generated_feat = pred_feat_seq[:, context_len:, :, :].squeeze(0).cpu()
|
||||||
|
yield feat_pred, generated_feat, context_len
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||||
|
|||||||
@@ -436,6 +436,7 @@ class AudioVAE(nn.Module):
|
|||||||
self.out_sample_rate = out_sample_rate
|
self.out_sample_rate = out_sample_rate
|
||||||
self.sr_bin_boundaries = sr_bin_boundaries
|
self.sr_bin_boundaries = sr_bin_boundaries
|
||||||
self.chunk_size = math.prod(encoder_rates)
|
self.chunk_size = math.prod(encoder_rates)
|
||||||
|
self.decode_chunk_size = math.prod(decoder_rates)
|
||||||
|
|
||||||
def preprocess(self, audio_data, sample_rate):
|
def preprocess(self, audio_data, sample_rate):
|
||||||
if sample_rate is None:
|
if sample_rate is None:
|
||||||
|
|||||||
@@ -225,7 +225,7 @@ class UnifiedCFM(torch.nn.Module):
|
|||||||
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
|
losses = F.mse_loss(u_pred, u_tgt.detach(), reduction="none").mean(dim=1)
|
||||||
if tgt_mask is not None:
|
if tgt_mask is not None:
|
||||||
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
|
weights = self.adaptive_loss_weighting(losses, tgt_mask.squeeze(1))
|
||||||
loss = (weights * losses).sum() / torch.sum(tgt_mask)
|
loss = (weights * losses).sum() / torch.clamp(torch.sum(tgt_mask), min=1.0)
|
||||||
else:
|
else:
|
||||||
loss = losses.mean()
|
loss = losses.mean()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user