e4e049624c
Support optional ref_audio samples in finetuning and make runtime device selection explicit while keeping auto fallback behavior consistent. Also ignore the local app override file to avoid accidental commits. Made-with: Cursor
544 lines
13 KiB
Python
544 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import sys
|
|
import types
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
CLI_PATH = ROOT / "src" / "voxcpm" / "cli.py"
|
|
V1_MODEL_PATH = ROOT / "models" / "openbmb__VoxCPM1.5"
|
|
V2_MODEL_PATH = ROOT / "models" / "VoxCPM2-1B-newaudiovae-6hz-nope-sft"
|
|
|
|
|
|
pkg = types.ModuleType("voxcpm")
|
|
pkg.__path__ = [str(ROOT / "src" / "voxcpm")]
|
|
sys.modules.setdefault("voxcpm", pkg)
|
|
|
|
core_stub = types.ModuleType("voxcpm.core")
|
|
|
|
|
|
class StubVoxCPM:
|
|
pass
|
|
|
|
|
|
core_stub.VoxCPM = StubVoxCPM
|
|
sys.modules["voxcpm.core"] = core_stub
|
|
|
|
spec = importlib.util.spec_from_file_location("voxcpm.cli", CLI_PATH)
|
|
cli = importlib.util.module_from_spec(spec)
|
|
sys.modules["voxcpm.cli"] = cli
|
|
assert spec.loader is not None
|
|
spec.loader.exec_module(cli)
|
|
|
|
|
|
class DummyTTSModel:
|
|
sample_rate = 16000
|
|
|
|
|
|
class DummyModel:
|
|
def __init__(self):
|
|
self.tts_model = DummyTTSModel()
|
|
self.calls = []
|
|
|
|
def generate(self, **kwargs):
|
|
self.calls.append(kwargs)
|
|
return np.zeros(160, dtype=np.float32)
|
|
|
|
|
|
def run_main(monkeypatch, argv):
|
|
monkeypatch.setattr(sys, "argv", ["voxcpm", *argv])
|
|
cli.main()
|
|
|
|
|
|
def test_parser_defaults_to_voxcpm2():
|
|
parser = cli._build_parser()
|
|
args = parser.parse_args(["design", "--text", "hello", "--output", "out.wav"])
|
|
assert args.hf_model_id == "openbmb/VoxCPM2"
|
|
assert args.device == "auto"
|
|
assert args.no_optimize is False
|
|
|
|
|
|
def test_load_model_respects_no_optimize_for_local_model(monkeypatch):
|
|
calls = {}
|
|
|
|
class FakeVoxCPM:
|
|
def __init__(self, **kwargs):
|
|
calls["kwargs"] = kwargs
|
|
self.tts_model = DummyTTSModel()
|
|
|
|
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
|
|
args = cli._build_parser().parse_args(
|
|
[
|
|
"design",
|
|
"--text",
|
|
"hello",
|
|
"--output",
|
|
"out.wav",
|
|
"--model-path",
|
|
str(V2_MODEL_PATH),
|
|
"--no-optimize",
|
|
]
|
|
)
|
|
|
|
cli.load_model(args)
|
|
|
|
assert calls["kwargs"]["device"] == "auto"
|
|
assert calls["kwargs"]["optimize"] is False
|
|
|
|
|
|
def test_load_model_defaults_optimize_for_hf(monkeypatch):
|
|
calls = {}
|
|
|
|
class FakeVoxCPM:
|
|
@classmethod
|
|
def from_pretrained(cls, **kwargs):
|
|
calls["kwargs"] = kwargs
|
|
return DummyModel()
|
|
|
|
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
|
|
args = cli._build_parser().parse_args(
|
|
[
|
|
"design",
|
|
"--text",
|
|
"hello",
|
|
"--output",
|
|
"out.wav",
|
|
]
|
|
)
|
|
|
|
cli.load_model(args)
|
|
|
|
assert calls["kwargs"]["device"] == "auto"
|
|
assert calls["kwargs"]["optimize"] is True
|
|
|
|
|
|
def test_load_model_respects_no_optimize_for_hf(monkeypatch):
|
|
calls = {}
|
|
|
|
class FakeVoxCPM:
|
|
@classmethod
|
|
def from_pretrained(cls, **kwargs):
|
|
calls["kwargs"] = kwargs
|
|
return DummyModel()
|
|
|
|
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
|
|
args = cli._build_parser().parse_args(
|
|
[
|
|
"design",
|
|
"--text",
|
|
"hello",
|
|
"--output",
|
|
"out.wav",
|
|
"--no-optimize",
|
|
]
|
|
)
|
|
|
|
cli.load_model(args)
|
|
|
|
assert calls["kwargs"]["device"] == "auto"
|
|
assert calls["kwargs"]["optimize"] is False
|
|
|
|
|
|
def test_load_model_passes_explicit_device_to_hf(monkeypatch):
|
|
calls = {}
|
|
|
|
class FakeVoxCPM:
|
|
@classmethod
|
|
def from_pretrained(cls, **kwargs):
|
|
calls["kwargs"] = kwargs
|
|
return DummyModel()
|
|
|
|
monkeypatch.setattr(cli, "VoxCPM", FakeVoxCPM)
|
|
args = cli._build_parser().parse_args(
|
|
[
|
|
"design",
|
|
"--text",
|
|
"hello",
|
|
"--output",
|
|
"out.wav",
|
|
"--device",
|
|
"mps",
|
|
]
|
|
)
|
|
|
|
cli.load_model(args)
|
|
|
|
assert calls["kwargs"]["device"] == "mps"
|
|
|
|
|
|
def test_design_subcommand_applies_control(monkeypatch, tmp_path):
|
|
dummy_model = DummyModel()
|
|
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
|
monkeypatch.setattr(cli.sf, "write", lambda *args, **kwargs: None)
|
|
|
|
run_main(
|
|
monkeypatch,
|
|
[
|
|
"design",
|
|
"--text",
|
|
"hello",
|
|
"--control",
|
|
"warm female voice",
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
assert dummy_model.calls[0]["text"] == "(warm female voice)hello"
|
|
assert dummy_model.calls[0]["prompt_wav_path"] is None
|
|
assert dummy_model.calls[0]["reference_wav_path"] is None
|
|
|
|
|
|
def test_clone_subcommand_reads_prompt_file(monkeypatch, tmp_path):
|
|
dummy_model = DummyModel()
|
|
prompt_audio = tmp_path / "prompt.wav"
|
|
prompt_audio.write_bytes(b"RIFF")
|
|
prompt_file = tmp_path / "prompt.txt"
|
|
prompt_file.write_text("prompt transcript\n", encoding="utf-8")
|
|
|
|
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
|
monkeypatch.setattr(cli.sf, "write", lambda *args, **kwargs: None)
|
|
|
|
run_main(
|
|
monkeypatch,
|
|
[
|
|
"clone",
|
|
"--text",
|
|
"hello",
|
|
"--prompt-audio",
|
|
str(prompt_audio),
|
|
"--prompt-file",
|
|
str(prompt_file),
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
assert dummy_model.calls[0]["prompt_wav_path"] == str(prompt_audio)
|
|
assert dummy_model.calls[0]["prompt_text"] == "prompt transcript"
|
|
|
|
|
|
def test_clone_rejects_reference_audio_for_v1_local_model(monkeypatch, tmp_path):
|
|
reference_audio = tmp_path / "ref.wav"
|
|
reference_audio.write_bytes(b"RIFF")
|
|
monkeypatch.setattr(
|
|
sys,
|
|
"argv",
|
|
[
|
|
"voxcpm",
|
|
"clone",
|
|
"--text",
|
|
"hello",
|
|
"--reference-audio",
|
|
str(reference_audio),
|
|
"--model-path",
|
|
str(V1_MODEL_PATH),
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
with pytest.raises(SystemExit):
|
|
cli.main()
|
|
|
|
|
|
def test_clone_rejects_reference_audio_for_v1_hf_model_id(monkeypatch, tmp_path):
|
|
reference_audio = tmp_path / "ref.wav"
|
|
reference_audio.write_bytes(b"RIFF")
|
|
monkeypatch.setattr(
|
|
sys,
|
|
"argv",
|
|
[
|
|
"voxcpm",
|
|
"clone",
|
|
"--text",
|
|
"hello",
|
|
"--reference-audio",
|
|
str(reference_audio),
|
|
"--hf-model-id",
|
|
"openbmb/VoxCPM1.5",
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
with pytest.raises(SystemExit):
|
|
cli.main()
|
|
|
|
|
|
def test_legacy_root_args_still_work_and_warn(monkeypatch, tmp_path, capsys):
|
|
dummy_model = DummyModel()
|
|
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
|
monkeypatch.setattr(cli.sf, "write", lambda *args, **kwargs: None)
|
|
|
|
run_main(
|
|
monkeypatch,
|
|
[
|
|
"--text",
|
|
"hello",
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
captured = capsys.readouterr()
|
|
assert "deprecated" in captured.err
|
|
assert dummy_model.calls[0]["text"] == "hello"
|
|
|
|
|
|
def test_batch_subcommand_applies_control(monkeypatch, tmp_path):
|
|
dummy_model = DummyModel()
|
|
input_file = tmp_path / "texts.txt"
|
|
input_file.write_text("hello\nworld\n", encoding="utf-8")
|
|
|
|
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
|
monkeypatch.setattr(cli.sf, "write", lambda *args, **kwargs: None)
|
|
|
|
run_main(
|
|
monkeypatch,
|
|
[
|
|
"batch",
|
|
"--input",
|
|
str(input_file),
|
|
"--output-dir",
|
|
str(tmp_path / "outs"),
|
|
"--control",
|
|
"calm narrator",
|
|
],
|
|
)
|
|
|
|
assert [call["text"] for call in dummy_model.calls] == [
|
|
"(calm narrator)hello",
|
|
"(calm narrator)world",
|
|
]
|
|
|
|
|
|
def test_legacy_clone_with_prompt_file_still_works(monkeypatch, tmp_path, capsys):
|
|
dummy_model = DummyModel()
|
|
prompt_audio = tmp_path / "prompt.wav"
|
|
prompt_audio.write_bytes(b"RIFF")
|
|
prompt_file = tmp_path / "prompt.txt"
|
|
prompt_file.write_text("legacy transcript", encoding="utf-8")
|
|
|
|
monkeypatch.setattr(cli, "load_model", lambda args: dummy_model)
|
|
monkeypatch.setattr(cli.sf, "write", lambda *args, **kwargs: None)
|
|
|
|
run_main(
|
|
monkeypatch,
|
|
[
|
|
"--text",
|
|
"hello",
|
|
"--prompt-audio",
|
|
str(prompt_audio),
|
|
"--prompt-file",
|
|
str(prompt_file),
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
captured = capsys.readouterr()
|
|
assert "deprecated" in captured.err
|
|
assert dummy_model.calls[0]["prompt_text"] == "legacy transcript"
|
|
|
|
|
|
def test_invalid_prompt_text_and_prompt_file_combination(monkeypatch, tmp_path, capsys):
|
|
prompt_audio = tmp_path / "prompt.wav"
|
|
prompt_audio.write_bytes(b"RIFF")
|
|
prompt_file = tmp_path / "prompt.txt"
|
|
prompt_file.write_text("transcript", encoding="utf-8")
|
|
|
|
monkeypatch.setattr(
|
|
sys,
|
|
"argv",
|
|
[
|
|
"voxcpm",
|
|
"clone",
|
|
"--text",
|
|
"hello",
|
|
"--prompt-audio",
|
|
str(prompt_audio),
|
|
"--prompt-text",
|
|
"inline transcript",
|
|
"--prompt-file",
|
|
str(prompt_file),
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
with pytest.raises(SystemExit):
|
|
cli.main()
|
|
|
|
assert "Use either --prompt-text or --prompt-file" in capsys.readouterr().err
|
|
|
|
|
|
def test_missing_prompt_file_reports_parser_error(monkeypatch, tmp_path, capsys):
|
|
prompt_audio = tmp_path / "prompt.wav"
|
|
prompt_audio.write_bytes(b"RIFF")
|
|
monkeypatch.setattr(
|
|
sys,
|
|
"argv",
|
|
[
|
|
"voxcpm",
|
|
"clone",
|
|
"--text",
|
|
"hello",
|
|
"--prompt-audio",
|
|
str(prompt_audio),
|
|
"--prompt-file",
|
|
str(tmp_path / "missing.txt"),
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
with pytest.raises(SystemExit):
|
|
cli.main()
|
|
|
|
assert "prompt text file" in capsys.readouterr().err
|
|
|
|
|
|
def test_design_rejects_prompt_audio_args(monkeypatch, tmp_path, capsys):
|
|
prompt_audio = tmp_path / "prompt.wav"
|
|
prompt_audio.write_bytes(b"RIFF")
|
|
monkeypatch.setattr(
|
|
sys,
|
|
"argv",
|
|
[
|
|
"voxcpm",
|
|
"design",
|
|
"--text",
|
|
"hello",
|
|
"--prompt-audio",
|
|
str(prompt_audio),
|
|
"--prompt-text",
|
|
"transcript",
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
with pytest.raises(SystemExit):
|
|
cli.main()
|
|
|
|
assert "does not accept prompt/reference audio" in capsys.readouterr().err
|
|
|
|
|
|
def test_clone_rejects_prompt_audio_without_transcript(monkeypatch, tmp_path, capsys):
|
|
prompt_audio = tmp_path / "prompt.wav"
|
|
prompt_audio.write_bytes(b"RIFF")
|
|
monkeypatch.setattr(
|
|
sys,
|
|
"argv",
|
|
[
|
|
"voxcpm",
|
|
"clone",
|
|
"--text",
|
|
"hello",
|
|
"--prompt-audio",
|
|
str(prompt_audio),
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
with pytest.raises(SystemExit):
|
|
cli.main()
|
|
|
|
assert (
|
|
"--prompt-audio requires --prompt-text or --prompt-file"
|
|
in capsys.readouterr().err
|
|
)
|
|
|
|
|
|
def test_clone_rejects_transcript_without_prompt_audio(monkeypatch, tmp_path, capsys):
|
|
monkeypatch.setattr(
|
|
sys,
|
|
"argv",
|
|
[
|
|
"voxcpm",
|
|
"clone",
|
|
"--text",
|
|
"hello",
|
|
"--prompt-text",
|
|
"transcript",
|
|
"--output",
|
|
str(tmp_path / "out.wav"),
|
|
],
|
|
)
|
|
|
|
with pytest.raises(SystemExit):
|
|
cli.main()
|
|
|
|
assert (
|
|
"--prompt-text/--prompt-file requires --prompt-audio" in capsys.readouterr().err
|
|
)
|
|
|
|
|
|
def test_batch_rejects_control_with_prompt_transcript(monkeypatch, tmp_path, capsys):
|
|
input_file = tmp_path / "texts.txt"
|
|
input_file.write_text("hello\n", encoding="utf-8")
|
|
prompt_audio = tmp_path / "prompt.wav"
|
|
prompt_audio.write_bytes(b"RIFF")
|
|
monkeypatch.setattr(
|
|
sys,
|
|
"argv",
|
|
[
|
|
"voxcpm",
|
|
"batch",
|
|
"--input",
|
|
str(input_file),
|
|
"--output-dir",
|
|
str(tmp_path / "outs"),
|
|
"--control",
|
|
"calm narrator",
|
|
"--prompt-audio",
|
|
str(prompt_audio),
|
|
"--prompt-text",
|
|
"transcript",
|
|
],
|
|
)
|
|
|
|
with pytest.raises(SystemExit):
|
|
cli.main()
|
|
|
|
assert "--control cannot be used together" in capsys.readouterr().err
|
|
|
|
|
|
def test_detect_model_architecture_uses_local_configs():
|
|
parser = cli._build_parser()
|
|
v1_args = parser.parse_args(
|
|
[
|
|
"clone",
|
|
"--text",
|
|
"hello",
|
|
"--reference-audio",
|
|
"ref.wav",
|
|
"--model-path",
|
|
str(V1_MODEL_PATH),
|
|
"--output",
|
|
"out.wav",
|
|
]
|
|
)
|
|
v2_args = parser.parse_args(
|
|
[
|
|
"clone",
|
|
"--text",
|
|
"hello",
|
|
"--reference-audio",
|
|
"ref.wav",
|
|
"--model-path",
|
|
str(V2_MODEL_PATH),
|
|
"--output",
|
|
"out.wav",
|
|
]
|
|
)
|
|
|
|
assert cli.detect_model_architecture(v1_args) == "voxcpm"
|
|
assert cli.detect_model_architecture(v2_args) == "voxcpm2"
|