Initial CutScript release - Open-source AI-powered text-based video editor
CutScript is a local-first, Descript-like video editor where you edit video by editing text. Delete a word from the transcript and it's cut from the video. Features: - Word-level transcription with WhisperX - Text-based video editing with undo/redo - AI filler word removal (Ollama/OpenAI/Claude) - AI clip creation for shorts - Waveform timeline with virtualized transcript - FFmpeg stream-copy (fast) and re-encode (4K) export - Caption burn-in and sidecar SRT generation - Studio Sound audio enhancement (DeepFilterNet) - Keyboard shortcuts (J/K/L, Space, Delete, Ctrl+Z/S/E) - Encrypted API key storage - Project save/load (.aive files) Architecture: - Electron + React + Tailwind (frontend) - FastAPI + Python (backend) - WhisperX for transcription - FFmpeg for video processing - Multi-provider AI support Performance optimizations: - RAF-throttled time updates - Zustand selectors for granular subscriptions - Dual-canvas waveform rendering - Virtualized transcript with react-virtuoso Built on top of DataAnts-AI/VideoTranscriber, completely rewritten as a desktop application. License: MIT
This commit is contained in:
117
backend/main.py
Normal file
117
backend/main.py
Normal file
@ -0,0 +1,117 @@
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import FastAPI, Query, Request, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from routers import transcribe, export, ai, captions, audio
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("AI Video Editor backend starting up")
|
||||
yield
|
||||
logger.info("AI Video Editor backend shutting down")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="AI Video Editor Backend",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["Content-Range", "Accept-Ranges", "Content-Length"],
|
||||
)
|
||||
|
||||
app.include_router(transcribe.router)
|
||||
app.include_router(export.router)
|
||||
app.include_router(ai.router)
|
||||
app.include_router(captions.router)
|
||||
app.include_router(audio.router)
|
||||
|
||||
|
||||
MIME_MAP = {
|
||||
".mp4": "video/mp4",
|
||||
".mkv": "video/x-matroska",
|
||||
".mov": "video/quicktime",
|
||||
".avi": "video/x-msvideo",
|
||||
".webm": "video/webm",
|
||||
".m4a": "audio/mp4",
|
||||
".wav": "audio/wav",
|
||||
".mp3": "audio/mpeg",
|
||||
".flac": "audio/flac",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/file")
|
||||
async def serve_local_file(request: Request, path: str = Query(...)):
|
||||
"""Stream a local file with HTTP Range support (required for video seeking)."""
|
||||
file_path = Path(path)
|
||||
if not file_path.is_file():
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {path}")
|
||||
|
||||
file_size = file_path.stat().st_size
|
||||
content_type = MIME_MAP.get(file_path.suffix.lower(), "application/octet-stream")
|
||||
|
||||
range_header = request.headers.get("range")
|
||||
if range_header:
|
||||
range_spec = range_header.replace("bytes=", "")
|
||||
range_start_str, range_end_str = range_spec.split("-")
|
||||
range_start = int(range_start_str) if range_start_str else 0
|
||||
range_end = int(range_end_str) if range_end_str else file_size - 1
|
||||
range_end = min(range_end, file_size - 1)
|
||||
content_length = range_end - range_start + 1
|
||||
|
||||
def iter_range():
|
||||
with open(file_path, "rb") as f:
|
||||
f.seek(range_start)
|
||||
remaining = content_length
|
||||
while remaining > 0:
|
||||
chunk = f.read(min(65536, remaining))
|
||||
if not chunk:
|
||||
break
|
||||
remaining -= len(chunk)
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
iter_range(),
|
||||
status_code=206,
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Content-Range": f"bytes {range_start}-{range_end}/{file_size}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(content_length),
|
||||
},
|
||||
)
|
||||
|
||||
def iter_file():
|
||||
with open(file_path, "rb") as f:
|
||||
while chunk := f.read(65536):
|
||||
yield chunk
|
||||
|
||||
return StreamingResponse(
|
||||
iter_file(),
|
||||
media_type=content_type,
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Length": str(file_size),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
33
backend/requirements.txt
Normal file
33
backend/requirements.txt
Normal file
@ -0,0 +1,33 @@
|
||||
# FastAPI backend
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.32.0
|
||||
websockets>=14.0
|
||||
python-multipart>=0.0.12
|
||||
|
||||
# Transcription (WhisperX for word-level alignment)
|
||||
whisperx>=3.1.0
|
||||
faster-whisper>=1.0.0
|
||||
|
||||
# Audio / Video processing
|
||||
moviepy>=1.0.3
|
||||
ffmpeg-python>=0.2.0
|
||||
soundfile>=0.10.3
|
||||
|
||||
# ML / GPU
|
||||
torch>=2.0.0
|
||||
torchaudio>=2.0.0
|
||||
numpy>=1.24.0
|
||||
|
||||
# Speaker diarization
|
||||
pyannote.audio>=3.1.1
|
||||
|
||||
# AI providers
|
||||
openai>=1.50.0
|
||||
anthropic>=0.39.0
|
||||
requests>=2.28.0
|
||||
|
||||
# Audio cleanup
|
||||
deepfilternet>=0.5.0
|
||||
|
||||
# Utilities
|
||||
pydantic>=2.0.0
|
||||
0
backend/routers/__init__.py
Normal file
0
backend/routers/__init__.py
Normal file
83
backend/routers/ai.py
Normal file
83
backend/routers/ai.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""AI feature endpoints: filler word detection, clip creation, Ollama model listing."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.ai_provider import AIProvider, detect_filler_words, create_clip_suggestion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class WordInfo(BaseModel):
|
||||
index: int
|
||||
word: str
|
||||
start: Optional[float] = None
|
||||
end: Optional[float] = None
|
||||
|
||||
|
||||
class FillerRequest(BaseModel):
|
||||
transcript: str
|
||||
words: List[WordInfo]
|
||||
provider: str = "ollama"
|
||||
model: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
custom_filler_words: Optional[str] = None
|
||||
|
||||
|
||||
class ClipRequest(BaseModel):
|
||||
transcript: str
|
||||
words: List[WordInfo]
|
||||
provider: str = "ollama"
|
||||
model: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
target_duration: int = 60
|
||||
|
||||
|
||||
@router.post("/ai/filler-removal")
|
||||
async def filler_removal(req: FillerRequest):
|
||||
try:
|
||||
words_dicts = [w.model_dump() for w in req.words]
|
||||
result = detect_filler_words(
|
||||
transcript=req.transcript,
|
||||
words=words_dicts,
|
||||
provider=req.provider,
|
||||
model=req.model,
|
||||
api_key=req.api_key,
|
||||
base_url=req.base_url,
|
||||
custom_filler_words=req.custom_filler_words,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Filler detection failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/ai/create-clip")
|
||||
async def create_clip(req: ClipRequest):
|
||||
try:
|
||||
words_dicts = [w.model_dump() for w in req.words]
|
||||
result = create_clip_suggestion(
|
||||
transcript=req.transcript,
|
||||
words=words_dicts,
|
||||
target_duration=req.target_duration,
|
||||
provider=req.provider,
|
||||
model=req.model,
|
||||
api_key=req.api_key,
|
||||
base_url=req.base_url,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Clip creation failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/ai/ollama-models")
|
||||
async def ollama_models(base_url: str = "http://localhost:11434"):
|
||||
models = AIProvider.list_ollama_models(base_url)
|
||||
return {"models": models}
|
||||
38
backend/routers/audio.py
Normal file
38
backend/routers/audio.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""Audio processing endpoint (noise reduction / Studio Sound)."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.audio_cleaner import clean_audio, is_deepfilter_available
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AudioCleanRequest(BaseModel):
|
||||
input_path: str
|
||||
output_path: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/audio/clean")
|
||||
async def clean_audio_endpoint(req: AudioCleanRequest):
|
||||
try:
|
||||
output = clean_audio(req.input_path, req.output_path or "")
|
||||
return {
|
||||
"status": "ok",
|
||||
"output_path": output,
|
||||
"engine": "deepfilternet" if is_deepfilter_available() else "ffmpeg_anlmdn",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Audio cleaning failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/audio/capabilities")
|
||||
async def audio_capabilities():
|
||||
return {
|
||||
"deepfilternet_available": is_deepfilter_available(),
|
||||
}
|
||||
65
backend/routers/captions.py
Normal file
65
backend/routers/captions.py
Normal file
@ -0,0 +1,65 @@
|
||||
"""Caption generation endpoint."""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.caption_generator import generate_srt, generate_vtt, generate_ass, save_captions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CaptionWord(BaseModel):
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
class CaptionStyle(BaseModel):
|
||||
fontName: str = "Arial"
|
||||
fontSize: int = 48
|
||||
fontColor: str = "&H00FFFFFF"
|
||||
backgroundColor: str = "&H80000000"
|
||||
position: str = "bottom"
|
||||
bold: bool = True
|
||||
|
||||
|
||||
class CaptionRequest(BaseModel):
|
||||
words: List[CaptionWord]
|
||||
deleted_indices: List[int] = []
|
||||
format: str = "srt"
|
||||
words_per_line: int = 8
|
||||
style: Optional[CaptionStyle] = None
|
||||
output_path: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/captions")
|
||||
async def generate_captions(req: CaptionRequest):
|
||||
try:
|
||||
words_dicts = [w.model_dump() for w in req.words]
|
||||
deleted_set = set(req.deleted_indices)
|
||||
|
||||
if req.format == "srt":
|
||||
content = generate_srt(words_dicts, deleted_set, req.words_per_line)
|
||||
elif req.format == "vtt":
|
||||
content = generate_vtt(words_dicts, deleted_set, req.words_per_line)
|
||||
elif req.format == "ass":
|
||||
style_dict = req.style.model_dump() if req.style else None
|
||||
content = generate_ass(words_dicts, deleted_set, req.words_per_line, style_dict)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown format: {req.format}")
|
||||
|
||||
if req.output_path:
|
||||
saved = save_captions(content, req.output_path)
|
||||
return {"status": "ok", "output_path": saved}
|
||||
|
||||
return PlainTextResponse(content, media_type="text/plain")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Caption generation failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
156
backend/routers/export.py
Normal file
156
backend/routers/export.py
Normal file
@ -0,0 +1,156 @@
|
||||
"""Export endpoint for video cutting and rendering."""
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.video_editor import export_stream_copy, export_reencode, export_reencode_with_subs
|
||||
from services.audio_cleaner import clean_audio
|
||||
from services.caption_generator import generate_srt, generate_ass, save_captions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SegmentModel(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
|
||||
|
||||
class ExportWordModel(BaseModel):
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
confidence: float = 0.0
|
||||
|
||||
|
||||
class ExportRequest(BaseModel):
|
||||
input_path: str
|
||||
output_path: str
|
||||
keep_segments: List[SegmentModel]
|
||||
mode: str = "fast"
|
||||
resolution: str = "1080p"
|
||||
format: str = "mp4"
|
||||
enhanceAudio: bool = False
|
||||
captions: str = "none"
|
||||
words: Optional[List[ExportWordModel]] = None
|
||||
deleted_indices: Optional[List[int]] = None
|
||||
|
||||
|
||||
def _mux_audio(video_path: str, audio_path: str, output_path: str) -> str:
|
||||
"""Replace video's audio track with cleaned audio using FFmpeg."""
|
||||
import subprocess
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", video_path,
|
||||
"-i", audio_path,
|
||||
"-c:v", "copy",
|
||||
"-map", "0:v:0",
|
||||
"-map", "1:a:0",
|
||||
"-shortest",
|
||||
output_path,
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Audio mux failed: {result.stderr[-300:]}")
|
||||
return output_path
|
||||
|
||||
|
||||
@router.post("/export")
|
||||
async def export_video(req: ExportRequest):
|
||||
try:
|
||||
segments = [{"start": s.start, "end": s.end} for s in req.keep_segments]
|
||||
|
||||
if not segments:
|
||||
raise HTTPException(status_code=400, detail="No segments to export")
|
||||
|
||||
use_stream_copy = req.mode == "fast" and len(segments) == 1
|
||||
needs_reencode_for_subs = req.captions == "burn-in"
|
||||
|
||||
# Burn-in captions require re-encode
|
||||
if needs_reencode_for_subs:
|
||||
use_stream_copy = False
|
||||
|
||||
words_dicts = [w.model_dump() for w in req.words] if req.words else []
|
||||
deleted_set = set(req.deleted_indices or [])
|
||||
|
||||
# Generate ASS file for burn-in
|
||||
ass_path = None
|
||||
if req.captions == "burn-in" and words_dicts:
|
||||
ass_content = generate_ass(words_dicts, deleted_set)
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=".ass", delete=False, mode="w", encoding="utf-8")
|
||||
tmp.write(ass_content)
|
||||
tmp.close()
|
||||
ass_path = tmp.name
|
||||
|
||||
try:
|
||||
if use_stream_copy:
|
||||
output = export_stream_copy(req.input_path, req.output_path, segments)
|
||||
elif ass_path:
|
||||
output = export_reencode_with_subs(
|
||||
req.input_path,
|
||||
req.output_path,
|
||||
segments,
|
||||
ass_path,
|
||||
resolution=req.resolution,
|
||||
format_hint=req.format,
|
||||
)
|
||||
else:
|
||||
output = export_reencode(
|
||||
req.input_path,
|
||||
req.output_path,
|
||||
segments,
|
||||
resolution=req.resolution,
|
||||
format_hint=req.format,
|
||||
)
|
||||
finally:
|
||||
if ass_path and os.path.exists(ass_path):
|
||||
os.unlink(ass_path)
|
||||
|
||||
# Audio enhancement: clean, then mux back into the exported video
|
||||
if req.enhanceAudio:
|
||||
try:
|
||||
tmp_dir = tempfile.mkdtemp(prefix="cutscript_audio_")
|
||||
cleaned_audio = os.path.join(tmp_dir, "cleaned.wav")
|
||||
clean_audio(output, cleaned_audio)
|
||||
|
||||
muxed_path = output + ".muxed.mp4"
|
||||
_mux_audio(output, cleaned_audio, muxed_path)
|
||||
|
||||
os.replace(muxed_path, output)
|
||||
logger.info(f"Audio enhanced and muxed into {output}")
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
os.remove(cleaned_audio)
|
||||
os.rmdir(tmp_dir)
|
||||
except OSError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Audio enhancement failed (non-fatal): {e}")
|
||||
|
||||
# Sidecar SRT: generate and save alongside video
|
||||
srt_path = None
|
||||
if req.captions == "sidecar" and words_dicts:
|
||||
srt_content = generate_srt(words_dicts, deleted_set)
|
||||
srt_path = req.output_path.rsplit(".", 1)[0] + ".srt"
|
||||
save_captions(srt_content, srt_path)
|
||||
logger.info(f"Sidecar SRT saved to {srt_path}")
|
||||
|
||||
result = {"status": "ok", "output_path": output}
|
||||
if srt_path:
|
||||
result["srt_path"] = srt_path
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"Export failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Export error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
53
backend/routers/transcribe.py
Normal file
53
backend/routers/transcribe.py
Normal file
@ -0,0 +1,53 @@
|
||||
"""Transcription endpoint using WhisperX."""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.transcription import transcribe_audio
|
||||
from services.diarization import diarize_and_label
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class TranscribeRequest(BaseModel):
|
||||
file_path: str
|
||||
model: str = "base"
|
||||
language: Optional[str] = None
|
||||
use_gpu: bool = True
|
||||
use_cache: bool = True
|
||||
diarize: bool = False
|
||||
hf_token: Optional[str] = None
|
||||
num_speakers: Optional[int] = None
|
||||
|
||||
|
||||
@router.post("/transcribe")
|
||||
async def transcribe(req: TranscribeRequest):
|
||||
try:
|
||||
result = transcribe_audio(
|
||||
file_path=req.file_path,
|
||||
model_name=req.model,
|
||||
use_gpu=req.use_gpu,
|
||||
use_cache=req.use_cache,
|
||||
language=req.language,
|
||||
)
|
||||
|
||||
if req.diarize and req.hf_token:
|
||||
result = diarize_and_label(
|
||||
transcription_result=result,
|
||||
audio_path=req.file_path,
|
||||
hf_token=req.hf_token,
|
||||
num_speakers=req.num_speakers,
|
||||
use_gpu=req.use_gpu,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except FileNotFoundError:
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {req.file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
0
backend/services/__init__.py
Normal file
0
backend/services/__init__.py
Normal file
211
backend/services/ai_provider.py
Normal file
211
backend/services/ai_provider.py
Normal file
@ -0,0 +1,211 @@
|
||||
"""
|
||||
Unified AI provider interface for Ollama, OpenAI, and Claude.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
import requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AIProvider:
|
||||
"""Routes completion requests to the configured provider."""
|
||||
|
||||
@staticmethod
|
||||
def complete(
|
||||
prompt: str,
|
||||
provider: str = "ollama",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
system_prompt: Optional[str] = None,
|
||||
temperature: float = 0.3,
|
||||
) -> str:
|
||||
if provider == "ollama":
|
||||
return _ollama_complete(prompt, model or "llama3", base_url or "http://localhost:11434", system_prompt, temperature)
|
||||
elif provider == "openai":
|
||||
return _openai_complete(prompt, model or "gpt-4o", api_key or "", system_prompt, temperature)
|
||||
elif provider == "claude":
|
||||
return _claude_complete(prompt, model or "claude-sonnet-4-20250514", api_key or "", system_prompt, temperature)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
||||
@staticmethod
|
||||
def list_ollama_models(base_url: str = "http://localhost:11434") -> List[str]:
|
||||
try:
|
||||
resp = requests.get(f"{base_url}/api/tags", timeout=3)
|
||||
if resp.status_code == 200:
|
||||
return [m["name"] for m in resp.json().get("models", [])]
|
||||
except Exception:
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
def _ollama_complete(prompt: str, model: str, base_url: str, system_prompt: Optional[str], temperature: float) -> str:
|
||||
body = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {"temperature": temperature},
|
||||
}
|
||||
if system_prompt:
|
||||
body["system"] = system_prompt
|
||||
|
||||
try:
|
||||
resp = requests.post(f"{base_url}/api/generate", json=body, timeout=120)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("response", "").strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Ollama error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _openai_complete(prompt: str, model: str, api_key: str, system_prompt: Optional[str], temperature: float) -> str:
|
||||
try:
|
||||
from openai import OpenAI
|
||||
client = OpenAI(api_key=api_key)
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({"role": "system", "content": system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
)
|
||||
return response.choices[0].message.content.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _claude_complete(prompt: str, model: str, api_key: str, system_prompt: Optional[str], temperature: float) -> str:
|
||||
try:
|
||||
import anthropic
|
||||
client = anthropic.Anthropic(api_key=api_key)
|
||||
kwargs = {
|
||||
"model": model,
|
||||
"max_tokens": 4096,
|
||||
"temperature": temperature,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
}
|
||||
if system_prompt:
|
||||
kwargs["system"] = system_prompt
|
||||
|
||||
response = client.messages.create(**kwargs)
|
||||
return response.content[0].text.strip()
|
||||
except Exception as e:
|
||||
logger.error(f"Claude error: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def detect_filler_words(
|
||||
transcript: str,
|
||||
words: List[dict],
|
||||
provider: str = "ollama",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
custom_filler_words: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Use an LLM to identify filler words in the transcript.
|
||||
Returns {"wordIndices": [...], "fillerWords": [{"index": N, "word": "...", "reason": "..."}]}
|
||||
"""
|
||||
word_list = "\n".join(f"{w['index']}: {w['word']}" for w in words)
|
||||
|
||||
custom_line = ""
|
||||
if custom_filler_words and custom_filler_words.strip():
|
||||
custom_line = f"\n\nAdditionally, flag these user-specified filler words/phrases: {custom_filler_words.strip()}"
|
||||
|
||||
prompt = f"""Analyze this transcript for filler words and verbal hesitations.
|
||||
|
||||
Filler words include: um, uh, uh huh, hmm, like (when used as filler), you know, so (when starting sentences unnecessarily), basically, actually, literally, right, I mean, kind of, sort of, well (when used as filler).
|
||||
|
||||
Also flag repeated words that indicate stammering (e.g., "I I I" or "the the").{custom_line}
|
||||
|
||||
Here are the words with their indices:
|
||||
{word_list}
|
||||
|
||||
Return ONLY a valid JSON object with this exact structure:
|
||||
{{"wordIndices": [list of integer indices to remove], "fillerWords": [{{"index": integer, "word": "the word", "reason": "brief reason"}}]}}
|
||||
|
||||
Be conservative -- only flag clear filler words, not words that are part of meaningful sentences."""
|
||||
|
||||
system = "You are a precise text analysis tool. Return only valid JSON, no explanation."
|
||||
|
||||
result_text = AIProvider.complete(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
system_prompt=system,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
try:
|
||||
start = result_text.find("{")
|
||||
end = result_text.rfind("}") + 1
|
||||
if start >= 0 and end > start:
|
||||
return json.loads(result_text[start:end])
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse AI response as JSON: {result_text[:200]}")
|
||||
|
||||
return {"wordIndices": [], "fillerWords": []}
|
||||
|
||||
|
||||
def create_clip_suggestion(
|
||||
transcript: str,
|
||||
words: List[dict],
|
||||
target_duration: int = 60,
|
||||
provider: str = "ollama",
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Use an LLM to find the best clip segments in a transcript.
|
||||
"""
|
||||
word_list = "\n".join(
|
||||
f"{w['index']}: \"{w['word']}\" ({w.get('start', 0):.1f}s - {w.get('end', 0):.1f}s)"
|
||||
for w in words
|
||||
)
|
||||
|
||||
prompt = f"""Analyze this transcript and find the most engaging {target_duration}-second segment(s) that would work well as a YouTube Short or social media clip.
|
||||
|
||||
Look for: compelling stories, surprising facts, emotional moments, clear explanations, humor, or quotable statements.
|
||||
|
||||
Words with indices and timestamps:
|
||||
{word_list}
|
||||
|
||||
Return ONLY a valid JSON object:
|
||||
{{"clips": [{{"title": "short catchy title", "startWordIndex": integer, "endWordIndex": integer, "startTime": float, "endTime": float, "reason": "why this segment is engaging"}}]}}
|
||||
|
||||
Suggest 1-3 clips, each approximately {target_duration} seconds long."""
|
||||
|
||||
system = "You are a viral content expert. Return only valid JSON, no explanation."
|
||||
|
||||
result_text = AIProvider.complete(
|
||||
prompt=prompt,
|
||||
provider=provider,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
system_prompt=system,
|
||||
temperature=0.5,
|
||||
)
|
||||
|
||||
try:
|
||||
start = result_text.find("{")
|
||||
end = result_text.rfind("}") + 1
|
||||
if start >= 0 and end > start:
|
||||
return json.loads(result_text[start:end])
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse clip suggestions: {result_text[:200]}")
|
||||
|
||||
return {"clips": []}
|
||||
79
backend/services/audio_cleaner.py
Normal file
79
backend/services/audio_cleaner.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""
|
||||
Audio noise reduction using DeepFilterNet.
|
||||
Falls back to a basic FFmpeg noise filter if DeepFilterNet is not installed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from df.enhance import enhance, init_df, load_audio, save_audio
|
||||
DEEPFILTER_AVAILABLE = True
|
||||
except ImportError:
|
||||
DEEPFILTER_AVAILABLE = False
|
||||
|
||||
|
||||
_df_model = None
|
||||
_df_state = None
|
||||
|
||||
|
||||
def _init_deepfilter():
|
||||
global _df_model, _df_state
|
||||
if _df_model is None:
|
||||
logger.info("Initializing DeepFilterNet model")
|
||||
_df_model, _df_state, _ = init_df()
|
||||
return _df_model, _df_state
|
||||
|
||||
|
||||
def clean_audio(
|
||||
input_path: str,
|
||||
output_path: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Apply noise reduction to an audio file.
|
||||
|
||||
If DeepFilterNet is available, uses it for high-quality results.
|
||||
Otherwise falls back to FFmpeg's anlmdn filter.
|
||||
|
||||
Returns: path to the cleaned audio file.
|
||||
"""
|
||||
input_path = Path(input_path)
|
||||
if not output_path:
|
||||
output_path = str(input_path.with_stem(input_path.stem + "_clean"))
|
||||
|
||||
if DEEPFILTER_AVAILABLE:
|
||||
return _clean_with_deepfilter(str(input_path), output_path)
|
||||
else:
|
||||
return _clean_with_ffmpeg(str(input_path), output_path)
|
||||
|
||||
|
||||
def _clean_with_deepfilter(input_path: str, output_path: str) -> str:
|
||||
model, state = _init_deepfilter()
|
||||
audio, info = load_audio(input_path, sr=state.sr())
|
||||
enhanced = enhance(model, state, audio)
|
||||
save_audio(output_path, enhanced, sr=state.sr())
|
||||
logger.info(f"DeepFilterNet cleaned audio saved to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def _clean_with_ffmpeg(input_path: str, output_path: str) -> str:
|
||||
"""Fallback: basic noise reduction using FFmpeg's anlmdn filter."""
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", input_path,
|
||||
"-af", "anlmdn=s=7:p=0.002:r=0.002:m=15",
|
||||
output_path,
|
||||
]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"FFmpeg audio cleaning failed: {result.stderr[-300:]}")
|
||||
logger.info(f"FFmpeg cleaned audio saved to {output_path}")
|
||||
return output_path
|
||||
|
||||
|
||||
def is_deepfilter_available() -> bool:
|
||||
return DEEPFILTER_AVAILABLE
|
||||
59
backend/services/background_removal.py
Normal file
59
backend/services/background_removal.py
Normal file
@ -0,0 +1,59 @@
|
||||
"""
|
||||
AI background removal (Phase 5 - future).
|
||||
Uses MediaPipe or Robust Video Matting for person segmentation.
|
||||
Export-only -- no real-time preview.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Placeholder for Phase 5 implementation
|
||||
# Will use mediapipe or rvm for segmentation at export time
|
||||
|
||||
MEDIAPIPE_AVAILABLE = False
|
||||
RVM_AVAILABLE = False
|
||||
|
||||
try:
|
||||
import mediapipe as mp
|
||||
MEDIAPIPE_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
pass # rvm import would go here
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
return MEDIAPIPE_AVAILABLE or RVM_AVAILABLE
|
||||
|
||||
|
||||
def remove_background_on_export(
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
replacement: str = "blur",
|
||||
replacement_value: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Process video frame-by-frame to remove/replace background.
|
||||
Only runs during export (not real-time).
|
||||
|
||||
Args:
|
||||
input_path: source video
|
||||
output_path: destination
|
||||
replacement: 'blur', 'color', 'image', or 'video'
|
||||
replacement_value: hex color, image path, or video path
|
||||
|
||||
Returns:
|
||||
output_path
|
||||
"""
|
||||
if not is_available():
|
||||
raise RuntimeError(
|
||||
"Background removal requires mediapipe or robust-video-matting. "
|
||||
"Install with: pip install mediapipe"
|
||||
)
|
||||
|
||||
# Phase 5 implementation will go here
|
||||
raise NotImplementedError("Background removal is planned for Phase 5")
|
||||
148
backend/services/caption_generator.py
Normal file
148
backend/services/caption_generator.py
Normal file
@ -0,0 +1,148 @@
|
||||
"""
|
||||
Generate caption files (SRT, VTT, ASS) from word-level timestamps.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _format_srt_time(seconds: float) -> str:
|
||||
h = int(seconds // 3600)
|
||||
m = int((seconds % 3600) // 60)
|
||||
s = int(seconds % 60)
|
||||
ms = int((seconds % 1) * 1000)
|
||||
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
|
||||
|
||||
|
||||
def _format_vtt_time(seconds: float) -> str:
|
||||
h = int(seconds // 3600)
|
||||
m = int((seconds % 3600) // 60)
|
||||
s = int(seconds % 60)
|
||||
ms = int((seconds % 1) * 1000)
|
||||
return f"{h:02d}:{m:02d}:{s:02d}.{ms:03d}"
|
||||
|
||||
|
||||
def _format_ass_time(seconds: float) -> str:
|
||||
h = int(seconds // 3600)
|
||||
m = int((seconds % 3600) // 60)
|
||||
s = int(seconds % 60)
|
||||
cs = int((seconds % 1) * 100)
|
||||
return f"{h}:{m:02d}:{s:02d}.{cs:02d}"
|
||||
|
||||
|
||||
def generate_srt(
|
||||
words: List[dict],
|
||||
deleted_indices: Optional[set] = None,
|
||||
words_per_line: int = 8,
|
||||
) -> str:
|
||||
"""Generate SRT caption content from word-level timestamps."""
|
||||
deleted_indices = deleted_indices or set()
|
||||
active_words = [(i, w) for i, w in enumerate(words) if i not in deleted_indices]
|
||||
|
||||
lines = []
|
||||
counter = 1
|
||||
for chunk_start in range(0, len(active_words), words_per_line):
|
||||
chunk = active_words[chunk_start:chunk_start + words_per_line]
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
start_time = chunk[0][1]["start"]
|
||||
end_time = chunk[-1][1]["end"]
|
||||
text = " ".join(w["word"] for _, w in chunk)
|
||||
|
||||
lines.append(str(counter))
|
||||
lines.append(f"{_format_srt_time(start_time)} --> {_format_srt_time(end_time)}")
|
||||
lines.append(text)
|
||||
lines.append("")
|
||||
counter += 1
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_vtt(
|
||||
words: List[dict],
|
||||
deleted_indices: Optional[set] = None,
|
||||
words_per_line: int = 8,
|
||||
) -> str:
|
||||
"""Generate WebVTT caption content."""
|
||||
deleted_indices = deleted_indices or set()
|
||||
active_words = [(i, w) for i, w in enumerate(words) if i not in deleted_indices]
|
||||
|
||||
lines = ["WEBVTT", ""]
|
||||
for chunk_start in range(0, len(active_words), words_per_line):
|
||||
chunk = active_words[chunk_start:chunk_start + words_per_line]
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
start_time = chunk[0][1]["start"]
|
||||
end_time = chunk[-1][1]["end"]
|
||||
text = " ".join(w["word"] for _, w in chunk)
|
||||
|
||||
lines.append(f"{_format_vtt_time(start_time)} --> {_format_vtt_time(end_time)}")
|
||||
lines.append(text)
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def generate_ass(
|
||||
words: List[dict],
|
||||
deleted_indices: Optional[set] = None,
|
||||
words_per_line: int = 8,
|
||||
style: Optional[dict] = None,
|
||||
) -> str:
|
||||
"""Generate ASS subtitle content with styling."""
|
||||
deleted_indices = deleted_indices or set()
|
||||
active_words = [(i, w) for i, w in enumerate(words) if i not in deleted_indices]
|
||||
|
||||
s = style or {}
|
||||
font = s.get("fontName", "Arial")
|
||||
size = s.get("fontSize", 48)
|
||||
color = s.get("fontColor", "&H00FFFFFF")
|
||||
bold = "-1" if s.get("bold", True) else "0"
|
||||
alignment = 2
|
||||
|
||||
header = f"""[Script Info]
|
||||
Title: AI Video Editor Captions
|
||||
ScriptType: v4.00+
|
||||
PlayResX: 1920
|
||||
PlayResY: 1080
|
||||
|
||||
[V4+ Styles]
|
||||
Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
|
||||
Style: Default,{font},{size},{color},&H000000FF,&H00000000,&H80000000,{bold},0,0,0,100,100,0,0,1,2,1,{alignment},20,20,40,1
|
||||
|
||||
[Events]
|
||||
Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
|
||||
"""
|
||||
|
||||
events = []
|
||||
for chunk_start in range(0, len(active_words), words_per_line):
|
||||
chunk = active_words[chunk_start:chunk_start + words_per_line]
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
start_time = chunk[0][1]["start"]
|
||||
end_time = chunk[-1][1]["end"]
|
||||
text = " ".join(w["word"] for _, w in chunk)
|
||||
|
||||
events.append(
|
||||
f"Dialogue: 0,{_format_ass_time(start_time)},{_format_ass_time(end_time)},Default,,0,0,0,,{text}"
|
||||
)
|
||||
|
||||
return header + "\n".join(events) + "\n"
|
||||
|
||||
|
||||
def save_captions(
|
||||
content: str,
|
||||
output_path: str,
|
||||
) -> str:
|
||||
"""Write caption content to a file."""
|
||||
output_path = Path(output_path)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_path.write_text(content, encoding="utf-8")
|
||||
logger.info(f"Saved captions to {output_path}")
|
||||
return str(output_path)
|
||||
98
backend/services/diarization.py
Normal file
98
backend/services/diarization.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
Speaker diarization service using pyannote.audio.
|
||||
Refactored from the original repo -- removed Streamlit dependency.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from utils.gpu_utils import get_optimal_device
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_pipeline_cache = {}
|
||||
|
||||
|
||||
def _get_pipeline(hf_token: str, device: torch.device):
|
||||
cache_key = str(device)
|
||||
if cache_key in _pipeline_cache:
|
||||
return _pipeline_cache[cache_key]
|
||||
|
||||
try:
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.0",
|
||||
use_auth_token=hf_token,
|
||||
)
|
||||
if device.type == "cuda":
|
||||
pipeline = pipeline.to(device)
|
||||
|
||||
_pipeline_cache[cache_key] = pipeline
|
||||
return pipeline
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load diarization pipeline: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def diarize_and_label(
|
||||
transcription_result: dict,
|
||||
audio_path: str,
|
||||
hf_token: Optional[str] = None,
|
||||
num_speakers: Optional[int] = None,
|
||||
use_gpu: bool = True,
|
||||
) -> dict:
|
||||
"""
|
||||
Apply speaker diarization to an existing transcription result.
|
||||
Adds 'speaker' field to each word and segment.
|
||||
|
||||
Returns the mutated transcription_result with speaker labels.
|
||||
"""
|
||||
hf_token = hf_token or os.environ.get("HF_TOKEN")
|
||||
if not hf_token:
|
||||
logger.warning("No HuggingFace token provided; skipping diarization")
|
||||
return transcription_result
|
||||
|
||||
device = get_optimal_device() if use_gpu else torch.device("cpu")
|
||||
pipeline = _get_pipeline(hf_token, device)
|
||||
if pipeline is None:
|
||||
return transcription_result
|
||||
|
||||
audio_path = Path(audio_path)
|
||||
logger.info(f"Running diarization on {audio_path}")
|
||||
|
||||
try:
|
||||
diarization = pipeline(str(audio_path), num_speakers=num_speakers)
|
||||
except Exception as e:
|
||||
logger.error(f"Diarization failed: {e}")
|
||||
return transcription_result
|
||||
|
||||
speaker_map = []
|
||||
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||
speaker_map.append((turn.start, turn.end, speaker))
|
||||
|
||||
def _find_speaker(start: float, end: float) -> str:
|
||||
best_overlap = 0
|
||||
best_speaker = "UNKNOWN"
|
||||
for s_start, s_end, speaker in speaker_map:
|
||||
overlap_start = max(start, s_start)
|
||||
overlap_end = min(end, s_end)
|
||||
overlap = max(0, overlap_end - overlap_start)
|
||||
if overlap > best_overlap:
|
||||
best_overlap = overlap
|
||||
best_speaker = speaker
|
||||
return best_speaker
|
||||
|
||||
for word in transcription_result.get("words", []):
|
||||
word["speaker"] = _find_speaker(word["start"], word["end"])
|
||||
|
||||
for segment in transcription_result.get("segments", []):
|
||||
segment["speaker"] = _find_speaker(segment["start"], segment["end"])
|
||||
for w in segment.get("words", []):
|
||||
w["speaker"] = _find_speaker(w["start"], w["end"])
|
||||
|
||||
return transcription_result
|
||||
205
backend/services/transcription.py
Normal file
205
backend/services/transcription.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""
|
||||
WhisperX-based transcription service with word-level alignment.
|
||||
Falls back to standard Whisper if WhisperX is not available.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from utils.gpu_utils import get_optimal_device, configure_gpu
|
||||
from utils.audio_processing import extract_audio
|
||||
from utils.cache import load_from_cache, save_to_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_model_cache: dict = {}
|
||||
|
||||
try:
|
||||
import whisperx
|
||||
WHISPERX_AVAILABLE = True
|
||||
except ImportError:
|
||||
WHISPERX_AVAILABLE = False
|
||||
import whisper
|
||||
|
||||
try:
|
||||
HF_TOKEN = None
|
||||
import os
|
||||
HF_TOKEN = os.environ.get("HF_TOKEN")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _get_device(use_gpu: bool = True) -> torch.device:
|
||||
if use_gpu:
|
||||
return get_optimal_device()
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def _load_model(model_name: str, device: torch.device):
|
||||
cache_key = f"{model_name}_{device}"
|
||||
if cache_key in _model_cache:
|
||||
return _model_cache[cache_key]
|
||||
|
||||
logger.info(f"Loading model: {model_name} on {device}")
|
||||
if WHISPERX_AVAILABLE:
|
||||
compute_type = "float16" if device.type == "cuda" else "int8"
|
||||
model = whisperx.load_model(
|
||||
model_name,
|
||||
device=str(device),
|
||||
compute_type=compute_type,
|
||||
)
|
||||
else:
|
||||
model = whisper.load_model(model_name, device=device)
|
||||
|
||||
_model_cache[cache_key] = model
|
||||
return model
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
file_path: str,
|
||||
model_name: str = "base",
|
||||
use_gpu: bool = True,
|
||||
use_cache: bool = True,
|
||||
language: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Transcribe audio/video file and return word-level timestamps.
|
||||
|
||||
Returns:
|
||||
dict with keys: words, segments, language
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if use_cache:
|
||||
cached = load_from_cache(file_path, model_name, "transcribe_wx")
|
||||
if cached:
|
||||
logger.info("Using cached transcription")
|
||||
return cached
|
||||
|
||||
video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm"}
|
||||
if file_path.suffix.lower() in video_extensions:
|
||||
audio_path = extract_audio(file_path)
|
||||
else:
|
||||
audio_path = file_path
|
||||
|
||||
device = _get_device(use_gpu)
|
||||
model = _load_model(model_name, device)
|
||||
|
||||
logger.info(f"Transcribing: {file_path}")
|
||||
|
||||
if WHISPERX_AVAILABLE:
|
||||
result = _transcribe_whisperx(model, str(audio_path), device, language)
|
||||
else:
|
||||
result = _transcribe_standard(model, str(audio_path), language)
|
||||
|
||||
if use_cache:
|
||||
save_to_cache(file_path, result, model_name, "transcribe_wx")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _transcribe_whisperx(model, audio_path: str, device: torch.device, language: Optional[str]) -> dict:
|
||||
audio = whisperx.load_audio(audio_path)
|
||||
transcribe_opts = {}
|
||||
if language:
|
||||
transcribe_opts["language"] = language
|
||||
|
||||
result = model.transcribe(audio, batch_size=16, **transcribe_opts)
|
||||
detected_language = result.get("language", "en")
|
||||
|
||||
align_model, align_metadata = whisperx.load_align_model(
|
||||
language_code=detected_language,
|
||||
device=str(device),
|
||||
)
|
||||
aligned = whisperx.align(
|
||||
result["segments"],
|
||||
align_model,
|
||||
align_metadata,
|
||||
audio,
|
||||
str(device),
|
||||
return_char_alignments=False,
|
||||
)
|
||||
|
||||
words = []
|
||||
for seg in aligned.get("segments", []):
|
||||
for w in seg.get("words", []):
|
||||
words.append({
|
||||
"word": w.get("word", ""),
|
||||
"start": round(w.get("start", 0), 3),
|
||||
"end": round(w.get("end", 0), 3),
|
||||
"confidence": round(w.get("score", 0), 3),
|
||||
})
|
||||
|
||||
segments = []
|
||||
for i, seg in enumerate(aligned.get("segments", [])):
|
||||
seg_words = []
|
||||
for w in seg.get("words", []):
|
||||
seg_words.append({
|
||||
"word": w.get("word", ""),
|
||||
"start": round(w.get("start", 0), 3),
|
||||
"end": round(w.get("end", 0), 3),
|
||||
"confidence": round(w.get("score", 0), 3),
|
||||
})
|
||||
segments.append({
|
||||
"id": i,
|
||||
"start": round(seg.get("start", 0), 3),
|
||||
"end": round(seg.get("end", 0), 3),
|
||||
"text": seg.get("text", "").strip(),
|
||||
"words": seg_words,
|
||||
})
|
||||
|
||||
return {
|
||||
"words": words,
|
||||
"segments": segments,
|
||||
"language": detected_language,
|
||||
}
|
||||
|
||||
|
||||
def _transcribe_standard(model, audio_path: str, language: Optional[str]) -> dict:
|
||||
"""Fallback: standard Whisper (segment-level only, synthesized word timestamps)."""
|
||||
opts = {}
|
||||
if language:
|
||||
opts["language"] = language
|
||||
|
||||
result = model.transcribe(audio_path, **opts)
|
||||
detected_language = result.get("language", "en")
|
||||
|
||||
words = []
|
||||
segments = []
|
||||
|
||||
for i, seg in enumerate(result.get("segments", [])):
|
||||
text = seg.get("text", "").strip()
|
||||
seg_start = seg.get("start", 0)
|
||||
seg_end = seg.get("end", 0)
|
||||
seg_words_text = text.split()
|
||||
duration = seg_end - seg_start
|
||||
|
||||
seg_words = []
|
||||
for j, w_text in enumerate(seg_words_text):
|
||||
w_start = seg_start + (j / max(len(seg_words_text), 1)) * duration
|
||||
w_end = seg_start + ((j + 1) / max(len(seg_words_text), 1)) * duration
|
||||
word_obj = {
|
||||
"word": w_text,
|
||||
"start": round(w_start, 3),
|
||||
"end": round(w_end, 3),
|
||||
"confidence": 0.5,
|
||||
}
|
||||
words.append(word_obj)
|
||||
seg_words.append(word_obj)
|
||||
|
||||
segments.append({
|
||||
"id": i,
|
||||
"start": round(seg_start, 3),
|
||||
"end": round(seg_end, 3),
|
||||
"text": text,
|
||||
"words": seg_words,
|
||||
})
|
||||
|
||||
return {
|
||||
"words": words,
|
||||
"segments": segments,
|
||||
"language": detected_language,
|
||||
}
|
||||
271
backend/services/video_editor.py
Normal file
271
backend/services/video_editor.py
Normal file
@ -0,0 +1,271 @@
|
||||
"""
|
||||
FFmpeg-based video cutting engine.
|
||||
Uses stream copy for fast, lossless cuts and falls back to re-encode when needed.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _find_ffmpeg() -> str:
|
||||
"""Locate ffmpeg binary."""
|
||||
for cmd in ["ffmpeg", "ffmpeg.exe"]:
|
||||
try:
|
||||
subprocess.run([cmd, "-version"], capture_output=True, check=True)
|
||||
return cmd
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
continue
|
||||
raise RuntimeError("FFmpeg not found. Install it or add it to PATH.")
|
||||
|
||||
|
||||
def export_stream_copy(
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
keep_segments: List[dict],
|
||||
) -> str:
|
||||
"""
|
||||
Export video using FFmpeg concat demuxer with stream copy.
|
||||
~100x faster than re-encoding. No quality loss.
|
||||
|
||||
Args:
|
||||
input_path: source video file
|
||||
output_path: destination file
|
||||
keep_segments: list of {"start": float, "end": float} to keep
|
||||
|
||||
Returns:
|
||||
output_path on success
|
||||
"""
|
||||
ffmpeg = _find_ffmpeg()
|
||||
input_path = str(Path(input_path).resolve())
|
||||
output_path = str(Path(output_path).resolve())
|
||||
|
||||
if not keep_segments:
|
||||
raise ValueError("No segments to export")
|
||||
|
||||
temp_dir = tempfile.mkdtemp(prefix="aive_export_")
|
||||
|
||||
try:
|
||||
segment_files = []
|
||||
for i, seg in enumerate(keep_segments):
|
||||
seg_file = os.path.join(temp_dir, f"seg_{i:04d}.ts")
|
||||
cmd = [
|
||||
ffmpeg, "-y",
|
||||
"-ss", str(seg["start"]),
|
||||
"-to", str(seg["end"]),
|
||||
"-i", input_path,
|
||||
"-c", "copy",
|
||||
"-avoid_negative_ts", "make_zero",
|
||||
"-f", "mpegts",
|
||||
seg_file,
|
||||
]
|
||||
logger.info(f"Extracting segment {i}: {seg['start']:.2f}s - {seg['end']:.2f}s")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
logger.warning(f"Stream copy segment {i} failed, will try re-encode: {result.stderr[-200:]}")
|
||||
return export_reencode(input_path, output_path, keep_segments)
|
||||
segment_files.append(seg_file)
|
||||
|
||||
concat_str = "|".join(segment_files)
|
||||
cmd = [
|
||||
ffmpeg, "-y",
|
||||
"-i", f"concat:{concat_str}",
|
||||
"-c", "copy",
|
||||
"-movflags", "+faststart",
|
||||
output_path,
|
||||
]
|
||||
logger.info(f"Concatenating {len(segment_files)} segments -> {output_path}")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
logger.warning(f"Concat failed, falling back to re-encode: {result.stderr[-200:]}")
|
||||
return export_reencode(input_path, output_path, keep_segments)
|
||||
|
||||
return output_path
|
||||
|
||||
finally:
|
||||
for f in os.listdir(temp_dir):
|
||||
try:
|
||||
os.remove(os.path.join(temp_dir, f))
|
||||
except OSError:
|
||||
pass
|
||||
try:
|
||||
os.rmdir(temp_dir)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def export_reencode(
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
keep_segments: List[dict],
|
||||
resolution: str = "1080p",
|
||||
format_hint: str = "mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Export video with full re-encode. Slower but supports resolution changes,
|
||||
format conversion, and avoids stream-copy edge cases.
|
||||
"""
|
||||
ffmpeg = _find_ffmpeg()
|
||||
input_path = str(Path(input_path).resolve())
|
||||
output_path = str(Path(output_path).resolve())
|
||||
|
||||
if not keep_segments:
|
||||
raise ValueError("No segments to export")
|
||||
|
||||
scale_map = {
|
||||
"720p": "scale=-2:720",
|
||||
"1080p": "scale=-2:1080",
|
||||
"4k": "scale=-2:2160",
|
||||
}
|
||||
|
||||
filter_parts = []
|
||||
for i, seg in enumerate(keep_segments):
|
||||
filter_parts.append(
|
||||
f"[0:v]trim=start={seg['start']}:end={seg['end']},setpts=PTS-STARTPTS[v{i}];"
|
||||
f"[0:a]atrim=start={seg['start']}:end={seg['end']},asetpts=PTS-STARTPTS[a{i}];"
|
||||
)
|
||||
|
||||
n = len(keep_segments)
|
||||
concat_inputs = "".join(f"[v{i}][a{i}]" for i in range(n))
|
||||
filter_parts.append(f"{concat_inputs}concat=n={n}:v=1:a=1[outv][outa]")
|
||||
|
||||
filter_complex = "".join(filter_parts)
|
||||
|
||||
scale = scale_map.get(resolution, "")
|
||||
if scale:
|
||||
filter_complex += f";[outv]{scale}[outv_scaled]"
|
||||
video_map = "[outv_scaled]"
|
||||
else:
|
||||
video_map = "[outv]"
|
||||
|
||||
codec_args = ["-c:v", "libx264", "-preset", "medium", "-crf", "18", "-c:a", "aac", "-b:a", "192k"]
|
||||
if format_hint == "webm":
|
||||
codec_args = ["-c:v", "libvpx-vp9", "-crf", "30", "-b:v", "0", "-c:a", "libopus"]
|
||||
|
||||
cmd = [
|
||||
ffmpeg, "-y",
|
||||
"-i", input_path,
|
||||
"-filter_complex", filter_complex,
|
||||
"-map", video_map,
|
||||
"-map", "[outa]",
|
||||
*codec_args,
|
||||
"-movflags", "+faststart",
|
||||
output_path,
|
||||
]
|
||||
|
||||
logger.info(f"Re-encoding {n} segments -> {output_path} ({resolution})")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"FFmpeg re-encode failed: {result.stderr[-500:]}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def export_reencode_with_subs(
|
||||
input_path: str,
|
||||
output_path: str,
|
||||
keep_segments: List[dict],
|
||||
subtitle_path: str,
|
||||
resolution: str = "1080p",
|
||||
format_hint: str = "mp4",
|
||||
) -> str:
|
||||
"""
|
||||
Export video with re-encode and burn-in subtitles (ASS format).
|
||||
Applies trim+concat first, then overlays the subtitle file.
|
||||
"""
|
||||
ffmpeg = _find_ffmpeg()
|
||||
input_path = str(Path(input_path).resolve())
|
||||
output_path = str(Path(output_path).resolve())
|
||||
subtitle_path = str(Path(subtitle_path).resolve())
|
||||
|
||||
if not keep_segments:
|
||||
raise ValueError("No segments to export")
|
||||
|
||||
scale_map = {
|
||||
"720p": "scale=-2:720",
|
||||
"1080p": "scale=-2:1080",
|
||||
"4k": "scale=-2:2160",
|
||||
}
|
||||
|
||||
filter_parts = []
|
||||
for i, seg in enumerate(keep_segments):
|
||||
filter_parts.append(
|
||||
f"[0:v]trim=start={seg['start']}:end={seg['end']},setpts=PTS-STARTPTS[v{i}];"
|
||||
f"[0:a]atrim=start={seg['start']}:end={seg['end']},asetpts=PTS-STARTPTS[a{i}];"
|
||||
)
|
||||
|
||||
n = len(keep_segments)
|
||||
concat_inputs = "".join(f"[v{i}][a{i}]" for i in range(n))
|
||||
filter_parts.append(f"{concat_inputs}concat=n={n}:v=1:a=1[outv][outa]")
|
||||
|
||||
filter_complex = "".join(filter_parts)
|
||||
|
||||
# Escape path for FFmpeg subtitle filter (Windows backslashes need escaping)
|
||||
escaped_sub = subtitle_path.replace("\\", "/").replace(":", "\\:")
|
||||
|
||||
scale = scale_map.get(resolution, "")
|
||||
if scale:
|
||||
filter_complex += f";[outv]{scale},ass='{escaped_sub}'[outv_final]"
|
||||
else:
|
||||
filter_complex += f";[outv]ass='{escaped_sub}'[outv_final]"
|
||||
video_map = "[outv_final]"
|
||||
|
||||
codec_args = ["-c:v", "libx264", "-preset", "medium", "-crf", "18", "-c:a", "aac", "-b:a", "192k"]
|
||||
if format_hint == "webm":
|
||||
codec_args = ["-c:v", "libvpx-vp9", "-crf", "30", "-b:v", "0", "-c:a", "libopus"]
|
||||
|
||||
cmd = [
|
||||
ffmpeg, "-y",
|
||||
"-i", input_path,
|
||||
"-filter_complex", filter_complex,
|
||||
"-map", video_map,
|
||||
"-map", "[outa]",
|
||||
*codec_args,
|
||||
"-movflags", "+faststart",
|
||||
output_path,
|
||||
]
|
||||
|
||||
logger.info(f"Re-encoding {n} segments with subtitles -> {output_path} ({resolution})")
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"FFmpeg re-encode with subs failed: {result.stderr[-500:]}")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def get_video_info(input_path: str) -> dict:
|
||||
"""Get basic video metadata using ffprobe."""
|
||||
ffmpeg = _find_ffmpeg()
|
||||
ffprobe = ffmpeg.replace("ffmpeg", "ffprobe")
|
||||
|
||||
cmd = [
|
||||
ffprobe, "-v", "quiet",
|
||||
"-print_format", "json",
|
||||
"-show_format", "-show_streams",
|
||||
str(input_path),
|
||||
]
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
import json
|
||||
data = json.loads(result.stdout)
|
||||
fmt = data.get("format", {})
|
||||
video_stream = next((s for s in data.get("streams", []) if s.get("codec_type") == "video"), {})
|
||||
|
||||
return {
|
||||
"duration": float(fmt.get("duration", 0)),
|
||||
"size": int(fmt.get("size", 0)),
|
||||
"format": fmt.get("format_name", ""),
|
||||
"width": int(video_stream.get("width", 0)),
|
||||
"height": int(video_stream.get("height", 0)),
|
||||
"codec": video_stream.get("codec_name", ""),
|
||||
"fps": eval(video_stream.get("r_frame_rate", "0/1")) if "/" in video_stream.get("r_frame_rate", "") else 0,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get video info: {e}")
|
||||
return {}
|
||||
0
backend/utils/__init__.py
Normal file
0
backend/utils/__init__.py
Normal file
59
backend/utils/audio_processing.py
Normal file
59
backend/utils/audio_processing.py
Normal file
@ -0,0 +1,59 @@
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import os
|
||||
import logging
|
||||
|
||||
try:
|
||||
from moviepy import AudioFileClip
|
||||
except ImportError:
|
||||
from moviepy.editor import AudioFileClip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_temp_audio_files = []
|
||||
|
||||
|
||||
def extract_audio(video_path: Path):
|
||||
"""Extract audio from a video file into a temp directory for automatic cleanup."""
|
||||
try:
|
||||
audio = AudioFileClip(str(video_path))
|
||||
temp_dir = tempfile.mkdtemp(prefix="videotranscriber_")
|
||||
audio_path = Path(temp_dir) / f"{video_path.stem}_audio.wav"
|
||||
try:
|
||||
audio.write_audiofile(str(audio_path), logger=None)
|
||||
except TypeError:
|
||||
# moviepy 1.x uses verbose parameter; moviepy 2.x removed it
|
||||
audio.write_audiofile(str(audio_path), verbose=False, logger=None)
|
||||
audio.close()
|
||||
_temp_audio_files.append(str(audio_path))
|
||||
return audio_path
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Audio extraction failed: {e}")
|
||||
|
||||
|
||||
def cleanup_temp_audio():
|
||||
"""Remove all temporary audio files created during processing."""
|
||||
cleaned = 0
|
||||
for fpath in _temp_audio_files:
|
||||
try:
|
||||
if os.path.exists(fpath):
|
||||
os.remove(fpath)
|
||||
parent = os.path.dirname(fpath)
|
||||
if os.path.isdir(parent) and not os.listdir(parent):
|
||||
os.rmdir(parent)
|
||||
cleaned += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove temp file {fpath}: {e}")
|
||||
_temp_audio_files.clear()
|
||||
return cleaned
|
||||
|
||||
|
||||
def get_video_duration(video_path: Path):
|
||||
"""Get duration of a video/audio file in seconds."""
|
||||
try:
|
||||
clip = AudioFileClip(str(video_path))
|
||||
duration = clip.duration
|
||||
clip.close()
|
||||
return duration
|
||||
except Exception:
|
||||
return None
|
||||
205
backend/utils/cache.py
Normal file
205
backend/utils/cache.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""
|
||||
Caching utilities for the OBS Recording Transcriber.
|
||||
Provides functions to cache and retrieve transcription and summarization results.
|
||||
"""
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import time
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default cache directory
|
||||
CACHE_DIR = Path.home() / ".obs_transcriber_cache"
|
||||
|
||||
|
||||
def get_file_hash(file_path):
|
||||
"""
|
||||
Generate a hash for a file based on its content and modification time.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the file
|
||||
|
||||
Returns:
|
||||
str: Hash string representing the file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
# Get file stats
|
||||
stats = file_path.stat()
|
||||
file_size = stats.st_size
|
||||
mod_time = stats.st_mtime
|
||||
|
||||
# Create a hash based on path, size and modification time
|
||||
# This is faster than hashing the entire file content
|
||||
hash_input = f"{file_path.absolute()}|{file_size}|{mod_time}"
|
||||
return hashlib.md5(hash_input.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_cache_path(file_path, model=None, operation=None):
|
||||
"""
|
||||
Get the cache file path for a given input file and operation.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the original file
|
||||
model (str, optional): Model used for processing
|
||||
operation (str, optional): Operation type (e.g., 'transcribe', 'summarize')
|
||||
|
||||
Returns:
|
||||
Path: Path to the cache file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
file_hash = get_file_hash(file_path)
|
||||
|
||||
if not file_hash:
|
||||
return None
|
||||
|
||||
# Create cache directory if it doesn't exist
|
||||
cache_dir = CACHE_DIR
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a cache filename based on the hash and optional parameters
|
||||
cache_name = file_hash
|
||||
if model:
|
||||
cache_name += f"_{model}"
|
||||
if operation:
|
||||
cache_name += f"_{operation}"
|
||||
|
||||
return cache_dir / f"{cache_name}.json"
|
||||
|
||||
|
||||
def save_to_cache(file_path, data, model=None, operation=None):
|
||||
"""
|
||||
Save data to cache.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the original file
|
||||
data (dict): Data to cache
|
||||
model (str, optional): Model used for processing
|
||||
operation (str, optional): Operation type
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
cache_path = get_cache_path(file_path, model, operation)
|
||||
if not cache_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Add metadata to the cached data
|
||||
cache_data = {
|
||||
"original_file": str(Path(file_path).absolute()),
|
||||
"timestamp": time.time(),
|
||||
"model": model,
|
||||
"operation": operation,
|
||||
"data": data
|
||||
}
|
||||
|
||||
with open(cache_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"Cached data saved to {cache_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving cache: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def load_from_cache(file_path, model=None, operation=None, max_age=None):
|
||||
"""
|
||||
Load data from cache if available and not expired.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the original file
|
||||
model (str, optional): Model used for processing
|
||||
operation (str, optional): Operation type
|
||||
max_age (float, optional): Maximum age of cache in seconds
|
||||
|
||||
Returns:
|
||||
dict or None: Cached data or None if not available
|
||||
"""
|
||||
cache_path = get_cache_path(file_path, model, operation)
|
||||
if not cache_path or not cache_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(cache_path, 'r', encoding='utf-8') as f:
|
||||
cache_data = json.load(f)
|
||||
|
||||
# Check if cache is expired
|
||||
if max_age is not None:
|
||||
cache_time = cache_data.get("timestamp", 0)
|
||||
if time.time() - cache_time > max_age:
|
||||
logger.info(f"Cache expired for {file_path}")
|
||||
return None
|
||||
|
||||
logger.info(f"Loaded data from cache: {cache_path}")
|
||||
return cache_data.get("data")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def clear_cache(max_age=None):
|
||||
"""
|
||||
Clear all cache files or only expired ones.
|
||||
|
||||
Args:
|
||||
max_age (float, optional): Maximum age of cache in seconds
|
||||
|
||||
Returns:
|
||||
int: Number of files deleted
|
||||
"""
|
||||
if not CACHE_DIR.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for cache_file in CACHE_DIR.glob("*.json"):
|
||||
try:
|
||||
if max_age is not None:
|
||||
# Check if file is expired
|
||||
with open(cache_file, 'r', encoding='utf-8') as f:
|
||||
cache_data = json.load(f)
|
||||
|
||||
cache_time = cache_data.get("timestamp", 0)
|
||||
if time.time() - cache_time <= max_age:
|
||||
continue # Skip non-expired files
|
||||
|
||||
# Delete the file
|
||||
os.remove(cache_file)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting cache file {cache_file}: {e}")
|
||||
|
||||
logger.info(f"Cleared {count} cache files")
|
||||
return count
|
||||
|
||||
|
||||
def get_cache_size():
|
||||
"""
|
||||
Get the total size of the cache directory.
|
||||
|
||||
Returns:
|
||||
tuple: (size_bytes, file_count)
|
||||
"""
|
||||
if not CACHE_DIR.exists():
|
||||
return 0, 0
|
||||
|
||||
total_size = 0
|
||||
file_count = 0
|
||||
|
||||
for cache_file in CACHE_DIR.glob("*.json"):
|
||||
try:
|
||||
total_size += cache_file.stat().st_size
|
||||
file_count += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return total_size, file_count
|
||||
196
backend/utils/gpu_utils.py
Normal file
196
backend/utils/gpu_utils.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""
|
||||
GPU utilities for the Video Transcriber.
|
||||
Provides functions to detect and configure GPU acceleration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_gpu_info():
|
||||
"""
|
||||
Get information about available GPUs.
|
||||
|
||||
Returns:
|
||||
dict: Information about available GPUs
|
||||
"""
|
||||
gpu_info = {
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
||||
"cuda_devices": [],
|
||||
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
}
|
||||
|
||||
# Get CUDA device information
|
||||
if gpu_info["cuda_available"]:
|
||||
for i in range(gpu_info["cuda_device_count"]):
|
||||
device_props = torch.cuda.get_device_properties(i)
|
||||
gpu_info["cuda_devices"].append({
|
||||
"index": i,
|
||||
"name": device_props.name,
|
||||
"total_memory": device_props.total_memory,
|
||||
"compute_capability": f"{device_props.major}.{device_props.minor}"
|
||||
})
|
||||
|
||||
return gpu_info
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
"""
|
||||
Get the optimal device for computation.
|
||||
|
||||
Returns:
|
||||
torch.device: The optimal device (cuda, mps, or cpu)
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
# If multiple GPUs are available, select the one with the most memory
|
||||
if torch.cuda.device_count() > 1:
|
||||
max_memory = 0
|
||||
best_device = 0
|
||||
for i in range(torch.cuda.device_count()):
|
||||
device_props = torch.cuda.get_device_properties(i)
|
||||
if device_props.total_memory > max_memory:
|
||||
max_memory = device_props.total_memory
|
||||
best_device = i
|
||||
return torch.device(f"cuda:{best_device}")
|
||||
return torch.device("cuda:0")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def set_memory_limits(memory_fraction=0.8):
|
||||
"""
|
||||
Set memory limits for GPU usage.
|
||||
|
||||
Args:
|
||||
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
# Set memory fraction for each device
|
||||
for i in range(torch.cuda.device_count()):
|
||||
torch.cuda.set_per_process_memory_fraction(memory_fraction, i)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting memory limits: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def optimize_for_inference():
|
||||
"""
|
||||
Apply optimizations for inference.
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Set deterministic algorithms for reproducibility
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
# Enable cuDNN benchmark mode for optimized performance
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# Disable gradient calculation for inference
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error optimizing for inference: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_recommended_batch_size(model_size="base"):
|
||||
"""
|
||||
Get recommended batch size based on available GPU memory.
|
||||
|
||||
Args:
|
||||
model_size (str): Size of the model (tiny, base, small, medium, large)
|
||||
|
||||
Returns:
|
||||
int: Recommended batch size
|
||||
"""
|
||||
# Default batch sizes for CPU
|
||||
default_batch_sizes = {
|
||||
"tiny": 16,
|
||||
"base": 8,
|
||||
"small": 4,
|
||||
"medium": 2,
|
||||
"large": 1
|
||||
}
|
||||
|
||||
# If CUDA is not available, return default CPU batch size
|
||||
if not torch.cuda.is_available():
|
||||
return default_batch_sizes.get(model_size, 1)
|
||||
|
||||
# Approximate memory requirements in GB for different model sizes
|
||||
memory_requirements = {
|
||||
"tiny": 1,
|
||||
"base": 2,
|
||||
"small": 4,
|
||||
"medium": 8,
|
||||
"large": 16
|
||||
}
|
||||
|
||||
# Get available GPU memory
|
||||
device = get_optimal_device()
|
||||
if device.type == "cuda":
|
||||
device_idx = device.index
|
||||
device_props = torch.cuda.get_device_properties(device_idx)
|
||||
available_memory_gb = device_props.total_memory / (1024 ** 3)
|
||||
|
||||
# Calculate batch size based on available memory
|
||||
model_memory = memory_requirements.get(model_size, 2)
|
||||
max_batch_size = int(available_memory_gb / model_memory)
|
||||
|
||||
# Ensure batch size is at least 1
|
||||
return max(1, max_batch_size)
|
||||
|
||||
# For MPS or other devices, return default
|
||||
return default_batch_sizes.get(model_size, 1)
|
||||
|
||||
|
||||
def configure_gpu(model_size="base", memory_fraction=0.8):
|
||||
"""
|
||||
Configure GPU settings for optimal performance.
|
||||
|
||||
Args:
|
||||
model_size (str): Size of the model (tiny, base, small, medium, large)
|
||||
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
dict: Configuration information
|
||||
"""
|
||||
gpu_info = get_gpu_info()
|
||||
device = get_optimal_device()
|
||||
|
||||
# Set memory limits if using CUDA
|
||||
if device.type == "cuda":
|
||||
set_memory_limits(memory_fraction)
|
||||
|
||||
# Apply inference optimizations
|
||||
optimize_for_inference()
|
||||
|
||||
# Get recommended batch size
|
||||
batch_size = get_recommended_batch_size(model_size)
|
||||
|
||||
config = {
|
||||
"device": device,
|
||||
"batch_size": batch_size,
|
||||
"gpu_info": gpu_info,
|
||||
"memory_fraction": memory_fraction if device.type == "cuda" else None
|
||||
}
|
||||
|
||||
logger.info(f"GPU configuration: Using {device} with batch size {batch_size}")
|
||||
return config
|
||||
Reference in New Issue
Block a user