update voxcpm2
This commit is contained in:
@@ -3,13 +3,14 @@ import sys
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import spaces
|
import spaces # noqa: F401
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
if os.environ.get("HF_REPO_ID", "").strip() == "":
|
||||||
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM1.5"
|
os.environ["HF_REPO_ID"] = "openbmb/VoxCPM2"
|
||||||
|
|
||||||
import voxcpm
|
import voxcpm
|
||||||
|
|
||||||
@@ -24,13 +25,13 @@ class VoxCPMDemo:
|
|||||||
self.asr_model: Optional[AutoModel] = AutoModel(
|
self.asr_model: Optional[AutoModel] = AutoModel(
|
||||||
model=self.asr_model_id,
|
model=self.asr_model_id,
|
||||||
disable_update=True,
|
disable_update=True,
|
||||||
log_level='DEBUG',
|
log_level="DEBUG",
|
||||||
device="cuda:0" if self.device == "cuda" else "cpu",
|
device="cuda:0" if self.device == "cuda" else "cpu",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TTS model (lazy init)
|
# TTS model (lazy init)
|
||||||
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
self.voxcpm_model: Optional[voxcpm.VoxCPM] = None
|
||||||
self.default_local_model_dir = "./models/VoxCPM1.5"
|
self.default_local_model_dir = "/Users/xinliu/Downloads/VoxCPM2-0.5B-newaudiovae-6hz-0316"
|
||||||
|
|
||||||
# ---------- Model helpers ----------
|
# ---------- Model helpers ----------
|
||||||
def _resolve_model_dir(self) -> str:
|
def _resolve_model_dir(self) -> str:
|
||||||
@@ -49,6 +50,7 @@ class VoxCPMDemo:
|
|||||||
if not os.path.isdir(target_dir):
|
if not os.path.isdir(target_dir):
|
||||||
try:
|
try:
|
||||||
from huggingface_hub import snapshot_download # type: ignore
|
from huggingface_hub import snapshot_download # type: ignore
|
||||||
|
|
||||||
os.makedirs(target_dir, exist_ok=True)
|
os.makedirs(target_dir, exist_ok=True)
|
||||||
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
print(f"Downloading model from HF repo '{repo_id}' to '{target_dir}' ...", file=sys.stderr)
|
||||||
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
snapshot_download(repo_id=repo_id, local_dir=target_dir, local_dir_use_symlinks=False)
|
||||||
@@ -64,7 +66,7 @@ class VoxCPMDemo:
|
|||||||
print("Model not loaded, initializing...", file=sys.stderr)
|
print("Model not loaded, initializing...", file=sys.stderr)
|
||||||
model_dir = self._resolve_model_dir()
|
model_dir = self._resolve_model_dir()
|
||||||
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
print(f"Using model dir: {model_dir}", file=sys.stderr)
|
||||||
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir)
|
self.voxcpm_model = voxcpm.VoxCPM(voxcpm_model_path=model_dir, optimize=False)
|
||||||
print("Model loaded successfully.", file=sys.stderr)
|
print("Model loaded successfully.", file=sys.stderr)
|
||||||
return self.voxcpm_model
|
return self.voxcpm_model
|
||||||
|
|
||||||
@@ -73,21 +75,24 @@ class VoxCPMDemo:
|
|||||||
if prompt_wav is None:
|
if prompt_wav is None:
|
||||||
return ""
|
return ""
|
||||||
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
res = self.asr_model.generate(input=prompt_wav, language="auto", use_itn=True)
|
||||||
text = res[0]["text"].split('|>')[-1]
|
text = res[0]["text"].split("|>")[-1]
|
||||||
return text
|
return text
|
||||||
|
|
||||||
def generate_tts_audio(
|
def generate_tts_audio(
|
||||||
self,
|
self,
|
||||||
text_input: str,
|
text_input: str,
|
||||||
prompt_wav_path_input: Optional[str] = None,
|
control_instruction: str = "",
|
||||||
prompt_text_input: Optional[str] = None,
|
reference_wav_path_input: Optional[str] = None,
|
||||||
cfg_value_input: float = 2.0,
|
cfg_value_input: float = 2.0,
|
||||||
inference_timesteps_input: int = 10,
|
inference_timesteps_input: int = 10,
|
||||||
do_normalize: bool = True,
|
do_normalize: bool = True,
|
||||||
denoise: bool = True,
|
denoise: bool = True,
|
||||||
) -> Tuple[int, np.ndarray]:
|
) -> Tuple[int, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Generate speech from text using VoxCPM; optional reference audio for voice style guidance.
|
Generate speech from text using VoxCPM.
|
||||||
|
- If reference_wav provided: Prompt isolation mode (voice cloning)
|
||||||
|
- If no reference_wav: Voice design mode (use control_instruction to describe voice)
|
||||||
|
|
||||||
Returns (sample_rate, waveform_numpy)
|
Returns (sample_rate, waveform_numpy)
|
||||||
"""
|
"""
|
||||||
current_model = self.get_or_load_voxcpm()
|
current_model = self.get_or_load_voxcpm()
|
||||||
@@ -96,14 +101,25 @@ class VoxCPMDemo:
|
|||||||
if len(text) == 0:
|
if len(text) == 0:
|
||||||
raise ValueError("Please input text to synthesize.")
|
raise ValueError("Please input text to synthesize.")
|
||||||
|
|
||||||
prompt_wav_path = prompt_wav_path_input if prompt_wav_path_input else None
|
# 处理 control instruction
|
||||||
prompt_text = prompt_text_input if prompt_text_input else None
|
control = (control_instruction or "").strip()
|
||||||
|
if control:
|
||||||
|
final_text = f"({control}){text}"
|
||||||
|
else:
|
||||||
|
final_text = text
|
||||||
|
|
||||||
print(f"Generating audio for text: '{text[:60]}...'", file=sys.stderr)
|
reference_wav_path = reference_wav_path_input if reference_wav_path_input else None
|
||||||
|
|
||||||
|
# 判断模式
|
||||||
|
if reference_wav_path:
|
||||||
|
print(f"[Prompt Isolation Mode] reference_wav: {reference_wav_path}", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
print(f"[Voice Design Mode] control: {control[:50] if control else 'None'}...", file=sys.stderr)
|
||||||
|
|
||||||
|
print(f"Generating audio for text: '{final_text[:80]}...'", file=sys.stderr)
|
||||||
wav = current_model.generate(
|
wav = current_model.generate(
|
||||||
text=text,
|
text=final_text,
|
||||||
prompt_text=prompt_text,
|
reference_wav_path=reference_wav_path,
|
||||||
prompt_wav_path=prompt_wav_path,
|
|
||||||
cfg_value=float(cfg_value_input),
|
cfg_value=float(cfg_value_input),
|
||||||
inference_timesteps=int(inference_timesteps_input),
|
inference_timesteps=int(inference_timesteps_input),
|
||||||
normalize=do_normalize,
|
normalize=do_normalize,
|
||||||
@@ -114,46 +130,53 @@ class VoxCPMDemo:
|
|||||||
|
|
||||||
# ---------- UI Builders ----------
|
# ---------- UI Builders ----------
|
||||||
|
|
||||||
|
THEME = gr.themes.Soft(
|
||||||
|
primary_hue="blue",
|
||||||
|
secondary_hue="gray",
|
||||||
|
neutral_hue="slate",
|
||||||
|
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"],
|
||||||
|
)
|
||||||
|
|
||||||
|
CSS = """
|
||||||
|
.logo-container {
|
||||||
|
text-align: center;
|
||||||
|
margin: 0.5rem 0 1rem 0;
|
||||||
|
}
|
||||||
|
.logo-container img {
|
||||||
|
height: 80px;
|
||||||
|
width: auto;
|
||||||
|
max-width: 200px;
|
||||||
|
display: inline-block;
|
||||||
|
}
|
||||||
|
/* Bold accordion labels */
|
||||||
|
#acc_quick > .label-wrap,
|
||||||
|
#acc_tips > .label-wrap,
|
||||||
|
#acc_quick > .label-wrap > span,
|
||||||
|
#acc_tips > .label-wrap > span,
|
||||||
|
#acc_quick summary,
|
||||||
|
#acc_tips summary {
|
||||||
|
font-weight: 600 !important;
|
||||||
|
font-size: 1.1em !important;
|
||||||
|
}
|
||||||
|
/* Bold labels for specific checkboxes */
|
||||||
|
#chk_denoise label,
|
||||||
|
#chk_denoise span,
|
||||||
|
#chk_normalize label,
|
||||||
|
#chk_normalize span {
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def create_demo_interface(demo: VoxCPMDemo):
|
def create_demo_interface(demo: VoxCPMDemo):
|
||||||
"""Build the Gradio UI for VoxCPM demo."""
|
"""Build the Gradio UI for VoxCPM demo."""
|
||||||
# static assets (logo path)
|
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
|
||||||
gr.set_static_paths(paths=[Path.cwd().absolute()/"assets"])
|
|
||||||
|
|
||||||
with gr.Blocks(
|
with gr.Blocks() as interface:
|
||||||
theme=gr.themes.Soft(
|
gr.HTML(
|
||||||
primary_hue="blue",
|
'<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>',
|
||||||
secondary_hue="gray",
|
padding=True,
|
||||||
neutral_hue="slate",
|
)
|
||||||
font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
|
|
||||||
),
|
|
||||||
css="""
|
|
||||||
.logo-container {
|
|
||||||
text-align: center;
|
|
||||||
margin: 0.5rem 0 1rem 0;
|
|
||||||
}
|
|
||||||
.logo-container img {
|
|
||||||
height: 80px;
|
|
||||||
width: auto;
|
|
||||||
max-width: 200px;
|
|
||||||
display: inline-block;
|
|
||||||
}
|
|
||||||
/* Bold accordion labels */
|
|
||||||
#acc_quick details > summary,
|
|
||||||
#acc_tips details > summary {
|
|
||||||
font-weight: 600 !important;
|
|
||||||
font-size: 1.1em !important;
|
|
||||||
}
|
|
||||||
/* Bold labels for specific checkboxes */
|
|
||||||
#chk_denoise label,
|
|
||||||
#chk_denoise span,
|
|
||||||
#chk_normalize label,
|
|
||||||
#chk_normalize span {
|
|
||||||
font-weight: 600;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
) as interface:
|
|
||||||
# Header logo
|
|
||||||
gr.HTML('<div class="logo-container"><img src="/gradio_api/file=assets/voxcpm_logo.png" alt="VoxCPM Logo"></div>')
|
|
||||||
|
|
||||||
# Quick Start
|
# Quick Start
|
||||||
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
|
with gr.Accordion("📋 Quick Start Guide |快速入门", open=False, elem_id="acc_quick"):
|
||||||
@@ -200,34 +223,56 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
# Main controls
|
# Main controls
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
prompt_wav = gr.Audio(
|
# 1. Reference Audio
|
||||||
sources=["upload", 'microphone'],
|
# gr.Markdown("### 🎤 Reference Audio (Optional)")
|
||||||
|
# gr.Markdown("*提供参考音频进行音色克隆;不提供则使用 Voice Design 模式*")
|
||||||
|
reference_wav = gr.Audio(
|
||||||
|
sources=["upload", "microphone"],
|
||||||
type="filepath",
|
type="filepath",
|
||||||
label="Prompt Speech (Optional, or let VoxCPM improvise)",
|
label="Reference Audio (Optional)",
|
||||||
value="./examples/example.wav",
|
|
||||||
)
|
)
|
||||||
DoDenoisePromptAudio = gr.Checkbox(
|
DoDenoisePromptAudio = gr.Checkbox(
|
||||||
value=False,
|
value=False,
|
||||||
label="Prompt Speech Enhancement",
|
label="Reference Audio Enhancement",
|
||||||
elem_id="chk_denoise",
|
elem_id="chk_denoise",
|
||||||
info="We use ZipEnhancer model to denoise the prompt audio."
|
info="Use ZipEnhancer to denoise the reference audio",
|
||||||
)
|
)
|
||||||
with gr.Row():
|
|
||||||
prompt_text = gr.Textbox(
|
# 2. Control Instruction
|
||||||
value="Just by listening a few minutes a day, you'll be able to eliminate negative thoughts by conditioning your mind to be more positive.",
|
# gr.Markdown("### 🎛️ Control Instruction (Optional)")
|
||||||
label="Prompt Text",
|
# gr.Markdown("*描述声音风格、情感等,格式:`(instruction) text`*")
|
||||||
placeholder="Please enter the prompt text. Automatic recognition is supported, and you can correct the results yourself..."
|
control_instruction = gr.Textbox(
|
||||||
)
|
value="",
|
||||||
run_btn = gr.Button("Generate Speech", variant="primary")
|
label="Control Instruction",
|
||||||
|
placeholder="*描述声音风格、情感等,格式:`(instruction) text`,例如:年轻女性,温柔甜美 / 悲伤地说 / an excited young man*",
|
||||||
|
lines=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Target Text
|
||||||
|
# gr.Markdown("### 📝 Target Text")
|
||||||
|
text = gr.Textbox(
|
||||||
|
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
||||||
|
label="Target Text",
|
||||||
|
lines=3,
|
||||||
|
)
|
||||||
|
DoNormalizeText = gr.Checkbox(
|
||||||
|
value=False,
|
||||||
|
label="Text Normalization",
|
||||||
|
elem_id="chk_normalize",
|
||||||
|
info="Use wetext library to normalize the input text",
|
||||||
|
)
|
||||||
|
|
||||||
|
run_btn = gr.Button("🔊 Generate Speech", variant="primary", size="lg")
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
gr.Markdown("### ⚙️ Generation Settings")
|
||||||
cfg_value = gr.Slider(
|
cfg_value = gr.Slider(
|
||||||
minimum=1.0,
|
minimum=1.0,
|
||||||
maximum=3.0,
|
maximum=3.0,
|
||||||
value=2.0,
|
value=2.0,
|
||||||
step=0.1,
|
step=0.1,
|
||||||
label="CFG Value (Guidance Scale)",
|
label="CFG Value (Guidance Scale)",
|
||||||
info="Higher values increase adherence to prompt, lower values allow more creativity"
|
info="Higher = more adherence to prompt; Lower = more creativity",
|
||||||
)
|
)
|
||||||
inference_timesteps = gr.Slider(
|
inference_timesteps = gr.Slider(
|
||||||
minimum=4,
|
minimum=4,
|
||||||
@@ -235,40 +280,54 @@ def create_demo_interface(demo: VoxCPMDemo):
|
|||||||
value=10,
|
value=10,
|
||||||
step=1,
|
step=1,
|
||||||
label="Inference Timesteps",
|
label="Inference Timesteps",
|
||||||
info="Number of inference timesteps for generation (higher values may improve quality but slower)"
|
info="Higher = better quality but slower",
|
||||||
)
|
)
|
||||||
with gr.Row():
|
|
||||||
text = gr.Textbox(
|
gr.Markdown("### 🔈 Output")
|
||||||
value="VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly realistic speech.",
|
audio_output = gr.Audio(label="Generated Audio")
|
||||||
label="Target Text",
|
|
||||||
)
|
gr.Markdown("""
|
||||||
with gr.Row():
|
---
|
||||||
DoNormalizeText = gr.Checkbox(
|
**模式说明 / Mode Info:**
|
||||||
value=False,
|
- **有 Reference Audio** → Prompt 隔离模式(音色克隆)
|
||||||
label="Text Normalization",
|
- **无 Reference Audio** → Voice Design 模式(用 Control Instruction 描述声音)
|
||||||
elem_id="chk_normalize",
|
|
||||||
info="We use wetext library to normalize the input text."
|
**Control Instruction 示例:**
|
||||||
)
|
- `年轻女性,温柔甜美`
|
||||||
audio_output = gr.Audio(label="Output Audio")
|
- `悲伤地说`
|
||||||
|
- `an excited young man`
|
||||||
|
""")
|
||||||
|
|
||||||
# Wiring
|
# Wiring
|
||||||
run_btn.click(
|
run_btn.click(
|
||||||
fn=demo.generate_tts_audio,
|
fn=demo.generate_tts_audio,
|
||||||
inputs=[text, prompt_wav, prompt_text, cfg_value, inference_timesteps, DoNormalizeText, DoDenoisePromptAudio],
|
inputs=[
|
||||||
|
text,
|
||||||
|
control_instruction,
|
||||||
|
reference_wav,
|
||||||
|
cfg_value,
|
||||||
|
inference_timesteps,
|
||||||
|
DoNormalizeText,
|
||||||
|
DoDenoisePromptAudio,
|
||||||
|
],
|
||||||
outputs=[audio_output],
|
outputs=[audio_output],
|
||||||
show_progress=True,
|
show_progress=True,
|
||||||
api_name="generate",
|
api_name="generate",
|
||||||
)
|
)
|
||||||
prompt_wav.change(fn=demo.prompt_wav_recognition, inputs=[prompt_wav], outputs=[prompt_text])
|
|
||||||
|
|
||||||
return interface
|
return interface
|
||||||
|
|
||||||
|
|
||||||
def run_demo(server_name: str = "localhost", server_port: int = 7860, show_error: bool = True):
|
def run_demo(server_name: str = "0.0.0.0", server_port: int = 7869, show_error: bool = True):
|
||||||
demo = VoxCPMDemo()
|
demo = VoxCPMDemo()
|
||||||
interface = create_demo_interface(demo)
|
interface = create_demo_interface(demo)
|
||||||
# Recommended to enable queue on Spaces for better throughput
|
interface.queue(max_size=10, default_concurrency_limit=1).launch(
|
||||||
interface.queue(max_size=10, default_concurrency_limit=1).launch(server_name=server_name, server_port=server_port, show_error=show_error)
|
server_name=server_name,
|
||||||
|
server_port=server_port,
|
||||||
|
show_error=show_error,
|
||||||
|
theme=THEME,
|
||||||
|
css=CSS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ lora:
|
|||||||
enable_lm: true
|
enable_lm: true
|
||||||
enable_dit: true
|
enable_dit: true
|
||||||
enable_proj: false
|
enable_proj: false
|
||||||
r: 32
|
r: 8
|
||||||
alpha: 16
|
alpha: 16
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ lora:
|
|||||||
enable_lm: true
|
enable_lm: true
|
||||||
enable_dit: true
|
enable_dit: true
|
||||||
enable_proj: false
|
enable_proj: false
|
||||||
r: 32
|
r: 8
|
||||||
alpha: 16
|
alpha: 16
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
+210
-223
@@ -1,18 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
import glob
|
|
||||||
import json
|
import json
|
||||||
import yaml
|
import yaml
|
||||||
import shutil
|
|
||||||
import datetime
|
import datetime
|
||||||
import subprocess
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
import soundfile as sf
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, List
|
from typing import Optional
|
||||||
|
|
||||||
# Add src to sys.path
|
# Add src to sys.path
|
||||||
project_root = Path(__file__).parent
|
project_root = Path(__file__).parent
|
||||||
@@ -89,7 +85,7 @@ LANG_DICT = {
|
|||||||
"lang_select": "Language / 语言",
|
"lang_select": "Language / 语言",
|
||||||
"refresh": "刷新",
|
"refresh": "刷新",
|
||||||
"output_name": "输出目录名称 (可选,若存在则继续训练)",
|
"output_name": "输出目录名称 (可选,若存在则继续训练)",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Global variables
|
# Global variables
|
||||||
@@ -98,9 +94,11 @@ asr_model: Optional[AutoModel] = None
|
|||||||
training_process: Optional[subprocess.Popen] = None
|
training_process: Optional[subprocess.Popen] = None
|
||||||
training_log = ""
|
training_log = ""
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp_str():
|
def get_timestamp_str():
|
||||||
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
|
||||||
|
|
||||||
def get_or_load_asr_model():
|
def get_or_load_asr_model():
|
||||||
global asr_model
|
global asr_model
|
||||||
if asr_model is None:
|
if asr_model is None:
|
||||||
@@ -109,23 +107,25 @@ def get_or_load_asr_model():
|
|||||||
asr_model = AutoModel(
|
asr_model = AutoModel(
|
||||||
model="iic/SenseVoiceSmall",
|
model="iic/SenseVoiceSmall",
|
||||||
disable_update=True,
|
disable_update=True,
|
||||||
log_level='ERROR',
|
log_level="ERROR",
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
return asr_model
|
return asr_model
|
||||||
|
|
||||||
|
|
||||||
def recognize_audio(audio_path):
|
def recognize_audio(audio_path):
|
||||||
if not audio_path:
|
if not audio_path:
|
||||||
return ""
|
return ""
|
||||||
try:
|
try:
|
||||||
model = get_or_load_asr_model()
|
model = get_or_load_asr_model()
|
||||||
res = model.generate(input=audio_path, language="auto", use_itn=True)
|
res = model.generate(input=audio_path, language="auto", use_itn=True)
|
||||||
text = res[0]["text"].split('|>')[-1]
|
text = res[0]["text"].split("|>")[-1]
|
||||||
return text
|
return text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ASR Error: {e}", file=sys.stderr)
|
print(f"ASR Error: {e}", file=sys.stderr)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
||||||
"""
|
"""
|
||||||
Scans for LoRA checkpoints in the lora directory.
|
Scans for LoRA checkpoints in the lora directory.
|
||||||
@@ -165,11 +165,12 @@ def scan_lora_checkpoints(root_dir="lora", with_info=False):
|
|||||||
# Also check for checkpoints in the default location if they exist
|
# Also check for checkpoints in the default location if they exist
|
||||||
default_ckpt = "checkpoints/finetune_lora"
|
default_ckpt = "checkpoints/finetune_lora"
|
||||||
if os.path.exists(os.path.join(root_dir, default_ckpt)):
|
if os.path.exists(os.path.join(root_dir, default_ckpt)):
|
||||||
# This might be covered by the walk, but good to be sure
|
# This might be covered by the walk, but good to be sure
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return sorted(checkpoints, reverse=True)
|
return sorted(checkpoints, reverse=True)
|
||||||
|
|
||||||
|
|
||||||
def load_lora_config_from_checkpoint(lora_path):
|
def load_lora_config_from_checkpoint(lora_path):
|
||||||
"""Load LoRA config from lora_config.json if available."""
|
"""Load LoRA config from lora_config.json if available."""
|
||||||
lora_config_file = os.path.join(lora_path, "lora_config.json")
|
lora_config_file = os.path.join(lora_path, "lora_config.json")
|
||||||
@@ -184,6 +185,7 @@ def load_lora_config_from_checkpoint(lora_path):
|
|||||||
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
|
print(f"Warning: Failed to load lora_config.json: {e}", file=sys.stderr)
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
|
|
||||||
def get_default_lora_config():
|
def get_default_lora_config():
|
||||||
"""Return default LoRA config for hot-swapping support."""
|
"""Return default LoRA config for hot-swapping support."""
|
||||||
return LoRAConfig(
|
return LoRAConfig(
|
||||||
@@ -192,9 +194,10 @@ def get_default_lora_config():
|
|||||||
r=32,
|
r=32,
|
||||||
alpha=16,
|
alpha=16,
|
||||||
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
target_modules_lm=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||||
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"]
|
target_modules_dit=["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_model(pretrained_path, lora_path=None):
|
def load_model(pretrained_path, lora_path=None):
|
||||||
global current_model
|
global current_model
|
||||||
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
|
print(f"Loading model from {pretrained_path}...", file=sys.stderr)
|
||||||
@@ -228,9 +231,8 @@ def load_model(pretrained_path, lora_path=None):
|
|||||||
)
|
)
|
||||||
return "Model loaded successfully!"
|
return "Model loaded successfully!"
|
||||||
|
|
||||||
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
|
|
||||||
global current_model
|
|
||||||
|
|
||||||
|
def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, steps, seed, pretrained_path=None):
|
||||||
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
|
# 如果选择了 LoRA 模型且当前模型未加载,尝试从 LoRA config 读取 base_model
|
||||||
if current_model is None:
|
if current_model is None:
|
||||||
# 优先使用用户指定的预训练模型路径
|
# 优先使用用户指定的预训练模型路径
|
||||||
@@ -261,7 +263,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
|||||||
# 加载模型
|
# 加载模型
|
||||||
try:
|
try:
|
||||||
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
print(f"Loading base model: {base_model_path}", file=sys.stderr)
|
||||||
status_msg = load_model(base_model_path)
|
load_model(base_model_path)
|
||||||
if lora_selection and lora_selection != "None":
|
if lora_selection and lora_selection != "None":
|
||||||
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
print(f"Model loaded for LoRA: {lora_selection}", file=sys.stderr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -270,6 +272,7 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
|||||||
return None, error_msg
|
return None, error_msg
|
||||||
|
|
||||||
# Handle LoRA hot-swapping
|
# Handle LoRA hot-swapping
|
||||||
|
assert current_model is not None, "Model must be loaded before inference"
|
||||||
if lora_selection and lora_selection != "None":
|
if lora_selection and lora_selection != "None":
|
||||||
full_lora_path = os.path.join("lora", lora_selection)
|
full_lora_path = os.path.join("lora", lora_selection)
|
||||||
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
print(f"Hot-loading LoRA: {full_lora_path}", file=sys.stderr)
|
||||||
@@ -317,14 +320,16 @@ def run_inference(text, prompt_wav, prompt_text, lora_selection, cfg_scale, step
|
|||||||
prompt_text=final_prompt_text,
|
prompt_text=final_prompt_text,
|
||||||
cfg_value=cfg_scale,
|
cfg_value=cfg_scale,
|
||||||
inference_timesteps=steps,
|
inference_timesteps=steps,
|
||||||
denoise=False
|
denoise=False,
|
||||||
)
|
)
|
||||||
return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
|
return (current_model.tts_model.sample_rate, audio_np), "Generation Success"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return None, f"Error: {str(e)}"
|
return None, f"Error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def start_training(
|
def start_training(
|
||||||
pretrained_path,
|
pretrained_path,
|
||||||
train_manifest,
|
train_manifest,
|
||||||
@@ -355,7 +360,7 @@ def start_training(
|
|||||||
hf_model_id="",
|
hf_model_id="",
|
||||||
distribute=False,
|
distribute=False,
|
||||||
):
|
):
|
||||||
global training_process, training_log
|
global training_log
|
||||||
|
|
||||||
if training_process is not None and training_process.poll() is None:
|
if training_process is not None and training_process.poll() is None:
|
||||||
return "Training is already running!"
|
return "Training is already running!"
|
||||||
@@ -394,10 +399,7 @@ def start_training(
|
|||||||
"max_steps": resolved_max_steps,
|
"max_steps": resolved_max_steps,
|
||||||
"save_path": checkpoints_dir,
|
"save_path": checkpoints_dir,
|
||||||
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
"tensorboard": tensorboard_path if tensorboard_path else logs_dir,
|
||||||
"lambdas": {
|
"lambdas": {"loss/diff": 1.0, "loss/stop": 1.0},
|
||||||
"loss/diff": 1.0,
|
|
||||||
"loss/stop": 1.0
|
|
||||||
},
|
|
||||||
"lora": {
|
"lora": {
|
||||||
"enable_lm": bool(enable_lm),
|
"enable_lm": bool(enable_lm),
|
||||||
"enable_dit": bool(enable_dit),
|
"enable_dit": bool(enable_dit),
|
||||||
@@ -406,7 +408,7 @@ def start_training(
|
|||||||
"alpha": int(lora_alpha),
|
"alpha": int(lora_alpha),
|
||||||
"dropout": float(dropout),
|
"dropout": float(dropout),
|
||||||
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
|
"target_modules_lm": ["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||||
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"]
|
"target_modules_dit": ["q_proj", "v_proj", "k_proj", "o_proj"],
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -420,25 +422,15 @@ def start_training(
|
|||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
yaml.dump(config, f)
|
yaml.dump(config, f)
|
||||||
|
|
||||||
cmd = [
|
cmd = [sys.executable, "scripts/train_voxcpm_finetune.py", "--config_path", config_path]
|
||||||
sys.executable,
|
|
||||||
"scripts/train_voxcpm_finetune.py",
|
|
||||||
"--config_path",
|
|
||||||
config_path
|
|
||||||
]
|
|
||||||
|
|
||||||
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
|
training_log = f"Starting training...\nConfig saved to {config_path}\nOutput dir: {save_dir}\n"
|
||||||
|
|
||||||
def run_process():
|
def run_process():
|
||||||
global training_process, training_log
|
global training_process, training_log
|
||||||
training_process = subprocess.Popen(
|
training_process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
|
||||||
cmd,
|
|
||||||
stdout=subprocess.PIPE,
|
|
||||||
stderr=subprocess.STDOUT,
|
|
||||||
text=True,
|
|
||||||
bufsize=1
|
|
||||||
)
|
|
||||||
|
|
||||||
|
assert training_process.stdout is not None
|
||||||
for line in training_process.stdout:
|
for line in training_process.stdout:
|
||||||
training_log += line
|
training_log += line
|
||||||
# Keep log size manageable
|
# Keep log size manageable
|
||||||
@@ -452,17 +444,20 @@ def start_training(
|
|||||||
|
|
||||||
return f"Training started! Check 'lora/{timestamp}'"
|
return f"Training started! Check 'lora/{timestamp}'"
|
||||||
|
|
||||||
|
|
||||||
def get_training_log():
|
def get_training_log():
|
||||||
return training_log
|
return training_log
|
||||||
|
|
||||||
|
|
||||||
def stop_training():
|
def stop_training():
|
||||||
global training_process, training_log
|
global training_log
|
||||||
if training_process is not None and training_process.poll() is None:
|
if training_process is not None and training_process.poll() is None:
|
||||||
training_process.terminate()
|
training_process.terminate()
|
||||||
training_log += "\nTraining terminated by user."
|
training_log += "\nTraining terminated by user."
|
||||||
return "Training stopped."
|
return "Training stopped."
|
||||||
return "No training running."
|
return "No training running."
|
||||||
|
|
||||||
|
|
||||||
# --- GUI Layout ---
|
# --- GUI Layout ---
|
||||||
|
|
||||||
# 自定义CSS样式
|
# 自定义CSS样式
|
||||||
@@ -830,14 +825,10 @@ label {
|
|||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with gr.Blocks(
|
with gr.Blocks(title="VoxCPM LoRA WebUI", theme=gr.themes.Soft(), css=custom_css) as app:
|
||||||
title="VoxCPM LoRA WebUI",
|
|
||||||
theme=gr.themes.Soft(),
|
|
||||||
css=custom_css
|
|
||||||
) as app:
|
|
||||||
|
|
||||||
# State for language
|
# State for language
|
||||||
lang_state = gr.State("zh") # Default to Chinese
|
lang_state = gr.State("zh") # Default to Chinese
|
||||||
|
|
||||||
# 标题区域
|
# 标题区域
|
||||||
with gr.Row(elem_classes="title-section"):
|
with gr.Row(elem_classes="title-section"):
|
||||||
@@ -850,10 +841,7 @@ with gr.Blocks(
|
|||||||
""")
|
""")
|
||||||
with gr.Column(scale=1):
|
with gr.Column(scale=1):
|
||||||
lang_btn = gr.Radio(
|
lang_btn = gr.Radio(
|
||||||
choices=["en", "zh"],
|
choices=["en", "zh"], value="zh", label="🌐 Language / 语言", elem_classes="lang-selector"
|
||||||
value="zh",
|
|
||||||
label="🌐 Language / 语言",
|
|
||||||
elem_classes="lang-selector"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Tabs(elem_classes="tabs") as tabs:
|
with gr.Tabs(elem_classes="tabs") as tabs:
|
||||||
@@ -869,79 +857,40 @@ with gr.Blocks(
|
|||||||
gr.Markdown("#### 📁 基础配置")
|
gr.Markdown("#### 📁 基础配置")
|
||||||
|
|
||||||
train_pretrained_path = gr.Textbox(
|
train_pretrained_path = gr.Textbox(
|
||||||
label="📂 预训练模型路径",
|
label="📂 预训练模型路径", value=default_pretrained_path, elem_classes="input-field"
|
||||||
value=default_pretrained_path,
|
|
||||||
elem_classes="input-field"
|
|
||||||
)
|
)
|
||||||
train_manifest = gr.Textbox(
|
train_manifest = gr.Textbox(
|
||||||
label="📋 训练数据清单 (jsonl)",
|
label="📋 训练数据清单 (jsonl)",
|
||||||
value="examples/train_data_example.jsonl",
|
value="examples/train_data_example.jsonl",
|
||||||
elem_classes="input-field"
|
elem_classes="input-field",
|
||||||
)
|
|
||||||
val_manifest = gr.Textbox(
|
|
||||||
label="📊 验证数据清单 (可选)",
|
|
||||||
value="",
|
|
||||||
elem_classes="input-field"
|
|
||||||
)
|
)
|
||||||
|
val_manifest = gr.Textbox(label="📊 验证数据清单 (可选)", value="", elem_classes="input-field")
|
||||||
|
|
||||||
gr.Markdown("#### ⚙️ 训练参数")
|
gr.Markdown("#### ⚙️ 训练参数")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lr = gr.Number(
|
lr = gr.Number(label="📈 学习率 (Learning Rate)", value=1e-4, elem_classes="input-field")
|
||||||
label="📈 学习率 (Learning Rate)",
|
|
||||||
value=1e-4,
|
|
||||||
elem_classes="input-field"
|
|
||||||
)
|
|
||||||
num_iters = gr.Number(
|
num_iters = gr.Number(
|
||||||
label="🔄 最大迭代次数",
|
label="🔄 最大迭代次数", value=2000, precision=0, elem_classes="input-field"
|
||||||
value=2000,
|
|
||||||
precision=0,
|
|
||||||
elem_classes="input-field"
|
|
||||||
)
|
)
|
||||||
batch_size = gr.Number(
|
batch_size = gr.Number(
|
||||||
label="📦 批次大小 (Batch Size)",
|
label="📦 批次大小 (Batch Size)", value=1, precision=0, elem_classes="input-field"
|
||||||
value=1,
|
|
||||||
precision=0,
|
|
||||||
elem_classes="input-field"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
lora_rank = gr.Number(
|
lora_rank = gr.Number(label="🎯 LoRA Rank", value=32, precision=0, elem_classes="input-field")
|
||||||
label="🎯 LoRA Rank",
|
lora_alpha = gr.Number(label="⚖️ LoRA Alpha", value=16, precision=0, elem_classes="input-field")
|
||||||
value=32,
|
|
||||||
precision=0,
|
|
||||||
elem_classes="input-field"
|
|
||||||
)
|
|
||||||
lora_alpha = gr.Number(
|
|
||||||
label="⚖️ LoRA Alpha",
|
|
||||||
value=16,
|
|
||||||
precision=0,
|
|
||||||
elem_classes="input-field"
|
|
||||||
)
|
|
||||||
save_interval = gr.Number(
|
save_interval = gr.Number(
|
||||||
label="💾 保存间隔 (Steps)",
|
label="💾 保存间隔 (Steps)", value=1000, precision=0, elem_classes="input-field"
|
||||||
value=1000,
|
|
||||||
precision=0,
|
|
||||||
elem_classes="input-field"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
output_name = gr.Textbox(
|
output_name = gr.Textbox(
|
||||||
label="📁 输出目录名称 (可选,若存在则继续训练)",
|
label="📁 输出目录名称 (可选,若存在则继续训练)", value="", elem_classes="input-field"
|
||||||
value="",
|
|
||||||
elem_classes="input-field"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_btn = gr.Button(
|
start_btn = gr.Button("▶️ 开始训练", variant="primary", elem_classes="button-primary")
|
||||||
"▶️ 开始训练",
|
stop_btn = gr.Button("⏹️ 停止训练", variant="stop", elem_classes="button-stop")
|
||||||
variant="primary",
|
|
||||||
elem_classes="button-primary"
|
|
||||||
)
|
|
||||||
stop_btn = gr.Button(
|
|
||||||
"⏹️ 停止训练",
|
|
||||||
variant="stop",
|
|
||||||
elem_classes="button-stop"
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
|
with gr.Accordion("🔧 高级选项 (Advanced)", open=False, elem_classes="accordion"):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@@ -964,7 +913,9 @@ with gr.Blocks(
|
|||||||
|
|
||||||
gr.Markdown("#### 分发选项 (Distribution)")
|
gr.Markdown("#### 分发选项 (Distribution)")
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
hf_model_id = gr.Textbox(label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5")
|
hf_model_id = gr.Textbox(
|
||||||
|
label="HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)", value="openbmb/VoxCPM1.5"
|
||||||
|
)
|
||||||
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
distribute = gr.Checkbox(label="分发模式 (distribute)", value=False)
|
||||||
|
|
||||||
with gr.Column(scale=2, elem_classes="form-section"):
|
with gr.Column(scale=2, elem_classes="form-section"):
|
||||||
@@ -975,23 +926,41 @@ with gr.Blocks(
|
|||||||
max_lines=30,
|
max_lines=30,
|
||||||
interactive=False,
|
interactive=False,
|
||||||
elem_classes="input-field",
|
elem_classes="input-field",
|
||||||
show_label=False
|
show_label=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
start_btn.click(
|
start_btn.click(
|
||||||
start_training,
|
start_training,
|
||||||
inputs=[
|
inputs=[
|
||||||
train_pretrained_path, train_manifest, val_manifest,
|
train_pretrained_path,
|
||||||
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
|
train_manifest,
|
||||||
|
val_manifest,
|
||||||
|
lr,
|
||||||
|
num_iters,
|
||||||
|
batch_size,
|
||||||
|
lora_rank,
|
||||||
|
lora_alpha,
|
||||||
|
save_interval,
|
||||||
output_name,
|
output_name,
|
||||||
# advanced
|
# advanced
|
||||||
grad_accum_steps, num_workers, log_interval, valid_interval,
|
grad_accum_steps,
|
||||||
weight_decay, warmup_steps, max_steps, sample_rate,
|
num_workers,
|
||||||
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
|
log_interval,
|
||||||
|
valid_interval,
|
||||||
|
weight_decay,
|
||||||
|
warmup_steps,
|
||||||
|
max_steps,
|
||||||
|
sample_rate,
|
||||||
|
enable_lm,
|
||||||
|
enable_dit,
|
||||||
|
enable_proj,
|
||||||
|
dropout,
|
||||||
|
tensorboard_path,
|
||||||
# distribution
|
# distribution
|
||||||
hf_model_id, distribute
|
hf_model_id,
|
||||||
|
distribute,
|
||||||
],
|
],
|
||||||
outputs=[logs_out] # Initial message
|
outputs=[logs_out], # Initial message
|
||||||
)
|
)
|
||||||
stop_btn.click(stop_training, outputs=[logs_out])
|
stop_btn.click(stop_training, outputs=[logs_out])
|
||||||
|
|
||||||
@@ -1016,21 +985,17 @@ with gr.Blocks(
|
|||||||
value="Hello, this is a test of the VoxCPM LoRA model.",
|
value="Hello, this is a test of the VoxCPM LoRA model.",
|
||||||
elem_classes="input-field",
|
elem_classes="input-field",
|
||||||
lines=4,
|
lines=4,
|
||||||
placeholder="输入要合成的文本内容..."
|
placeholder="输入要合成的文本内容...",
|
||||||
)
|
)
|
||||||
|
|
||||||
gr.Markdown("**🎭 声音克隆(可选)**")
|
gr.Markdown("**🎭 声音克隆(可选)**")
|
||||||
|
|
||||||
prompt_wav = gr.Audio(
|
prompt_wav = gr.Audio(label="🎵 参考音频", type="filepath", elem_classes="input-field")
|
||||||
label="🎵 参考音频",
|
|
||||||
type="filepath",
|
|
||||||
elem_classes="input-field"
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_text = gr.Textbox(
|
prompt_text = gr.Textbox(
|
||||||
label="📝 参考文本(可选)",
|
label="📝 参考文本(可选)",
|
||||||
elem_classes="input-field",
|
elem_classes="input-field",
|
||||||
placeholder="如不填写,将自动识别参考音频内容"
|
placeholder="如不填写,将自动识别参考音频内容",
|
||||||
)
|
)
|
||||||
|
|
||||||
# 中栏:模型选择和参数配置 (35%)
|
# 中栏:模型选择和参数配置 (35%)
|
||||||
@@ -1043,14 +1008,10 @@ with gr.Blocks(
|
|||||||
value="None",
|
value="None",
|
||||||
interactive=True,
|
interactive=True,
|
||||||
elem_classes="input-field",
|
elem_classes="input-field",
|
||||||
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型"
|
info="选择训练好的 LoRA 模型,或选择 None 使用基础模型",
|
||||||
)
|
)
|
||||||
|
|
||||||
refresh_lora_btn = gr.Button(
|
refresh_lora_btn = gr.Button("🔄 刷新模型列表", elem_classes="button-refresh", size="sm")
|
||||||
"🔄 刷新模型列表",
|
|
||||||
elem_classes="button-refresh",
|
|
||||||
size="sm"
|
|
||||||
)
|
|
||||||
|
|
||||||
gr.Markdown("#### ⚙️ 生成参数")
|
gr.Markdown("#### ⚙️ 生成参数")
|
||||||
|
|
||||||
@@ -1060,7 +1021,7 @@ with gr.Blocks(
|
|||||||
maximum=5.0,
|
maximum=5.0,
|
||||||
value=2.0,
|
value=2.0,
|
||||||
step=0.1,
|
step=0.1,
|
||||||
info="引导系数,值越大越贴近提示"
|
info="引导系数,值越大越贴近提示",
|
||||||
)
|
)
|
||||||
|
|
||||||
steps = gr.Slider(
|
steps = gr.Slider(
|
||||||
@@ -1069,7 +1030,7 @@ with gr.Blocks(
|
|||||||
maximum=50,
|
maximum=50,
|
||||||
value=10,
|
value=10,
|
||||||
step=1,
|
step=1,
|
||||||
info="生成质量与步数成正比,但耗时更长"
|
info="生成质量与步数成正比,但耗时更长",
|
||||||
)
|
)
|
||||||
|
|
||||||
seed = gr.Number(
|
seed = gr.Number(
|
||||||
@@ -1077,25 +1038,16 @@ with gr.Blocks(
|
|||||||
value=-1,
|
value=-1,
|
||||||
precision=0,
|
precision=0,
|
||||||
elem_classes="input-field",
|
elem_classes="input-field",
|
||||||
info="-1 为随机,固定值可复现结果"
|
info="-1 为随机,固定值可复现结果",
|
||||||
)
|
)
|
||||||
|
|
||||||
generate_btn = gr.Button(
|
generate_btn = gr.Button("🎵 生成音频", variant="primary", elem_classes="button-primary", size="lg")
|
||||||
"🎵 生成音频",
|
|
||||||
variant="primary",
|
|
||||||
elem_classes="button-primary",
|
|
||||||
size="lg"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 右栏:生成结果 (30%)
|
# 右栏:生成结果 (30%)
|
||||||
with gr.Column(scale=30, elem_classes="form-section"):
|
with gr.Column(scale=30, elem_classes="form-section"):
|
||||||
gr.Markdown("#### 🎧 生成结果")
|
gr.Markdown("#### 🎧 生成结果")
|
||||||
|
|
||||||
audio_out = gr.Audio(
|
audio_out = gr.Audio(label="", elem_classes="input-field", show_label=False)
|
||||||
label="",
|
|
||||||
elem_classes="input-field",
|
|
||||||
show_label=False
|
|
||||||
)
|
|
||||||
|
|
||||||
gr.Markdown("#### 📋 状态信息")
|
gr.Markdown("#### 📋 状态信息")
|
||||||
|
|
||||||
@@ -1105,7 +1057,7 @@ with gr.Blocks(
|
|||||||
elem_classes="input-field",
|
elem_classes="input-field",
|
||||||
show_label=False,
|
show_label=False,
|
||||||
lines=3,
|
lines=3,
|
||||||
placeholder="等待生成..."
|
placeholder="等待生成...",
|
||||||
)
|
)
|
||||||
|
|
||||||
def refresh_loras():
|
def refresh_loras():
|
||||||
@@ -1126,16 +1078,21 @@ with gr.Blocks(
|
|||||||
refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
|
refresh_lora_btn.click(refresh_loras, outputs=[lora_select])
|
||||||
|
|
||||||
# Auto-recognize audio when uploaded
|
# Auto-recognize audio when uploaded
|
||||||
prompt_wav.change(
|
prompt_wav.change(fn=recognize_audio, inputs=[prompt_wav], outputs=[prompt_text])
|
||||||
fn=recognize_audio,
|
|
||||||
inputs=[prompt_wav],
|
|
||||||
outputs=[prompt_text]
|
|
||||||
)
|
|
||||||
|
|
||||||
generate_btn.click(
|
generate_btn.click(
|
||||||
run_inference,
|
run_inference,
|
||||||
inputs=[infer_text, prompt_wav, prompt_text, lora_select, cfg_scale, steps, seed, train_pretrained_path],
|
inputs=[
|
||||||
outputs=[audio_out, status_out]
|
infer_text,
|
||||||
|
prompt_wav,
|
||||||
|
prompt_text,
|
||||||
|
lora_select,
|
||||||
|
cfg_scale,
|
||||||
|
steps,
|
||||||
|
seed,
|
||||||
|
train_pretrained_path,
|
||||||
|
],
|
||||||
|
outputs=[audio_out, status_out],
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Language Switching Logic ---
|
# --- Language Switching Logic ---
|
||||||
@@ -1144,108 +1101,138 @@ with gr.Blocks(
|
|||||||
# Labels for advanced options
|
# Labels for advanced options
|
||||||
if lang == "zh":
|
if lang == "zh":
|
||||||
adv = {
|
adv = {
|
||||||
'grad_accum_steps': "梯度累积 (grad_accum_steps)",
|
"grad_accum_steps": "梯度累积 (grad_accum_steps)",
|
||||||
'num_workers': "数据加载线程 (num_workers)",
|
"num_workers": "数据加载线程 (num_workers)",
|
||||||
'log_interval': "日志间隔 (log_interval)",
|
"log_interval": "日志间隔 (log_interval)",
|
||||||
'valid_interval': "验证间隔 (valid_interval)",
|
"valid_interval": "验证间隔 (valid_interval)",
|
||||||
'weight_decay': "权重衰减 (weight_decay)",
|
"weight_decay": "权重衰减 (weight_decay)",
|
||||||
'warmup_steps': "warmup_steps",
|
"warmup_steps": "warmup_steps",
|
||||||
'max_steps': "最大步数 (max_steps)",
|
"max_steps": "最大步数 (max_steps)",
|
||||||
'sample_rate': "采样率 (sample_rate)",
|
"sample_rate": "采样率 (sample_rate)",
|
||||||
'enable_lm': "启用 LoRA LM (enable_lm)",
|
"enable_lm": "启用 LoRA LM (enable_lm)",
|
||||||
'enable_dit': "启用 LoRA DIT (enable_dit)",
|
"enable_dit": "启用 LoRA DIT (enable_dit)",
|
||||||
'enable_proj': "启用投影 (enable_proj)",
|
"enable_proj": "启用投影 (enable_proj)",
|
||||||
'dropout': "LoRA Dropout",
|
"dropout": "LoRA Dropout",
|
||||||
'tensorboard_path': "Tensorboard 路径 (可选)",
|
"tensorboard_path": "Tensorboard 路径 (可选)",
|
||||||
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||||
'distribute': "分发模式 (distribute)",
|
"distribute": "分发模式 (distribute)",
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
adv = {
|
adv = {
|
||||||
'grad_accum_steps': "Grad Accum Steps",
|
"grad_accum_steps": "Grad Accum Steps",
|
||||||
'num_workers': "Num Workers",
|
"num_workers": "Num Workers",
|
||||||
'log_interval': "Log Interval",
|
"log_interval": "Log Interval",
|
||||||
'valid_interval': "Valid Interval",
|
"valid_interval": "Valid Interval",
|
||||||
'weight_decay': "Weight Decay",
|
"weight_decay": "Weight Decay",
|
||||||
'warmup_steps': "Warmup Steps",
|
"warmup_steps": "Warmup Steps",
|
||||||
'max_steps': "Max Steps",
|
"max_steps": "Max Steps",
|
||||||
'sample_rate': "Sample Rate",
|
"sample_rate": "Sample Rate",
|
||||||
'enable_lm': "Enable LoRA LM",
|
"enable_lm": "Enable LoRA LM",
|
||||||
'enable_dit': "Enable LoRA DIT",
|
"enable_dit": "Enable LoRA DIT",
|
||||||
'enable_proj': "Enable Projection",
|
"enable_proj": "Enable Projection",
|
||||||
'dropout': "LoRA Dropout",
|
"dropout": "LoRA Dropout",
|
||||||
'tensorboard_path': "Tensorboard Path (Optional)",
|
"tensorboard_path": "Tensorboard Path (Optional)",
|
||||||
'hf_model_id': "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
"hf_model_id": "HuggingFace Model ID (e.g., openbmb/VoxCPM1.5)",
|
||||||
'distribute': "Distribute Mode",
|
"distribute": "Distribute Mode",
|
||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
gr.update(value=f"# {d['title']}"),
|
gr.update(value=f"# {d['title']}"),
|
||||||
gr.update(label=d['tab_train']),
|
gr.update(label=d["tab_train"]),
|
||||||
gr.update(label=d['tab_infer']),
|
gr.update(label=d["tab_infer"]),
|
||||||
gr.update(label=d['pretrained_path']),
|
gr.update(label=d["pretrained_path"]),
|
||||||
gr.update(label=d['train_manifest']),
|
gr.update(label=d["train_manifest"]),
|
||||||
gr.update(label=d['val_manifest']),
|
gr.update(label=d["val_manifest"]),
|
||||||
gr.update(label=d['lr']),
|
gr.update(label=d["lr"]),
|
||||||
gr.update(label=d['max_iters']),
|
gr.update(label=d["max_iters"]),
|
||||||
gr.update(label=d['batch_size']),
|
gr.update(label=d["batch_size"]),
|
||||||
gr.update(label=d['lora_rank']),
|
gr.update(label=d["lora_rank"]),
|
||||||
gr.update(label=d['lora_alpha']),
|
gr.update(label=d["lora_alpha"]),
|
||||||
gr.update(label=d['save_interval']),
|
gr.update(label=d["save_interval"]),
|
||||||
gr.update(label=d['output_name']),
|
gr.update(label=d["output_name"]),
|
||||||
gr.update(value=d['start_train']),
|
gr.update(value=d["start_train"]),
|
||||||
gr.update(value=d['stop_train']),
|
gr.update(value=d["stop_train"]),
|
||||||
gr.update(label=d['train_logs']),
|
gr.update(label=d["train_logs"]),
|
||||||
# Advanced options (must match outputs order)
|
# Advanced options (must match outputs order)
|
||||||
gr.update(label=adv['grad_accum_steps']),
|
gr.update(label=adv["grad_accum_steps"]),
|
||||||
gr.update(label=adv['num_workers']),
|
gr.update(label=adv["num_workers"]),
|
||||||
gr.update(label=adv['log_interval']),
|
gr.update(label=adv["log_interval"]),
|
||||||
gr.update(label=adv['valid_interval']),
|
gr.update(label=adv["valid_interval"]),
|
||||||
gr.update(label=adv['weight_decay']),
|
gr.update(label=adv["weight_decay"]),
|
||||||
gr.update(label=adv['warmup_steps']),
|
gr.update(label=adv["warmup_steps"]),
|
||||||
gr.update(label=adv['max_steps']),
|
gr.update(label=adv["max_steps"]),
|
||||||
gr.update(label=adv['sample_rate']),
|
gr.update(label=adv["sample_rate"]),
|
||||||
gr.update(label=adv['enable_lm']),
|
gr.update(label=adv["enable_lm"]),
|
||||||
gr.update(label=adv['enable_dit']),
|
gr.update(label=adv["enable_dit"]),
|
||||||
gr.update(label=adv['enable_proj']),
|
gr.update(label=adv["enable_proj"]),
|
||||||
gr.update(label=adv['dropout']),
|
gr.update(label=adv["dropout"]),
|
||||||
gr.update(label=adv['tensorboard_path']),
|
gr.update(label=adv["tensorboard_path"]),
|
||||||
# Distribution options
|
# Distribution options
|
||||||
gr.update(label=adv['hf_model_id']),
|
gr.update(label=adv["hf_model_id"]),
|
||||||
gr.update(label=adv['distribute']),
|
gr.update(label=adv["distribute"]),
|
||||||
# Inference section
|
# Inference section
|
||||||
gr.update(label=d['text_to_synth']),
|
gr.update(label=d["text_to_synth"]),
|
||||||
gr.update(label=d['ref_audio']),
|
gr.update(label=d["ref_audio"]),
|
||||||
gr.update(label=d['ref_text']),
|
gr.update(label=d["ref_text"]),
|
||||||
gr.update(label=d['select_lora']),
|
gr.update(label=d["select_lora"]),
|
||||||
gr.update(value=d['refresh']),
|
gr.update(value=d["refresh"]),
|
||||||
gr.update(label=d['cfg_scale']),
|
gr.update(label=d["cfg_scale"]),
|
||||||
gr.update(label=d['infer_steps']),
|
gr.update(label=d["infer_steps"]),
|
||||||
gr.update(label=d['seed']),
|
gr.update(label=d["seed"]),
|
||||||
gr.update(value=d['gen_audio']),
|
gr.update(value=d["gen_audio"]),
|
||||||
gr.update(label=d['gen_output']),
|
gr.update(label=d["gen_output"]),
|
||||||
gr.update(label=d['status']),
|
gr.update(label=d["status"]),
|
||||||
)
|
)
|
||||||
|
|
||||||
lang_btn.change(
|
lang_btn.change(
|
||||||
change_language,
|
change_language,
|
||||||
inputs=[lang_btn],
|
inputs=[lang_btn],
|
||||||
outputs=[
|
outputs=[
|
||||||
title_md, tab_train, tab_infer,
|
title_md,
|
||||||
train_pretrained_path, train_manifest, val_manifest,
|
tab_train,
|
||||||
lr, num_iters, batch_size, lora_rank, lora_alpha, save_interval,
|
tab_infer,
|
||||||
|
train_pretrained_path,
|
||||||
|
train_manifest,
|
||||||
|
val_manifest,
|
||||||
|
lr,
|
||||||
|
num_iters,
|
||||||
|
batch_size,
|
||||||
|
lora_rank,
|
||||||
|
lora_alpha,
|
||||||
|
save_interval,
|
||||||
output_name,
|
output_name,
|
||||||
start_btn, stop_btn, logs_out,
|
start_btn,
|
||||||
|
stop_btn,
|
||||||
|
logs_out,
|
||||||
# advanced outputs
|
# advanced outputs
|
||||||
grad_accum_steps, num_workers, log_interval, valid_interval,
|
grad_accum_steps,
|
||||||
weight_decay, warmup_steps, max_steps, sample_rate,
|
num_workers,
|
||||||
enable_lm, enable_dit, enable_proj, dropout, tensorboard_path,
|
log_interval,
|
||||||
|
valid_interval,
|
||||||
|
weight_decay,
|
||||||
|
warmup_steps,
|
||||||
|
max_steps,
|
||||||
|
sample_rate,
|
||||||
|
enable_lm,
|
||||||
|
enable_dit,
|
||||||
|
enable_proj,
|
||||||
|
dropout,
|
||||||
|
tensorboard_path,
|
||||||
# distribution outputs
|
# distribution outputs
|
||||||
hf_model_id, distribute,
|
hf_model_id,
|
||||||
infer_text, prompt_wav, prompt_text,
|
distribute,
|
||||||
lora_select, refresh_lora_btn, cfg_scale, steps, seed,
|
infer_text,
|
||||||
generate_btn, audio_out, status_out
|
prompt_wav,
|
||||||
]
|
prompt_text,
|
||||||
|
lora_select,
|
||||||
|
refresh_lora_btn,
|
||||||
|
cfg_scale,
|
||||||
|
steps,
|
||||||
|
seed,
|
||||||
|
generate_btn,
|
||||||
|
audio_out,
|
||||||
|
status_out,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
+1
-3
@@ -30,7 +30,7 @@ dependencies = [
|
|||||||
"torchcodec",
|
"torchcodec",
|
||||||
"transformers>=4.36.2",
|
"transformers>=4.36.2",
|
||||||
"einops",
|
"einops",
|
||||||
"gradio<6",
|
"gradio>=6,<7",
|
||||||
"inflect",
|
"inflect",
|
||||||
"addict",
|
"addict",
|
||||||
"wetext",
|
"wetext",
|
||||||
@@ -57,7 +57,6 @@ dev = [
|
|||||||
"pytest-cov>=2.0",
|
"pytest-cov>=2.0",
|
||||||
"black>=21.0",
|
"black>=21.0",
|
||||||
"flake8>=3.8",
|
"flake8>=3.8",
|
||||||
"mypy>=0.800",
|
|
||||||
"pre-commit>=2.0",
|
"pre-commit>=2.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -90,7 +89,6 @@ extend-exclude = '''
|
|||||||
\.eggs
|
\.eggs
|
||||||
| \.git
|
| \.git
|
||||||
| \.hg
|
| \.hg
|
||||||
| \.mypy_cache
|
|
||||||
| \.tox
|
| \.tox
|
||||||
| \.venv
|
| \.venv
|
||||||
| build
|
| build
|
||||||
|
|||||||
@@ -125,7 +125,10 @@ def main():
|
|||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
sf.write(str(out_path), audio_np, model.tts_model.sample_rate)
|
||||||
|
|
||||||
print(f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
print(
|
||||||
|
f"[FT Inference] Saved to: {out_path}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -127,7 +127,9 @@ def main():
|
|||||||
|
|
||||||
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
|
print(f"Loaded config from: {lora_config_path}", file=sys.stderr)
|
||||||
print(f" Base model: {pretrained_path}", file=sys.stderr)
|
print(f" Base model: {pretrained_path}", file=sys.stderr)
|
||||||
print(f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr)
|
print(
|
||||||
|
f" LoRA config: r={lora_cfg.r}, alpha={lora_cfg.alpha}" if lora_cfg else " LoRA config: None", file=sys.stderr
|
||||||
|
)
|
||||||
|
|
||||||
# 3. Load model with LoRA (no denoiser)
|
# 3. Load model with LoRA (no denoiser)
|
||||||
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
|
print(f"\n[1/2] Loading model with LoRA: {pretrained_path}", file=sys.stderr)
|
||||||
@@ -146,10 +148,10 @@ def main():
|
|||||||
out_path = Path(args.output)
|
out_path = Path(args.output)
|
||||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
print(f"\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
print("\n[2/2] Starting synthesis tests...", file=sys.stderr)
|
||||||
|
|
||||||
# === Test 1: With LoRA ===
|
# === Test 1: With LoRA ===
|
||||||
print(f"\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
print("\n [Test 1] Synthesize with LoRA...", file=sys.stderr)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
prompt_wav_path=prompt_wav_path,
|
prompt_wav_path=prompt_wav_path,
|
||||||
@@ -162,10 +164,13 @@ def main():
|
|||||||
)
|
)
|
||||||
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
lora_output = out_path.with_stem(out_path.stem + "_with_lora")
|
||||||
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(lora_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
print(
|
||||||
|
f" Saved: {lora_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
# === Test 2: Disable LoRA (via set_lora_enabled) ===
|
||||||
print(f"\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
print("\n [Test 2] Disable LoRA (set_lora_enabled=False)...", file=sys.stderr)
|
||||||
model.set_lora_enabled(False)
|
model.set_lora_enabled(False)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@@ -179,10 +184,13 @@ def main():
|
|||||||
)
|
)
|
||||||
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
disabled_output = out_path.with_stem(out_path.stem + "_lora_disabled")
|
||||||
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(disabled_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
print(
|
||||||
|
f" Saved: {disabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
# === Test 3: Re-enable LoRA ===
|
# === Test 3: Re-enable LoRA ===
|
||||||
print(f"\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
print("\n [Test 3] Re-enable LoRA (set_lora_enabled=True)...", file=sys.stderr)
|
||||||
model.set_lora_enabled(True)
|
model.set_lora_enabled(True)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@@ -196,10 +204,13 @@ def main():
|
|||||||
)
|
)
|
||||||
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
reenabled_output = out_path.with_stem(out_path.stem + "_lora_reenabled")
|
||||||
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(reenabled_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
print(
|
||||||
|
f" Saved: {reenabled_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
# === Test 4: Unload LoRA (reset_lora_weights) ===
|
||||||
print(f"\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
print("\n [Test 4] Unload LoRA (unload_lora)...", file=sys.stderr)
|
||||||
model.unload_lora()
|
model.unload_lora()
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
@@ -213,10 +224,13 @@ def main():
|
|||||||
)
|
)
|
||||||
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
reset_output = out_path.with_stem(out_path.stem + "_lora_reset")
|
||||||
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(reset_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
print(
|
||||||
|
f" Saved: {reset_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
# === Test 5: Hot-reload LoRA (load_lora) ===
|
# === Test 5: Hot-reload LoRA (load_lora) ===
|
||||||
print(f"\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
print("\n [Test 5] Hot-reload LoRA (load_lora)...", file=sys.stderr)
|
||||||
loaded, skipped = model.load_lora(ckpt_dir)
|
loaded, skipped = model.load_lora(ckpt_dir)
|
||||||
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
|
print(f" Reloaded {len(loaded)} parameters", file=sys.stderr)
|
||||||
audio_np = model.generate(
|
audio_np = model.generate(
|
||||||
@@ -231,9 +245,12 @@ def main():
|
|||||||
)
|
)
|
||||||
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
reload_output = out_path.with_stem(out_path.stem + "_lora_reloaded")
|
||||||
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
sf.write(str(reload_output), audio_np, model.tts_model.sample_rate)
|
||||||
print(f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s", file=sys.stderr)
|
print(
|
||||||
|
f" Saved: {reload_output}, duration: {len(audio_np) / model.tts_model.sample_rate:.2f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
print(f"\n[Done] All tests completed!", file=sys.stderr)
|
print("\n[Done] All tests completed!", file=sys.stderr)
|
||||||
print(f" - with_lora: {lora_output}", file=sys.stderr)
|
print(f" - with_lora: {lora_output}", file=sys.stderr)
|
||||||
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
|
print(f" - lora_disabled: {disabled_output}", file=sys.stderr)
|
||||||
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
|
print(f" - lora_reenabled: {reenabled_output}", file=sys.stderr)
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ project_root = Path(__file__).parent.parent
|
|||||||
sys.path.insert(0, str(project_root / "src"))
|
sys.path.insert(0, str(project_root / "src"))
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
from typing import Dict, Optional
|
from typing import Dict
|
||||||
|
|
||||||
import argbind
|
import argbind
|
||||||
import torch
|
import torch
|
||||||
@@ -17,16 +17,19 @@ from transformers import get_cosine_schedule_with_warmup
|
|||||||
import signal
|
import signal
|
||||||
import os
|
import os
|
||||||
|
|
||||||
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from safetensors.torch import save_file
|
from safetensors.torch import save_file
|
||||||
|
|
||||||
SAFETENSORS_AVAILABLE = True
|
SAFETENSORS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
SAFETENSORS_AVAILABLE = False
|
SAFETENSORS_AVAILABLE = False
|
||||||
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
|
print("Warning: safetensors not available, will use pytorch format", file=sys.stderr)
|
||||||
|
|
||||||
from voxcpm.model import VoxCPMModel
|
import json
|
||||||
|
|
||||||
|
from voxcpm.model import VoxCPMModel, VoxCPM2Model
|
||||||
from voxcpm.model.voxcpm import LoRAConfig
|
from voxcpm.model.voxcpm import LoRAConfig
|
||||||
from voxcpm.training import (
|
from voxcpm.training import (
|
||||||
Accelerator,
|
Accelerator,
|
||||||
@@ -61,8 +64,8 @@ def train(
|
|||||||
lora: dict = None,
|
lora: dict = None,
|
||||||
config_path: str = "",
|
config_path: str = "",
|
||||||
# Distribution options (for LoRA checkpoints)
|
# Distribution options (for LoRA checkpoints)
|
||||||
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
hf_model_id: str = "", # HuggingFace model ID (e.g., "openbmb/VoxCPM1.5")
|
||||||
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
distribute: bool = False, # If True, save hf_model_id as base_model; otherwise save pretrained_path
|
||||||
):
|
):
|
||||||
_ = config_path
|
_ = config_path
|
||||||
|
|
||||||
@@ -84,7 +87,15 @@ def train(
|
|||||||
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
|
writer = SummaryWriter(log_dir=str(tb_dir)) if accelerator.rank == 0 else None
|
||||||
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
|
tracker = TrainingTracker(writer=writer, log_file=str(save_dir / "train.log"), rank=accelerator.rank)
|
||||||
|
|
||||||
base_model = VoxCPMModel.from_local(pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None)
|
# Auto-detect model architecture from config.json
|
||||||
|
with open(os.path.join(pretrained_path, "config.json"), "r", encoding="utf-8") as _f:
|
||||||
|
_arch = json.load(_f).get("architecture", "voxcpm").lower()
|
||||||
|
_model_cls = VoxCPM2Model if _arch == "voxcpm2" else VoxCPMModel
|
||||||
|
if accelerator.rank == 0:
|
||||||
|
print(f"Detected architecture: {_arch} -> {_model_cls.__name__}", file=sys.stderr)
|
||||||
|
base_model = _model_cls.from_local(
|
||||||
|
pretrained_path, optimize=False, training=True, lora_config=LoRAConfig(**lora) if lora else None
|
||||||
|
)
|
||||||
tokenizer = base_model.text_tokenizer
|
tokenizer = base_model.text_tokenizer
|
||||||
|
|
||||||
train_ds, val_ds = load_audio_text_datasets(
|
train_ds, val_ds = load_audio_text_datasets(
|
||||||
@@ -166,7 +177,6 @@ def train(
|
|||||||
unwrapped_model = accelerator.unwrap(model)
|
unwrapped_model = accelerator.unwrap(model)
|
||||||
unwrapped_model.train()
|
unwrapped_model.train()
|
||||||
|
|
||||||
|
|
||||||
# Only print param info on rank 0 to avoid cluttered output
|
# Only print param info on rank 0 to avoid cluttered output
|
||||||
if accelerator.rank == 0:
|
if accelerator.rank == 0:
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
@@ -199,7 +209,19 @@ def train(
|
|||||||
resume = {"step": start_step}
|
resume = {"step": start_step}
|
||||||
|
|
||||||
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
|
# Register signal handler to save checkpoint on termination (SIGTERM/SIGINT)
|
||||||
def _signal_handler(signum, frame, _model=model, _optim=optimizer, _sched=scheduler, _save_dir=save_dir, _pretrained=pretrained_path, _hf_id=hf_model_id, _dist=distribute, _resume=resume, _rank=accelerator.rank):
|
def _signal_handler(
|
||||||
|
signum,
|
||||||
|
frame,
|
||||||
|
_model=model,
|
||||||
|
_optim=optimizer,
|
||||||
|
_sched=scheduler,
|
||||||
|
_save_dir=save_dir,
|
||||||
|
_pretrained=pretrained_path,
|
||||||
|
_hf_id=hf_model_id,
|
||||||
|
_dist=distribute,
|
||||||
|
_resume=resume,
|
||||||
|
_rank=accelerator.rank,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
cur_step = int(_resume.get("step", start_step))
|
cur_step = int(_resume.get("step", start_step))
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -229,8 +251,8 @@ def train(
|
|||||||
except StopIteration:
|
except StopIteration:
|
||||||
data_epoch += 1
|
data_epoch += 1
|
||||||
# Key: set DistributedSampler epoch to ensure different data order each epoch
|
# Key: set DistributedSampler epoch to ensure different data order each epoch
|
||||||
sampler = getattr(train_loader, 'sampler', None)
|
sampler = getattr(train_loader, "sampler", None)
|
||||||
if hasattr(sampler, 'set_epoch'):
|
if hasattr(sampler, "set_epoch"):
|
||||||
sampler.set_epoch(data_epoch)
|
sampler.set_epoch(data_epoch)
|
||||||
train_iter = iter(train_loader)
|
train_iter = iter(train_loader)
|
||||||
return next(train_iter)
|
return next(train_iter)
|
||||||
@@ -250,7 +272,7 @@ def train(
|
|||||||
|
|
||||||
# Only sync gradients on the last micro-batch
|
# Only sync gradients on the last micro-batch
|
||||||
# Use no_sync() for intermediate steps to reduce communication overhead
|
# Use no_sync() for intermediate steps to reduce communication overhead
|
||||||
is_last_micro_step = (micro_step == grad_accum_steps - 1)
|
is_last_micro_step = micro_step == grad_accum_steps - 1
|
||||||
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
|
sync_context = contextlib.nullcontext() if is_last_micro_step else accelerator.no_sync()
|
||||||
|
|
||||||
with sync_context:
|
with sync_context:
|
||||||
@@ -299,10 +321,22 @@ def train(
|
|||||||
tracker.log_metrics(loss_values, split="train")
|
tracker.log_metrics(loss_values, split="train")
|
||||||
|
|
||||||
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
|
if val_loader is not None and (step % valid_interval == 0 or step == num_iters - 1):
|
||||||
validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
validate(
|
||||||
writer=writer, step=step, val_ds=val_ds, audio_vae=audio_vae_for_gen,
|
model,
|
||||||
sample_rate=sample_rate, val_texts=val_texts, tokenizer=tokenizer,
|
val_loader,
|
||||||
valid_interval=valid_interval)
|
batch_processor,
|
||||||
|
accelerator,
|
||||||
|
tracker,
|
||||||
|
lambdas,
|
||||||
|
writer=writer,
|
||||||
|
step=step,
|
||||||
|
val_ds=val_ds,
|
||||||
|
audio_vae=audio_vae_for_gen,
|
||||||
|
sample_rate=sample_rate,
|
||||||
|
val_texts=val_texts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
valid_interval=valid_interval,
|
||||||
|
)
|
||||||
|
|
||||||
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
|
if (step % save_interval == 0 or step == num_iters - 1) and accelerator.rank == 0:
|
||||||
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
|
save_checkpoint(model, optimizer, scheduler, save_dir, step, pretrained_path, hf_model_id, distribute)
|
||||||
@@ -313,11 +347,24 @@ def train(
|
|||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
|
|
||||||
def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
def validate(
|
||||||
writer=None, step=0, val_ds=None, audio_vae=None, sample_rate=22050,
|
model,
|
||||||
val_texts=None, tokenizer=None, valid_interval=1000):
|
val_loader,
|
||||||
|
batch_processor,
|
||||||
|
accelerator,
|
||||||
|
tracker,
|
||||||
|
lambdas,
|
||||||
|
writer=None,
|
||||||
|
step=0,
|
||||||
|
val_ds=None,
|
||||||
|
audio_vae=None,
|
||||||
|
sample_rate=22050,
|
||||||
|
val_texts=None,
|
||||||
|
tokenizer=None,
|
||||||
|
valid_interval=1000,
|
||||||
|
):
|
||||||
"""Validate and generate sample audio"""
|
"""Validate and generate sample audio"""
|
||||||
import numpy as np
|
import numpy as np # noqa: F401
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
@@ -369,13 +416,24 @@ def validate(model, val_loader, batch_processor, accelerator, tracker, lambdas,
|
|||||||
# Generate sample audio for TensorBoard display
|
# Generate sample audio for TensorBoard display
|
||||||
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
|
if writer is not None and val_ds is not None and audio_vae is not None and accelerator.rank == 0:
|
||||||
try:
|
try:
|
||||||
generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate,
|
generate_sample_audio(
|
||||||
val_texts=val_texts, tokenizer=tokenizer, valid_interval=valid_interval,
|
model,
|
||||||
tracker=tracker)
|
val_ds,
|
||||||
|
audio_vae,
|
||||||
|
writer,
|
||||||
|
step,
|
||||||
|
accelerator,
|
||||||
|
sample_rate,
|
||||||
|
val_texts=val_texts,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
valid_interval=valid_interval,
|
||||||
|
tracker=tracker,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
|
tracker.print(f"[Warning] Failed to generate sample audio: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
import io
|
import io
|
||||||
|
|
||||||
buf = io.StringIO()
|
buf = io.StringIO()
|
||||||
traceback.print_exc(file=buf)
|
traceback.print_exc(file=buf)
|
||||||
tracker.print(buf.getvalue())
|
tracker.print(buf.getvalue())
|
||||||
@@ -398,6 +456,7 @@ def compute_mel_spectrogram(audio_np, sample_rate, n_mels=128):
|
|||||||
"""Compute Mel Spectrogram (dB) using librosa"""
|
"""Compute Mel Spectrogram (dB) using librosa"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import librosa
|
import librosa
|
||||||
|
|
||||||
audio_np = audio_np.flatten().astype(np.float32)
|
audio_np = audio_np.flatten().astype(np.float32)
|
||||||
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
|
mel = librosa.feature.melspectrogram(y=audio_np, sr=sample_rate, n_mels=n_mels, fmax=sample_rate // 2)
|
||||||
return librosa.power_to_db(mel, ref=np.max)
|
return librosa.power_to_db(mel, ref=np.max)
|
||||||
@@ -408,7 +467,8 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
|
|||||||
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
|
Create mel spectrogram figure: show comparison if reference audio exists, otherwise show generated only
|
||||||
"""
|
"""
|
||||||
import matplotlib
|
import matplotlib
|
||||||
matplotlib.use('Agg')
|
|
||||||
|
matplotlib.use("Agg")
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import librosa.display
|
import librosa.display
|
||||||
|
|
||||||
@@ -419,19 +479,32 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
|
|||||||
# Comparison mode: reference vs generated
|
# Comparison mode: reference vs generated
|
||||||
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
|
fig, (ax_ref, ax_gen) = plt.subplots(2, 1, figsize=(12, 8))
|
||||||
|
|
||||||
img_ref = librosa.display.specshow(ref_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_ref)
|
img_ref = librosa.display.specshow(
|
||||||
ax_ref.set_title(f'Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}', fontsize=10, fontweight='bold', color='#28A745')
|
ref_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_ref
|
||||||
plt.colorbar(img_ref, ax=ax_ref, format='%+2.0f dB', pad=0.02)
|
)
|
||||||
|
ax_ref.set_title(
|
||||||
|
f"Reference (GT) - {len(ref_audio_np)/sample_rate:.2f}s{step_str}",
|
||||||
|
fontsize=10,
|
||||||
|
fontweight="bold",
|
||||||
|
color="#28A745",
|
||||||
|
)
|
||||||
|
plt.colorbar(img_ref, ax=ax_ref, format="%+2.0f dB", pad=0.02)
|
||||||
|
|
||||||
img_gen = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax_gen)
|
img_gen = librosa.display.specshow(
|
||||||
ax_gen.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s', fontsize=10, fontweight='bold', color='#DC3545')
|
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax_gen
|
||||||
plt.colorbar(img_gen, ax=ax_gen, format='%+2.0f dB', pad=0.02)
|
)
|
||||||
|
ax_gen.set_title(
|
||||||
|
f"Generated - {len(gen_audio_np)/sample_rate:.2f}s", fontsize=10, fontweight="bold", color="#DC3545"
|
||||||
|
)
|
||||||
|
plt.colorbar(img_gen, ax=ax_gen, format="%+2.0f dB", pad=0.02)
|
||||||
else:
|
else:
|
||||||
# Single figure mode: show generated only
|
# Single figure mode: show generated only
|
||||||
fig, ax = plt.subplots(figsize=(12, 4))
|
fig, ax = plt.subplots(figsize=(12, 4))
|
||||||
img = librosa.display.specshow(gen_mel, sr=sample_rate, x_axis='time', y_axis='mel', fmax=fmax, cmap='viridis', ax=ax)
|
img = librosa.display.specshow(
|
||||||
ax.set_title(f'Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}', fontsize=11, fontweight='bold')
|
gen_mel, sr=sample_rate, x_axis="time", y_axis="mel", fmax=fmax, cmap="viridis", ax=ax
|
||||||
plt.colorbar(img, ax=ax, format='%+2.0f dB', pad=0.02)
|
)
|
||||||
|
ax.set_title(f"Generated - {len(gen_audio_np)/sample_rate:.2f}s{step_str}", fontsize=11, fontweight="bold")
|
||||||
|
plt.colorbar(img, ax=ax, format="%+2.0f dB", pad=0.02)
|
||||||
|
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
return fig
|
return fig
|
||||||
@@ -440,13 +513,25 @@ def create_mel_figure(gen_audio_np, gen_mel, sample_rate, step=None, ref_audio_n
|
|||||||
def normalize_audio(audio_np):
|
def normalize_audio(audio_np):
|
||||||
"""Normalize audio to [-0.9, 0.9]"""
|
"""Normalize audio to [-0.9, 0.9]"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
max_val = np.abs(audio_np).max()
|
max_val = np.abs(audio_np).max()
|
||||||
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
|
return audio_np / max_val * 0.9 if max_val > 0 else audio_np
|
||||||
|
|
||||||
|
|
||||||
def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, sample_rate=22050,
|
def generate_sample_audio(
|
||||||
val_texts=None, tokenizer=None, pretrained_path=None, valid_interval=1000,
|
model,
|
||||||
tracker=None):
|
val_ds,
|
||||||
|
audio_vae,
|
||||||
|
writer,
|
||||||
|
step,
|
||||||
|
accelerator,
|
||||||
|
sample_rate=22050,
|
||||||
|
val_texts=None,
|
||||||
|
tokenizer=None,
|
||||||
|
pretrained_path=None,
|
||||||
|
valid_interval=1000,
|
||||||
|
tracker=None,
|
||||||
|
):
|
||||||
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
|
"""Select 2 fixed validation samples, generate audio and log to TensorBoard"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -468,7 +553,10 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
|||||||
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
|
ref_sr = sample["audio"].get("sampling_rate", sample_rate)
|
||||||
if ref_sr != sample_rate:
|
if ref_sr != sample_rate:
|
||||||
import torchaudio.functional as F
|
import torchaudio.functional as F
|
||||||
ref_audio_np = F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
|
|
||||||
|
ref_audio_np = (
|
||||||
|
F.resample(torch.from_numpy(ref_audio_np).unsqueeze(0), ref_sr, sample_rate).squeeze(0).numpy()
|
||||||
|
)
|
||||||
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
|
log(f"[Audio] Loaded reference audio for sample {i}: duration={len(ref_audio_np)/sample_rate:.2f}s")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log(f"[Warning] Failed to load reference audio: {e}")
|
log(f"[Warning] Failed to load reference audio: {e}")
|
||||||
@@ -500,7 +588,11 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Process generated audio
|
# Process generated audio
|
||||||
gen_audio_np = generated.cpu().float().numpy().flatten() if isinstance(generated, torch.Tensor) else np.array(generated, dtype=np.float32).flatten()
|
gen_audio_np = (
|
||||||
|
generated.cpu().float().numpy().flatten()
|
||||||
|
if isinstance(generated, torch.Tensor)
|
||||||
|
else np.array(generated, dtype=np.float32).flatten()
|
||||||
|
)
|
||||||
gen_audio_np = normalize_audio(gen_audio_np)
|
gen_audio_np = normalize_audio(gen_audio_np)
|
||||||
|
|
||||||
tag = f"val_sample_{i}"
|
tag = f"val_sample_{i}"
|
||||||
@@ -509,7 +601,9 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
|||||||
|
|
||||||
# Log reference audio
|
# Log reference audio
|
||||||
if ref_audio_np is not None:
|
if ref_audio_np is not None:
|
||||||
writer.add_audio(f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate)
|
writer.add_audio(
|
||||||
|
f"{tag}/reference_audio", normalize_audio(ref_audio_np), global_step=step, sample_rate=sample_rate
|
||||||
|
)
|
||||||
|
|
||||||
# Generate mel spectrogram figure
|
# Generate mel spectrogram figure
|
||||||
try:
|
try:
|
||||||
@@ -524,6 +618,7 @@ def generate_sample_audio(model, val_ds, audio_vae, writer, step, accelerator, s
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
|
log(f"[Warning] Failed to generate audio for sample {i}: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
@@ -545,8 +640,6 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
|||||||
Called by all ranks so that distributed state stays aligned.
|
Called by all ranks so that distributed state stays aligned.
|
||||||
Returns the step number to resume from, or 0 if no checkpoint found.
|
Returns the step number to resume from, or 0 if no checkpoint found.
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
|
|
||||||
latest_folder = save_dir / "latest"
|
latest_folder = save_dir / "latest"
|
||||||
if not latest_folder.exists():
|
if not latest_folder.exists():
|
||||||
return 0
|
return 0
|
||||||
@@ -564,6 +657,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
|||||||
if lora_weights_path.exists():
|
if lora_weights_path.exists():
|
||||||
if lora_weights_path.suffix == ".safetensors":
|
if lora_weights_path.suffix == ".safetensors":
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
state_dict = load_file(str(lora_weights_path))
|
state_dict = load_file(str(lora_weights_path))
|
||||||
else:
|
else:
|
||||||
ckpt = torch.load(lora_weights_path, map_location="cpu")
|
ckpt = torch.load(lora_weights_path, map_location="cpu")
|
||||||
@@ -581,6 +675,7 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
|||||||
if model_path.exists():
|
if model_path.exists():
|
||||||
if model_path.suffix == ".safetensors":
|
if model_path.suffix == ".safetensors":
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
state_dict = load_file(str(model_path))
|
state_dict = load_file(str(model_path))
|
||||||
else:
|
else:
|
||||||
ckpt = torch.load(model_path, map_location="cpu")
|
ckpt = torch.load(model_path, map_location="cpu")
|
||||||
@@ -625,13 +720,21 @@ def load_checkpoint(model, optimizer, scheduler, save_dir: Path, rank: int = 0):
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pretrained_path: str = None, hf_model_id: str = "", distribute: bool = False):
|
def save_checkpoint(
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
scheduler,
|
||||||
|
save_dir: Path,
|
||||||
|
step: int,
|
||||||
|
pretrained_path: str = None,
|
||||||
|
hf_model_id: str = "",
|
||||||
|
distribute: bool = False,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Save checkpoint with different strategies for full finetune vs LoRA:
|
Save checkpoint with different strategies for full finetune vs LoRA:
|
||||||
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
- Full finetune: save non-vae weights to model.safetensors (or pytorch_model.bin if safetensors unavailable)
|
||||||
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
|
- LoRA: save only lora weights to lora_weights.safetensors (or lora_weights.ckpt if safetensors unavailable)
|
||||||
"""
|
"""
|
||||||
import json
|
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -671,7 +774,14 @@ def save_checkpoint(model, optimizer, scheduler, save_dir: Path, step: int, pret
|
|||||||
# Copy config files from pretrained path
|
# Copy config files from pretrained path
|
||||||
if pretrained_path:
|
if pretrained_path:
|
||||||
pretrained_dir = Path(pretrained_path)
|
pretrained_dir = Path(pretrained_path)
|
||||||
files_to_copy = ["config.json", "audiovae.pth", "tokenizer.json", "special_tokens_map.json", "tokenizer_config.json"]
|
files_to_copy = [
|
||||||
|
"config.json",
|
||||||
|
"audiovae.pth",
|
||||||
|
"audiovae.safetensors",
|
||||||
|
"tokenizer.json",
|
||||||
|
"special_tokens_map.json",
|
||||||
|
"tokenizer_config.json",
|
||||||
|
]
|
||||||
for fname in files_to_copy:
|
for fname in files_to_copy:
|
||||||
src = pretrained_dir / fname
|
src = pretrained_dir / fname
|
||||||
if src.exists():
|
if src.exists():
|
||||||
|
|||||||
+46
-27
@@ -13,11 +13,11 @@ import soundfile as sf
|
|||||||
|
|
||||||
from voxcpm.core import VoxCPM
|
from voxcpm.core import VoxCPM
|
||||||
|
|
||||||
|
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
# Validators
|
# Validators
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
def validate_file_exists(file_path: str, file_type: str = "file") -> Path:
|
||||||
path = Path(file_path)
|
path = Path(file_path)
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
@@ -53,12 +53,11 @@ def validate_ranges(args, parser):
|
|||||||
# Model loading
|
# Model loading
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
def load_model(args) -> VoxCPM:
|
def load_model(args) -> VoxCPM:
|
||||||
print("Loading VoxCPM model...", file=sys.stderr)
|
print("Loading VoxCPM model...", file=sys.stderr)
|
||||||
|
|
||||||
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get(
|
zipenhancer_path = getattr(args, "zipenhancer_path", None) or os.environ.get("ZIPENHANCER_MODEL_PATH", None)
|
||||||
"ZIPENHANCER_MODEL_PATH", None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build LoRA config if provided
|
# Build LoRA config if provided
|
||||||
lora_config = None
|
lora_config = None
|
||||||
@@ -119,22 +118,29 @@ def load_model(args) -> VoxCPM:
|
|||||||
# Commands
|
# Commands
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
def cmd_clone(args):
|
def cmd_clone(args):
|
||||||
if not args.text:
|
if not args.text:
|
||||||
sys.exit("Error: Please provide --text for synthesis")
|
sys.exit("Error: Please provide --text for synthesis")
|
||||||
|
|
||||||
if not args.prompt_audio or not args.prompt_text:
|
has_prompt = args.prompt_audio and args.prompt_text
|
||||||
sys.exit("Error: Voice cloning requires both --prompt-audio and --prompt-text")
|
has_ref = args.reference_audio is not None
|
||||||
|
if not has_prompt and not has_ref:
|
||||||
|
sys.exit("Error: Voice cloning requires --prompt-audio + --prompt-text, or --reference-audio, or both")
|
||||||
|
|
||||||
prompt_audio_path = validate_file_exists(args.prompt_audio, "reference audio file")
|
if args.prompt_audio:
|
||||||
|
validate_file_exists(args.prompt_audio, "prompt audio file")
|
||||||
|
if args.reference_audio:
|
||||||
|
validate_file_exists(args.reference_audio, "reference audio file")
|
||||||
output_path = validate_output_path(args.output)
|
output_path = validate_output_path(args.output)
|
||||||
|
|
||||||
model = load_model(args)
|
model = load_model(args)
|
||||||
|
|
||||||
audio_array = model.generate(
|
audio_array = model.generate(
|
||||||
text=args.text,
|
text=args.text,
|
||||||
prompt_wav_path=str(prompt_audio_path),
|
prompt_wav_path=args.prompt_audio if has_prompt else None,
|
||||||
prompt_text=args.prompt_text,
|
prompt_text=args.prompt_text if has_prompt else None,
|
||||||
|
reference_wav_path=args.reference_audio,
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
@@ -185,7 +191,11 @@ def cmd_batch(args):
|
|||||||
|
|
||||||
prompt_audio_path = None
|
prompt_audio_path = None
|
||||||
if args.prompt_audio:
|
if args.prompt_audio:
|
||||||
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "reference audio file"))
|
prompt_audio_path = str(validate_file_exists(args.prompt_audio, "prompt audio file"))
|
||||||
|
|
||||||
|
reference_audio_path = None
|
||||||
|
if args.reference_audio:
|
||||||
|
reference_audio_path = str(validate_file_exists(args.reference_audio, "reference audio file"))
|
||||||
|
|
||||||
success_count = 0
|
success_count = 0
|
||||||
|
|
||||||
@@ -195,10 +205,11 @@ def cmd_batch(args):
|
|||||||
text=text,
|
text=text,
|
||||||
prompt_wav_path=prompt_audio_path,
|
prompt_wav_path=prompt_audio_path,
|
||||||
prompt_text=args.prompt_text,
|
prompt_text=args.prompt_text,
|
||||||
|
reference_wav_path=reference_audio_path,
|
||||||
cfg_value=args.cfg_value,
|
cfg_value=args.cfg_value,
|
||||||
inference_timesteps=args.inference_timesteps,
|
inference_timesteps=args.inference_timesteps,
|
||||||
normalize=args.normalize,
|
normalize=args.normalize,
|
||||||
denoise=args.denoise and prompt_audio_path is not None,
|
denoise=args.denoise and (prompt_audio_path is not None or reference_audio_path is not None),
|
||||||
)
|
)
|
||||||
|
|
||||||
output_file = output_dir / f"output_{i:03d}.wav"
|
output_file = output_dir / f"output_{i:03d}.wav"
|
||||||
@@ -218,6 +229,7 @@ def cmd_batch(args):
|
|||||||
# Parser
|
# Parser
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
def _build_unified_parser():
|
def _build_unified_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
|
description="VoxCPM CLI - voice cloning, direct TTS, and batch processing",
|
||||||
@@ -236,34 +248,40 @@ Examples:
|
|||||||
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
|
parser.add_argument("--text", "-t", help="Text to synthesize (single or clone mode)")
|
||||||
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
|
parser.add_argument("--output", "-o", help="Output audio file path (single or clone mode)")
|
||||||
|
|
||||||
# Prompt
|
# Prompt / Reference
|
||||||
parser.add_argument("--prompt-audio", "-pa", help="Reference audio file path (clone mode)")
|
parser.add_argument(
|
||||||
parser.add_argument("--prompt-text", "-pt", help="Reference text corresponding to the audio")
|
"--prompt-audio", "-pa", help="Prompt audio file path (continuation mode, requires --prompt-text)"
|
||||||
parser.add_argument("--denoise", action="store_true", help="Enable prompt speech enhancement")
|
)
|
||||||
|
parser.add_argument("--prompt-text", "-pt", help="Text corresponding to the prompt audio")
|
||||||
|
parser.add_argument(
|
||||||
|
"--reference-audio", "-ra", help="Reference audio for voice cloning (isolated mode, VoxCPM2 only)"
|
||||||
|
)
|
||||||
|
parser.add_argument("--denoise", action="store_true", help="Enable prompt/reference speech enhancement")
|
||||||
|
|
||||||
# Generation parameters
|
# Generation parameters
|
||||||
parser.add_argument("--cfg-value", type=float, default=2.0,
|
parser.add_argument(
|
||||||
help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)")
|
"--cfg-value", type=float, default=2.0, help="CFG guidance scale (float, recommended 0.5–5.0, default: 2.0)"
|
||||||
parser.add_argument("--inference-timesteps", type=int, default=10,
|
)
|
||||||
help="Inference steps (int, 1–100, default: 10)")
|
parser.add_argument("--inference-timesteps", type=int, default=10, help="Inference steps (int, 1–100, default: 10)")
|
||||||
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
parser.add_argument("--normalize", action="store_true", help="Enable text normalization")
|
||||||
|
|
||||||
# Model loading
|
# Model loading
|
||||||
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
|
parser.add_argument("--model-path", type=str, help="Local VoxCPM model path")
|
||||||
parser.add_argument("--hf-model-id", type=str, default="openbmb/VoxCPM1.5",
|
parser.add_argument(
|
||||||
help="Hugging Face repo id (default: openbmb/VoxCPM1.5)")
|
"--hf-model-id", type=str, default="openbmb/VoxCPM1.5", help="Hugging Face repo id (default: openbmb/VoxCPM1.5)"
|
||||||
|
)
|
||||||
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
parser.add_argument("--cache-dir", type=str, help="Cache directory for Hub downloads")
|
||||||
parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
|
parser.add_argument("--local-files-only", action="store_true", help="Disable network access")
|
||||||
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
parser.add_argument("--no-denoiser", action="store_true", help="Disable denoiser model loading")
|
||||||
parser.add_argument("--zipenhancer-path", type=str,
|
parser.add_argument(
|
||||||
help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)")
|
"--zipenhancer-path", type=str, help="ZipEnhancer model id or local path (or env ZIPENHANCER_MODEL_PATH)"
|
||||||
|
)
|
||||||
|
|
||||||
# LoRA
|
# LoRA
|
||||||
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
|
parser.add_argument("--lora-path", type=str, help="Path to LoRA weights")
|
||||||
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
|
parser.add_argument("--lora-r", type=int, default=32, help="LoRA rank (positive int, default: 32)")
|
||||||
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
|
parser.add_argument("--lora-alpha", type=int, default=16, help="LoRA alpha (positive int, default: 16)")
|
||||||
parser.add_argument("--lora-dropout", type=float, default=0.0,
|
parser.add_argument("--lora-dropout", type=float, default=0.0, help="LoRA dropout rate (0.0–1.0, default: 0.0)")
|
||||||
help="LoRA dropout rate (0.0–1.0, default: 0.0)")
|
|
||||||
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
|
parser.add_argument("--lora-disable-lm", action="store_true", help="Disable LoRA on LM layers")
|
||||||
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
|
parser.add_argument("--lora-disable-dit", action="store_true", help="Disable LoRA on DiT layers")
|
||||||
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
|
parser.add_argument("--lora-enable-proj", action="store_true", help="Enable LoRA on projection layers")
|
||||||
@@ -275,6 +293,7 @@ Examples:
|
|||||||
# Entrypoint
|
# Entrypoint
|
||||||
# -----------------------------
|
# -----------------------------
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = _build_unified_parser()
|
parser = _build_unified_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -296,8 +315,8 @@ def main():
|
|||||||
if not args.text or not args.output:
|
if not args.text or not args.output:
|
||||||
parser.error("Single-sample mode requires --text and --output")
|
parser.error("Single-sample mode requires --text and --output")
|
||||||
|
|
||||||
# Clone mode
|
# Clone mode (prompt continuation, reference isolation, or both)
|
||||||
if args.prompt_audio or args.prompt_text:
|
if args.prompt_audio or args.prompt_text or args.reference_audio:
|
||||||
return cmd_clone(args)
|
return cmd_clone(args)
|
||||||
|
|
||||||
# Direct synthesis
|
# Direct synthesis
|
||||||
|
|||||||
+127
-76
@@ -1,21 +1,25 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import re
|
import re
|
||||||
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from typing import Generator, Optional
|
from typing import Generator, Optional
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
from .model.voxcpm import VoxCPMModel, LoRAConfig
|
||||||
|
from .model.voxcpm2 import VoxCPM2Model
|
||||||
|
|
||||||
|
|
||||||
class VoxCPM:
|
class VoxCPM:
|
||||||
def __init__(self,
|
def __init__(
|
||||||
voxcpm_model_path : str,
|
self,
|
||||||
zipenhancer_model_path : str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
voxcpm_model_path: str,
|
||||||
enable_denoiser : bool = True,
|
zipenhancer_model_path: str | None = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
optimize: bool = True,
|
enable_denoiser: bool = True,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
optimize: bool = True,
|
||||||
lora_weights_path: Optional[str] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
):
|
lora_weights_path: Optional[str] = None,
|
||||||
|
):
|
||||||
"""Initialize VoxCPM TTS pipeline.
|
"""Initialize VoxCPM TTS pipeline.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -31,7 +35,10 @@ class VoxCPM:
|
|||||||
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
lora_weights_path: Path to pre-trained LoRA weights (.pth file or directory
|
||||||
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
containing lora_weights.ckpt). If provided, LoRA weights will be loaded.
|
||||||
"""
|
"""
|
||||||
print(f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}", file=sys.stderr)
|
print(
|
||||||
|
f"voxcpm_model_path: {voxcpm_model_path}, zipenhancer_model_path: {zipenhancer_model_path}, enable_denoiser: {enable_denoiser}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
# If lora_weights_path is provided but no lora_config, create a default one
|
# If lora_weights_path is provided but no lora_config, create a default one
|
||||||
if lora_weights_path is not None and lora_config is None:
|
if lora_weights_path is not None and lora_config is None:
|
||||||
@@ -42,7 +49,20 @@ class VoxCPM:
|
|||||||
)
|
)
|
||||||
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
|
print(f"Auto-created default LoRAConfig for loading weights from: {lora_weights_path}", file=sys.stderr)
|
||||||
|
|
||||||
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
# Determine model type from config.json architecture field
|
||||||
|
config_path = os.path.join(voxcpm_model_path, "config.json")
|
||||||
|
with open(config_path, "r", encoding="utf-8") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
arch = config.get("architecture", "voxcpm").lower()
|
||||||
|
|
||||||
|
if arch == "voxcpm2":
|
||||||
|
self.tts_model = VoxCPM2Model.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||||
|
print("Loaded VoxCPM2Model", file=sys.stderr)
|
||||||
|
elif arch == "voxcpm":
|
||||||
|
self.tts_model = VoxCPMModel.from_local(voxcpm_model_path, optimize=optimize, lora_config=lora_config)
|
||||||
|
print("Loaded VoxCPMModel", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported architecture: {arch}")
|
||||||
|
|
||||||
# Load LoRA weights if path is provided
|
# Load LoRA weights if path is provided
|
||||||
if lora_weights_path is not None:
|
if lora_weights_path is not None:
|
||||||
@@ -51,8 +71,10 @@ class VoxCPM:
|
|||||||
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
|
print(f"Loaded {len(loaded_keys)} LoRA parameters, skipped {len(skipped_keys)}", file=sys.stderr)
|
||||||
|
|
||||||
self.text_normalizer = None
|
self.text_normalizer = None
|
||||||
|
self.denoiser = None
|
||||||
if enable_denoiser and zipenhancer_model_path is not None:
|
if enable_denoiser and zipenhancer_model_path is not None:
|
||||||
from .zipenhancer import ZipEnhancer
|
from .zipenhancer import ZipEnhancer
|
||||||
|
|
||||||
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
self.denoiser = ZipEnhancer(zipenhancer_model_path)
|
||||||
else:
|
else:
|
||||||
self.denoiser = None
|
self.denoiser = None
|
||||||
@@ -64,17 +86,18 @@ class VoxCPM:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls,
|
def from_pretrained(
|
||||||
hf_model_id: str = "openbmb/VoxCPM1.5",
|
cls,
|
||||||
load_denoiser: bool = True,
|
hf_model_id: str = "openbmb/VoxCPM2",
|
||||||
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
load_denoiser: bool = True,
|
||||||
cache_dir: str = None,
|
zipenhancer_model_id: str = "iic/speech_zipenhancer_ans_multiloss_16k_base",
|
||||||
local_files_only: bool = False,
|
cache_dir: str = None,
|
||||||
optimize: bool = True,
|
local_files_only: bool = False,
|
||||||
lora_config: Optional[LoRAConfig] = None,
|
optimize: bool = True,
|
||||||
lora_weights_path: Optional[str] = None,
|
lora_config: Optional[LoRAConfig] = None,
|
||||||
**kwargs,
|
lora_weights_path: Optional[str] = None,
|
||||||
):
|
**kwargs,
|
||||||
|
):
|
||||||
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
"""Instantiate ``VoxCPM`` from a Hugging Face Hub snapshot.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -134,46 +157,47 @@ class VoxCPM:
|
|||||||
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
def generate_streaming(self, *args, **kwargs) -> Generator[np.ndarray, None, None]:
|
||||||
return self._generate(*args, streaming=True, **kwargs)
|
return self._generate(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
def _generate(self,
|
def _generate(
|
||||||
text : str,
|
self,
|
||||||
prompt_wav_path : str = None,
|
text: str,
|
||||||
prompt_text : str = None,
|
prompt_wav_path: str = None,
|
||||||
cfg_value : float = 2.0,
|
prompt_text: str = None,
|
||||||
inference_timesteps : int = 10,
|
reference_wav_path: str = None,
|
||||||
min_len : int = 2,
|
cfg_value: float = 2.0,
|
||||||
max_len : int = 4096,
|
inference_timesteps: int = 10,
|
||||||
normalize : bool = False,
|
min_len: int = 2,
|
||||||
denoise : bool = False,
|
max_len: int = 4096,
|
||||||
retry_badcase : bool = True,
|
normalize: bool = False,
|
||||||
retry_badcase_max_times : int = 3,
|
denoise: bool = False,
|
||||||
retry_badcase_ratio_threshold : float = 6.0,
|
retry_badcase: bool = True,
|
||||||
streaming: bool = False,
|
retry_badcase_max_times: int = 3,
|
||||||
) -> Generator[np.ndarray, None, None]:
|
retry_badcase_ratio_threshold: float = 6.0,
|
||||||
|
streaming: bool = False,
|
||||||
|
) -> Generator[np.ndarray, None, None]:
|
||||||
"""Synthesize speech for the given text and return a single waveform.
|
"""Synthesize speech for the given text and return a single waveform.
|
||||||
|
|
||||||
This method optionally builds and reuses a prompt cache. If an external
|
|
||||||
prompt (``prompt_wav_path`` + ``prompt_text``) is provided, it will be
|
|
||||||
used for all sub-sentences. Otherwise, the prompt cache is built from
|
|
||||||
the first generated result and reused for the remaining text chunks.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: Input text. Can include newlines; each non-empty line is
|
text: Input text to synthesize.
|
||||||
treated as a sub-sentence.
|
prompt_wav_path: Path to prompt audio for continuation mode.
|
||||||
prompt_wav_path: Path to a reference audio file for prompting.
|
Must be paired with ``prompt_text``.
|
||||||
prompt_text: Text content corresponding to the prompt audio.
|
prompt_text: Text content corresponding to the prompt audio.
|
||||||
|
reference_wav_path: Path to reference audio for voice cloning
|
||||||
|
(structurally isolated via ref_audio tokens). Can be used
|
||||||
|
alone or combined with ``prompt_wav_path`` + ``prompt_text``.
|
||||||
cfg_value: Guidance scale for the generation model.
|
cfg_value: Guidance scale for the generation model.
|
||||||
inference_timesteps: Number of inference steps.
|
inference_timesteps: Number of inference steps.
|
||||||
|
min_len: Minimum audio length.
|
||||||
max_len: Maximum token length during generation.
|
max_len: Maximum token length during generation.
|
||||||
normalize: Whether to run text normalization before generation.
|
normalize: Whether to run text normalization before generation.
|
||||||
denoise: Whether to denoise the prompt audio if a denoiser is
|
denoise: Whether to denoise the prompt/reference audio if a
|
||||||
available.
|
denoiser is available.
|
||||||
retry_badcase: Whether to retry badcase.
|
retry_badcase: Whether to retry badcase.
|
||||||
retry_badcase_max_times: Maximum number of times to retry badcase.
|
retry_badcase_max_times: Maximum number of times to retry badcase.
|
||||||
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
retry_badcase_ratio_threshold: Threshold for audio-to-text ratio.
|
||||||
streaming: Whether to return a generator of audio chunks.
|
streaming: Whether to return a generator of audio chunks.
|
||||||
Returns:
|
Returns:
|
||||||
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
Generator of numpy.ndarray: 1D waveform array (float32) on CPU.
|
||||||
Yields audio chunks for each generations step if ``streaming=True``,
|
Yields audio chunks for each generation step if ``streaming=True``,
|
||||||
otherwise yields a single array containing the final audio.
|
otherwise yields a single array containing the final audio.
|
||||||
"""
|
"""
|
||||||
if not text.strip() or not isinstance(text, str):
|
if not text.strip() or not isinstance(text, str):
|
||||||
@@ -183,55 +207,82 @@ class VoxCPM:
|
|||||||
if not os.path.exists(prompt_wav_path):
|
if not os.path.exists(prompt_wav_path):
|
||||||
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
raise FileNotFoundError(f"prompt_wav_path does not exist: {prompt_wav_path}")
|
||||||
|
|
||||||
|
if reference_wav_path is not None:
|
||||||
|
if not os.path.exists(reference_wav_path):
|
||||||
|
raise FileNotFoundError(f"reference_wav_path does not exist: {reference_wav_path}")
|
||||||
|
|
||||||
if (prompt_wav_path is None) != (prompt_text is None):
|
if (prompt_wav_path is None) != (prompt_text is None):
|
||||||
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
raise ValueError("prompt_wav_path and prompt_text must both be provided or both be None")
|
||||||
|
|
||||||
|
is_v2 = isinstance(self.tts_model, VoxCPM2Model)
|
||||||
|
if reference_wav_path is not None and not is_v2:
|
||||||
|
raise ValueError("reference_wav_path is only supported with VoxCPM2 models")
|
||||||
|
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
text = re.sub(r'\s+', ' ', text)
|
text = re.sub(r"\s+", " ", text)
|
||||||
temp_prompt_wav_path = None
|
temp_files = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if prompt_wav_path is not None and prompt_text is not None:
|
actual_prompt_path = prompt_wav_path
|
||||||
if denoise and self.denoiser is not None:
|
actual_ref_path = reference_wav_path
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
|
||||||
temp_prompt_wav_path = tmp_file.name
|
if denoise and self.denoiser is not None:
|
||||||
self.denoiser.enhance(prompt_wav_path, output_path=temp_prompt_wav_path)
|
if prompt_wav_path is not None:
|
||||||
prompt_wav_path = temp_prompt_wav_path
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||||
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
temp_files.append(tmp.name)
|
||||||
prompt_wav_path=prompt_wav_path,
|
self.denoiser.enhance(prompt_wav_path, output_path=temp_files[-1])
|
||||||
prompt_text=prompt_text
|
actual_prompt_path = temp_files[-1]
|
||||||
)
|
if reference_wav_path is not None:
|
||||||
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||||
|
temp_files.append(tmp.name)
|
||||||
|
self.denoiser.enhance(reference_wav_path, output_path=temp_files[-1])
|
||||||
|
actual_ref_path = temp_files[-1]
|
||||||
|
|
||||||
|
if actual_prompt_path is not None or actual_ref_path is not None:
|
||||||
|
if is_v2:
|
||||||
|
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
prompt_wav_path=actual_prompt_path,
|
||||||
|
reference_wav_path=actual_ref_path,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
fixed_prompt_cache = self.tts_model.build_prompt_cache(
|
||||||
|
prompt_text=prompt_text,
|
||||||
|
prompt_wav_path=actual_prompt_path,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
fixed_prompt_cache = None # will be built from the first inference
|
fixed_prompt_cache = None
|
||||||
|
|
||||||
if normalize:
|
if normalize:
|
||||||
if self.text_normalizer is None:
|
if self.text_normalizer is None:
|
||||||
from .utils.text_normalize import TextNormalizer
|
from .utils.text_normalize import TextNormalizer
|
||||||
|
|
||||||
self.text_normalizer = TextNormalizer()
|
self.text_normalizer = TextNormalizer()
|
||||||
text = self.text_normalizer.normalize(text)
|
text = self.text_normalizer.normalize(text)
|
||||||
|
|
||||||
generate_result = self.tts_model._generate_with_prompt_cache(
|
generate_result = self.tts_model._generate_with_prompt_cache(
|
||||||
target_text=text,
|
target_text=text,
|
||||||
prompt_cache=fixed_prompt_cache,
|
prompt_cache=fixed_prompt_cache,
|
||||||
min_len=min_len,
|
min_len=min_len,
|
||||||
max_len=max_len,
|
max_len=max_len,
|
||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
retry_badcase=retry_badcase,
|
retry_badcase=retry_badcase,
|
||||||
retry_badcase_max_times=retry_badcase_max_times,
|
retry_badcase_max_times=retry_badcase_max_times,
|
||||||
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
retry_badcase_ratio_threshold=retry_badcase_ratio_threshold,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
)
|
)
|
||||||
|
|
||||||
for wav, _, _ in generate_result:
|
for wav, _, _ in generate_result:
|
||||||
yield wav.squeeze(0).cpu().numpy()
|
yield wav.squeeze(0).cpu().numpy()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
if temp_prompt_wav_path and os.path.exists(temp_prompt_wav_path):
|
for tmp_path in temp_files:
|
||||||
try:
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
os.unlink(temp_prompt_wav_path)
|
try:
|
||||||
except OSError:
|
os.unlink(tmp_path)
|
||||||
pass
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
# LoRA Interface (delegated to VoxCPMModel)
|
# LoRA Interface (delegated to VoxCPMModel)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
from .voxcpm import VoxCPMModel
|
from .voxcpm import VoxCPMModel
|
||||||
|
from .voxcpm2 import VoxCPM2Model
|
||||||
|
|
||||||
__all__ = ["VoxCPMModel"]
|
__all__ = ["VoxCPMModel", "VoxCPM2Model"]
|
||||||
|
|||||||
@@ -24,8 +24,7 @@ def mask_multichar_chinese_tokens(tokenizer: PreTrainedTokenizer):
|
|||||||
"""
|
"""
|
||||||
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
|
# Pre-compute multi-character tokens (length >= 2, pure Chinese characters)
|
||||||
multichar_tokens = {
|
multichar_tokens = {
|
||||||
token for token in tokenizer.vocab.keys()
|
token for token in tokenizer.vocab.keys() if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
||||||
if len(token) >= 2 and all("\u4e00" <= c <= "\u9fff" for c in token)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
class CharTokenizerWrapper:
|
class CharTokenizerWrapper:
|
||||||
|
|||||||
+80
-67
@@ -24,7 +24,6 @@ from typing import Tuple, Union, Generator, List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import warnings
|
import warnings
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
@@ -32,6 +31,7 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
SAFETENSORS_AVAILABLE = True
|
SAFETENSORS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
SAFETENSORS_AVAILABLE = False
|
SAFETENSORS_AVAILABLE = False
|
||||||
@@ -84,9 +84,9 @@ class VoxCPMConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class LoRAConfig(BaseModel):
|
class LoRAConfig(BaseModel):
|
||||||
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
enable_lm: bool = False # Apply LoRA to base_lm + residual_lm
|
||||||
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
enable_dit: bool = False # Apply LoRA to VoxCPMLocDiT
|
||||||
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
enable_proj: bool = False # Apply LoRA to projection Linear layers
|
||||||
|
|
||||||
r: int = 8
|
r: int = 8
|
||||||
alpha: int = 16
|
alpha: int = 16
|
||||||
@@ -168,7 +168,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
config.lm_config.hidden_size,
|
config.lm_config.hidden_size,
|
||||||
config.lm_config.hidden_size,
|
config.lm_config.hidden_size,
|
||||||
config.scalar_quantization_latent_dim,
|
config.scalar_quantization_latent_dim,
|
||||||
config.scalar_quantization_scale
|
config.scalar_quantization_scale,
|
||||||
)
|
)
|
||||||
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
self.enc_to_lm_proj = nn.Linear(config.encoder_config.hidden_dim, config.lm_config.hidden_size)
|
||||||
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
self.lm_to_dit_proj = nn.Linear(config.lm_config.hidden_size, config.dit_config.hidden_dim)
|
||||||
@@ -196,9 +196,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
# LM: base_lm + residual_lm
|
# LM: base_lm + residual_lm
|
||||||
if cfg.enable_lm:
|
if cfg.enable_lm:
|
||||||
for lm in [self.base_lm, self.residual_lm]:
|
for lm in [self.base_lm, self.residual_lm]:
|
||||||
apply_lora_to_named_linear_modules(
|
apply_lora_to_named_linear_modules(lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs)
|
||||||
lm, target_submodule_names=cfg.target_modules_lm, **lora_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
# DiT: feat_decoder.estimator
|
# DiT: feat_decoder.estimator
|
||||||
if cfg.enable_dit:
|
if cfg.enable_dit:
|
||||||
@@ -209,6 +207,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
# 投影层
|
# 投影层
|
||||||
if cfg.enable_proj:
|
if cfg.enable_proj:
|
||||||
from ..modules.layers.lora import LoRALinear
|
from ..modules.layers.lora import LoRALinear
|
||||||
|
|
||||||
for attr_name in cfg.target_proj_modules:
|
for attr_name in cfg.target_proj_modules:
|
||||||
module = getattr(self, attr_name, None)
|
module = getattr(self, attr_name, None)
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
@@ -221,13 +220,17 @@ class VoxCPMModel(nn.Module):
|
|||||||
if self.device != "cuda":
|
if self.device != "cuda":
|
||||||
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
raise ValueError("VoxCPMModel can only be optimized on CUDA device")
|
||||||
try:
|
try:
|
||||||
import triton
|
import triton # noqa: F401
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ValueError("triton is not installed")
|
raise ValueError("triton is not installed")
|
||||||
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
self.base_lm.forward_step = torch.compile(self.base_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
||||||
self.residual_lm.forward_step = torch.compile(self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True)
|
self.residual_lm.forward_step = torch.compile(
|
||||||
|
self.residual_lm.forward_step, mode="reduce-overhead", fullgraph=True
|
||||||
|
)
|
||||||
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
self.feat_encoder = torch.compile(self.feat_encoder, mode="reduce-overhead", fullgraph=True)
|
||||||
self.feat_decoder.estimator = torch.compile(self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True)
|
self.feat_decoder.estimator = torch.compile(
|
||||||
|
self.feat_decoder.estimator, mode="reduce-overhead", fullgraph=True
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
|
print(f"Warning: torch.compile disabled - {e}", file=sys.stderr)
|
||||||
return self
|
return self
|
||||||
@@ -313,9 +316,11 @@ class VoxCPMModel(nn.Module):
|
|||||||
mu=dit_hidden,
|
mu=dit_hidden,
|
||||||
patch_size=self.patch_size,
|
patch_size=self.patch_size,
|
||||||
cond=feat_cond_for_sample,
|
cond=feat_cond_for_sample,
|
||||||
n_timesteps=self.config.dit_config.cfm_config.inference_cfg_rate
|
n_timesteps=(
|
||||||
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
self.config.dit_config.cfm_config.inference_cfg_rate
|
||||||
else 10,
|
if hasattr(self.config.dit_config.cfm_config, "inference_cfg_rate")
|
||||||
|
else 10
|
||||||
|
),
|
||||||
)
|
)
|
||||||
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(feat_pred_seq.transpose(1, 2), "(b t) d p -> b d (t p)", b=B, p=self.patch_size)
|
||||||
|
|
||||||
@@ -331,7 +336,6 @@ class VoxCPMModel(nn.Module):
|
|||||||
def _dtype(self):
|
def _dtype(self):
|
||||||
return get_dtype(self.config.dtype)
|
return get_dtype(self.config.dtype)
|
||||||
|
|
||||||
|
|
||||||
def generate(self, *args, **kwargs) -> torch.Tensor:
|
def generate(self, *args, **kwargs) -> torch.Tensor:
|
||||||
return next(self._generate(*args, streaming=False, **kwargs))
|
return next(self._generate(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
@@ -350,7 +354,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
cfg_value: float = 2.0,
|
cfg_value: float = 2.0,
|
||||||
retry_badcase: bool = False,
|
retry_badcase: bool = False,
|
||||||
retry_badcase_max_times: int = 3,
|
retry_badcase_max_times: int = 3,
|
||||||
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
retry_badcase_ratio_threshold: float = 6.0, # setting acceptable ratio of audio length to text length (for badcase detection)
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
) -> Generator[torch.Tensor, None, None]:
|
) -> Generator[torch.Tensor, None, None]:
|
||||||
if retry_badcase and streaming:
|
if retry_badcase and streaming:
|
||||||
@@ -444,7 +448,9 @@ class VoxCPMModel(nn.Module):
|
|||||||
audio_feat,
|
audio_feat,
|
||||||
audio_mask,
|
audio_mask,
|
||||||
min_len=min_len,
|
min_len=min_len,
|
||||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
max_len=min(
|
||||||
|
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
|
||||||
|
), # avoid too long audio
|
||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
@@ -460,7 +466,10 @@ class VoxCPMModel(nn.Module):
|
|||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
print(
|
||||||
|
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
retry_badcase_times += 1
|
retry_badcase_times += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@@ -514,7 +523,9 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.audio_vae.latent_dim,
|
self.audio_vae.latent_dim,
|
||||||
-1,
|
-1,
|
||||||
self.patch_size,
|
self.patch_size,
|
||||||
).permute(1, 2, 0) # (D, T, P)
|
).permute(
|
||||||
|
1, 2, 0
|
||||||
|
) # (D, T, P)
|
||||||
# build prompt cache - only save raw text and audio features
|
# build prompt cache - only save raw text and audio features
|
||||||
prompt_cache = {
|
prompt_cache = {
|
||||||
"prompt_text": prompt_text,
|
"prompt_text": prompt_text,
|
||||||
@@ -523,7 +534,6 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
return prompt_cache
|
return prompt_cache
|
||||||
|
|
||||||
|
|
||||||
def merge_prompt_cache(
|
def merge_prompt_cache(
|
||||||
self,
|
self,
|
||||||
original_cache: dict,
|
original_cache: dict,
|
||||||
@@ -560,17 +570,14 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
return merged_cache
|
return merged_cache
|
||||||
|
|
||||||
|
|
||||||
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
def generate_with_prompt_cache(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
return next(self._generate_with_prompt_cache(*args, streaming=False, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
def generate_with_prompt_cache_streaming(
|
def generate_with_prompt_cache_streaming(
|
||||||
self, *args, **kwargs
|
self, *args, **kwargs
|
||||||
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
) -> Generator[Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]], None, None]:
|
||||||
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
return self._generate_with_prompt_cache(*args, streaming=True, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _generate_with_prompt_cache(
|
def _generate_with_prompt_cache(
|
||||||
self,
|
self,
|
||||||
@@ -645,8 +652,12 @@ class VoxCPMModel(nn.Module):
|
|||||||
)
|
)
|
||||||
text_token = torch.cat([text_token, text_pad_token])
|
text_token = torch.cat([text_token, text_pad_token])
|
||||||
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
|
audio_feat = torch.cat([audio_pad_feat, prompt_audio_feat], dim=0)
|
||||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
text_mask = (
|
||||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
torch.cat([torch.ones(text_length), torch.zeros(audio_length)]).type(torch.int32).to(text_token.device)
|
||||||
|
)
|
||||||
|
audio_mask = (
|
||||||
|
torch.cat([torch.zeros(text_length), torch.ones(audio_length)]).type(torch.int32).to(text_token.device)
|
||||||
|
)
|
||||||
|
|
||||||
text_token = text_token.unsqueeze(0).to(self.device)
|
text_token = text_token.unsqueeze(0).to(self.device)
|
||||||
text_mask = text_mask.unsqueeze(0).to(self.device)
|
text_mask = text_mask.unsqueeze(0).to(self.device)
|
||||||
@@ -663,7 +674,9 @@ class VoxCPMModel(nn.Module):
|
|||||||
audio_feat,
|
audio_feat,
|
||||||
audio_mask,
|
audio_mask,
|
||||||
min_len=min_len,
|
min_len=min_len,
|
||||||
max_len=min(int(target_text_length * retry_badcase_ratio_threshold + 10), max_len), # avoid too long audio
|
max_len=min(
|
||||||
|
int(target_text_length * retry_badcase_ratio_threshold + 10), max_len
|
||||||
|
), # avoid too long audio
|
||||||
inference_timesteps=inference_timesteps,
|
inference_timesteps=inference_timesteps,
|
||||||
cfg_value=cfg_value,
|
cfg_value=cfg_value,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
@@ -674,17 +687,16 @@ class VoxCPMModel(nn.Module):
|
|||||||
for latent_pred, pred_audio_feat in inference_result:
|
for latent_pred, pred_audio_feat in inference_result:
|
||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
decode_audio = decode_audio[..., -patch_len:].squeeze(1).cpu()
|
||||||
yield (
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
decode_audio,
|
|
||||||
target_text_token,
|
|
||||||
pred_audio_feat
|
|
||||||
)
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
latent_pred, pred_audio_feat = next(inference_result)
|
latent_pred, pred_audio_feat = next(inference_result)
|
||||||
if retry_badcase:
|
if retry_badcase:
|
||||||
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
if pred_audio_feat.shape[0] >= target_text_length * retry_badcase_ratio_threshold:
|
||||||
print(f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...", file=sys.stderr)
|
print(
|
||||||
|
f" Badcase detected, audio_text_ratio={pred_audio_feat.shape[0] / target_text_length}, retrying...",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
retry_badcase_times += 1
|
retry_badcase_times += 1
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@@ -695,14 +707,10 @@ class VoxCPMModel(nn.Module):
|
|||||||
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
decode_audio = self.audio_vae.decode(latent_pred.to(torch.float32))
|
||||||
patch_len = self.patch_size * self.chunk_size
|
patch_len = self.patch_size * self.chunk_size
|
||||||
if audio_mask.sum().item() > 0:
|
if audio_mask.sum().item() > 0:
|
||||||
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1):].squeeze(1).cpu()
|
decode_audio = decode_audio[..., patch_len * (streaming_prefix_len - 1) :].squeeze(1).cpu()
|
||||||
else:
|
else:
|
||||||
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
decode_audio = decode_audio[..., :].squeeze(1).cpu()
|
||||||
yield (
|
yield (decode_audio, target_text_token, pred_audio_feat)
|
||||||
decode_audio,
|
|
||||||
target_text_token,
|
|
||||||
pred_audio_feat
|
|
||||||
)
|
|
||||||
|
|
||||||
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
def inference(self, *args, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
return next(self._inference(*args, streaming=False, **kwargs))
|
return next(self._inference(*args, streaming=False, **kwargs))
|
||||||
@@ -782,7 +790,6 @@ class VoxCPMModel(nn.Module):
|
|||||||
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
enc_outputs = self.fsq_layer(enc_outputs) * feat_mask.unsqueeze(-1) + enc_outputs * text_mask.unsqueeze(-1)
|
||||||
lm_hidden = enc_outputs[:, -1, :]
|
lm_hidden = enc_outputs[:, -1, :]
|
||||||
|
|
||||||
|
|
||||||
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
residual_enc_outputs, residual_kv_cache_tuple = self.residual_lm(
|
||||||
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
|
inputs_embeds=enc_outputs + feat_mask.unsqueeze(-1) * feat_embed,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
@@ -790,7 +797,6 @@ class VoxCPMModel(nn.Module):
|
|||||||
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
self.residual_lm.kv_cache.fill_caches(residual_kv_cache_tuple)
|
||||||
residual_hidden = residual_enc_outputs[:, -1, :]
|
residual_hidden = residual_enc_outputs[:, -1, :]
|
||||||
|
|
||||||
|
|
||||||
for i in tqdm(range(max_len)):
|
for i in tqdm(range(max_len)):
|
||||||
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
dit_hidden_1 = self.lm_to_dit_proj(lm_hidden) # [b, h_dit]
|
||||||
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
dit_hidden_2 = self.res_to_dit_proj(residual_hidden) # [b, h_dit]
|
||||||
@@ -827,10 +833,10 @@ class VoxCPMModel(nn.Module):
|
|||||||
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
curr_embed[:, 0, :], torch.tensor([self.base_lm.kv_cache.step()], device=curr_embed.device)
|
||||||
).clone()
|
).clone()
|
||||||
|
|
||||||
|
|
||||||
lm_hidden = self.fsq_layer(lm_hidden)
|
lm_hidden = self.fsq_layer(lm_hidden)
|
||||||
residual_hidden = self.residual_lm.forward_step(
|
residual_hidden = self.residual_lm.forward_step(
|
||||||
lm_hidden + curr_embed[:, 0, :], torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device)
|
lm_hidden + curr_embed[:, 0, :],
|
||||||
|
torch.tensor([self.residual_lm.kv_cache.step()], device=curr_embed.device),
|
||||||
).clone()
|
).clone()
|
||||||
|
|
||||||
if not streaming:
|
if not streaming:
|
||||||
@@ -838,29 +844,41 @@ class VoxCPMModel(nn.Module):
|
|||||||
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
feat_pred = rearrange(pred_feat_seq, "b t p d -> b d (t p)", b=B, p=self.patch_size)
|
||||||
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
yield feat_pred, pred_feat_seq.squeeze(0).cpu()
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
def from_local(cls, path: str, optimize: bool = True, training: bool = False, lora_config: LoRAConfig = None):
|
||||||
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
config = VoxCPMConfig.model_validate_json(open(os.path.join(path, "config.json")).read())
|
||||||
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
tokenizer = LlamaTokenizerFast.from_pretrained(path)
|
||||||
audio_vae_config = getattr(config, 'audio_vae_config', None)
|
audio_vae_config = getattr(config, "audio_vae_config", None)
|
||||||
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
audio_vae = AudioVAE(config=audio_vae_config) if audio_vae_config else AudioVAE()
|
||||||
vae_state_dict = torch.load(
|
# Try to load AudioVAE from safetensors first, fallback to pytorch
|
||||||
os.path.join(path, "audiovae.pth"),
|
audiovae_safetensors_path = os.path.join(path, "audiovae.safetensors")
|
||||||
map_location="cpu",
|
audiovae_pth_path = os.path.join(path, "audiovae.pth")
|
||||||
weights_only=True,
|
if os.path.exists(audiovae_safetensors_path) and SAFETENSORS_AVAILABLE:
|
||||||
)["state_dict"]
|
print(f"Loading AudioVAE from safetensors: {audiovae_safetensors_path}", file=sys.stderr)
|
||||||
|
vae_state_dict = load_file(audiovae_safetensors_path, device="cpu")
|
||||||
|
elif os.path.exists(audiovae_pth_path):
|
||||||
|
print(f"Loading AudioVAE from pytorch: {audiovae_pth_path}", file=sys.stderr)
|
||||||
|
checkpoint = torch.load(
|
||||||
|
audiovae_pth_path,
|
||||||
|
map_location="cpu",
|
||||||
|
weights_only=True,
|
||||||
|
)
|
||||||
|
vae_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"AudioVAE checkpoint not found. Expected either {audiovae_safetensors_path} or {audiovae_pth_path}"
|
||||||
|
)
|
||||||
model = cls(config, tokenizer, audio_vae, lora_config)
|
model = cls(config, tokenizer, audio_vae, lora_config)
|
||||||
if not training:
|
if not training:
|
||||||
lm_dtype = get_dtype(model.config.dtype)
|
lm_dtype = get_dtype(model.config.dtype)
|
||||||
model = model.to(lm_dtype)
|
model = model.to(lm_dtype)
|
||||||
else: # training mode
|
else: # training mode
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if "audio_vae" in name: # freeze VAE weights
|
if "audio_vae" in name: # freeze VAE weights
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
continue
|
continue
|
||||||
if lora_config is not None:
|
if lora_config is not None:
|
||||||
if "lora" not in name: # freeze non-LoRA weights
|
if "lora" not in name: # freeze non-LoRA weights
|
||||||
param.requires_grad = False
|
param.requires_grad = False
|
||||||
model.audio_vae = model.audio_vae.to(torch.float32)
|
model.audio_vae = model.audio_vae.to(torch.float32)
|
||||||
|
|
||||||
@@ -880,9 +898,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
)
|
)
|
||||||
model_state_dict = checkpoint.get("state_dict", checkpoint)
|
model_state_dict = checkpoint.get("state_dict", checkpoint)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}")
|
||||||
f"Model file not found. Expected either {safetensors_path} or {pytorch_model_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for kw, val in vae_state_dict.items():
|
for kw, val in vae_state_dict.items():
|
||||||
model_state_dict[f"audio_vae.{kw}"] = val
|
model_state_dict[f"audio_vae.{kw}"] = val
|
||||||
@@ -900,6 +916,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
def _iter_lora_modules(self):
|
def _iter_lora_modules(self):
|
||||||
"""Iterate over all LoRA modules."""
|
"""Iterate over all LoRA modules."""
|
||||||
from ..modules.layers.lora import LoRALinear
|
from ..modules.layers.lora import LoRALinear
|
||||||
|
|
||||||
for module in self.modules():
|
for module in self.modules():
|
||||||
if isinstance(module, LoRALinear):
|
if isinstance(module, LoRALinear):
|
||||||
yield module
|
yield module
|
||||||
@@ -919,15 +936,15 @@ class VoxCPMModel(nn.Module):
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
device = device or self.device
|
device = device or self.device
|
||||||
lora_path = Path(lora_path)
|
lora_p = Path(lora_path)
|
||||||
|
|
||||||
# Try safetensors first, then fallback to .ckpt
|
# Try safetensors first, then fallback to .ckpt
|
||||||
if lora_path.is_dir():
|
if lora_p.is_dir():
|
||||||
safetensors_file = lora_path / "lora_weights.safetensors"
|
safetensors_file = lora_p / "lora_weights.safetensors"
|
||||||
ckpt_file = lora_path / "lora_weights.ckpt"
|
ckpt_file = lora_p / "lora_weights.ckpt"
|
||||||
else:
|
else:
|
||||||
safetensors_file = lora_path if lora_path.suffix == ".safetensors" else None
|
safetensors_file = lora_p if lora_p.suffix == ".safetensors" else None
|
||||||
ckpt_file = lora_path if lora_path.suffix in [".ckpt", ".pth"] else None
|
ckpt_file = lora_p if lora_p.suffix in [".ckpt", ".pth"] else None
|
||||||
|
|
||||||
# Load from safetensors if available
|
# Load from safetensors if available
|
||||||
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
if safetensors_file and safetensors_file.exists() and SAFETENSORS_AVAILABLE:
|
||||||
@@ -936,9 +953,7 @@ class VoxCPMModel(nn.Module):
|
|||||||
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
ckpt = torch.load(ckpt_file, map_location=device, weights_only=False)
|
||||||
state_dict = ckpt.get("state_dict", ckpt)
|
state_dict = ckpt.get("state_dict", ckpt)
|
||||||
else:
|
else:
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}")
|
||||||
f"LoRA checkpoint not found. Expected either {safetensors_file} or {ckpt_file}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build param mapping (handle torch.compile's _orig_mod prefix)
|
# Build param mapping (handle torch.compile's _orig_mod prefix)
|
||||||
model_params = dict(self.named_parameters())
|
model_params = dict(self.named_parameters())
|
||||||
@@ -967,6 +982,4 @@ class VoxCPMModel(nn.Module):
|
|||||||
|
|
||||||
def get_lora_state_dict(self) -> dict:
|
def get_lora_state_dict(self) -> dict:
|
||||||
"""Get all LoRA parameters (lora_A/lora_B)."""
|
"""Get all LoRA parameters (lora_A/lora_B)."""
|
||||||
return {name: param.data.clone()
|
return {name: param.data.clone() for name, param in self.named_parameters() if "lora_" in name}
|
||||||
for name, param in self.named_parameters()
|
|
||||||
if "lora_" in name}
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1 +1,2 @@
|
|||||||
from .audio_vae import AudioVAE, AudioVAEConfig
|
from .audio_vae import AudioVAE, AudioVAEConfig
|
||||||
|
from .audio_vae_v2 import AudioVAE as AudioVAEV2, AudioVAEConfig as AudioVAEConfigV2
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
from typing import List, Union, Optional
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -285,7 +285,7 @@ class AudioVAE(nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: Optional[AudioVAEConfig] = None,
|
config: AudioVAEConfig = None,
|
||||||
):
|
):
|
||||||
# 如果没有传入config,使用默认配置
|
# 如果没有传入config,使用默认配置
|
||||||
if config is None:
|
if config is None:
|
||||||
|
|||||||
@@ -0,0 +1,486 @@
|
|||||||
|
import math
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import weight_norm
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
def WNConv1d(*args, **kwargs):
|
||||||
|
return weight_norm(nn.Conv1d(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def WNConvTranspose1d(*args, **kwargs):
|
||||||
|
return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
class CausalConv1d(nn.Conv1d):
|
||||||
|
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.__padding = padding
|
||||||
|
self.__output_padding = output_padding
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_pad = F.pad(x, (self.__padding * 2 - self.__output_padding, 0))
|
||||||
|
return super().forward(x_pad)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalTransposeConv1d(nn.ConvTranspose1d):
|
||||||
|
def __init__(self, *args, padding: int = 0, output_padding: int = 0, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.__padding = padding
|
||||||
|
self.__output_padding = output_padding
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return super().forward(x)[..., : -(self.__padding * 2 - self.__output_padding)]
|
||||||
|
|
||||||
|
|
||||||
|
def WNCausalConv1d(*args, **kwargs):
|
||||||
|
return weight_norm(CausalConv1d(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
def WNCausalTransposeConv1d(*args, **kwargs):
|
||||||
|
return weight_norm(CausalTransposeConv1d(*args, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
|
# Scripting this brings model speed up 1.4x
|
||||||
|
@torch.jit.script
|
||||||
|
def snake(x, alpha):
|
||||||
|
shape = x.shape
|
||||||
|
x = x.reshape(shape[0], shape[1], -1)
|
||||||
|
x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2)
|
||||||
|
x = x.reshape(shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class Snake1d(nn.Module):
|
||||||
|
def __init__(self, channels):
|
||||||
|
super().__init__()
|
||||||
|
self.alpha = nn.Parameter(torch.ones(1, channels, 1))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return snake(x, self.alpha)
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m, nn.Conv1d):
|
||||||
|
nn.init.trunc_normal_(m.weight, std=0.02)
|
||||||
|
if m.bias is not None:
|
||||||
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalResidualUnit(nn.Module):
|
||||||
|
def __init__(self, dim: int = 16, dilation: int = 1, kernel: int = 7, groups: int = 1):
|
||||||
|
super().__init__()
|
||||||
|
pad = ((7 - 1) * dilation) // 2
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
Snake1d(dim),
|
||||||
|
WNCausalConv1d(
|
||||||
|
dim,
|
||||||
|
dim,
|
||||||
|
kernel_size=kernel,
|
||||||
|
dilation=dilation,
|
||||||
|
padding=pad,
|
||||||
|
groups=groups,
|
||||||
|
),
|
||||||
|
Snake1d(dim),
|
||||||
|
WNCausalConv1d(dim, dim, kernel_size=1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = self.block(x)
|
||||||
|
pad = (x.shape[-1] - y.shape[-1]) // 2
|
||||||
|
assert pad == 0
|
||||||
|
if pad > 0:
|
||||||
|
x = x[..., pad:-pad]
|
||||||
|
return x + y
|
||||||
|
|
||||||
|
|
||||||
|
class CausalEncoderBlock(nn.Module):
|
||||||
|
def __init__(self, output_dim: int = 16, input_dim=None, stride: int = 1, groups=1):
|
||||||
|
super().__init__()
|
||||||
|
input_dim = input_dim or output_dim // 2
|
||||||
|
self.block = nn.Sequential(
|
||||||
|
CausalResidualUnit(input_dim, dilation=1, groups=groups),
|
||||||
|
CausalResidualUnit(input_dim, dilation=3, groups=groups),
|
||||||
|
CausalResidualUnit(input_dim, dilation=9, groups=groups),
|
||||||
|
Snake1d(input_dim),
|
||||||
|
WNCausalConv1d(
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
kernel_size=2 * stride,
|
||||||
|
stride=stride,
|
||||||
|
padding=math.ceil(stride / 2),
|
||||||
|
output_padding=stride % 2,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalEncoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int = 64,
|
||||||
|
latent_dim: int = 32,
|
||||||
|
strides: list = [2, 4, 8, 8],
|
||||||
|
depthwise: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
# Create first convolution
|
||||||
|
self.block = [WNCausalConv1d(1, d_model, kernel_size=7, padding=3)]
|
||||||
|
|
||||||
|
# Create EncoderBlocks that double channels as they downsample by `stride`
|
||||||
|
for stride in strides:
|
||||||
|
d_model *= 2
|
||||||
|
groups = d_model // 2 if depthwise else 1
|
||||||
|
self.block += [CausalEncoderBlock(output_dim=d_model, stride=stride, groups=groups)]
|
||||||
|
|
||||||
|
groups = d_model if depthwise else 1
|
||||||
|
|
||||||
|
# Create two convolution, for mu and logvar
|
||||||
|
self.fc_mu = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||||
|
self.fc_logvar = WNCausalConv1d(d_model, latent_dim, kernel_size=3, padding=1)
|
||||||
|
|
||||||
|
# Wrap black into nn.Sequential
|
||||||
|
self.block = nn.Sequential(*self.block)
|
||||||
|
self.enc_dim = d_model
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
hidden_state = self.block(x)
|
||||||
|
return {
|
||||||
|
"hidden_state": hidden_state,
|
||||||
|
"mu": self.fc_mu(hidden_state),
|
||||||
|
"logvar": self.fc_logvar(hidden_state),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class NoiseBlock(nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = WNCausalConv1d(dim, dim, kernel_size=1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B, C, T = x.shape
|
||||||
|
noise = torch.randn((B, 1, T), device=x.device, dtype=x.dtype)
|
||||||
|
h = self.linear(x)
|
||||||
|
n = noise * h
|
||||||
|
x = x + n
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class CausalDecoderBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int = 16,
|
||||||
|
output_dim: int = 8,
|
||||||
|
stride: int = 1,
|
||||||
|
groups=1,
|
||||||
|
use_noise_block: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
layers = [
|
||||||
|
Snake1d(input_dim),
|
||||||
|
WNCausalTransposeConv1d(
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
kernel_size=2 * stride,
|
||||||
|
stride=stride,
|
||||||
|
padding=math.ceil(stride / 2),
|
||||||
|
output_padding=stride % 2,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
if use_noise_block:
|
||||||
|
layers.append(NoiseBlock(output_dim))
|
||||||
|
layers.extend(
|
||||||
|
[
|
||||||
|
CausalResidualUnit(output_dim, dilation=1, groups=groups),
|
||||||
|
CausalResidualUnit(output_dim, dilation=3, groups=groups),
|
||||||
|
CausalResidualUnit(output_dim, dilation=9, groups=groups),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.block = nn.Sequential(*layers)
|
||||||
|
self.input_channels = input_dim
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.block(x)
|
||||||
|
|
||||||
|
|
||||||
|
class TransposeLastTwoDim(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
return torch.transpose(x, -1, -2)
|
||||||
|
|
||||||
|
|
||||||
|
class SampleRateConditionLayer(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dim: int,
|
||||||
|
sr_bin_buckets: int = None,
|
||||||
|
cond_type: str = "scale_bias",
|
||||||
|
cond_dim: int = 128,
|
||||||
|
out_layer: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.cond_type, out_layer_in_dim = cond_type, input_dim
|
||||||
|
|
||||||
|
if cond_type == "scale_bias":
|
||||||
|
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||||
|
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||||
|
nn.init.ones_(self.scale_embed.weight)
|
||||||
|
nn.init.zeros_(self.bias_embed.weight)
|
||||||
|
elif cond_type == "scale_bias_init":
|
||||||
|
self.scale_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||||
|
self.bias_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||||
|
nn.init.normal_(self.scale_embed.weight, mean=1)
|
||||||
|
nn.init.normal_(self.bias_embed.weight)
|
||||||
|
elif cond_type == "add":
|
||||||
|
self.cond_embed = nn.Embedding(sr_bin_buckets, input_dim)
|
||||||
|
nn.init.normal_(self.cond_embed.weight)
|
||||||
|
elif cond_type == "concat":
|
||||||
|
self.cond_embed = nn.Embedding(sr_bin_buckets, cond_dim)
|
||||||
|
assert out_layer, "out_layer must be True for concat cond_type"
|
||||||
|
out_layer_in_dim = input_dim + cond_dim
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid cond_type: {cond_type}")
|
||||||
|
|
||||||
|
if out_layer:
|
||||||
|
self.out_layer = nn.Sequential(
|
||||||
|
Snake1d(out_layer_in_dim),
|
||||||
|
WNCausalConv1d(out_layer_in_dim, input_dim, kernel_size=1),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.out_layer = nn.Identity()
|
||||||
|
|
||||||
|
def forward(self, x, sr_cond):
|
||||||
|
if self.cond_type == "scale_bias" or self.cond_type == "scale_bias_init":
|
||||||
|
x = x * self.scale_embed(sr_cond).unsqueeze(-1) + self.bias_embed(sr_cond).unsqueeze(-1)
|
||||||
|
elif self.cond_type == "add":
|
||||||
|
x = x + self.cond_embed(sr_cond).unsqueeze(-1)
|
||||||
|
elif self.cond_type == "concat":
|
||||||
|
x = torch.cat([x, self.cond_embed(sr_cond).unsqueeze(-1).repeat(1, 1, x.shape[-1])], dim=1)
|
||||||
|
|
||||||
|
return self.out_layer(x)
|
||||||
|
|
||||||
|
|
||||||
|
class CausalDecoder(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_channel,
|
||||||
|
channels,
|
||||||
|
rates,
|
||||||
|
depthwise: bool = False,
|
||||||
|
d_out: int = 1,
|
||||||
|
use_noise_block: bool = False,
|
||||||
|
sr_bin_boundaries: List[int] = None,
|
||||||
|
cond_type: str = "scale_bias",
|
||||||
|
cond_dim: int = 128,
|
||||||
|
cond_out_layer: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Add first conv layer
|
||||||
|
if depthwise:
|
||||||
|
layers = [
|
||||||
|
WNCausalConv1d(input_channel, input_channel, kernel_size=7, padding=3, groups=input_channel),
|
||||||
|
WNCausalConv1d(input_channel, channels, kernel_size=1),
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
layers = [WNCausalConv1d(input_channel, channels, kernel_size=7, padding=3)]
|
||||||
|
|
||||||
|
# Add upsampling + MRF blocks
|
||||||
|
for i, stride in enumerate(rates):
|
||||||
|
input_dim = channels // 2**i
|
||||||
|
output_dim = channels // 2 ** (i + 1)
|
||||||
|
groups = output_dim if depthwise else 1
|
||||||
|
layers += [
|
||||||
|
CausalDecoderBlock(
|
||||||
|
input_dim,
|
||||||
|
output_dim,
|
||||||
|
stride,
|
||||||
|
groups=groups,
|
||||||
|
use_noise_block=use_noise_block,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Add final conv layer
|
||||||
|
layers += [
|
||||||
|
Snake1d(output_dim),
|
||||||
|
WNCausalConv1d(output_dim, d_out, kernel_size=7, padding=3),
|
||||||
|
nn.Tanh(),
|
||||||
|
]
|
||||||
|
|
||||||
|
if sr_bin_boundaries is None:
|
||||||
|
self.model = nn.Sequential(*layers)
|
||||||
|
self.sr_bin_boundaries = None
|
||||||
|
else:
|
||||||
|
self.model = nn.ModuleList(layers)
|
||||||
|
|
||||||
|
self.register_buffer("sr_bin_boundaries", torch.tensor(sr_bin_boundaries, dtype=torch.int32))
|
||||||
|
self.sr_bin_buckets = len(sr_bin_boundaries) + 1
|
||||||
|
|
||||||
|
cond_layers = []
|
||||||
|
for layer in self.model:
|
||||||
|
if layer.__class__.__name__ == "CausalDecoderBlock":
|
||||||
|
cond_layers.append(
|
||||||
|
SampleRateConditionLayer(
|
||||||
|
input_dim=layer.input_channels,
|
||||||
|
sr_bin_buckets=self.sr_bin_buckets,
|
||||||
|
cond_type=cond_type,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
out_layer=cond_out_layer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cond_layers.append(None)
|
||||||
|
self.sr_cond_model = nn.ModuleList(cond_layers)
|
||||||
|
|
||||||
|
def get_sr_idx(self, sr):
|
||||||
|
return torch.bucketize(sr, self.sr_bin_boundaries)
|
||||||
|
|
||||||
|
def forward(self, x, sr_cond=None):
|
||||||
|
if self.sr_bin_boundaries is not None:
|
||||||
|
# assert sr_cond is not None
|
||||||
|
sr_cond = self.get_sr_idx(sr_cond)
|
||||||
|
|
||||||
|
for layer, sr_cond_layer in zip(self.model, self.sr_cond_model):
|
||||||
|
if sr_cond_layer is not None:
|
||||||
|
x = sr_cond_layer(x, sr_cond)
|
||||||
|
x = layer(x)
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
return self.model(x)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioVAEConfig(BaseModel):
|
||||||
|
encoder_dim: int = 128
|
||||||
|
encoder_rates: List[int] = [2, 5, 8, 8]
|
||||||
|
latent_dim: int = 64
|
||||||
|
decoder_dim: int = 2048
|
||||||
|
decoder_rates: List[int] = [8, 6, 5, 2, 2, 2]
|
||||||
|
depthwise: bool = True
|
||||||
|
sample_rate: int = 16000
|
||||||
|
out_sample_rate: int = 48000
|
||||||
|
use_noise_block: bool = False
|
||||||
|
sr_bin_boundaries: Optional[List[int]] = [20000, 30000, 40000]
|
||||||
|
cond_type: str = "scale_bias"
|
||||||
|
cond_dim: int = 128
|
||||||
|
cond_out_layer: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class AudioVAE(nn.Module):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: AudioVAEConfig = None,
|
||||||
|
):
|
||||||
|
# 如果没有传入config,使用默认配置
|
||||||
|
if config is None:
|
||||||
|
config = AudioVAEConfig()
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
encoder_dim = config.encoder_dim
|
||||||
|
encoder_rates = config.encoder_rates
|
||||||
|
latent_dim = config.latent_dim
|
||||||
|
decoder_dim = config.decoder_dim
|
||||||
|
decoder_rates = config.decoder_rates
|
||||||
|
depthwise = config.depthwise
|
||||||
|
sample_rate = config.sample_rate
|
||||||
|
out_sample_rate = config.out_sample_rate
|
||||||
|
use_noise_block = config.use_noise_block
|
||||||
|
sr_bin_boundaries = config.sr_bin_boundaries
|
||||||
|
cond_type = config.cond_type
|
||||||
|
cond_dim = config.cond_dim
|
||||||
|
cond_out_layer = config.cond_out_layer
|
||||||
|
|
||||||
|
self.encoder_dim = encoder_dim
|
||||||
|
self.encoder_rates = encoder_rates
|
||||||
|
self.decoder_dim = decoder_dim
|
||||||
|
self.decoder_rates = decoder_rates
|
||||||
|
self.depthwise = depthwise
|
||||||
|
|
||||||
|
self.use_noise_block = use_noise_block
|
||||||
|
|
||||||
|
if latent_dim is None:
|
||||||
|
latent_dim = encoder_dim * (2 ** len(encoder_rates))
|
||||||
|
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
self.hop_length = np.prod(encoder_rates)
|
||||||
|
self.encoder = CausalEncoder(
|
||||||
|
encoder_dim,
|
||||||
|
latent_dim,
|
||||||
|
encoder_rates,
|
||||||
|
depthwise=depthwise,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.decoder = CausalDecoder(
|
||||||
|
latent_dim,
|
||||||
|
decoder_dim,
|
||||||
|
decoder_rates,
|
||||||
|
depthwise=depthwise,
|
||||||
|
use_noise_block=use_noise_block,
|
||||||
|
sr_bin_boundaries=sr_bin_boundaries,
|
||||||
|
cond_type=cond_type,
|
||||||
|
cond_dim=cond_dim,
|
||||||
|
cond_out_layer=cond_out_layer,
|
||||||
|
)
|
||||||
|
self.sample_rate = sample_rate
|
||||||
|
self.out_sample_rate = out_sample_rate
|
||||||
|
self.sr_bin_boundaries = sr_bin_boundaries
|
||||||
|
self.chunk_size = math.prod(encoder_rates)
|
||||||
|
|
||||||
|
def preprocess(self, audio_data, sample_rate):
|
||||||
|
if sample_rate is None:
|
||||||
|
sample_rate = self.sample_rate
|
||||||
|
assert sample_rate == self.sample_rate
|
||||||
|
pad_to = self.hop_length
|
||||||
|
length = audio_data.shape[-1]
|
||||||
|
right_pad = math.ceil(length / pad_to) * pad_to - length
|
||||||
|
audio_data = nn.functional.pad(audio_data, (0, right_pad))
|
||||||
|
|
||||||
|
return audio_data
|
||||||
|
|
||||||
|
def decode(self, z: torch.Tensor, sr_cond: torch.Tensor = None):
|
||||||
|
"""Decode given latent codes and return audio data
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
z : Tensor[B x D x T]
|
||||||
|
Quantized continuous representation of input
|
||||||
|
length : int, optional
|
||||||
|
Number of samples in output audio, by default None
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
A dictionary with the following keys:
|
||||||
|
"audio" : Tensor[B x 1 x length]
|
||||||
|
Decoded audio data.
|
||||||
|
"""
|
||||||
|
if self.sr_bin_boundaries is not None:
|
||||||
|
# use default output sample rate
|
||||||
|
if sr_cond is None:
|
||||||
|
sr_cond = torch.tensor([self.out_sample_rate], device=z.device, dtype=torch.int32)
|
||||||
|
return self.decoder(z, sr_cond)
|
||||||
|
|
||||||
|
def encode(self, audio_data: torch.Tensor, sample_rate: int):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
audio_data: Tensor[B x 1 x T]
|
||||||
|
sample_rate: int
|
||||||
|
Returns:
|
||||||
|
z: Tensor[B x D x T]
|
||||||
|
"""
|
||||||
|
if audio_data.ndim == 2:
|
||||||
|
audio_data = audio_data.unsqueeze(1)
|
||||||
|
|
||||||
|
audio_data = self.preprocess(audio_data, sample_rate)
|
||||||
|
return self.encoder(audio_data)["mu"]
|
||||||
@@ -128,6 +128,3 @@ def apply_lora_to_named_linear_modules(
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
)
|
)
|
||||||
setattr(parent, short_name, lora_layer)
|
setattr(parent, short_name, lora_layer)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
from .unified_cfm import UnifiedCFM, CfmConfig
|
from .unified_cfm import UnifiedCFM, CfmConfig
|
||||||
from .local_dit import VoxCPMLocDiT
|
from .local_dit import VoxCPMLocDiT
|
||||||
|
from .local_dit_v2 import VoxCPMLocDiT as VoxCPMLocDiTV2
|
||||||
|
|||||||
@@ -0,0 +1,116 @@
|
|||||||
|
import torch
|
||||||
|
from ..minicpm4 import MiniCPMModel, MiniCPM4Config
|
||||||
|
import torch.nn as nn
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class SinusoidalPosEmb(torch.nn.Module):
|
||||||
|
def __init__(self, dim):
|
||||||
|
super().__init__()
|
||||||
|
self.dim = dim
|
||||||
|
assert self.dim % 2 == 0, "SinusoidalPosEmb requires dim to be even"
|
||||||
|
|
||||||
|
def forward(self, x, scale=1000):
|
||||||
|
if x.ndim < 1:
|
||||||
|
x = x.unsqueeze(0)
|
||||||
|
device = x.device
|
||||||
|
half_dim = self.dim // 2
|
||||||
|
emb = math.log(10000) / (half_dim - 1)
|
||||||
|
emb = torch.exp(torch.arange(half_dim, dtype=x.dtype, device=device) * -emb)
|
||||||
|
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||||
|
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||||
|
return emb
|
||||||
|
|
||||||
|
|
||||||
|
class TimestepEmbedding(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
time_embed_dim: int,
|
||||||
|
out_dim: int = None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
|
||||||
|
self.act = nn.SiLU()
|
||||||
|
if out_dim is not None:
|
||||||
|
time_embed_dim_out = out_dim
|
||||||
|
else:
|
||||||
|
time_embed_dim_out = time_embed_dim
|
||||||
|
|
||||||
|
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, bias=True)
|
||||||
|
|
||||||
|
def forward(self, sample):
|
||||||
|
sample = self.linear_1(sample)
|
||||||
|
sample = self.act(sample)
|
||||||
|
sample = self.linear_2(sample)
|
||||||
|
return sample
|
||||||
|
|
||||||
|
|
||||||
|
class VoxCPMLocDiT(nn.Module):
|
||||||
|
"""
|
||||||
|
Diffusion model with a Transformer backbone.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: MiniCPM4Config,
|
||||||
|
in_channels: int = 64,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.in_channels = in_channels
|
||||||
|
self.out_channels = in_channels
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
self.in_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||||
|
self.cond_proj = nn.Linear(in_channels, config.hidden_size, bias=True)
|
||||||
|
self.out_proj = nn.Linear(config.hidden_size, self.out_channels, bias=True)
|
||||||
|
|
||||||
|
self.time_embeddings = SinusoidalPosEmb(config.hidden_size)
|
||||||
|
self.time_mlp = TimestepEmbedding(
|
||||||
|
in_channels=config.hidden_size,
|
||||||
|
time_embed_dim=config.hidden_size,
|
||||||
|
)
|
||||||
|
self.delta_time_mlp = TimestepEmbedding(
|
||||||
|
in_channels=config.hidden_size,
|
||||||
|
time_embed_dim=config.hidden_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.vocab_size == 0, "vocab_size must be 0 for local DiT"
|
||||||
|
self.decoder = MiniCPMModel(config)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
mu: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
cond: torch.Tensor,
|
||||||
|
dt: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Forward pass of DiT.
|
||||||
|
x: (N, C, T) tensor of inputs
|
||||||
|
mu: (N, C) tensor of hidden embedding
|
||||||
|
t: (N,) tensor of diffusion timesteps
|
||||||
|
cond: (N, C, T') tensor of prefix conditions
|
||||||
|
dt: (N,) used for mean velocity (may be supported in the future...)
|
||||||
|
"""
|
||||||
|
x = self.in_proj(x.transpose(1, 2).contiguous())
|
||||||
|
|
||||||
|
cond = self.cond_proj(cond.transpose(1, 2).contiguous())
|
||||||
|
prefix = cond.size(1)
|
||||||
|
|
||||||
|
t = self.time_embeddings(t).to(x.dtype)
|
||||||
|
t = self.time_mlp(t)
|
||||||
|
dt = self.time_embeddings(dt).to(x.dtype)
|
||||||
|
dt = self.delta_time_mlp(dt)
|
||||||
|
t = t + dt
|
||||||
|
|
||||||
|
mu = mu.view(x.size(0), -1, x.size(-1))
|
||||||
|
x = torch.cat([mu, (t).unsqueeze(1), cond, x], dim=1)
|
||||||
|
|
||||||
|
hidden, _ = self.decoder(x, is_causal=False)
|
||||||
|
hidden = hidden[:, prefix + mu.size(1) + 1 :, :]
|
||||||
|
hidden = self.out_proj(hidden)
|
||||||
|
|
||||||
|
return hidden.transpose(1, 2).contiguous()
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -138,7 +138,9 @@ class UnifiedCFM(torch.nn.Module):
|
|||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
# Training loss
|
# Training loss
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
def adaptive_loss_weighting(self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3):
|
def adaptive_loss_weighting(
|
||||||
|
self, losses: torch.Tensor, mask: torch.Tensor | None = None, p: float = 0.0, epsilon: float = 1e-3
|
||||||
|
):
|
||||||
weights = 1.0 / ((losses + epsilon).pow(p))
|
weights = 1.0 / ((losses + epsilon).pow(p))
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
weights = weights * mask
|
weights = weights * mask
|
||||||
@@ -193,8 +195,7 @@ class UnifiedCFM(torch.nn.Module):
|
|||||||
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
|
cond = cond + noisy_mask.view(-1, 1, 1) * torch.randn_like(cond) * self.noise_cond_scale
|
||||||
|
|
||||||
ratio_r_neq_t = (
|
ratio_r_neq_t = (
|
||||||
self.ratio_r_neq_t_range[0]
|
self.ratio_r_neq_t_range[0] + progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
||||||
+ progress * (self.ratio_r_neq_t_range[1] - self.ratio_r_neq_t_range[0])
|
|
||||||
if self.mean_mode
|
if self.mean_mode
|
||||||
else 0.0
|
else 0.0
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -64,10 +64,8 @@ class MiniCPMLongRoPE(nn.Module):
|
|||||||
self.long_factor = config.rope_scaling.long_factor
|
self.long_factor = config.rope_scaling.long_factor
|
||||||
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
|
self.original_max_position_embeddings = config.rope_scaling.original_max_position_embeddings
|
||||||
|
|
||||||
scale = (self.max_position_embeddings / self.original_max_position_embeddings)
|
scale = self.max_position_embeddings / self.original_max_position_embeddings
|
||||||
self.scaling_factor = math.sqrt(
|
self.scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings))
|
||||||
1 + math.log(scale) / math.log(self.original_max_position_embeddings)
|
|
||||||
)
|
|
||||||
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
@@ -76,11 +74,7 @@ class MiniCPMLongRoPE(nn.Module):
|
|||||||
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
|
self.register_buffer("cos_cached", torch.empty(0), persistent=False)
|
||||||
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
|
self.register_buffer("sin_cached", torch.empty(0), persistent=False)
|
||||||
|
|
||||||
self._set_cos_sin_cache(
|
self._set_cos_sin_cache(seq_len=self.max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32)
|
||||||
seq_len=self.max_position_embeddings,
|
|
||||||
device=self.inv_freq.device,
|
|
||||||
dtype=torch.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
||||||
"""设置cos和sin缓存"""
|
"""设置cos和sin缓存"""
|
||||||
@@ -93,8 +87,7 @@ class MiniCPMLongRoPE(nn.Module):
|
|||||||
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
|
ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=device)
|
||||||
|
|
||||||
freqs = torch.mul(
|
freqs = torch.mul(
|
||||||
torch.outer(t, 1.0 / ext_factors).to(device=device),
|
torch.outer(t, 1.0 / ext_factors).to(device=device), self.inv_freq.to(device=device).to(dtype)
|
||||||
self.inv_freq.to(device=device).to(dtype)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建embeddings
|
# 创建embeddings
|
||||||
@@ -123,7 +116,9 @@ class MiniCPMAttention(nn.Module):
|
|||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_dim = config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
self.head_dim = (
|
||||||
|
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
||||||
|
)
|
||||||
self.num_key_value_heads = config.num_key_value_heads
|
self.num_key_value_heads = config.num_key_value_heads
|
||||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.max_position_embeddings = config.max_position_embeddings
|
self.max_position_embeddings = config.max_position_embeddings
|
||||||
@@ -413,7 +408,11 @@ class MiniCPMModel(nn.Module):
|
|||||||
self.kv_cache = StaticKVCache(
|
self.kv_cache = StaticKVCache(
|
||||||
num_layers=self.config.num_hidden_layers,
|
num_layers=self.config.num_hidden_layers,
|
||||||
num_kv_heads=self.config.num_key_value_heads,
|
num_kv_heads=self.config.num_key_value_heads,
|
||||||
dim_kv_head=self.config.hidden_size // self.config.num_attention_heads if self.config.kv_channels is None else self.config.kv_channels,
|
dim_kv_head=(
|
||||||
|
self.config.hidden_size // self.config.num_attention_heads
|
||||||
|
if self.config.kv_channels is None
|
||||||
|
else self.config.kv_channels
|
||||||
|
),
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|||||||
@@ -25,4 +25,3 @@ __all__ = [
|
|||||||
"load_audio_text_datasets",
|
"load_audio_text_datasets",
|
||||||
"build_dataloader",
|
"build_dataloader",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -47,9 +47,7 @@ class Accelerator:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
|
self.scaler = torch.amp.GradScaler("cuda") if (amp and torch.cuda.is_available()) else DummyScaler()
|
||||||
self.device_ctx = (
|
self.device_ctx = torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
||||||
torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
|
|
||||||
)
|
|
||||||
self._ddp_model = None # For no_sync support
|
self._ddp_model = None # For no_sync support
|
||||||
|
|
||||||
def _set_seed(self, seed: int):
|
def _set_seed(self, seed: int):
|
||||||
@@ -84,7 +82,7 @@ class Accelerator:
|
|||||||
# Model helpers
|
# Model helpers
|
||||||
# ------------------------------------------------------------------ #
|
# ------------------------------------------------------------------ #
|
||||||
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
def prepare_model(self, model: torch.nn.Module, **kwargs):
|
||||||
if hasattr(model, 'device'): # make sure the matrix will be moved to the correct device
|
if hasattr(model, "device"): # make sure the matrix will be moved to the correct device
|
||||||
model.device = self.device
|
model.device = self.device
|
||||||
model = model.to(self.device)
|
model = model.to(self.device)
|
||||||
if self.world_size > 1:
|
if self.world_size > 1:
|
||||||
@@ -163,4 +161,3 @@ class Accelerator:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
|
def unwrap(model: torch.nn.Module) -> torch.nn.Module:
|
||||||
return model.module if hasattr(model, "module") else model
|
return model.module if hasattr(model, "module") else model
|
||||||
|
|
||||||
|
|||||||
@@ -36,5 +36,3 @@ def parse_args_with_config(config_path: str | Path | None = None):
|
|||||||
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
|
yaml_args = argbind.parse_args(yaml_args=yaml_args, argv=[])
|
||||||
cli_args.update(yaml_args)
|
cli_args.update(yaml_args)
|
||||||
return cli_args
|
return cli_args
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import argbind
|
import argbind
|
||||||
@@ -11,7 +10,6 @@ from ..model.voxcpm import VoxCPMConfig
|
|||||||
from ..modules.audiovae import AudioVAE
|
from ..modules.audiovae import AudioVAE
|
||||||
from .packers import AudioFeatureProcessingPacker
|
from .packers import AudioFeatureProcessingPacker
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TEXT_COLUMN = "text"
|
DEFAULT_TEXT_COLUMN = "text"
|
||||||
DEFAULT_AUDIO_COLUMN = "audio"
|
DEFAULT_AUDIO_COLUMN = "audio"
|
||||||
DEFAULT_ID_COLUMN = "dataset_id"
|
DEFAULT_ID_COLUMN = "dataset_id"
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
|
from typing import Dict, List
|
||||||
from typing import Dict, List, Tuple
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -147,31 +146,26 @@ class AudioFeatureProcessingPacker:
|
|||||||
|
|
||||||
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
|
def pad_1d(x: torch.Tensor, pad_value: int = 0) -> torch.Tensor:
|
||||||
if x.size(0) >= max_len:
|
if x.size(0) >= max_len:
|
||||||
return x[: max_len]
|
return x[:max_len]
|
||||||
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
|
pad = torch.full((max_len - x.size(0),), pad_value, dtype=x.dtype, device=x.device)
|
||||||
return torch.cat([x, pad], dim=0)
|
return torch.cat([x, pad], dim=0)
|
||||||
|
|
||||||
def pad_3d(x: torch.Tensor) -> torch.Tensor:
|
def pad_3d(x: torch.Tensor) -> torch.Tensor:
|
||||||
# x: [T, P, D]
|
# x: [T, P, D]
|
||||||
if x.size(0) >= max_len:
|
if x.size(0) >= max_len:
|
||||||
return x[: max_len]
|
return x[:max_len]
|
||||||
pad = torch.zeros(
|
pad = torch.zeros((max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device)
|
||||||
(max_len - x.size(0),) + x.shape[1:], dtype=x.dtype, device=x.device
|
|
||||||
)
|
|
||||||
return torch.cat([x, pad], dim=0)
|
return torch.cat([x, pad], dim=0)
|
||||||
|
|
||||||
if lengths:
|
if lengths:
|
||||||
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
|
text_tokens_batch = torch.stack([pad_1d(t, pad_value=0) for t in text_tokens_list], dim=0)
|
||||||
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
|
text_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in text_mask_list], dim=0)
|
||||||
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
|
audio_feats_batch = torch.stack([pad_3d(f) for f in audio_feats_list], dim=0)
|
||||||
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
|
audio_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in audio_mask_list], dim=0)
|
||||||
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
|
loss_mask_batch = torch.stack([pad_1d(m, pad_value=0) for m in loss_mask_list], dim=0)
|
||||||
labels_batch = torch.stack([pad_1d(l, pad_value=0) for l in labels_list], dim=0)
|
labels_batch = torch.stack([pad_1d(lbl, pad_value=0) for lbl in labels_list], dim=0)
|
||||||
audio_task_ids_batch = torch.stack(
|
audio_task_ids_batch = torch.stack([pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0)
|
||||||
[pad_1d(t, pad_value=0) for t in audio_task_ids_list], dim=0
|
audio_dataset_ids_batch = torch.stack([pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0)
|
||||||
)
|
|
||||||
audio_dataset_ids_batch = torch.stack(
|
|
||||||
[pad_1d(d, pad_value=0) for d in audio_dataset_ids_list], dim=0
|
|
||||||
)
|
|
||||||
|
|
||||||
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
|
# Position ids: [B, T], simple 0..L_i-1 then padded with 0
|
||||||
position_ids_list = []
|
position_ids_list = []
|
||||||
@@ -265,13 +259,27 @@ class AudioFeatureProcessingPacker:
|
|||||||
)
|
)
|
||||||
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
|
audio_feat_info = torch.cat([audio_pad_feat, audio_feat_info, audio_pad_feat[0:1, ...]], dim=0)
|
||||||
|
|
||||||
text_mask = torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)]).type(torch.int32).to(
|
text_mask = (
|
||||||
text_token.device
|
torch.cat([torch.ones(text_length), torch.zeros(audio_length), torch.ones(1)])
|
||||||
|
.type(torch.int32)
|
||||||
|
.to(text_token.device)
|
||||||
|
)
|
||||||
|
audio_mask = (
|
||||||
|
torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)])
|
||||||
|
.type(torch.int32)
|
||||||
|
.to(text_token.device)
|
||||||
|
)
|
||||||
|
loss_mask = (
|
||||||
|
torch.cat(
|
||||||
|
[
|
||||||
|
torch.zeros(text_length),
|
||||||
|
torch.zeros(audio_length) if is_prompt else torch.ones(audio_length),
|
||||||
|
torch.zeros(1),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.type(torch.int32)
|
||||||
|
.to(text_token.device)
|
||||||
)
|
)
|
||||||
audio_mask = torch.cat([torch.zeros(text_length), torch.ones(audio_length), torch.zeros(1)]).type(
|
|
||||||
torch.int32
|
|
||||||
).to(text_token.device)
|
|
||||||
loss_mask = torch.cat([torch.zeros(text_length), torch.zeros(audio_length) if is_prompt else torch.ones(audio_length), torch.zeros(1)]).type(torch.int32).to(text_token.device)
|
|
||||||
|
|
||||||
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
|
labels = torch.zeros(text_length + audio_length + 1).type(torch.int32).to(text_token.device)
|
||||||
labels[-2] = 1
|
labels[-2] = 1
|
||||||
@@ -286,4 +294,3 @@ class AudioFeatureProcessingPacker:
|
|||||||
audio_duration,
|
audio_duration,
|
||||||
text_token_count,
|
text_token_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -18,4 +18,3 @@ class TrainingState:
|
|||||||
val_loader: object
|
val_loader: object
|
||||||
tracker: object
|
tracker: object
|
||||||
batch_processor: object
|
batch_processor: object
|
||||||
|
|
||||||
|
|||||||
@@ -76,4 +76,3 @@ class TrainingTracker:
|
|||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def live(self):
|
def live(self):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,10 @@
|
|||||||
import re
|
import re
|
||||||
import regex
|
import regex
|
||||||
import inflect
|
import inflect
|
||||||
from functools import partial
|
|
||||||
from wetext import Normalizer
|
from wetext import Normalizer
|
||||||
|
|
||||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]+")
|
||||||
|
|
||||||
|
|
||||||
# whether contain chinese character
|
# whether contain chinese character
|
||||||
def contains_chinese(text):
|
def contains_chinese(text):
|
||||||
@@ -14,19 +14,19 @@ def contains_chinese(text):
|
|||||||
|
|
||||||
# replace special symbol
|
# replace special symbol
|
||||||
def replace_corner_mark(text):
|
def replace_corner_mark(text):
|
||||||
text = text.replace('²', '平方')
|
text = text.replace("²", "平方")
|
||||||
text = text.replace('³', '立方')
|
text = text.replace("³", "立方")
|
||||||
text = text.replace('√', '根号')
|
text = text.replace("√", "根号")
|
||||||
text = text.replace('≈', '约等于')
|
text = text.replace("≈", "约等于")
|
||||||
text = text.replace('<', '小于')
|
text = text.replace("<", "小于")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
# remove meaningless symbol
|
# remove meaningless symbol
|
||||||
def remove_bracket(text):
|
def remove_bracket(text):
|
||||||
text = text.replace('(', ' ').replace(')', ' ')
|
text = text.replace("(", " ").replace(")", " ")
|
||||||
text = text.replace('【', ' ').replace('】', ' ')
|
text = text.replace("【", " ").replace("】", " ")
|
||||||
text = text.replace('`', '').replace('`', '')
|
text = text.replace("`", "").replace("`", "")
|
||||||
text = text.replace("——", " ")
|
text = text.replace("——", " ")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ def spell_out_number(text: str, inflect_parser):
|
|||||||
for i, c in enumerate(text):
|
for i, c in enumerate(text):
|
||||||
if not c.isdigit():
|
if not c.isdigit():
|
||||||
if st is not None:
|
if st is not None:
|
||||||
num_str = inflect_parser.number_to_words(text[st: i])
|
num_str = inflect_parser.number_to_words(text[st:i])
|
||||||
new_text.append(num_str)
|
new_text.append(num_str)
|
||||||
st = None
|
st = None
|
||||||
new_text.append(c)
|
new_text.append(c)
|
||||||
@@ -48,7 +48,7 @@ def spell_out_number(text: str, inflect_parser):
|
|||||||
if st is not None and st < len(text):
|
if st is not None and st < len(text):
|
||||||
num_str = inflect_parser.number_to_words(text[st:])
|
num_str = inflect_parser.number_to_words(text[st:])
|
||||||
new_text.append(num_str)
|
new_text.append(num_str)
|
||||||
return ''.join(new_text)
|
return "".join(new_text)
|
||||||
|
|
||||||
|
|
||||||
# split paragrah logic:
|
# split paragrah logic:
|
||||||
@@ -69,18 +69,18 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
|
|||||||
return len(tokenize(_text)) < merge_len
|
return len(tokenize(_text)) < merge_len
|
||||||
|
|
||||||
if lang == "zh":
|
if lang == "zh":
|
||||||
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
pounc = ["。", "?", "!", ";", ":", "、", ".", "?", "!", ";"]
|
||||||
else:
|
else:
|
||||||
pounc = ['.', '?', '!', ';', ':']
|
pounc = [".", "?", "!", ";", ":"]
|
||||||
if comma_split:
|
if comma_split:
|
||||||
pounc.extend([',', ','])
|
pounc.extend([",", ","])
|
||||||
st = 0
|
st = 0
|
||||||
utts = []
|
utts = []
|
||||||
for i, c in enumerate(text):
|
for i, c in enumerate(text):
|
||||||
if c in pounc:
|
if c in pounc:
|
||||||
if len(text[st: i]) > 0:
|
if len(text[st:i]) > 0:
|
||||||
utts.append(text[st: i] + c)
|
utts.append(text[st:i] + c)
|
||||||
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
if i + 1 < len(text) and text[i + 1] in ['"', "”"]:
|
||||||
tmp = utts.pop(-1)
|
tmp = utts.pop(-1)
|
||||||
utts.append(tmp + text[i + 1])
|
utts.append(tmp + text[i + 1])
|
||||||
st = i + 2
|
st = i + 2
|
||||||
@@ -88,9 +88,9 @@ def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=
|
|||||||
st = i + 1
|
st = i + 1
|
||||||
if len(utts) == 0:
|
if len(utts) == 0:
|
||||||
if lang == "zh":
|
if lang == "zh":
|
||||||
utts.append(text + '。')
|
utts.append(text + "。")
|
||||||
else:
|
else:
|
||||||
utts.append(text + '.')
|
utts.append(text + ".")
|
||||||
final_utts = []
|
final_utts = []
|
||||||
cur_utt = ""
|
cur_utt = ""
|
||||||
for utt in utts:
|
for utt in utts:
|
||||||
@@ -112,13 +112,13 @@ def replace_blank(text: str):
|
|||||||
out_str = []
|
out_str = []
|
||||||
for i, c in enumerate(text):
|
for i, c in enumerate(text):
|
||||||
if c == " ":
|
if c == " ":
|
||||||
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
if (text[i + 1].isascii() and text[i + 1] != " ") and (text[i - 1].isascii() and text[i - 1] != " "):
|
||||||
(text[i - 1].isascii() and text[i - 1] != " ")):
|
|
||||||
out_str.append(c)
|
out_str.append(c)
|
||||||
else:
|
else:
|
||||||
out_str.append(c)
|
out_str.append(c)
|
||||||
return "".join(out_str)
|
return "".join(out_str)
|
||||||
|
|
||||||
|
|
||||||
def clean_markdown(md_text: str) -> str:
|
def clean_markdown(md_text: str) -> str:
|
||||||
# 去除代码块 ``` ```(包括多行)
|
# 去除代码块 ``` ```(包括多行)
|
||||||
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
|
md_text = re.sub(r"```.*?```", "", md_text, flags=re.DOTALL)
|
||||||
@@ -133,7 +133,7 @@ def clean_markdown(md_text: str) -> str:
|
|||||||
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
|
md_text = re.sub(r"\[([^\]]+)\]\([^)]+\)", r"\1", md_text)
|
||||||
|
|
||||||
# 替换无序列表符号
|
# 替换无序列表符号
|
||||||
md_text = re.sub(r'^(\s*)-\s+', r'\1', md_text, flags=re.MULTILINE)
|
md_text = re.sub(r"^(\s*)-\s+", r"\1", md_text, flags=re.MULTILINE)
|
||||||
|
|
||||||
# 去除HTML标签
|
# 去除HTML标签
|
||||||
md_text = re.sub(r"<[^>]+>", "", md_text)
|
md_text = re.sub(r"<[^>]+>", "", md_text)
|
||||||
@@ -152,13 +152,14 @@ def clean_text(text):
|
|||||||
# 去除 Markdown 语法
|
# 去除 Markdown 语法
|
||||||
text = clean_markdown(text)
|
text = clean_markdown(text)
|
||||||
# 匹配并移除表情符号
|
# 匹配并移除表情符号
|
||||||
text = regex.compile(r'\p{Emoji_Presentation}|\p{Emoji}\uFE0F', flags=regex.UNICODE).sub("",text)
|
text = regex.compile(r"\p{Emoji_Presentation}|\p{Emoji}\uFE0F", flags=regex.UNICODE).sub("", text)
|
||||||
# 去除换行符
|
# 去除换行符
|
||||||
text = text.replace("\n", " ")
|
text = text.replace("\n", " ")
|
||||||
text = text.replace("\t", " ")
|
text = text.replace("\t", " ")
|
||||||
text = text.replace('"', "\“")
|
text = text.replace("“", '"').replace("”", '"')
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
class TextNormalizer:
|
class TextNormalizer:
|
||||||
def __init__(self, tokenizer=None):
|
def __init__(self, tokenizer=None):
|
||||||
self.tokenizer = tokenizer
|
self.tokenizer = tokenizer
|
||||||
@@ -171,9 +172,11 @@ class TextNormalizer:
|
|||||||
lang = "zh" if contains_chinese(text) else "en"
|
lang = "zh" if contains_chinese(text) else "en"
|
||||||
text = clean_text(text)
|
text = clean_text(text)
|
||||||
if lang == "zh":
|
if lang == "zh":
|
||||||
text = text.replace("=", "等于") # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
text = text.replace(
|
||||||
if re.search(r'([\d$%^*_+≥≤≠×÷?=])', text): # 避免 英文连字符被错误正则为减
|
"=", "等于"
|
||||||
text = re.sub(r'(?<=[a-zA-Z0-9])-(?=\d)', ' - ', text) # 修复 x-2 被正则为 x负2
|
) # 修复 ”550 + 320 等于 870 千卡。“ 被错误正则为 ”五百五十加三百二十等于八七十千卡.“
|
||||||
|
if re.search(r"([\d$%^*_+≥≤≠×÷?=])", text): # 避免 英文连字符被错误正则为减
|
||||||
|
text = re.sub(r"(?<=[a-zA-Z0-9])-(?=\d)", " - ", text) # 修复 x-2 被正则为 x负2
|
||||||
text = self.zh_tn_model.normalize(text)
|
text = self.zh_tn_model.normalize(text)
|
||||||
text = replace_blank(text)
|
text = replace_blank(text)
|
||||||
text = replace_corner_mark(text)
|
text = replace_corner_mark(text)
|
||||||
|
|||||||
@@ -7,15 +7,15 @@ Related dependencies are imported only when denoising functionality is needed.
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Optional, Union
|
from typing import Optional
|
||||||
import torchaudio
|
import torchaudio
|
||||||
import torch
|
|
||||||
from modelscope.pipelines import pipeline
|
from modelscope.pipelines import pipeline
|
||||||
from modelscope.utils.constant import Tasks
|
from modelscope.utils.constant import Tasks
|
||||||
|
|
||||||
|
|
||||||
class ZipEnhancer:
|
class ZipEnhancer:
|
||||||
"""ZipEnhancer Audio Denoising Enhancer"""
|
"""ZipEnhancer Audio Denoising Enhancer"""
|
||||||
|
|
||||||
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
|
def __init__(self, model_path: str = "iic/speech_zipenhancer_ans_multiloss_16k_base"):
|
||||||
"""
|
"""
|
||||||
Initialize ZipEnhancer
|
Initialize ZipEnhancer
|
||||||
@@ -23,10 +23,7 @@ class ZipEnhancer:
|
|||||||
model_path: ModelScope model path or local path
|
model_path: ModelScope model path or local path
|
||||||
"""
|
"""
|
||||||
self.model_path = model_path
|
self.model_path = model_path
|
||||||
self._pipeline = pipeline(
|
self._pipeline = pipeline(Tasks.acoustic_noise_suppression, model=self.model_path)
|
||||||
Tasks.acoustic_noise_suppression,
|
|
||||||
model=self.model_path
|
|
||||||
)
|
|
||||||
|
|
||||||
def _normalize_loudness(self, wav_path: str):
|
def _normalize_loudness(self, wav_path: str):
|
||||||
"""
|
"""
|
||||||
@@ -37,11 +34,10 @@ class ZipEnhancer:
|
|||||||
"""
|
"""
|
||||||
audio, sr = torchaudio.load(wav_path)
|
audio, sr = torchaudio.load(wav_path)
|
||||||
loudness = torchaudio.functional.loudness(audio, sr)
|
loudness = torchaudio.functional.loudness(audio, sr)
|
||||||
normalized_audio = torchaudio.functional.gain(audio, -20-loudness)
|
normalized_audio = torchaudio.functional.gain(audio, -20 - loudness)
|
||||||
torchaudio.save(wav_path, normalized_audio, sr)
|
torchaudio.save(wav_path, normalized_audio, sr)
|
||||||
|
|
||||||
def enhance(self, input_path: str, output_path: Optional[str] = None,
|
def enhance(self, input_path: str, output_path: Optional[str] = None, normalize_loudness: bool = True) -> str:
|
||||||
normalize_loudness: bool = True) -> str:
|
|
||||||
"""
|
"""
|
||||||
Audio denoising enhancement
|
Audio denoising enhancement
|
||||||
Args:
|
Args:
|
||||||
@@ -57,7 +53,7 @@ class ZipEnhancer:
|
|||||||
raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
|
raise FileNotFoundError(f"Input audio file does not exist: {input_path}")
|
||||||
# Create temporary file if no output path is specified
|
# Create temporary file if no output path is specified
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
|
||||||
output_path = tmp_file.name
|
output_path = tmp_file.name
|
||||||
try:
|
try:
|
||||||
# Perform denoising processing
|
# Perform denoising processing
|
||||||
|
|||||||
Reference in New Issue
Block a user