forgot to add stuff

This commit is contained in:
2026-04-15 18:02:25 -06:00
parent 17874587a4
commit 7479acd3ee
10 changed files with 940 additions and 1 deletions

View File

View File

@ -0,0 +1,57 @@
import tempfile
import time
import unittest
from pathlib import Path
from backend.utils import cache as cache_utils
class CacheUtilsTests(unittest.TestCase):
def setUp(self) -> None:
self._tmp_dir = tempfile.TemporaryDirectory()
self._old_cache_dir = cache_utils.CACHE_DIR
cache_utils.CACHE_DIR = Path(self._tmp_dir.name) / "cache"
self._work_dir = Path(self._tmp_dir.name) / "work"
self._work_dir.mkdir(parents=True, exist_ok=True)
self._src_file = self._work_dir / "sample.txt"
self._src_file.write_text("hello", encoding="utf-8")
def tearDown(self) -> None:
cache_utils.CACHE_DIR = self._old_cache_dir
self._tmp_dir.cleanup()
def test_get_file_hash_returns_none_for_missing_file(self) -> None:
missing = self._work_dir / "missing.txt"
self.assertIsNone(cache_utils.get_file_hash(missing))
def test_save_and_load_round_trip(self) -> None:
payload = {"value": 123, "ok": True}
saved = cache_utils.save_to_cache(self._src_file, payload, model="m1", operation="transcribe")
self.assertTrue(saved)
loaded = cache_utils.load_from_cache(self._src_file, model="m1", operation="transcribe")
self.assertEqual(payload, loaded)
def test_load_from_cache_respects_max_age(self) -> None:
payload = {"value": 999}
self.assertTrue(cache_utils.save_to_cache(self._src_file, payload, operation="transcribe"))
time.sleep(0.02)
expired = cache_utils.load_from_cache(self._src_file, operation="transcribe", max_age=0.001)
self.assertIsNone(expired)
def test_clear_cache_deletes_files(self) -> None:
self.assertTrue(cache_utils.save_to_cache(self._src_file, {"a": 1}, operation="transcribe"))
self.assertTrue(cache_utils.save_to_cache(self._src_file, {"a": 2}, operation="summarize"))
deleted_count = cache_utils.clear_cache()
self.assertGreaterEqual(deleted_count, 1)
size_bytes, file_count = cache_utils.get_cache_size()
self.assertEqual(size_bytes, 0)
self.assertEqual(file_count, 0)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,451 @@
import unittest
from unittest.mock import patch
from pathlib import Path
from tempfile import TemporaryDirectory
import os
from types import SimpleNamespace
from fastapi.testclient import TestClient
from backend.main import app
from routers import audio as audio_router
class RouterContractTests(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
cls.client = TestClient(app)
def setUp(self) -> None:
audio_router._waveform_cache.clear()
def test_health_endpoint(self) -> None:
res = self.client.get("/health")
self.assertEqual(res.status_code, 200)
self.assertEqual(res.json(), {"status": "ok"})
def test_file_endpoint_full_content(self) -> None:
with TemporaryDirectory() as tmp:
file_path = Path(tmp) / "sample.wav"
file_path.write_bytes(b"abcdefghij")
res = self.client.get("/file", params={"path": str(file_path)})
self.assertEqual(res.status_code, 200)
self.assertEqual(res.content, b"abcdefghij")
self.assertEqual(res.headers.get("accept-ranges"), "bytes")
def test_file_endpoint_range_request(self) -> None:
with TemporaryDirectory() as tmp:
file_path = Path(tmp) / "sample.wav"
file_path.write_bytes(b"abcdefghij")
res = self.client.get(
"/file",
params={"path": str(file_path)},
headers={"Range": "bytes=2-5"},
)
self.assertEqual(res.status_code, 206)
self.assertEqual(res.content, b"cdef")
self.assertEqual(res.headers.get("content-range"), "bytes 2-5/10")
def test_file_endpoint_missing_file(self) -> None:
res = self.client.get("/file", params={"path": "/tmp/does-not-exist.wav"})
self.assertEqual(res.status_code, 404)
self.assertIn("File not found", res.json()["detail"])
@patch("routers.audio.subprocess.run")
def test_audio_waveform_cache_miss_then_hit(self, mock_subprocess_run) -> None:
with TemporaryDirectory() as tmp:
media_file = Path(tmp) / "input.mp4"
media_file.write_bytes(b"fake-media")
def fake_ffmpeg(cmd, capture_output, text):
out_path = Path(cmd[-1])
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_bytes(b"fake-wav")
return SimpleNamespace(returncode=0, stderr="")
mock_subprocess_run.side_effect = fake_ffmpeg
res1 = self.client.get("/audio/waveform", params={"path": str(media_file)})
self.assertEqual(res1.status_code, 200)
self.assertTrue(res1.headers.get("content-type", "").startswith("audio/wav"))
res2 = self.client.get("/audio/waveform", params={"path": str(media_file)})
self.assertEqual(res2.status_code, 200)
self.assertTrue(res2.headers.get("content-type", "").startswith("audio/wav"))
self.assertEqual(mock_subprocess_run.call_count, 1)
@patch("routers.audio.subprocess.run")
def test_audio_waveform_ffmpeg_failure_returns_500(self, mock_subprocess_run) -> None:
with TemporaryDirectory() as tmp:
media_file = Path(tmp) / "input.mp4"
media_file.write_bytes(b"fake-media")
mock_subprocess_run.return_value = SimpleNamespace(returncode=1, stderr="ffmpeg failed")
res = self.client.get("/audio/waveform", params={"path": str(media_file)})
self.assertEqual(res.status_code, 500)
self.assertIn("Failed to extract audio", res.json()["detail"])
@patch("routers.ai.detect_filler_words")
def test_ai_filler_removal_contract(self, mock_detect_filler_words) -> None:
mock_detect_filler_words.return_value = {
"wordIndices": [2, 5],
"fillerWords": [
{"index": 2, "word": "um", "reason": "filler"},
{"index": 5, "word": "uh", "reason": "filler"},
],
}
payload = {
"transcript": "Hello um world uh",
"words": [
{"index": 0, "word": "Hello"},
{"index": 1, "word": "um"},
{"index": 2, "word": "world"},
],
"provider": "ollama",
"model": "llama3",
}
res = self.client.post("/ai/filler-removal", json=payload)
self.assertEqual(res.status_code, 200)
self.assertIn("wordIndices", res.json())
mock_detect_filler_words.assert_called_once()
@patch("routers.ai.detect_filler_words")
def test_ai_filler_removal_error_returns_500(self, mock_detect_filler_words) -> None:
mock_detect_filler_words.side_effect = RuntimeError("ai-filler-fail")
payload = {
"transcript": "Hello world",
"words": [{"index": 0, "word": "Hello"}],
"provider": "ollama",
}
res = self.client.post("/ai/filler-removal", json=payload)
self.assertEqual(res.status_code, 500)
self.assertEqual(res.json()["detail"], "ai-filler-fail")
@patch("routers.ai.create_clip_suggestion")
def test_ai_create_clip_contract(self, mock_create_clip_suggestion) -> None:
mock_create_clip_suggestion.return_value = {
"title": "Best Moment",
"startWordIndex": 10,
"endWordIndex": 40,
"startTime": 12.3,
"endTime": 48.8,
"reason": "Strong hook",
}
payload = {
"transcript": "Long transcript...",
"words": [{"index": 0, "word": "hello"}],
"provider": "ollama",
"target_duration": 45,
}
res = self.client.post("/ai/create-clip", json=payload)
self.assertEqual(res.status_code, 200)
self.assertEqual(res.json()["title"], "Best Moment")
mock_create_clip_suggestion.assert_called_once()
@patch("routers.ai.create_clip_suggestion")
def test_ai_create_clip_error_returns_500(self, mock_create_clip_suggestion) -> None:
mock_create_clip_suggestion.side_effect = RuntimeError("ai-clip-fail")
payload = {
"transcript": "Hello world",
"words": [{"index": 0, "word": "hello"}],
"provider": "ollama",
}
res = self.client.post("/ai/create-clip", json=payload)
self.assertEqual(res.status_code, 500)
self.assertEqual(res.json()["detail"], "ai-clip-fail")
@patch("routers.ai.AIProvider.list_ollama_models")
def test_ai_ollama_models_contract(self, mock_list_ollama_models) -> None:
mock_list_ollama_models.return_value = ["llama3", "qwen2.5"]
res = self.client.get("/ai/ollama-models?base_url=http://localhost:11434")
self.assertEqual(res.status_code, 200)
self.assertEqual(res.json(), {"models": ["llama3", "qwen2.5"]})
mock_list_ollama_models.assert_called_once_with("http://localhost:11434")
@patch("routers.ai.AIProvider.list_ollama_models")
def test_ai_ollama_models_unhandled_error_returns_500(self, mock_list_ollama_models) -> None:
mock_list_ollama_models.side_effect = RuntimeError("ollama-unreachable")
local_client = TestClient(app, raise_server_exceptions=False)
res = local_client.get("/ai/ollama-models")
self.assertEqual(res.status_code, 500)
@patch("routers.transcribe.transcribe_audio")
def test_transcribe_success(self, mock_transcribe) -> None:
mock_transcribe.return_value = {"words": [], "segments": [], "language": "en"}
payload = {
"file_path": "/tmp/input.wav",
"model": "base",
"use_gpu": False,
"use_cache": True,
}
res = self.client.post("/transcribe", json=payload)
self.assertEqual(res.status_code, 200)
self.assertEqual(res.json(), {"words": [], "segments": [], "language": "en"})
mock_transcribe.assert_called_once()
@patch("routers.transcribe.diarize_and_label")
@patch("routers.transcribe.transcribe_audio")
def test_transcribe_with_diarization(self, mock_transcribe, mock_diarize) -> None:
mock_transcribe.return_value = {"words": [{"word": "hi", "start": 0.0, "end": 0.2}], "segments": []}
mock_diarize.return_value = {"words": [{"word": "hi", "start": 0.0, "end": 0.2, "speaker": "SPEAKER_00"}], "segments": []}
payload = {
"file_path": "/tmp/input.wav",
"model": "base",
"diarize": True,
"hf_token": "hf_xxx",
"num_speakers": 2,
}
res = self.client.post("/transcribe", json=payload)
self.assertEqual(res.status_code, 200)
self.assertIn("words", res.json())
mock_transcribe.assert_called_once()
mock_diarize.assert_called_once()
@patch("routers.transcribe.transcribe_audio")
def test_transcribe_file_not_found_returns_404(self, mock_transcribe) -> None:
mock_transcribe.side_effect = FileNotFoundError("missing")
payload = {
"file_path": "/tmp/missing.wav",
"model": "base",
}
res = self.client.post("/transcribe", json=payload)
self.assertEqual(res.status_code, 404)
self.assertIn("File not found", res.json()["detail"])
@patch("routers.transcribe.transcribe_audio")
def test_transcribe_runtime_failure_returns_500(self, mock_transcribe) -> None:
mock_transcribe.side_effect = RuntimeError("boom")
payload = {
"file_path": "/tmp/in.wav",
"model": "base",
}
res = self.client.post("/transcribe", json=payload)
self.assertEqual(res.status_code, 500)
self.assertEqual(res.json()["detail"], "boom")
@patch("routers.captions.generate_srt")
def test_captions_plain_response(self, mock_generate_srt) -> None:
mock_generate_srt.return_value = "1\n00:00:00,000 --> 00:00:01,000\nHello\n"
payload = {
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
"format": "srt",
}
res = self.client.post("/captions", json=payload)
self.assertEqual(res.status_code, 200)
self.assertIn("Hello", res.text)
mock_generate_srt.assert_called_once()
@patch("routers.captions.save_captions")
@patch("routers.captions.generate_srt")
def test_captions_save_output_path(self, mock_generate_srt, mock_save) -> None:
mock_generate_srt.return_value = "1\n00:00:00,000 --> 00:00:01,000\nHello\n"
mock_save.return_value = "/tmp/out.srt"
payload = {
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
"format": "srt",
"output_path": "/tmp/out.srt",
}
res = self.client.post("/captions", json=payload)
self.assertEqual(res.status_code, 200)
self.assertEqual(res.json(), {"status": "ok", "output_path": "/tmp/out.srt"})
mock_save.assert_called_once()
def test_captions_unknown_format_returns_400(self) -> None:
payload = {
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
"format": "txt",
}
res = self.client.post("/captions", json=payload)
self.assertEqual(res.status_code, 400)
self.assertIn("Unknown format", res.json()["detail"])
@patch("routers.captions.generate_srt")
def test_captions_internal_error_returns_500(self, mock_generate_srt) -> None:
mock_generate_srt.side_effect = RuntimeError("caption-fail")
payload = {
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
"format": "srt",
}
res = self.client.post("/captions", json=payload)
self.assertEqual(res.status_code, 500)
self.assertEqual(res.json()["detail"], "caption-fail")
@patch("routers.audio.is_deepfilter_available")
@patch("routers.audio.clean_audio")
def test_audio_clean_contract(self, mock_clean_audio, mock_is_deepfilter_available) -> None:
mock_clean_audio.return_value = "/tmp/cleaned.wav"
mock_is_deepfilter_available.return_value = True
payload = {
"input_path": "/tmp/in.wav",
"output_path": "/tmp/cleaned.wav",
}
res = self.client.post("/audio/clean", json=payload)
self.assertEqual(res.status_code, 200)
body = res.json()
self.assertEqual(body["status"], "ok")
self.assertEqual(body["output_path"], "/tmp/cleaned.wav")
self.assertEqual(body["engine"], "deepfilternet")
@patch("routers.audio.clean_audio")
def test_audio_clean_error_returns_500(self, mock_clean_audio) -> None:
mock_clean_audio.side_effect = RuntimeError("clean-fail")
payload = {
"input_path": "/tmp/in.wav",
"output_path": "/tmp/cleaned.wav",
}
res = self.client.post("/audio/clean", json=payload)
self.assertEqual(res.status_code, 500)
self.assertEqual(res.json()["detail"], "clean-fail")
@patch("routers.audio.detect_silence_ranges")
def test_audio_detect_silence_contract(self, mock_detect_silence_ranges) -> None:
mock_detect_silence_ranges.return_value = [{"start": 1.2, "end": 2.1, "duration": 0.9}]
payload = {
"input_path": "/tmp/in.wav",
"min_silence_ms": 500,
"silence_db": -35.0,
}
res = self.client.post("/audio/detect-silence", json=payload)
self.assertEqual(res.status_code, 200)
body = res.json()
self.assertEqual(body["status"], "ok")
self.assertEqual(body["count"], 1)
self.assertEqual(len(body["ranges"]), 1)
@patch("routers.audio.detect_silence_ranges")
def test_audio_detect_silence_error_returns_500(self, mock_detect_silence_ranges) -> None:
mock_detect_silence_ranges.side_effect = RuntimeError("silence-fail")
payload = {
"input_path": "/tmp/in.wav",
"min_silence_ms": 500,
"silence_db": -35.0,
}
res = self.client.post("/audio/detect-silence", json=payload)
self.assertEqual(res.status_code, 500)
self.assertEqual(res.json()["detail"], "silence-fail")
@patch("routers.audio.is_deepfilter_available")
def test_audio_capabilities_contract(self, mock_is_deepfilter_available) -> None:
mock_is_deepfilter_available.return_value = False
res = self.client.get("/audio/capabilities")
self.assertEqual(res.status_code, 200)
self.assertEqual(res.json(), {"deepfilternet_available": False})
@patch("routers.export.export_stream_copy")
def test_export_fast_contract(self, mock_export_stream_copy) -> None:
mock_export_stream_copy.return_value = "/tmp/out.mp4"
payload = {
"input_path": "/tmp/in.mp4",
"output_path": "/tmp/out.mp4",
"keep_segments": [{"start": 0.0, "end": 2.0}],
"mode": "fast",
"captions": "none",
}
res = self.client.post("/export", json=payload)
self.assertEqual(res.status_code, 200)
self.assertEqual(res.json(), {"status": "ok", "output_path": "/tmp/out.mp4"})
mock_export_stream_copy.assert_called_once()
@patch("routers.export.save_captions")
@patch("routers.export.generate_srt")
@patch("routers.export.export_stream_copy")
def test_export_sidecar_caption_contract(self, mock_export_stream_copy, mock_generate_srt, mock_save_captions) -> None:
mock_export_stream_copy.return_value = "/tmp/out.mp4"
mock_generate_srt.return_value = "1\n00:00:00,000 --> 00:00:01,000\nHello\n"
payload = {
"input_path": "/tmp/in.mp4",
"output_path": "/tmp/out.mp4",
"keep_segments": [{"start": 0.0, "end": 2.0}],
"mode": "fast",
"captions": "sidecar",
"words": [{"word": "Hello", "start": 0.0, "end": 1.0}],
"deleted_indices": [],
}
res = self.client.post("/export", json=payload)
self.assertEqual(res.status_code, 200)
body = res.json()
self.assertEqual(body["status"], "ok")
self.assertEqual(body["output_path"], "/tmp/out.mp4")
self.assertEqual(body["srt_path"], "/tmp/out.srt")
mock_save_captions.assert_called_once()
def test_export_missing_segments_returns_400(self) -> None:
payload = {
"input_path": "/tmp/in.mp4",
"output_path": "/tmp/out.mp4",
"keep_segments": [],
"mode": "fast",
"captions": "none",
}
res = self.client.post("/export", json=payload)
self.assertEqual(res.status_code, 400)
self.assertIn("No segments to export", res.json()["detail"])
@patch("routers.export.export_stream_copy")
def test_export_runtime_error_returns_500(self, mock_export_stream_copy) -> None:
mock_export_stream_copy.side_effect = RuntimeError("export-fail")
payload = {
"input_path": "/tmp/in.mp4",
"output_path": "/tmp/out.mp4",
"keep_segments": [{"start": 0.0, "end": 2.0}],
"mode": "fast",
"captions": "none",
}
res = self.client.post("/export", json=payload)
self.assertEqual(res.status_code, 500)
self.assertEqual(res.json()["detail"], "export-fail")
if __name__ == "__main__":
unittest.main()