forgot to add stuff
This commit is contained in:
451
backend/tests/test_router_contracts.py
Normal file
451
backend/tests/test_router_contracts.py
Normal file
@ -0,0 +1,451 @@
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from backend.main import app
|
||||
from routers import audio as audio_router
|
||||
|
||||
|
||||
class RouterContractTests(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.client = TestClient(app)
|
||||
|
||||
def setUp(self) -> None:
|
||||
audio_router._waveform_cache.clear()
|
||||
|
||||
def test_health_endpoint(self) -> None:
|
||||
res = self.client.get("/health")
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.json(), {"status": "ok"})
|
||||
|
||||
def test_file_endpoint_full_content(self) -> None:
|
||||
with TemporaryDirectory() as tmp:
|
||||
file_path = Path(tmp) / "sample.wav"
|
||||
file_path.write_bytes(b"abcdefghij")
|
||||
|
||||
res = self.client.get("/file", params={"path": str(file_path)})
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.content, b"abcdefghij")
|
||||
self.assertEqual(res.headers.get("accept-ranges"), "bytes")
|
||||
|
||||
def test_file_endpoint_range_request(self) -> None:
|
||||
with TemporaryDirectory() as tmp:
|
||||
file_path = Path(tmp) / "sample.wav"
|
||||
file_path.write_bytes(b"abcdefghij")
|
||||
|
||||
res = self.client.get(
|
||||
"/file",
|
||||
params={"path": str(file_path)},
|
||||
headers={"Range": "bytes=2-5"},
|
||||
)
|
||||
|
||||
self.assertEqual(res.status_code, 206)
|
||||
self.assertEqual(res.content, b"cdef")
|
||||
self.assertEqual(res.headers.get("content-range"), "bytes 2-5/10")
|
||||
|
||||
def test_file_endpoint_missing_file(self) -> None:
|
||||
res = self.client.get("/file", params={"path": "/tmp/does-not-exist.wav"})
|
||||
|
||||
self.assertEqual(res.status_code, 404)
|
||||
self.assertIn("File not found", res.json()["detail"])
|
||||
|
||||
@patch("routers.audio.subprocess.run")
|
||||
def test_audio_waveform_cache_miss_then_hit(self, mock_subprocess_run) -> None:
|
||||
with TemporaryDirectory() as tmp:
|
||||
media_file = Path(tmp) / "input.mp4"
|
||||
media_file.write_bytes(b"fake-media")
|
||||
|
||||
def fake_ffmpeg(cmd, capture_output, text):
|
||||
out_path = Path(cmd[-1])
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_bytes(b"fake-wav")
|
||||
return SimpleNamespace(returncode=0, stderr="")
|
||||
|
||||
mock_subprocess_run.side_effect = fake_ffmpeg
|
||||
|
||||
res1 = self.client.get("/audio/waveform", params={"path": str(media_file)})
|
||||
self.assertEqual(res1.status_code, 200)
|
||||
self.assertTrue(res1.headers.get("content-type", "").startswith("audio/wav"))
|
||||
|
||||
res2 = self.client.get("/audio/waveform", params={"path": str(media_file)})
|
||||
self.assertEqual(res2.status_code, 200)
|
||||
self.assertTrue(res2.headers.get("content-type", "").startswith("audio/wav"))
|
||||
|
||||
self.assertEqual(mock_subprocess_run.call_count, 1)
|
||||
|
||||
@patch("routers.audio.subprocess.run")
|
||||
def test_audio_waveform_ffmpeg_failure_returns_500(self, mock_subprocess_run) -> None:
|
||||
with TemporaryDirectory() as tmp:
|
||||
media_file = Path(tmp) / "input.mp4"
|
||||
media_file.write_bytes(b"fake-media")
|
||||
|
||||
mock_subprocess_run.return_value = SimpleNamespace(returncode=1, stderr="ffmpeg failed")
|
||||
|
||||
res = self.client.get("/audio/waveform", params={"path": str(media_file)})
|
||||
|
||||
self.assertEqual(res.status_code, 500)
|
||||
self.assertIn("Failed to extract audio", res.json()["detail"])
|
||||
|
||||
@patch("routers.ai.detect_filler_words")
|
||||
def test_ai_filler_removal_contract(self, mock_detect_filler_words) -> None:
|
||||
mock_detect_filler_words.return_value = {
|
||||
"wordIndices": [2, 5],
|
||||
"fillerWords": [
|
||||
{"index": 2, "word": "um", "reason": "filler"},
|
||||
{"index": 5, "word": "uh", "reason": "filler"},
|
||||
],
|
||||
}
|
||||
|
||||
payload = {
|
||||
"transcript": "Hello um world uh",
|
||||
"words": [
|
||||
{"index": 0, "word": "Hello"},
|
||||
{"index": 1, "word": "um"},
|
||||
{"index": 2, "word": "world"},
|
||||
],
|
||||
"provider": "ollama",
|
||||
"model": "llama3",
|
||||
}
|
||||
res = self.client.post("/ai/filler-removal", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertIn("wordIndices", res.json())
|
||||
mock_detect_filler_words.assert_called_once()
|
||||
|
||||
@patch("routers.ai.detect_filler_words")
|
||||
def test_ai_filler_removal_error_returns_500(self, mock_detect_filler_words) -> None:
|
||||
mock_detect_filler_words.side_effect = RuntimeError("ai-filler-fail")
|
||||
|
||||
payload = {
|
||||
"transcript": "Hello world",
|
||||
"words": [{"index": 0, "word": "Hello"}],
|
||||
"provider": "ollama",
|
||||
}
|
||||
res = self.client.post("/ai/filler-removal", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 500)
|
||||
self.assertEqual(res.json()["detail"], "ai-filler-fail")
|
||||
|
||||
@patch("routers.ai.create_clip_suggestion")
|
||||
def test_ai_create_clip_contract(self, mock_create_clip_suggestion) -> None:
|
||||
mock_create_clip_suggestion.return_value = {
|
||||
"title": "Best Moment",
|
||||
"startWordIndex": 10,
|
||||
"endWordIndex": 40,
|
||||
"startTime": 12.3,
|
||||
"endTime": 48.8,
|
||||
"reason": "Strong hook",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"transcript": "Long transcript...",
|
||||
"words": [{"index": 0, "word": "hello"}],
|
||||
"provider": "ollama",
|
||||
"target_duration": 45,
|
||||
}
|
||||
res = self.client.post("/ai/create-clip", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.json()["title"], "Best Moment")
|
||||
mock_create_clip_suggestion.assert_called_once()
|
||||
|
||||
@patch("routers.ai.create_clip_suggestion")
|
||||
def test_ai_create_clip_error_returns_500(self, mock_create_clip_suggestion) -> None:
|
||||
mock_create_clip_suggestion.side_effect = RuntimeError("ai-clip-fail")
|
||||
|
||||
payload = {
|
||||
"transcript": "Hello world",
|
||||
"words": [{"index": 0, "word": "hello"}],
|
||||
"provider": "ollama",
|
||||
}
|
||||
res = self.client.post("/ai/create-clip", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 500)
|
||||
self.assertEqual(res.json()["detail"], "ai-clip-fail")
|
||||
|
||||
@patch("routers.ai.AIProvider.list_ollama_models")
|
||||
def test_ai_ollama_models_contract(self, mock_list_ollama_models) -> None:
|
||||
mock_list_ollama_models.return_value = ["llama3", "qwen2.5"]
|
||||
|
||||
res = self.client.get("/ai/ollama-models?base_url=http://localhost:11434")
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.json(), {"models": ["llama3", "qwen2.5"]})
|
||||
mock_list_ollama_models.assert_called_once_with("http://localhost:11434")
|
||||
|
||||
@patch("routers.ai.AIProvider.list_ollama_models")
|
||||
def test_ai_ollama_models_unhandled_error_returns_500(self, mock_list_ollama_models) -> None:
|
||||
mock_list_ollama_models.side_effect = RuntimeError("ollama-unreachable")
|
||||
|
||||
local_client = TestClient(app, raise_server_exceptions=False)
|
||||
res = local_client.get("/ai/ollama-models")
|
||||
|
||||
self.assertEqual(res.status_code, 500)
|
||||
|
||||
@patch("routers.transcribe.transcribe_audio")
|
||||
def test_transcribe_success(self, mock_transcribe) -> None:
|
||||
mock_transcribe.return_value = {"words": [], "segments": [], "language": "en"}
|
||||
|
||||
payload = {
|
||||
"file_path": "/tmp/input.wav",
|
||||
"model": "base",
|
||||
"use_gpu": False,
|
||||
"use_cache": True,
|
||||
}
|
||||
res = self.client.post("/transcribe", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.json(), {"words": [], "segments": [], "language": "en"})
|
||||
mock_transcribe.assert_called_once()
|
||||
|
||||
@patch("routers.transcribe.diarize_and_label")
|
||||
@patch("routers.transcribe.transcribe_audio")
|
||||
def test_transcribe_with_diarization(self, mock_transcribe, mock_diarize) -> None:
|
||||
mock_transcribe.return_value = {"words": [{"word": "hi", "start": 0.0, "end": 0.2}], "segments": []}
|
||||
mock_diarize.return_value = {"words": [{"word": "hi", "start": 0.0, "end": 0.2, "speaker": "SPEAKER_00"}], "segments": []}
|
||||
|
||||
payload = {
|
||||
"file_path": "/tmp/input.wav",
|
||||
"model": "base",
|
||||
"diarize": True,
|
||||
"hf_token": "hf_xxx",
|
||||
"num_speakers": 2,
|
||||
}
|
||||
res = self.client.post("/transcribe", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertIn("words", res.json())
|
||||
mock_transcribe.assert_called_once()
|
||||
mock_diarize.assert_called_once()
|
||||
|
||||
@patch("routers.transcribe.transcribe_audio")
|
||||
def test_transcribe_file_not_found_returns_404(self, mock_transcribe) -> None:
|
||||
mock_transcribe.side_effect = FileNotFoundError("missing")
|
||||
|
||||
payload = {
|
||||
"file_path": "/tmp/missing.wav",
|
||||
"model": "base",
|
||||
}
|
||||
res = self.client.post("/transcribe", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 404)
|
||||
self.assertIn("File not found", res.json()["detail"])
|
||||
|
||||
@patch("routers.transcribe.transcribe_audio")
|
||||
def test_transcribe_runtime_failure_returns_500(self, mock_transcribe) -> None:
|
||||
mock_transcribe.side_effect = RuntimeError("boom")
|
||||
|
||||
payload = {
|
||||
"file_path": "/tmp/in.wav",
|
||||
"model": "base",
|
||||
}
|
||||
res = self.client.post("/transcribe", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 500)
|
||||
self.assertEqual(res.json()["detail"], "boom")
|
||||
|
||||
@patch("routers.captions.generate_srt")
|
||||
def test_captions_plain_response(self, mock_generate_srt) -> None:
|
||||
mock_generate_srt.return_value = "1\n00:00:00,000 --> 00:00:01,000\nHello\n"
|
||||
|
||||
payload = {
|
||||
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
|
||||
"format": "srt",
|
||||
}
|
||||
res = self.client.post("/captions", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertIn("Hello", res.text)
|
||||
mock_generate_srt.assert_called_once()
|
||||
|
||||
@patch("routers.captions.save_captions")
|
||||
@patch("routers.captions.generate_srt")
|
||||
def test_captions_save_output_path(self, mock_generate_srt, mock_save) -> None:
|
||||
mock_generate_srt.return_value = "1\n00:00:00,000 --> 00:00:01,000\nHello\n"
|
||||
mock_save.return_value = "/tmp/out.srt"
|
||||
|
||||
payload = {
|
||||
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
|
||||
"format": "srt",
|
||||
"output_path": "/tmp/out.srt",
|
||||
}
|
||||
res = self.client.post("/captions", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.json(), {"status": "ok", "output_path": "/tmp/out.srt"})
|
||||
mock_save.assert_called_once()
|
||||
|
||||
def test_captions_unknown_format_returns_400(self) -> None:
|
||||
payload = {
|
||||
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
|
||||
"format": "txt",
|
||||
}
|
||||
res = self.client.post("/captions", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 400)
|
||||
self.assertIn("Unknown format", res.json()["detail"])
|
||||
|
||||
@patch("routers.captions.generate_srt")
|
||||
def test_captions_internal_error_returns_500(self, mock_generate_srt) -> None:
|
||||
mock_generate_srt.side_effect = RuntimeError("caption-fail")
|
||||
|
||||
payload = {
|
||||
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
|
||||
"format": "srt",
|
||||
}
|
||||
res = self.client.post("/captions", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 500)
|
||||
self.assertEqual(res.json()["detail"], "caption-fail")
|
||||
|
||||
@patch("routers.audio.is_deepfilter_available")
|
||||
@patch("routers.audio.clean_audio")
|
||||
def test_audio_clean_contract(self, mock_clean_audio, mock_is_deepfilter_available) -> None:
|
||||
mock_clean_audio.return_value = "/tmp/cleaned.wav"
|
||||
mock_is_deepfilter_available.return_value = True
|
||||
|
||||
payload = {
|
||||
"input_path": "/tmp/in.wav",
|
||||
"output_path": "/tmp/cleaned.wav",
|
||||
}
|
||||
res = self.client.post("/audio/clean", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = res.json()
|
||||
self.assertEqual(body["status"], "ok")
|
||||
self.assertEqual(body["output_path"], "/tmp/cleaned.wav")
|
||||
self.assertEqual(body["engine"], "deepfilternet")
|
||||
|
||||
@patch("routers.audio.clean_audio")
|
||||
def test_audio_clean_error_returns_500(self, mock_clean_audio) -> None:
|
||||
mock_clean_audio.side_effect = RuntimeError("clean-fail")
|
||||
|
||||
payload = {
|
||||
"input_path": "/tmp/in.wav",
|
||||
"output_path": "/tmp/cleaned.wav",
|
||||
}
|
||||
res = self.client.post("/audio/clean", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 500)
|
||||
self.assertEqual(res.json()["detail"], "clean-fail")
|
||||
|
||||
@patch("routers.audio.detect_silence_ranges")
|
||||
def test_audio_detect_silence_contract(self, mock_detect_silence_ranges) -> None:
|
||||
mock_detect_silence_ranges.return_value = [{"start": 1.2, "end": 2.1, "duration": 0.9}]
|
||||
|
||||
payload = {
|
||||
"input_path": "/tmp/in.wav",
|
||||
"min_silence_ms": 500,
|
||||
"silence_db": -35.0,
|
||||
}
|
||||
res = self.client.post("/audio/detect-silence", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = res.json()
|
||||
self.assertEqual(body["status"], "ok")
|
||||
self.assertEqual(body["count"], 1)
|
||||
self.assertEqual(len(body["ranges"]), 1)
|
||||
|
||||
@patch("routers.audio.detect_silence_ranges")
|
||||
def test_audio_detect_silence_error_returns_500(self, mock_detect_silence_ranges) -> None:
|
||||
mock_detect_silence_ranges.side_effect = RuntimeError("silence-fail")
|
||||
|
||||
payload = {
|
||||
"input_path": "/tmp/in.wav",
|
||||
"min_silence_ms": 500,
|
||||
"silence_db": -35.0,
|
||||
}
|
||||
res = self.client.post("/audio/detect-silence", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 500)
|
||||
self.assertEqual(res.json()["detail"], "silence-fail")
|
||||
|
||||
@patch("routers.audio.is_deepfilter_available")
|
||||
def test_audio_capabilities_contract(self, mock_is_deepfilter_available) -> None:
|
||||
mock_is_deepfilter_available.return_value = False
|
||||
|
||||
res = self.client.get("/audio/capabilities")
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.json(), {"deepfilternet_available": False})
|
||||
|
||||
@patch("routers.export.export_stream_copy")
|
||||
def test_export_fast_contract(self, mock_export_stream_copy) -> None:
|
||||
mock_export_stream_copy.return_value = "/tmp/out.mp4"
|
||||
|
||||
payload = {
|
||||
"input_path": "/tmp/in.mp4",
|
||||
"output_path": "/tmp/out.mp4",
|
||||
"keep_segments": [{"start": 0.0, "end": 2.0}],
|
||||
"mode": "fast",
|
||||
"captions": "none",
|
||||
}
|
||||
res = self.client.post("/export", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
self.assertEqual(res.json(), {"status": "ok", "output_path": "/tmp/out.mp4"})
|
||||
mock_export_stream_copy.assert_called_once()
|
||||
|
||||
@patch("routers.export.save_captions")
|
||||
@patch("routers.export.generate_srt")
|
||||
@patch("routers.export.export_stream_copy")
|
||||
def test_export_sidecar_caption_contract(self, mock_export_stream_copy, mock_generate_srt, mock_save_captions) -> None:
|
||||
mock_export_stream_copy.return_value = "/tmp/out.mp4"
|
||||
mock_generate_srt.return_value = "1\n00:00:00,000 --> 00:00:01,000\nHello\n"
|
||||
|
||||
payload = {
|
||||
"input_path": "/tmp/in.mp4",
|
||||
"output_path": "/tmp/out.mp4",
|
||||
"keep_segments": [{"start": 0.0, "end": 2.0}],
|
||||
"mode": "fast",
|
||||
"captions": "sidecar",
|
||||
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
|
||||
"deleted_indices": [],
|
||||
}
|
||||
res = self.client.post("/export", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 200)
|
||||
body = res.json()
|
||||
self.assertEqual(body["status"], "ok")
|
||||
self.assertEqual(body["output_path"], "/tmp/out.mp4")
|
||||
self.assertEqual(body["srt_path"], "/tmp/out.srt")
|
||||
mock_save_captions.assert_called_once()
|
||||
|
||||
def test_export_missing_segments_returns_400(self) -> None:
|
||||
payload = {
|
||||
"input_path": "/tmp/in.mp4",
|
||||
"output_path": "/tmp/out.mp4",
|
||||
"keep_segments": [],
|
||||
"mode": "fast",
|
||||
"captions": "none",
|
||||
}
|
||||
res = self.client.post("/export", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 400)
|
||||
self.assertIn("No segments to export", res.json()["detail"])
|
||||
|
||||
@patch("routers.export.export_stream_copy")
|
||||
def test_export_runtime_error_returns_500(self, mock_export_stream_copy) -> None:
|
||||
mock_export_stream_copy.side_effect = RuntimeError("export-fail")
|
||||
|
||||
payload = {
|
||||
"input_path": "/tmp/in.mp4",
|
||||
"output_path": "/tmp/out.mp4",
|
||||
"keep_segments": [{"start": 0.0, "end": 2.0}],
|
||||
"mode": "fast",
|
||||
"captions": "none",
|
||||
}
|
||||
res = self.client.post("/export", json=payload)
|
||||
|
||||
self.assertEqual(res.status_code, 500)
|
||||
self.assertEqual(res.json()["detail"], "export-fail")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user