diff --git a/src/voxcpm/training/validate.py b/src/voxcpm/training/validate.py index d16b3b2..bad32a0 100644 --- a/src/voxcpm/training/validate.py +++ b/src/voxcpm/training/validate.py @@ -31,7 +31,10 @@ class ValidationResult: def _check_audio_file(audio_path: str, sample_rate: int) -> Optional[str]: - """Check if an audio file exists and is readable. Returns error message or None.""" + """Check if an audio file exists, is readable, and matches expected sample rate. + + Returns an error message, or None if the file is valid. + """ if not os.path.isfile(audio_path): return f"Audio file not found: {audio_path}" try: @@ -40,6 +43,11 @@ def _check_audio_file(audio_path: str, sample_rate: int) -> Optional[str]: info = sf.info(audio_path) if info.frames == 0: return f"Audio file is empty: {audio_path}" + if info.samplerate != sample_rate: + return ( + f"Sample rate mismatch in {audio_path}: " + f"expected {sample_rate} Hz, got {info.samplerate} Hz" + ) return None except ImportError: # soundfile not available; just check existence @@ -173,6 +181,7 @@ def validate_manifest( missing_audio_count += 1 if missing_audio_count <= 5: result.errors.append(f"Line {i + 1}: {audio_error}") + has_error = True else: duration = _get_audio_duration(audio_path) if duration is not None: @@ -187,6 +196,7 @@ def validate_manifest( ) else: result.errors.append(f"Line {i + 1}: Invalid audio path") + has_error = True # Validate optional ref_audio if "ref_audio" in entry: @@ -198,7 +208,7 @@ def validate_manifest( ref_path = str(manifest_dir / ref_path) if os.path.isfile(ref_path): result.has_ref_audio += 1 - elif result.has_ref_audio == 0: + else: result.warnings.append( f"Line {i + 1}: ref_audio file not found: {ref_path}" ) diff --git a/tests/test_validate.py b/tests/test_validate.py index 3c364be..0c15a61 100644 --- a/tests/test_validate.py +++ b/tests/test_validate.py @@ -183,3 +183,68 @@ class TestValidateManifest: r3 = ValidationResult(total_samples=0, valid_samples=0) assert not r3.is_valid + + def test_invalid_audio_not_counted_as_valid(self, tmp_dir): + """A row with a bad audio path must not increment valid_samples.""" + manifest = tmp_dir / "bad_audio.jsonl" + _write_manifest( + manifest, + [{"text": "hello", "audio": "/nonexistent/audio.wav"}], + ) + result = validate_manifest(str(manifest)) + assert result.total_samples == 1 + assert result.valid_samples == 0 + assert not result.is_valid + assert any("not found" in e for e in result.errors) + + def test_sample_rate_mismatch(self, tmp_dir): + """A file with a different sample rate should be reported as an error.""" + try: + import soundfile as sf + import numpy as np + except ImportError: + pytest.skip("soundfile not available") + + audio = tmp_dir / "audio_8k.wav" + import numpy as np + samples = np.zeros(8000, dtype=np.float32) + sf.write(str(audio), samples, 8000) + + manifest = tmp_dir / "sr_mismatch.jsonl" + _write_manifest(manifest, [{"text": "hello", "audio": str(audio)}]) + + result = validate_manifest(str(manifest), sample_rate=16000) + assert result.valid_samples == 0 + assert not result.is_valid + assert any("Sample rate mismatch" in e or "sample rate" in e.lower() for e in result.errors) + + def test_mixed_ref_audio_warns_for_each_missing(self, tmp_dir): + """Missing ref_audio entries should each generate a warning independently.""" + audio = tmp_dir / "audio.wav" + ref_good = tmp_dir / "ref_good.wav" + _create_wav(audio) + _create_wav(ref_good) + + manifest = tmp_dir / "mixed_ref.jsonl" + _write_manifest( + manifest, + [ + {"text": "row1", "audio": str(audio), "ref_audio": str(ref_good)}, + {"text": "row2", "audio": str(audio), "ref_audio": "/nonexistent/ref.wav"}, + ], + ) + result = validate_manifest(str(manifest)) + assert result.has_ref_audio == 1 + assert any("ref_audio file not found" in w for w in result.warnings) + + def test_cli_validate_exit_code(self, tmp_dir): + """validate subcommand must exit non-zero on error.""" + import subprocess + manifest = tmp_dir / "bad.jsonl" + _write_manifest(manifest, [{"text": "hi", "audio": "/nonexistent/x.wav"}]) + + proc = subprocess.run( + [sys.executable, "-m", "voxcpm.cli", "validate", str(manifest)], + capture_output=True, + ) + assert proc.returncode != 0