452 lines
17 KiB
Python
452 lines
17 KiB
Python
import unittest
|
|
from unittest.mock import patch
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
import os
|
|
from types import SimpleNamespace
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
from backend.main import app
|
|
from routers import audio as audio_router
|
|
|
|
|
|
class RouterContractTests(unittest.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls) -> None:
|
|
cls.client = TestClient(app)
|
|
|
|
def setUp(self) -> None:
|
|
audio_router._waveform_cache.clear()
|
|
|
|
def test_health_endpoint(self) -> None:
|
|
res = self.client.get("/health")
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.json(), {"status": "ok"})
|
|
|
|
def test_file_endpoint_full_content(self) -> None:
|
|
with TemporaryDirectory() as tmp:
|
|
file_path = Path(tmp) / "sample.wav"
|
|
file_path.write_bytes(b"abcdefghij")
|
|
|
|
res = self.client.get("/file", params={"path": str(file_path)})
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.content, b"abcdefghij")
|
|
self.assertEqual(res.headers.get("accept-ranges"), "bytes")
|
|
|
|
def test_file_endpoint_range_request(self) -> None:
|
|
with TemporaryDirectory() as tmp:
|
|
file_path = Path(tmp) / "sample.wav"
|
|
file_path.write_bytes(b"abcdefghij")
|
|
|
|
res = self.client.get(
|
|
"/file",
|
|
params={"path": str(file_path)},
|
|
headers={"Range": "bytes=2-5"},
|
|
)
|
|
|
|
self.assertEqual(res.status_code, 206)
|
|
self.assertEqual(res.content, b"cdef")
|
|
self.assertEqual(res.headers.get("content-range"), "bytes 2-5/10")
|
|
|
|
def test_file_endpoint_missing_file(self) -> None:
|
|
res = self.client.get("/file", params={"path": "/tmp/does-not-exist.wav"})
|
|
|
|
self.assertEqual(res.status_code, 404)
|
|
self.assertIn("File not found", res.json()["detail"])
|
|
|
|
@patch("routers.audio.subprocess.run")
|
|
def test_audio_waveform_cache_miss_then_hit(self, mock_subprocess_run) -> None:
|
|
with TemporaryDirectory() as tmp:
|
|
media_file = Path(tmp) / "input.mp4"
|
|
media_file.write_bytes(b"fake-media")
|
|
|
|
def fake_ffmpeg(cmd, capture_output, text):
|
|
out_path = Path(cmd[-1])
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
out_path.write_bytes(b"fake-wav")
|
|
return SimpleNamespace(returncode=0, stderr="")
|
|
|
|
mock_subprocess_run.side_effect = fake_ffmpeg
|
|
|
|
res1 = self.client.get("/audio/waveform", params={"path": str(media_file)})
|
|
self.assertEqual(res1.status_code, 200)
|
|
self.assertTrue(res1.headers.get("content-type", "").startswith("audio/wav"))
|
|
|
|
res2 = self.client.get("/audio/waveform", params={"path": str(media_file)})
|
|
self.assertEqual(res2.status_code, 200)
|
|
self.assertTrue(res2.headers.get("content-type", "").startswith("audio/wav"))
|
|
|
|
self.assertEqual(mock_subprocess_run.call_count, 1)
|
|
|
|
@patch("routers.audio.subprocess.run")
|
|
def test_audio_waveform_ffmpeg_failure_returns_500(self, mock_subprocess_run) -> None:
|
|
with TemporaryDirectory() as tmp:
|
|
media_file = Path(tmp) / "input.mp4"
|
|
media_file.write_bytes(b"fake-media")
|
|
|
|
mock_subprocess_run.return_value = SimpleNamespace(returncode=1, stderr="ffmpeg failed")
|
|
|
|
res = self.client.get("/audio/waveform", params={"path": str(media_file)})
|
|
|
|
self.assertEqual(res.status_code, 500)
|
|
self.assertIn("Failed to extract audio", res.json()["detail"])
|
|
|
|
@patch("routers.ai.detect_filler_words")
|
|
def test_ai_filler_removal_contract(self, mock_detect_filler_words) -> None:
|
|
mock_detect_filler_words.return_value = {
|
|
"wordIndices": [2, 5],
|
|
"fillerWords": [
|
|
{"index": 2, "word": "um", "reason": "filler"},
|
|
{"index": 5, "word": "uh", "reason": "filler"},
|
|
],
|
|
}
|
|
|
|
payload = {
|
|
"transcript": "Hello um world uh",
|
|
"words": [
|
|
{"index": 0, "word": "Hello"},
|
|
{"index": 1, "word": "um"},
|
|
{"index": 2, "word": "world"},
|
|
],
|
|
"provider": "ollama",
|
|
"model": "llama3",
|
|
}
|
|
res = self.client.post("/ai/filler-removal", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertIn("wordIndices", res.json())
|
|
mock_detect_filler_words.assert_called_once()
|
|
|
|
@patch("routers.ai.detect_filler_words")
|
|
def test_ai_filler_removal_error_returns_500(self, mock_detect_filler_words) -> None:
|
|
mock_detect_filler_words.side_effect = RuntimeError("ai-filler-fail")
|
|
|
|
payload = {
|
|
"transcript": "Hello world",
|
|
"words": [{"index": 0, "word": "Hello"}],
|
|
"provider": "ollama",
|
|
}
|
|
res = self.client.post("/ai/filler-removal", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 500)
|
|
self.assertEqual(res.json()["detail"], "ai-filler-fail")
|
|
|
|
@patch("routers.ai.create_clip_suggestion")
|
|
def test_ai_create_clip_contract(self, mock_create_clip_suggestion) -> None:
|
|
mock_create_clip_suggestion.return_value = {
|
|
"title": "Best Moment",
|
|
"startWordIndex": 10,
|
|
"endWordIndex": 40,
|
|
"startTime": 12.3,
|
|
"endTime": 48.8,
|
|
"reason": "Strong hook",
|
|
}
|
|
|
|
payload = {
|
|
"transcript": "Long transcript...",
|
|
"words": [{"index": 0, "word": "hello"}],
|
|
"provider": "ollama",
|
|
"target_duration": 45,
|
|
}
|
|
res = self.client.post("/ai/create-clip", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.json()["title"], "Best Moment")
|
|
mock_create_clip_suggestion.assert_called_once()
|
|
|
|
@patch("routers.ai.create_clip_suggestion")
|
|
def test_ai_create_clip_error_returns_500(self, mock_create_clip_suggestion) -> None:
|
|
mock_create_clip_suggestion.side_effect = RuntimeError("ai-clip-fail")
|
|
|
|
payload = {
|
|
"transcript": "Hello world",
|
|
"words": [{"index": 0, "word": "hello"}],
|
|
"provider": "ollama",
|
|
}
|
|
res = self.client.post("/ai/create-clip", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 500)
|
|
self.assertEqual(res.json()["detail"], "ai-clip-fail")
|
|
|
|
@patch("routers.ai.AIProvider.list_ollama_models")
|
|
def test_ai_ollama_models_contract(self, mock_list_ollama_models) -> None:
|
|
mock_list_ollama_models.return_value = ["llama3", "qwen2.5"]
|
|
|
|
res = self.client.get("/ai/ollama-models?base_url=http://localhost:11434")
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.json(), {"models": ["llama3", "qwen2.5"]})
|
|
mock_list_ollama_models.assert_called_once_with("http://localhost:11434")
|
|
|
|
@patch("routers.ai.AIProvider.list_ollama_models")
|
|
def test_ai_ollama_models_unhandled_error_returns_500(self, mock_list_ollama_models) -> None:
|
|
mock_list_ollama_models.side_effect = RuntimeError("ollama-unreachable")
|
|
|
|
local_client = TestClient(app, raise_server_exceptions=False)
|
|
res = local_client.get("/ai/ollama-models")
|
|
|
|
self.assertEqual(res.status_code, 500)
|
|
|
|
@patch("routers.transcribe.transcribe_audio")
|
|
def test_transcribe_success(self, mock_transcribe) -> None:
|
|
mock_transcribe.return_value = {"words": [], "segments": [], "language": "en"}
|
|
|
|
payload = {
|
|
"file_path": "/tmp/input.wav",
|
|
"model": "base",
|
|
"use_gpu": False,
|
|
"use_cache": True,
|
|
}
|
|
res = self.client.post("/transcribe", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.json(), {"words": [], "segments": [], "language": "en"})
|
|
mock_transcribe.assert_called_once()
|
|
|
|
@patch("routers.transcribe.diarize_and_label")
|
|
@patch("routers.transcribe.transcribe_audio")
|
|
def test_transcribe_with_diarization(self, mock_transcribe, mock_diarize) -> None:
|
|
mock_transcribe.return_value = {"words": [{"word": "hi", "start": 0.0, "end": 0.2}], "segments": []}
|
|
mock_diarize.return_value = {"words": [{"word": "hi", "start": 0.0, "end": 0.2, "speaker": "SPEAKER_00"}], "segments": []}
|
|
|
|
payload = {
|
|
"file_path": "/tmp/input.wav",
|
|
"model": "base",
|
|
"diarize": True,
|
|
"hf_token": "hf_xxx",
|
|
"num_speakers": 2,
|
|
}
|
|
res = self.client.post("/transcribe", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertIn("words", res.json())
|
|
mock_transcribe.assert_called_once()
|
|
mock_diarize.assert_called_once()
|
|
|
|
@patch("routers.transcribe.transcribe_audio")
|
|
def test_transcribe_file_not_found_returns_404(self, mock_transcribe) -> None:
|
|
mock_transcribe.side_effect = FileNotFoundError("missing")
|
|
|
|
payload = {
|
|
"file_path": "/tmp/missing.wav",
|
|
"model": "base",
|
|
}
|
|
res = self.client.post("/transcribe", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 404)
|
|
self.assertIn("File not found", res.json()["detail"])
|
|
|
|
@patch("routers.transcribe.transcribe_audio")
|
|
def test_transcribe_runtime_failure_returns_500(self, mock_transcribe) -> None:
|
|
mock_transcribe.side_effect = RuntimeError("boom")
|
|
|
|
payload = {
|
|
"file_path": "/tmp/in.wav",
|
|
"model": "base",
|
|
}
|
|
res = self.client.post("/transcribe", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 500)
|
|
self.assertEqual(res.json()["detail"], "boom")
|
|
|
|
@patch("routers.captions.generate_srt")
|
|
def test_captions_plain_response(self, mock_generate_srt) -> None:
|
|
mock_generate_srt.return_value = "1\n00:00:00,000 --> 00:00:01,000\nHello\n"
|
|
|
|
payload = {
|
|
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
|
|
"format": "srt",
|
|
}
|
|
res = self.client.post("/captions", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertIn("Hello", res.text)
|
|
mock_generate_srt.assert_called_once()
|
|
|
|
@patch("routers.captions.save_captions")
|
|
@patch("routers.captions.generate_srt")
|
|
def test_captions_save_output_path(self, mock_generate_srt, mock_save) -> None:
|
|
mock_generate_srt.return_value = "1\n00:00:00,000 --> 00:00:01,000\nHello\n"
|
|
mock_save.return_value = "/tmp/out.srt"
|
|
|
|
payload = {
|
|
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
|
|
"format": "srt",
|
|
"output_path": "/tmp/out.srt",
|
|
}
|
|
res = self.client.post("/captions", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.json(), {"status": "ok", "output_path": "/tmp/out.srt"})
|
|
mock_save.assert_called_once()
|
|
|
|
def test_captions_unknown_format_returns_400(self) -> None:
|
|
payload = {
|
|
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
|
|
"format": "txt",
|
|
}
|
|
res = self.client.post("/captions", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 400)
|
|
self.assertIn("Unknown format", res.json()["detail"])
|
|
|
|
@patch("routers.captions.generate_srt")
|
|
def test_captions_internal_error_returns_500(self, mock_generate_srt) -> None:
|
|
mock_generate_srt.side_effect = RuntimeError("caption-fail")
|
|
|
|
payload = {
|
|
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
|
|
"format": "srt",
|
|
}
|
|
res = self.client.post("/captions", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 500)
|
|
self.assertEqual(res.json()["detail"], "caption-fail")
|
|
|
|
@patch("routers.audio.is_deepfilter_available")
|
|
@patch("routers.audio.clean_audio")
|
|
def test_audio_clean_contract(self, mock_clean_audio, mock_is_deepfilter_available) -> None:
|
|
mock_clean_audio.return_value = "/tmp/cleaned.wav"
|
|
mock_is_deepfilter_available.return_value = True
|
|
|
|
payload = {
|
|
"input_path": "/tmp/in.wav",
|
|
"output_path": "/tmp/cleaned.wav",
|
|
}
|
|
res = self.client.post("/audio/clean", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
body = res.json()
|
|
self.assertEqual(body["status"], "ok")
|
|
self.assertEqual(body["output_path"], "/tmp/cleaned.wav")
|
|
self.assertEqual(body["engine"], "deepfilternet")
|
|
|
|
@patch("routers.audio.clean_audio")
|
|
def test_audio_clean_error_returns_500(self, mock_clean_audio) -> None:
|
|
mock_clean_audio.side_effect = RuntimeError("clean-fail")
|
|
|
|
payload = {
|
|
"input_path": "/tmp/in.wav",
|
|
"output_path": "/tmp/cleaned.wav",
|
|
}
|
|
res = self.client.post("/audio/clean", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 500)
|
|
self.assertEqual(res.json()["detail"], "clean-fail")
|
|
|
|
@patch("routers.audio.detect_silence_ranges")
|
|
def test_audio_detect_silence_contract(self, mock_detect_silence_ranges) -> None:
|
|
mock_detect_silence_ranges.return_value = [{"start": 1.2, "end": 2.1, "duration": 0.9}]
|
|
|
|
payload = {
|
|
"input_path": "/tmp/in.wav",
|
|
"min_silence_ms": 500,
|
|
"silence_db": -35.0,
|
|
}
|
|
res = self.client.post("/audio/detect-silence", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
body = res.json()
|
|
self.assertEqual(body["status"], "ok")
|
|
self.assertEqual(body["count"], 1)
|
|
self.assertEqual(len(body["ranges"]), 1)
|
|
|
|
@patch("routers.audio.detect_silence_ranges")
|
|
def test_audio_detect_silence_error_returns_500(self, mock_detect_silence_ranges) -> None:
|
|
mock_detect_silence_ranges.side_effect = RuntimeError("silence-fail")
|
|
|
|
payload = {
|
|
"input_path": "/tmp/in.wav",
|
|
"min_silence_ms": 500,
|
|
"silence_db": -35.0,
|
|
}
|
|
res = self.client.post("/audio/detect-silence", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 500)
|
|
self.assertEqual(res.json()["detail"], "silence-fail")
|
|
|
|
@patch("routers.audio.is_deepfilter_available")
|
|
def test_audio_capabilities_contract(self, mock_is_deepfilter_available) -> None:
|
|
mock_is_deepfilter_available.return_value = False
|
|
|
|
res = self.client.get("/audio/capabilities")
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.json(), {"deepfilternet_available": False})
|
|
|
|
@patch("routers.export.export_stream_copy")
|
|
def test_export_fast_contract(self, mock_export_stream_copy) -> None:
|
|
mock_export_stream_copy.return_value = "/tmp/out.mp4"
|
|
|
|
payload = {
|
|
"input_path": "/tmp/in.mp4",
|
|
"output_path": "/tmp/out.mp4",
|
|
"keep_segments": [{"start": 0.0, "end": 2.0}],
|
|
"mode": "fast",
|
|
"captions": "none",
|
|
}
|
|
res = self.client.post("/export", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
self.assertEqual(res.json(), {"status": "ok", "output_path": "/tmp/out.mp4"})
|
|
mock_export_stream_copy.assert_called_once()
|
|
|
|
@patch("routers.export.save_captions")
|
|
@patch("routers.export.generate_srt")
|
|
@patch("routers.export.export_stream_copy")
|
|
def test_export_sidecar_caption_contract(self, mock_export_stream_copy, mock_generate_srt, mock_save_captions) -> None:
|
|
mock_export_stream_copy.return_value = "/tmp/out.mp4"
|
|
mock_generate_srt.return_value = "1\n00:00:00,000 --> 00:00:01,000\nHello\n"
|
|
|
|
payload = {
|
|
"input_path": "/tmp/in.mp4",
|
|
"output_path": "/tmp/out.mp4",
|
|
"keep_segments": [{"start": 0.0, "end": 2.0}],
|
|
"mode": "fast",
|
|
"captions": "sidecar",
|
|
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
|
|
"deleted_indices": [],
|
|
}
|
|
res = self.client.post("/export", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 200)
|
|
body = res.json()
|
|
self.assertEqual(body["status"], "ok")
|
|
self.assertEqual(body["output_path"], "/tmp/out.mp4")
|
|
self.assertEqual(body["srt_path"], "/tmp/out.srt")
|
|
mock_save_captions.assert_called_once()
|
|
|
|
def test_export_missing_segments_returns_400(self) -> None:
|
|
payload = {
|
|
"input_path": "/tmp/in.mp4",
|
|
"output_path": "/tmp/out.mp4",
|
|
"keep_segments": [],
|
|
"mode": "fast",
|
|
"captions": "none",
|
|
}
|
|
res = self.client.post("/export", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 400)
|
|
self.assertIn("No segments to export", res.json()["detail"])
|
|
|
|
@patch("routers.export.export_stream_copy")
|
|
def test_export_runtime_error_returns_500(self, mock_export_stream_copy) -> None:
|
|
mock_export_stream_copy.side_effect = RuntimeError("export-fail")
|
|
|
|
payload = {
|
|
"input_path": "/tmp/in.mp4",
|
|
"output_path": "/tmp/out.mp4",
|
|
"keep_segments": [{"start": 0.0, "end": 2.0}],
|
|
"mode": "fast",
|
|
"captions": "none",
|
|
}
|
|
res = self.client.post("/export", json=payload)
|
|
|
|
self.assertEqual(res.status_code, 500)
|
|
self.assertEqual(res.json()["detail"], "export-fail")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|