feat: add voxcpm validate CLI for pre-flight training data checks
Add a new `validate` subcommand that checks JSONL training manifests before starting expensive fine-tuning jobs. This catches format issues, missing audio files, and data quality problems early. The validator performs: - JSONL format validation (each line must be valid JSON) - Required column checks (text, audio) - Audio file existence and readability verification - Duration and text length statistics (min, max, mean, median) - Optional ref_audio column validation - Warnings for very short (<0.3s) or very long (>30s) audio samples Usage: voxcpm validate --manifest train.jsonl voxcpm validate --manifest train.jsonl --sample-rate 16000 --verbose The module uses lazy imports for soundfile, so it works even in minimal environments. Includes 11 unit tests covering all validation paths.
This commit is contained in:
@@ -288,6 +288,24 @@ def cmd_clone(args, parser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_validate(args, parser):
|
||||||
|
from voxcpm.training.validate import (
|
||||||
|
print_validation_report,
|
||||||
|
validate_manifest,
|
||||||
|
)
|
||||||
|
|
||||||
|
manifest = str(require_file_exists(args.manifest, parser, "manifest file"))
|
||||||
|
result = validate_manifest(
|
||||||
|
manifest_path=manifest,
|
||||||
|
sample_rate=args.sample_rate,
|
||||||
|
max_samples=args.max_samples,
|
||||||
|
verbose=args.verbose,
|
||||||
|
)
|
||||||
|
print_validation_report(result, manifest)
|
||||||
|
if not result.is_valid:
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def cmd_batch(args, parser):
|
def cmd_batch(args, parser):
|
||||||
input_file = require_file_exists(args.input, parser, "input file")
|
input_file = require_file_exists(args.input, parser, "input file")
|
||||||
output_dir = Path(args.output_dir)
|
output_dir = Path(args.output_dir)
|
||||||
@@ -532,6 +550,30 @@ Examples:
|
|||||||
_add_model_args(batch_parser)
|
_add_model_args(batch_parser)
|
||||||
_add_lora_args(batch_parser)
|
_add_lora_args(batch_parser)
|
||||||
|
|
||||||
|
# Validate subcommand
|
||||||
|
validate_parser = subparsers.add_parser(
|
||||||
|
"validate",
|
||||||
|
help="Validate a training data manifest (JSONL) before fine-tuning",
|
||||||
|
)
|
||||||
|
validate_parser.add_argument(
|
||||||
|
"--manifest", "-m", required=True, help="Path to JSONL training manifest"
|
||||||
|
)
|
||||||
|
validate_parser.add_argument(
|
||||||
|
"--sample-rate",
|
||||||
|
type=int,
|
||||||
|
default=16_000,
|
||||||
|
help="Expected audio sample rate in Hz (default: 16000)",
|
||||||
|
)
|
||||||
|
validate_parser.add_argument(
|
||||||
|
"--max-samples",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Maximum number of samples to validate (0 = all, default: 0)",
|
||||||
|
)
|
||||||
|
validate_parser.add_argument(
|
||||||
|
"--verbose", "-v", action="store_true", help="Print per-sample progress"
|
||||||
|
)
|
||||||
|
|
||||||
# Legacy root arguments
|
# Legacy root arguments
|
||||||
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
|
parser.add_argument("--input", "-i", help="Input text file (batch mode only)")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -584,6 +626,9 @@ def main():
|
|||||||
parser = _build_parser()
|
parser = _build_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.command == "validate":
|
||||||
|
return cmd_validate(args, parser)
|
||||||
|
|
||||||
validate_ranges(args, parser)
|
validate_ranges(args, parser)
|
||||||
|
|
||||||
if args.command == "design":
|
if args.command == "design":
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from .data import (
|
|||||||
BatchProcessor,
|
BatchProcessor,
|
||||||
)
|
)
|
||||||
from .state import TrainingState
|
from .state import TrainingState
|
||||||
|
from .validate import validate_manifest, ValidationResult
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Accelerator",
|
"Accelerator",
|
||||||
@@ -24,4 +25,6 @@ __all__ = [
|
|||||||
"TrainingState",
|
"TrainingState",
|
||||||
"load_audio_text_datasets",
|
"load_audio_text_datasets",
|
||||||
"build_dataloader",
|
"build_dataloader",
|
||||||
|
"validate_manifest",
|
||||||
|
"ValidationResult",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -0,0 +1,299 @@
|
|||||||
|
"""
|
||||||
|
Pre-flight validation for VoxCPM training data manifests.
|
||||||
|
|
||||||
|
Validates JSONL manifest files before starting expensive fine-tuning jobs,
|
||||||
|
catching format issues, missing files, and data quality problems early.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
"""Structured result of a manifest validation run."""
|
||||||
|
|
||||||
|
total_samples: int = 0
|
||||||
|
valid_samples: int = 0
|
||||||
|
errors: List[str] = field(default_factory=list)
|
||||||
|
warnings: List[str] = field(default_factory=list)
|
||||||
|
audio_durations: List[float] = field(default_factory=list)
|
||||||
|
text_lengths: List[int] = field(default_factory=list)
|
||||||
|
has_ref_audio: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_valid(self) -> bool:
|
||||||
|
return len(self.errors) == 0 and self.valid_samples > 0
|
||||||
|
|
||||||
|
|
||||||
|
def _check_audio_file(audio_path: str, sample_rate: int) -> Optional[str]:
|
||||||
|
"""Check if an audio file exists and is readable. Returns error message or None."""
|
||||||
|
if not os.path.isfile(audio_path):
|
||||||
|
return f"Audio file not found: {audio_path}"
|
||||||
|
try:
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
info = sf.info(audio_path)
|
||||||
|
if info.frames == 0:
|
||||||
|
return f"Audio file is empty: {audio_path}"
|
||||||
|
return None
|
||||||
|
except ImportError:
|
||||||
|
# soundfile not available; just check existence
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
return f"Cannot read audio file {audio_path}: {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_audio_duration(audio_path: str) -> Optional[float]:
|
||||||
|
"""Get audio duration in seconds. Returns None if unavailable."""
|
||||||
|
try:
|
||||||
|
import soundfile as sf
|
||||||
|
|
||||||
|
info = sf.info(audio_path)
|
||||||
|
return info.duration
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def validate_manifest(
|
||||||
|
manifest_path: str,
|
||||||
|
sample_rate: int = 16_000,
|
||||||
|
max_samples: int = 0,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> ValidationResult:
|
||||||
|
"""Validate a JSONL training manifest file.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
1. File exists and is readable
|
||||||
|
2. Each line is valid JSON
|
||||||
|
3. Required columns present (text, audio)
|
||||||
|
4. Audio files exist and are readable
|
||||||
|
5. Text content is non-empty
|
||||||
|
6. Collects duration and text length statistics
|
||||||
|
7. Validates optional ref_audio column
|
||||||
|
|
||||||
|
Args:
|
||||||
|
manifest_path: Path to the JSONL manifest file.
|
||||||
|
sample_rate: Expected audio sample rate (for informational purposes).
|
||||||
|
max_samples: Maximum number of samples to validate (0 = all).
|
||||||
|
verbose: Print per-sample progress.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ValidationResult with errors, warnings, and statistics.
|
||||||
|
"""
|
||||||
|
result = ValidationResult()
|
||||||
|
path = Path(manifest_path)
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
result.errors.append(f"Manifest file not found: {manifest_path}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
result.errors.append(f"Manifest path is not a file: {manifest_path}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
manifest_dir = path.parent
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
except Exception as e:
|
||||||
|
result.errors.append(f"Cannot read manifest file: {e}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
if not lines:
|
||||||
|
result.errors.append("Manifest file is empty")
|
||||||
|
return result
|
||||||
|
|
||||||
|
samples_to_check = len(lines)
|
||||||
|
if max_samples > 0:
|
||||||
|
samples_to_check = min(samples_to_check, max_samples)
|
||||||
|
|
||||||
|
missing_audio_count = 0
|
||||||
|
empty_text_count = 0
|
||||||
|
|
||||||
|
for i, line in enumerate(lines[:samples_to_check]):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.total_samples += 1
|
||||||
|
|
||||||
|
# Check JSON validity
|
||||||
|
try:
|
||||||
|
entry = json.loads(line)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
result.errors.append(f"Line {i + 1}: Invalid JSON — {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not isinstance(entry, dict):
|
||||||
|
result.errors.append(f"Line {i + 1}: Expected JSON object, got {type(entry).__name__}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check required columns
|
||||||
|
has_error = False
|
||||||
|
|
||||||
|
if "text" not in entry:
|
||||||
|
result.errors.append(f"Line {i + 1}: Missing required column 'text'")
|
||||||
|
has_error = True
|
||||||
|
|
||||||
|
if "audio" not in entry:
|
||||||
|
result.errors.append(f"Line {i + 1}: Missing required column 'audio'")
|
||||||
|
has_error = True
|
||||||
|
|
||||||
|
if has_error:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Validate text
|
||||||
|
text = entry["text"]
|
||||||
|
if not isinstance(text, str) or not text.strip():
|
||||||
|
empty_text_count += 1
|
||||||
|
if empty_text_count <= 5:
|
||||||
|
result.warnings.append(f"Line {i + 1}: Empty or non-string text")
|
||||||
|
else:
|
||||||
|
result.text_lengths.append(len(text))
|
||||||
|
|
||||||
|
# Validate audio path
|
||||||
|
audio_path = entry["audio"]
|
||||||
|
if isinstance(audio_path, dict):
|
||||||
|
# HuggingFace Audio format with {"path": ..., "array": ...}
|
||||||
|
audio_path = audio_path.get("path", "")
|
||||||
|
|
||||||
|
if isinstance(audio_path, str) and audio_path:
|
||||||
|
# Resolve relative paths against manifest directory
|
||||||
|
if not os.path.isabs(audio_path):
|
||||||
|
audio_path = str(manifest_dir / audio_path)
|
||||||
|
|
||||||
|
audio_error = _check_audio_file(audio_path, sample_rate)
|
||||||
|
if audio_error:
|
||||||
|
missing_audio_count += 1
|
||||||
|
if missing_audio_count <= 5:
|
||||||
|
result.errors.append(f"Line {i + 1}: {audio_error}")
|
||||||
|
else:
|
||||||
|
duration = _get_audio_duration(audio_path)
|
||||||
|
if duration is not None:
|
||||||
|
result.audio_durations.append(duration)
|
||||||
|
if duration < 0.3:
|
||||||
|
result.warnings.append(
|
||||||
|
f"Line {i + 1}: Very short audio ({duration:.2f}s)"
|
||||||
|
)
|
||||||
|
elif duration > 30.0:
|
||||||
|
result.warnings.append(
|
||||||
|
f"Line {i + 1}: Very long audio ({duration:.1f}s), may cause OOM"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
result.errors.append(f"Line {i + 1}: Invalid audio path")
|
||||||
|
|
||||||
|
# Validate optional ref_audio
|
||||||
|
if "ref_audio" in entry:
|
||||||
|
ref_path = entry["ref_audio"]
|
||||||
|
if isinstance(ref_path, dict):
|
||||||
|
ref_path = ref_path.get("path", "")
|
||||||
|
if isinstance(ref_path, str) and ref_path:
|
||||||
|
if not os.path.isabs(ref_path):
|
||||||
|
ref_path = str(manifest_dir / ref_path)
|
||||||
|
if os.path.isfile(ref_path):
|
||||||
|
result.has_ref_audio += 1
|
||||||
|
elif result.has_ref_audio == 0:
|
||||||
|
result.warnings.append(
|
||||||
|
f"Line {i + 1}: ref_audio file not found: {ref_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not has_error:
|
||||||
|
result.valid_samples += 1
|
||||||
|
|
||||||
|
if verbose and (i + 1) % 100 == 0:
|
||||||
|
print(f" Validated {i + 1}/{samples_to_check} samples...", file=sys.stderr)
|
||||||
|
|
||||||
|
# Summarize truncated errors
|
||||||
|
if missing_audio_count > 5:
|
||||||
|
result.errors.append(
|
||||||
|
f"... and {missing_audio_count - 5} more missing audio files "
|
||||||
|
f"({missing_audio_count} total)"
|
||||||
|
)
|
||||||
|
if empty_text_count > 5:
|
||||||
|
result.warnings.append(
|
||||||
|
f"... and {empty_text_count - 5} more empty text entries "
|
||||||
|
f"({empty_text_count} total)"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def print_validation_report(result: ValidationResult, manifest_path: str) -> None:
|
||||||
|
"""Print a human-readable validation report to stderr."""
|
||||||
|
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||||
|
print(f" VoxCPM Training Data Validation Report", file=sys.stderr)
|
||||||
|
print(f"{'=' * 60}", file=sys.stderr)
|
||||||
|
print(f" Manifest : {manifest_path}", file=sys.stderr)
|
||||||
|
print(f" Samples : {result.valid_samples}/{result.total_samples} valid", file=sys.stderr)
|
||||||
|
|
||||||
|
if result.has_ref_audio > 0:
|
||||||
|
print(
|
||||||
|
f" Ref Audio: {result.has_ref_audio} samples with reference audio",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Audio duration statistics
|
||||||
|
if result.audio_durations:
|
||||||
|
durations = sorted(result.audio_durations)
|
||||||
|
total_hrs = sum(durations) / 3600
|
||||||
|
print(f"\n Audio Duration Statistics:", file=sys.stderr)
|
||||||
|
print(f" Total : {total_hrs:.2f} hours", file=sys.stderr)
|
||||||
|
print(
|
||||||
|
f" Range : {durations[0]:.2f}s — {durations[-1]:.1f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Mean : {sum(durations) / len(durations):.2f}s",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
median_idx = len(durations) // 2
|
||||||
|
print(f" Median : {durations[median_idx]:.2f}s", file=sys.stderr)
|
||||||
|
|
||||||
|
# Text length statistics
|
||||||
|
if result.text_lengths:
|
||||||
|
lengths = sorted(result.text_lengths)
|
||||||
|
print(f"\n Text Length Statistics (characters):", file=sys.stderr)
|
||||||
|
print(
|
||||||
|
f" Range : {lengths[0]} — {lengths[-1]}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f" Mean : {sum(lengths) / len(lengths):.0f}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Errors
|
||||||
|
if result.errors:
|
||||||
|
print(f"\n ERRORS ({len(result.errors)}):", file=sys.stderr)
|
||||||
|
for err in result.errors[:20]:
|
||||||
|
print(f" x {err}", file=sys.stderr)
|
||||||
|
if len(result.errors) > 20:
|
||||||
|
print(
|
||||||
|
f" ... ({len(result.errors) - 20} more errors omitted)",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Warnings
|
||||||
|
if result.warnings:
|
||||||
|
print(f"\n WARNINGS ({len(result.warnings)}):", file=sys.stderr)
|
||||||
|
for warn in result.warnings[:10]:
|
||||||
|
print(f" ! {warn}", file=sys.stderr)
|
||||||
|
if len(result.warnings) > 10:
|
||||||
|
print(
|
||||||
|
f" ... ({len(result.warnings) - 10} more warnings omitted)",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
print(f"\n{'=' * 60}", file=sys.stderr)
|
||||||
|
if result.is_valid:
|
||||||
|
print(" PASSED: Manifest is valid for training.", file=sys.stderr)
|
||||||
|
else:
|
||||||
|
print(" FAILED: Fix errors above before starting training.", file=sys.stderr)
|
||||||
|
print(f"{'=' * 60}\n", file=sys.stderr)
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
"""Tests for the training data validation module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import types
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
|
||||||
|
# Stub voxcpm package so imports work without full dependencies
|
||||||
|
pkg = types.ModuleType("voxcpm")
|
||||||
|
pkg.__path__ = [str(ROOT / "src" / "voxcpm")]
|
||||||
|
sys.modules.setdefault("voxcpm", pkg)
|
||||||
|
|
||||||
|
training_pkg = types.ModuleType("voxcpm.training")
|
||||||
|
training_pkg.__path__ = [str(ROOT / "src" / "voxcpm" / "training")]
|
||||||
|
sys.modules.setdefault("voxcpm.training", training_pkg)
|
||||||
|
|
||||||
|
from voxcpm.training.validate import ValidationResult, validate_manifest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_dir():
|
||||||
|
with tempfile.TemporaryDirectory() as d:
|
||||||
|
yield Path(d)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_wav(path: Path, duration_s: float = 1.0, sr: int = 16000):
|
||||||
|
"""Create a minimal valid WAV file."""
|
||||||
|
try:
|
||||||
|
import soundfile as sf
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
samples = int(duration_s * sr)
|
||||||
|
data = np.zeros(samples, dtype=np.float32)
|
||||||
|
sf.write(str(path), data, sr)
|
||||||
|
except ImportError:
|
||||||
|
# If soundfile is not available, create a minimal WAV header
|
||||||
|
import struct
|
||||||
|
|
||||||
|
samples = int(duration_s * sr)
|
||||||
|
data_size = samples * 2 # 16-bit PCM
|
||||||
|
with open(path, "wb") as f:
|
||||||
|
f.write(b"RIFF")
|
||||||
|
f.write(struct.pack("<I", 36 + data_size))
|
||||||
|
f.write(b"WAVEfmt ")
|
||||||
|
f.write(struct.pack("<IHHIIHH", 16, 1, 1, sr, sr * 2, 2, 16))
|
||||||
|
f.write(b"data")
|
||||||
|
f.write(struct.pack("<I", data_size))
|
||||||
|
f.write(b"\x00" * data_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_manifest(path: Path, entries: list[dict]):
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
for entry in entries:
|
||||||
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
class TestValidateManifest:
|
||||||
|
def test_valid_manifest(self, tmp_dir):
|
||||||
|
audio1 = tmp_dir / "audio1.wav"
|
||||||
|
audio2 = tmp_dir / "audio2.wav"
|
||||||
|
_create_wav(audio1, 2.0)
|
||||||
|
_create_wav(audio2, 3.0)
|
||||||
|
|
||||||
|
manifest = tmp_dir / "train.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[
|
||||||
|
{"text": "Hello world", "audio": str(audio1)},
|
||||||
|
{"text": "Goodbye world", "audio": str(audio2)},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert result.total_samples == 2
|
||||||
|
assert result.valid_samples == 2
|
||||||
|
assert result.is_valid
|
||||||
|
assert len(result.errors) == 0
|
||||||
|
|
||||||
|
def test_missing_manifest(self):
|
||||||
|
result = validate_manifest("/nonexistent/path.jsonl")
|
||||||
|
assert not result.is_valid
|
||||||
|
assert any("not found" in e for e in result.errors)
|
||||||
|
|
||||||
|
def test_empty_manifest(self, tmp_dir):
|
||||||
|
manifest = tmp_dir / "empty.jsonl"
|
||||||
|
manifest.write_text("")
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert not result.is_valid
|
||||||
|
|
||||||
|
def test_invalid_json(self, tmp_dir):
|
||||||
|
manifest = tmp_dir / "bad.jsonl"
|
||||||
|
manifest.write_text("not json\n{bad json}\n")
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert len(result.errors) >= 2
|
||||||
|
assert any("Invalid JSON" in e for e in result.errors)
|
||||||
|
|
||||||
|
def test_missing_columns(self, tmp_dir):
|
||||||
|
manifest = tmp_dir / "missing.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[
|
||||||
|
{"text": "hello"}, # missing audio
|
||||||
|
{"audio": "test.wav"}, # missing text
|
||||||
|
],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert len(result.errors) >= 2
|
||||||
|
assert any("'audio'" in e for e in result.errors)
|
||||||
|
assert any("'text'" in e for e in result.errors)
|
||||||
|
|
||||||
|
def test_missing_audio_file(self, tmp_dir):
|
||||||
|
manifest = tmp_dir / "missing_audio.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": "hello", "audio": "/nonexistent/audio.wav"}],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert not result.is_valid
|
||||||
|
assert any("not found" in e for e in result.errors)
|
||||||
|
|
||||||
|
def test_empty_text_warning(self, tmp_dir):
|
||||||
|
audio = tmp_dir / "audio.wav"
|
||||||
|
_create_wav(audio)
|
||||||
|
manifest = tmp_dir / "empty_text.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": "", "audio": str(audio)}],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert len(result.warnings) > 0
|
||||||
|
assert any("Empty" in w for w in result.warnings)
|
||||||
|
|
||||||
|
def test_relative_audio_path(self, tmp_dir):
|
||||||
|
audio = tmp_dir / "audio.wav"
|
||||||
|
_create_wav(audio)
|
||||||
|
manifest = tmp_dir / "rel.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": "hello", "audio": "audio.wav"}],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert result.valid_samples == 1
|
||||||
|
assert result.is_valid
|
||||||
|
|
||||||
|
def test_max_samples_limit(self, tmp_dir):
|
||||||
|
audio = tmp_dir / "audio.wav"
|
||||||
|
_create_wav(audio)
|
||||||
|
manifest = tmp_dir / "many.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": f"sample {i}", "audio": str(audio)} for i in range(100)],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest), max_samples=10)
|
||||||
|
assert result.total_samples == 10
|
||||||
|
|
||||||
|
def test_ref_audio_counted(self, tmp_dir):
|
||||||
|
audio = tmp_dir / "audio.wav"
|
||||||
|
ref = tmp_dir / "ref.wav"
|
||||||
|
_create_wav(audio)
|
||||||
|
_create_wav(ref)
|
||||||
|
manifest = tmp_dir / "ref.jsonl"
|
||||||
|
_write_manifest(
|
||||||
|
manifest,
|
||||||
|
[{"text": "hello", "audio": str(audio), "ref_audio": str(ref)}],
|
||||||
|
)
|
||||||
|
result = validate_manifest(str(manifest))
|
||||||
|
assert result.has_ref_audio == 1
|
||||||
|
|
||||||
|
def test_validation_result_properties(self):
|
||||||
|
r = ValidationResult(total_samples=5, valid_samples=5)
|
||||||
|
assert r.is_valid
|
||||||
|
|
||||||
|
r2 = ValidationResult(total_samples=5, valid_samples=5, errors=["err"])
|
||||||
|
assert not r2.is_valid
|
||||||
|
|
||||||
|
r3 = ValidationResult(total_samples=0, valid_samples=0)
|
||||||
|
assert not r3.is_valid
|
||||||
Reference in New Issue
Block a user