Initial CutScript release - Open-source AI-powered text-based video editor

CutScript is a local-first, Descript-like video editor where you edit video by editing text.
Delete a word from the transcript and it's cut from the video.

Features:
- Word-level transcription with WhisperX
- Text-based video editing with undo/redo
- AI filler word removal (Ollama/OpenAI/Claude)
- AI clip creation for shorts
- Waveform timeline with virtualized transcript
- FFmpeg stream-copy (fast) and re-encode (4K) export
- Caption burn-in and sidecar SRT generation
- Studio Sound audio enhancement (DeepFilterNet)
- Keyboard shortcuts (J/K/L, Space, Delete, Ctrl+Z/S/E)
- Encrypted API key storage
- Project save/load (.aive files)

Architecture:
- Electron + React + Tailwind (frontend)
- FastAPI + Python (backend)
- WhisperX for transcription
- FFmpeg for video processing
- Multi-provider AI support

Performance optimizations:
- RAF-throttled time updates
- Zustand selectors for granular subscriptions
- Dual-canvas waveform rendering
- Virtualized transcript with react-virtuoso

Built on top of DataAnts-AI/VideoTranscriber, completely rewritten as a desktop application.

License: MIT
This commit is contained in:
Your Name
2026-03-03 06:31:04 -05:00
parent d1e1fedcae
commit 33cca5f552
73 changed files with 7463 additions and 3906 deletions

117
backend/main.py Normal file
View File

@ -0,0 +1,117 @@
import logging
import os
import stat
from contextlib import asynccontextmanager
from pathlib import Path
from fastapi import FastAPI, Query, Request, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from routers import transcribe, export, ai, captions, audio
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI):
logger.info("AI Video Editor backend starting up")
yield
logger.info("AI Video Editor backend shutting down")
app = FastAPI(
title="AI Video Editor Backend",
version="0.1.0",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["Content-Range", "Accept-Ranges", "Content-Length"],
)
app.include_router(transcribe.router)
app.include_router(export.router)
app.include_router(ai.router)
app.include_router(captions.router)
app.include_router(audio.router)
MIME_MAP = {
".mp4": "video/mp4",
".mkv": "video/x-matroska",
".mov": "video/quicktime",
".avi": "video/x-msvideo",
".webm": "video/webm",
".m4a": "audio/mp4",
".wav": "audio/wav",
".mp3": "audio/mpeg",
".flac": "audio/flac",
}
@app.get("/file")
async def serve_local_file(request: Request, path: str = Query(...)):
"""Stream a local file with HTTP Range support (required for video seeking)."""
file_path = Path(path)
if not file_path.is_file():
raise HTTPException(status_code=404, detail=f"File not found: {path}")
file_size = file_path.stat().st_size
content_type = MIME_MAP.get(file_path.suffix.lower(), "application/octet-stream")
range_header = request.headers.get("range")
if range_header:
range_spec = range_header.replace("bytes=", "")
range_start_str, range_end_str = range_spec.split("-")
range_start = int(range_start_str) if range_start_str else 0
range_end = int(range_end_str) if range_end_str else file_size - 1
range_end = min(range_end, file_size - 1)
content_length = range_end - range_start + 1
def iter_range():
with open(file_path, "rb") as f:
f.seek(range_start)
remaining = content_length
while remaining > 0:
chunk = f.read(min(65536, remaining))
if not chunk:
break
remaining -= len(chunk)
yield chunk
return StreamingResponse(
iter_range(),
status_code=206,
media_type=content_type,
headers={
"Content-Range": f"bytes {range_start}-{range_end}/{file_size}",
"Accept-Ranges": "bytes",
"Content-Length": str(content_length),
},
)
def iter_file():
with open(file_path, "rb") as f:
while chunk := f.read(65536):
yield chunk
return StreamingResponse(
iter_file(),
media_type=content_type,
headers={
"Accept-Ranges": "bytes",
"Content-Length": str(file_size),
},
)
@app.get("/health")
async def health():
return {"status": "ok"}

33
backend/requirements.txt Normal file
View File

@ -0,0 +1,33 @@
# FastAPI backend
fastapi>=0.115.0
uvicorn[standard]>=0.32.0
websockets>=14.0
python-multipart>=0.0.12
# Transcription (WhisperX for word-level alignment)
whisperx>=3.1.0
faster-whisper>=1.0.0
# Audio / Video processing
moviepy>=1.0.3
ffmpeg-python>=0.2.0
soundfile>=0.10.3
# ML / GPU
torch>=2.0.0
torchaudio>=2.0.0
numpy>=1.24.0
# Speaker diarization
pyannote.audio>=3.1.1
# AI providers
openai>=1.50.0
anthropic>=0.39.0
requests>=2.28.0
# Audio cleanup
deepfilternet>=0.5.0
# Utilities
pydantic>=2.0.0

View File

83
backend/routers/ai.py Normal file
View File

@ -0,0 +1,83 @@
"""AI feature endpoints: filler word detection, clip creation, Ollama model listing."""
import logging
from typing import List, Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from services.ai_provider import AIProvider, detect_filler_words, create_clip_suggestion
logger = logging.getLogger(__name__)
router = APIRouter()
class WordInfo(BaseModel):
index: int
word: str
start: Optional[float] = None
end: Optional[float] = None
class FillerRequest(BaseModel):
transcript: str
words: List[WordInfo]
provider: str = "ollama"
model: Optional[str] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
custom_filler_words: Optional[str] = None
class ClipRequest(BaseModel):
transcript: str
words: List[WordInfo]
provider: str = "ollama"
model: Optional[str] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
target_duration: int = 60
@router.post("/ai/filler-removal")
async def filler_removal(req: FillerRequest):
try:
words_dicts = [w.model_dump() for w in req.words]
result = detect_filler_words(
transcript=req.transcript,
words=words_dicts,
provider=req.provider,
model=req.model,
api_key=req.api_key,
base_url=req.base_url,
custom_filler_words=req.custom_filler_words,
)
return result
except Exception as e:
logger.error(f"Filler detection failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.post("/ai/create-clip")
async def create_clip(req: ClipRequest):
try:
words_dicts = [w.model_dump() for w in req.words]
result = create_clip_suggestion(
transcript=req.transcript,
words=words_dicts,
target_duration=req.target_duration,
provider=req.provider,
model=req.model,
api_key=req.api_key,
base_url=req.base_url,
)
return result
except Exception as e:
logger.error(f"Clip creation failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/ai/ollama-models")
async def ollama_models(base_url: str = "http://localhost:11434"):
models = AIProvider.list_ollama_models(base_url)
return {"models": models}

38
backend/routers/audio.py Normal file
View File

@ -0,0 +1,38 @@
"""Audio processing endpoint (noise reduction / Studio Sound)."""
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from services.audio_cleaner import clean_audio, is_deepfilter_available
logger = logging.getLogger(__name__)
router = APIRouter()
class AudioCleanRequest(BaseModel):
input_path: str
output_path: Optional[str] = None
@router.post("/audio/clean")
async def clean_audio_endpoint(req: AudioCleanRequest):
try:
output = clean_audio(req.input_path, req.output_path or "")
return {
"status": "ok",
"output_path": output,
"engine": "deepfilternet" if is_deepfilter_available() else "ffmpeg_anlmdn",
}
except Exception as e:
logger.error(f"Audio cleaning failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/audio/capabilities")
async def audio_capabilities():
return {
"deepfilternet_available": is_deepfilter_available(),
}

View File

@ -0,0 +1,65 @@
"""Caption generation endpoint."""
import logging
from typing import List, Optional
from fastapi import APIRouter, HTTPException
from fastapi.responses import PlainTextResponse
from pydantic import BaseModel
from services.caption_generator import generate_srt, generate_vtt, generate_ass, save_captions
logger = logging.getLogger(__name__)
router = APIRouter()
class CaptionWord(BaseModel):
word: str
start: float
end: float
confidence: float = 0.0
class CaptionStyle(BaseModel):
fontName: str = "Arial"
fontSize: int = 48
fontColor: str = "&H00FFFFFF"
backgroundColor: str = "&H80000000"
position: str = "bottom"
bold: bool = True
class CaptionRequest(BaseModel):
words: List[CaptionWord]
deleted_indices: List[int] = []
format: str = "srt"
words_per_line: int = 8
style: Optional[CaptionStyle] = None
output_path: Optional[str] = None
@router.post("/captions")
async def generate_captions(req: CaptionRequest):
try:
words_dicts = [w.model_dump() for w in req.words]
deleted_set = set(req.deleted_indices)
if req.format == "srt":
content = generate_srt(words_dicts, deleted_set, req.words_per_line)
elif req.format == "vtt":
content = generate_vtt(words_dicts, deleted_set, req.words_per_line)
elif req.format == "ass":
style_dict = req.style.model_dump() if req.style else None
content = generate_ass(words_dicts, deleted_set, req.words_per_line, style_dict)
else:
raise HTTPException(status_code=400, detail=f"Unknown format: {req.format}")
if req.output_path:
saved = save_captions(content, req.output_path)
return {"status": "ok", "output_path": saved}
return PlainTextResponse(content, media_type="text/plain")
except Exception as e:
logger.error(f"Caption generation failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

156
backend/routers/export.py Normal file
View File

@ -0,0 +1,156 @@
"""Export endpoint for video cutting and rendering."""
import logging
import tempfile
import os
from typing import List, Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from services.video_editor import export_stream_copy, export_reencode, export_reencode_with_subs
from services.audio_cleaner import clean_audio
from services.caption_generator import generate_srt, generate_ass, save_captions
logger = logging.getLogger(__name__)
router = APIRouter()
class SegmentModel(BaseModel):
start: float
end: float
class ExportWordModel(BaseModel):
word: str
start: float
end: float
confidence: float = 0.0
class ExportRequest(BaseModel):
input_path: str
output_path: str
keep_segments: List[SegmentModel]
mode: str = "fast"
resolution: str = "1080p"
format: str = "mp4"
enhanceAudio: bool = False
captions: str = "none"
words: Optional[List[ExportWordModel]] = None
deleted_indices: Optional[List[int]] = None
def _mux_audio(video_path: str, audio_path: str, output_path: str) -> str:
"""Replace video's audio track with cleaned audio using FFmpeg."""
import subprocess
cmd = [
"ffmpeg", "-y",
"-i", video_path,
"-i", audio_path,
"-c:v", "copy",
"-map", "0:v:0",
"-map", "1:a:0",
"-shortest",
output_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Audio mux failed: {result.stderr[-300:]}")
return output_path
@router.post("/export")
async def export_video(req: ExportRequest):
try:
segments = [{"start": s.start, "end": s.end} for s in req.keep_segments]
if not segments:
raise HTTPException(status_code=400, detail="No segments to export")
use_stream_copy = req.mode == "fast" and len(segments) == 1
needs_reencode_for_subs = req.captions == "burn-in"
# Burn-in captions require re-encode
if needs_reencode_for_subs:
use_stream_copy = False
words_dicts = [w.model_dump() for w in req.words] if req.words else []
deleted_set = set(req.deleted_indices or [])
# Generate ASS file for burn-in
ass_path = None
if req.captions == "burn-in" and words_dicts:
ass_content = generate_ass(words_dicts, deleted_set)
tmp = tempfile.NamedTemporaryFile(suffix=".ass", delete=False, mode="w", encoding="utf-8")
tmp.write(ass_content)
tmp.close()
ass_path = tmp.name
try:
if use_stream_copy:
output = export_stream_copy(req.input_path, req.output_path, segments)
elif ass_path:
output = export_reencode_with_subs(
req.input_path,
req.output_path,
segments,
ass_path,
resolution=req.resolution,
format_hint=req.format,
)
else:
output = export_reencode(
req.input_path,
req.output_path,
segments,
resolution=req.resolution,
format_hint=req.format,
)
finally:
if ass_path and os.path.exists(ass_path):
os.unlink(ass_path)
# Audio enhancement: clean, then mux back into the exported video
if req.enhanceAudio:
try:
tmp_dir = tempfile.mkdtemp(prefix="cutscript_audio_")
cleaned_audio = os.path.join(tmp_dir, "cleaned.wav")
clean_audio(output, cleaned_audio)
muxed_path = output + ".muxed.mp4"
_mux_audio(output, cleaned_audio, muxed_path)
os.replace(muxed_path, output)
logger.info(f"Audio enhanced and muxed into {output}")
# Cleanup
try:
os.remove(cleaned_audio)
os.rmdir(tmp_dir)
except OSError:
pass
except Exception as e:
logger.warning(f"Audio enhancement failed (non-fatal): {e}")
# Sidecar SRT: generate and save alongside video
srt_path = None
if req.captions == "sidecar" and words_dicts:
srt_content = generate_srt(words_dicts, deleted_set)
srt_path = req.output_path.rsplit(".", 1)[0] + ".srt"
save_captions(srt_content, srt_path)
logger.info(f"Sidecar SRT saved to {srt_path}")
result = {"status": "ok", "output_path": output}
if srt_path:
result["srt_path"] = srt_path
return result
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except RuntimeError as e:
logger.error(f"Export failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
except Exception as e:
logger.error(f"Export error: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@ -0,0 +1,53 @@
"""Transcription endpoint using WhisperX."""
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from services.transcription import transcribe_audio
from services.diarization import diarize_and_label
logger = logging.getLogger(__name__)
router = APIRouter()
class TranscribeRequest(BaseModel):
file_path: str
model: str = "base"
language: Optional[str] = None
use_gpu: bool = True
use_cache: bool = True
diarize: bool = False
hf_token: Optional[str] = None
num_speakers: Optional[int] = None
@router.post("/transcribe")
async def transcribe(req: TranscribeRequest):
try:
result = transcribe_audio(
file_path=req.file_path,
model_name=req.model,
use_gpu=req.use_gpu,
use_cache=req.use_cache,
language=req.language,
)
if req.diarize and req.hf_token:
result = diarize_and_label(
transcription_result=result,
audio_path=req.file_path,
hf_token=req.hf_token,
num_speakers=req.num_speakers,
use_gpu=req.use_gpu,
)
return result
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"File not found: {req.file_path}")
except Exception as e:
logger.error(f"Transcription failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

View File

@ -0,0 +1,211 @@
"""
Unified AI provider interface for Ollama, OpenAI, and Claude.
"""
import json
import logging
from typing import Optional, List
import requests
logger = logging.getLogger(__name__)
class AIProvider:
"""Routes completion requests to the configured provider."""
@staticmethod
def complete(
prompt: str,
provider: str = "ollama",
model: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
system_prompt: Optional[str] = None,
temperature: float = 0.3,
) -> str:
if provider == "ollama":
return _ollama_complete(prompt, model or "llama3", base_url or "http://localhost:11434", system_prompt, temperature)
elif provider == "openai":
return _openai_complete(prompt, model or "gpt-4o", api_key or "", system_prompt, temperature)
elif provider == "claude":
return _claude_complete(prompt, model or "claude-sonnet-4-20250514", api_key or "", system_prompt, temperature)
else:
raise ValueError(f"Unknown provider: {provider}")
@staticmethod
def list_ollama_models(base_url: str = "http://localhost:11434") -> List[str]:
try:
resp = requests.get(f"{base_url}/api/tags", timeout=3)
if resp.status_code == 200:
return [m["name"] for m in resp.json().get("models", [])]
except Exception:
pass
return []
def _ollama_complete(prompt: str, model: str, base_url: str, system_prompt: Optional[str], temperature: float) -> str:
body = {
"model": model,
"prompt": prompt,
"stream": False,
"options": {"temperature": temperature},
}
if system_prompt:
body["system"] = system_prompt
try:
resp = requests.post(f"{base_url}/api/generate", json=body, timeout=120)
resp.raise_for_status()
return resp.json().get("response", "").strip()
except Exception as e:
logger.error(f"Ollama error: {e}")
raise
def _openai_complete(prompt: str, model: str, api_key: str, system_prompt: Optional[str], temperature: float) -> str:
try:
from openai import OpenAI
client = OpenAI(api_key=api_key)
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
return response.choices[0].message.content.strip()
except Exception as e:
logger.error(f"OpenAI error: {e}")
raise
def _claude_complete(prompt: str, model: str, api_key: str, system_prompt: Optional[str], temperature: float) -> str:
try:
import anthropic
client = anthropic.Anthropic(api_key=api_key)
kwargs = {
"model": model,
"max_tokens": 4096,
"temperature": temperature,
"messages": [{"role": "user", "content": prompt}],
}
if system_prompt:
kwargs["system"] = system_prompt
response = client.messages.create(**kwargs)
return response.content[0].text.strip()
except Exception as e:
logger.error(f"Claude error: {e}")
raise
def detect_filler_words(
transcript: str,
words: List[dict],
provider: str = "ollama",
model: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
custom_filler_words: Optional[str] = None,
) -> dict:
"""
Use an LLM to identify filler words in the transcript.
Returns {"wordIndices": [...], "fillerWords": [{"index": N, "word": "...", "reason": "..."}]}
"""
word_list = "\n".join(f"{w['index']}: {w['word']}" for w in words)
custom_line = ""
if custom_filler_words and custom_filler_words.strip():
custom_line = f"\n\nAdditionally, flag these user-specified filler words/phrases: {custom_filler_words.strip()}"
prompt = f"""Analyze this transcript for filler words and verbal hesitations.
Filler words include: um, uh, uh huh, hmm, like (when used as filler), you know, so (when starting sentences unnecessarily), basically, actually, literally, right, I mean, kind of, sort of, well (when used as filler).
Also flag repeated words that indicate stammering (e.g., "I I I" or "the the").{custom_line}
Here are the words with their indices:
{word_list}
Return ONLY a valid JSON object with this exact structure:
{{"wordIndices": [list of integer indices to remove], "fillerWords": [{{"index": integer, "word": "the word", "reason": "brief reason"}}]}}
Be conservative -- only flag clear filler words, not words that are part of meaningful sentences."""
system = "You are a precise text analysis tool. Return only valid JSON, no explanation."
result_text = AIProvider.complete(
prompt=prompt,
provider=provider,
model=model,
api_key=api_key,
base_url=base_url,
system_prompt=system,
temperature=0.1,
)
try:
start = result_text.find("{")
end = result_text.rfind("}") + 1
if start >= 0 and end > start:
return json.loads(result_text[start:end])
except json.JSONDecodeError:
logger.error(f"Failed to parse AI response as JSON: {result_text[:200]}")
return {"wordIndices": [], "fillerWords": []}
def create_clip_suggestion(
transcript: str,
words: List[dict],
target_duration: int = 60,
provider: str = "ollama",
model: Optional[str] = None,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
) -> dict:
"""
Use an LLM to find the best clip segments in a transcript.
"""
word_list = "\n".join(
f"{w['index']}: \"{w['word']}\" ({w.get('start', 0):.1f}s - {w.get('end', 0):.1f}s)"
for w in words
)
prompt = f"""Analyze this transcript and find the most engaging {target_duration}-second segment(s) that would work well as a YouTube Short or social media clip.
Look for: compelling stories, surprising facts, emotional moments, clear explanations, humor, or quotable statements.
Words with indices and timestamps:
{word_list}
Return ONLY a valid JSON object:
{{"clips": [{{"title": "short catchy title", "startWordIndex": integer, "endWordIndex": integer, "startTime": float, "endTime": float, "reason": "why this segment is engaging"}}]}}
Suggest 1-3 clips, each approximately {target_duration} seconds long."""
system = "You are a viral content expert. Return only valid JSON, no explanation."
result_text = AIProvider.complete(
prompt=prompt,
provider=provider,
model=model,
api_key=api_key,
base_url=base_url,
system_prompt=system,
temperature=0.5,
)
try:
start = result_text.find("{")
end = result_text.rfind("}") + 1
if start >= 0 and end > start:
return json.loads(result_text[start:end])
except json.JSONDecodeError:
logger.error(f"Failed to parse clip suggestions: {result_text[:200]}")
return {"clips": []}

View File

@ -0,0 +1,79 @@
"""
Audio noise reduction using DeepFilterNet.
Falls back to a basic FFmpeg noise filter if DeepFilterNet is not installed.
"""
import logging
import subprocess
import tempfile
from pathlib import Path
logger = logging.getLogger(__name__)
try:
from df.enhance import enhance, init_df, load_audio, save_audio
DEEPFILTER_AVAILABLE = True
except ImportError:
DEEPFILTER_AVAILABLE = False
_df_model = None
_df_state = None
def _init_deepfilter():
global _df_model, _df_state
if _df_model is None:
logger.info("Initializing DeepFilterNet model")
_df_model, _df_state, _ = init_df()
return _df_model, _df_state
def clean_audio(
input_path: str,
output_path: str = "",
) -> str:
"""
Apply noise reduction to an audio file.
If DeepFilterNet is available, uses it for high-quality results.
Otherwise falls back to FFmpeg's anlmdn filter.
Returns: path to the cleaned audio file.
"""
input_path = Path(input_path)
if not output_path:
output_path = str(input_path.with_stem(input_path.stem + "_clean"))
if DEEPFILTER_AVAILABLE:
return _clean_with_deepfilter(str(input_path), output_path)
else:
return _clean_with_ffmpeg(str(input_path), output_path)
def _clean_with_deepfilter(input_path: str, output_path: str) -> str:
model, state = _init_deepfilter()
audio, info = load_audio(input_path, sr=state.sr())
enhanced = enhance(model, state, audio)
save_audio(output_path, enhanced, sr=state.sr())
logger.info(f"DeepFilterNet cleaned audio saved to {output_path}")
return output_path
def _clean_with_ffmpeg(input_path: str, output_path: str) -> str:
"""Fallback: basic noise reduction using FFmpeg's anlmdn filter."""
cmd = [
"ffmpeg", "-y",
"-i", input_path,
"-af", "anlmdn=s=7:p=0.002:r=0.002:m=15",
output_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg audio cleaning failed: {result.stderr[-300:]}")
logger.info(f"FFmpeg cleaned audio saved to {output_path}")
return output_path
def is_deepfilter_available() -> bool:
return DEEPFILTER_AVAILABLE

View File

@ -0,0 +1,59 @@
"""
AI background removal (Phase 5 - future).
Uses MediaPipe or Robust Video Matting for person segmentation.
Export-only -- no real-time preview.
"""
import logging
logger = logging.getLogger(__name__)
# Placeholder for Phase 5 implementation
# Will use mediapipe or rvm for segmentation at export time
MEDIAPIPE_AVAILABLE = False
RVM_AVAILABLE = False
try:
import mediapipe as mp
MEDIAPIPE_AVAILABLE = True
except ImportError:
pass
try:
pass # rvm import would go here
except ImportError:
pass
def is_available() -> bool:
return MEDIAPIPE_AVAILABLE or RVM_AVAILABLE
def remove_background_on_export(
input_path: str,
output_path: str,
replacement: str = "blur",
replacement_value: str = "",
) -> str:
"""
Process video frame-by-frame to remove/replace background.
Only runs during export (not real-time).
Args:
input_path: source video
output_path: destination
replacement: 'blur', 'color', 'image', or 'video'
replacement_value: hex color, image path, or video path
Returns:
output_path
"""
if not is_available():
raise RuntimeError(
"Background removal requires mediapipe or robust-video-matting. "
"Install with: pip install mediapipe"
)
# Phase 5 implementation will go here
raise NotImplementedError("Background removal is planned for Phase 5")

View File

@ -0,0 +1,148 @@
"""
Generate caption files (SRT, VTT, ASS) from word-level timestamps.
"""
import logging
from pathlib import Path
from typing import List, Optional
logger = logging.getLogger(__name__)
def _format_srt_time(seconds: float) -> str:
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = int(seconds % 60)
ms = int((seconds % 1) * 1000)
return f"{h:02d}:{m:02d}:{s:02d},{ms:03d}"
def _format_vtt_time(seconds: float) -> str:
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = int(seconds % 60)
ms = int((seconds % 1) * 1000)
return f"{h:02d}:{m:02d}:{s:02d}.{ms:03d}"
def _format_ass_time(seconds: float) -> str:
h = int(seconds // 3600)
m = int((seconds % 3600) // 60)
s = int(seconds % 60)
cs = int((seconds % 1) * 100)
return f"{h}:{m:02d}:{s:02d}.{cs:02d}"
def generate_srt(
words: List[dict],
deleted_indices: Optional[set] = None,
words_per_line: int = 8,
) -> str:
"""Generate SRT caption content from word-level timestamps."""
deleted_indices = deleted_indices or set()
active_words = [(i, w) for i, w in enumerate(words) if i not in deleted_indices]
lines = []
counter = 1
for chunk_start in range(0, len(active_words), words_per_line):
chunk = active_words[chunk_start:chunk_start + words_per_line]
if not chunk:
continue
start_time = chunk[0][1]["start"]
end_time = chunk[-1][1]["end"]
text = " ".join(w["word"] for _, w in chunk)
lines.append(str(counter))
lines.append(f"{_format_srt_time(start_time)} --> {_format_srt_time(end_time)}")
lines.append(text)
lines.append("")
counter += 1
return "\n".join(lines)
def generate_vtt(
words: List[dict],
deleted_indices: Optional[set] = None,
words_per_line: int = 8,
) -> str:
"""Generate WebVTT caption content."""
deleted_indices = deleted_indices or set()
active_words = [(i, w) for i, w in enumerate(words) if i not in deleted_indices]
lines = ["WEBVTT", ""]
for chunk_start in range(0, len(active_words), words_per_line):
chunk = active_words[chunk_start:chunk_start + words_per_line]
if not chunk:
continue
start_time = chunk[0][1]["start"]
end_time = chunk[-1][1]["end"]
text = " ".join(w["word"] for _, w in chunk)
lines.append(f"{_format_vtt_time(start_time)} --> {_format_vtt_time(end_time)}")
lines.append(text)
lines.append("")
return "\n".join(lines)
def generate_ass(
words: List[dict],
deleted_indices: Optional[set] = None,
words_per_line: int = 8,
style: Optional[dict] = None,
) -> str:
"""Generate ASS subtitle content with styling."""
deleted_indices = deleted_indices or set()
active_words = [(i, w) for i, w in enumerate(words) if i not in deleted_indices]
s = style or {}
font = s.get("fontName", "Arial")
size = s.get("fontSize", 48)
color = s.get("fontColor", "&H00FFFFFF")
bold = "-1" if s.get("bold", True) else "0"
alignment = 2
header = f"""[Script Info]
Title: AI Video Editor Captions
ScriptType: v4.00+
PlayResX: 1920
PlayResY: 1080
[V4+ Styles]
Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
Style: Default,{font},{size},{color},&H000000FF,&H00000000,&H80000000,{bold},0,0,0,100,100,0,0,1,2,1,{alignment},20,20,40,1
[Events]
Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
"""
events = []
for chunk_start in range(0, len(active_words), words_per_line):
chunk = active_words[chunk_start:chunk_start + words_per_line]
if not chunk:
continue
start_time = chunk[0][1]["start"]
end_time = chunk[-1][1]["end"]
text = " ".join(w["word"] for _, w in chunk)
events.append(
f"Dialogue: 0,{_format_ass_time(start_time)},{_format_ass_time(end_time)},Default,,0,0,0,,{text}"
)
return header + "\n".join(events) + "\n"
def save_captions(
content: str,
output_path: str,
) -> str:
"""Write caption content to a file."""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(content, encoding="utf-8")
logger.info(f"Saved captions to {output_path}")
return str(output_path)

View File

@ -0,0 +1,98 @@
"""
Speaker diarization service using pyannote.audio.
Refactored from the original repo -- removed Streamlit dependency.
"""
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from utils.gpu_utils import get_optimal_device
logger = logging.getLogger(__name__)
_pipeline_cache = {}
def _get_pipeline(hf_token: str, device: torch.device):
cache_key = str(device)
if cache_key in _pipeline_cache:
return _pipeline_cache[cache_key]
try:
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
use_auth_token=hf_token,
)
if device.type == "cuda":
pipeline = pipeline.to(device)
_pipeline_cache[cache_key] = pipeline
return pipeline
except Exception as e:
logger.error(f"Failed to load diarization pipeline: {e}")
return None
def diarize_and_label(
transcription_result: dict,
audio_path: str,
hf_token: Optional[str] = None,
num_speakers: Optional[int] = None,
use_gpu: bool = True,
) -> dict:
"""
Apply speaker diarization to an existing transcription result.
Adds 'speaker' field to each word and segment.
Returns the mutated transcription_result with speaker labels.
"""
hf_token = hf_token or os.environ.get("HF_TOKEN")
if not hf_token:
logger.warning("No HuggingFace token provided; skipping diarization")
return transcription_result
device = get_optimal_device() if use_gpu else torch.device("cpu")
pipeline = _get_pipeline(hf_token, device)
if pipeline is None:
return transcription_result
audio_path = Path(audio_path)
logger.info(f"Running diarization on {audio_path}")
try:
diarization = pipeline(str(audio_path), num_speakers=num_speakers)
except Exception as e:
logger.error(f"Diarization failed: {e}")
return transcription_result
speaker_map = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
speaker_map.append((turn.start, turn.end, speaker))
def _find_speaker(start: float, end: float) -> str:
best_overlap = 0
best_speaker = "UNKNOWN"
for s_start, s_end, speaker in speaker_map:
overlap_start = max(start, s_start)
overlap_end = min(end, s_end)
overlap = max(0, overlap_end - overlap_start)
if overlap > best_overlap:
best_overlap = overlap
best_speaker = speaker
return best_speaker
for word in transcription_result.get("words", []):
word["speaker"] = _find_speaker(word["start"], word["end"])
for segment in transcription_result.get("segments", []):
segment["speaker"] = _find_speaker(segment["start"], segment["end"])
for w in segment.get("words", []):
w["speaker"] = _find_speaker(w["start"], w["end"])
return transcription_result

View File

@ -0,0 +1,205 @@
"""
WhisperX-based transcription service with word-level alignment.
Falls back to standard Whisper if WhisperX is not available.
"""
import logging
from pathlib import Path
from typing import Optional
import torch
from utils.gpu_utils import get_optimal_device, configure_gpu
from utils.audio_processing import extract_audio
from utils.cache import load_from_cache, save_to_cache
logger = logging.getLogger(__name__)
_model_cache: dict = {}
try:
import whisperx
WHISPERX_AVAILABLE = True
except ImportError:
WHISPERX_AVAILABLE = False
import whisper
try:
HF_TOKEN = None
import os
HF_TOKEN = os.environ.get("HF_TOKEN")
except Exception:
pass
def _get_device(use_gpu: bool = True) -> torch.device:
if use_gpu:
return get_optimal_device()
return torch.device("cpu")
def _load_model(model_name: str, device: torch.device):
cache_key = f"{model_name}_{device}"
if cache_key in _model_cache:
return _model_cache[cache_key]
logger.info(f"Loading model: {model_name} on {device}")
if WHISPERX_AVAILABLE:
compute_type = "float16" if device.type == "cuda" else "int8"
model = whisperx.load_model(
model_name,
device=str(device),
compute_type=compute_type,
)
else:
model = whisper.load_model(model_name, device=device)
_model_cache[cache_key] = model
return model
def transcribe_audio(
file_path: str,
model_name: str = "base",
use_gpu: bool = True,
use_cache: bool = True,
language: Optional[str] = None,
) -> dict:
"""
Transcribe audio/video file and return word-level timestamps.
Returns:
dict with keys: words, segments, language
"""
file_path = Path(file_path)
if use_cache:
cached = load_from_cache(file_path, model_name, "transcribe_wx")
if cached:
logger.info("Using cached transcription")
return cached
video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm"}
if file_path.suffix.lower() in video_extensions:
audio_path = extract_audio(file_path)
else:
audio_path = file_path
device = _get_device(use_gpu)
model = _load_model(model_name, device)
logger.info(f"Transcribing: {file_path}")
if WHISPERX_AVAILABLE:
result = _transcribe_whisperx(model, str(audio_path), device, language)
else:
result = _transcribe_standard(model, str(audio_path), language)
if use_cache:
save_to_cache(file_path, result, model_name, "transcribe_wx")
return result
def _transcribe_whisperx(model, audio_path: str, device: torch.device, language: Optional[str]) -> dict:
audio = whisperx.load_audio(audio_path)
transcribe_opts = {}
if language:
transcribe_opts["language"] = language
result = model.transcribe(audio, batch_size=16, **transcribe_opts)
detected_language = result.get("language", "en")
align_model, align_metadata = whisperx.load_align_model(
language_code=detected_language,
device=str(device),
)
aligned = whisperx.align(
result["segments"],
align_model,
align_metadata,
audio,
str(device),
return_char_alignments=False,
)
words = []
for seg in aligned.get("segments", []):
for w in seg.get("words", []):
words.append({
"word": w.get("word", ""),
"start": round(w.get("start", 0), 3),
"end": round(w.get("end", 0), 3),
"confidence": round(w.get("score", 0), 3),
})
segments = []
for i, seg in enumerate(aligned.get("segments", [])):
seg_words = []
for w in seg.get("words", []):
seg_words.append({
"word": w.get("word", ""),
"start": round(w.get("start", 0), 3),
"end": round(w.get("end", 0), 3),
"confidence": round(w.get("score", 0), 3),
})
segments.append({
"id": i,
"start": round(seg.get("start", 0), 3),
"end": round(seg.get("end", 0), 3),
"text": seg.get("text", "").strip(),
"words": seg_words,
})
return {
"words": words,
"segments": segments,
"language": detected_language,
}
def _transcribe_standard(model, audio_path: str, language: Optional[str]) -> dict:
"""Fallback: standard Whisper (segment-level only, synthesized word timestamps)."""
opts = {}
if language:
opts["language"] = language
result = model.transcribe(audio_path, **opts)
detected_language = result.get("language", "en")
words = []
segments = []
for i, seg in enumerate(result.get("segments", [])):
text = seg.get("text", "").strip()
seg_start = seg.get("start", 0)
seg_end = seg.get("end", 0)
seg_words_text = text.split()
duration = seg_end - seg_start
seg_words = []
for j, w_text in enumerate(seg_words_text):
w_start = seg_start + (j / max(len(seg_words_text), 1)) * duration
w_end = seg_start + ((j + 1) / max(len(seg_words_text), 1)) * duration
word_obj = {
"word": w_text,
"start": round(w_start, 3),
"end": round(w_end, 3),
"confidence": 0.5,
}
words.append(word_obj)
seg_words.append(word_obj)
segments.append({
"id": i,
"start": round(seg_start, 3),
"end": round(seg_end, 3),
"text": text,
"words": seg_words,
})
return {
"words": words,
"segments": segments,
"language": detected_language,
}

View File

@ -0,0 +1,271 @@
"""
FFmpeg-based video cutting engine.
Uses stream copy for fast, lossless cuts and falls back to re-encode when needed.
"""
import logging
import subprocess
import tempfile
import os
from pathlib import Path
from typing import List
logger = logging.getLogger(__name__)
def _find_ffmpeg() -> str:
"""Locate ffmpeg binary."""
for cmd in ["ffmpeg", "ffmpeg.exe"]:
try:
subprocess.run([cmd, "-version"], capture_output=True, check=True)
return cmd
except (FileNotFoundError, subprocess.CalledProcessError):
continue
raise RuntimeError("FFmpeg not found. Install it or add it to PATH.")
def export_stream_copy(
input_path: str,
output_path: str,
keep_segments: List[dict],
) -> str:
"""
Export video using FFmpeg concat demuxer with stream copy.
~100x faster than re-encoding. No quality loss.
Args:
input_path: source video file
output_path: destination file
keep_segments: list of {"start": float, "end": float} to keep
Returns:
output_path on success
"""
ffmpeg = _find_ffmpeg()
input_path = str(Path(input_path).resolve())
output_path = str(Path(output_path).resolve())
if not keep_segments:
raise ValueError("No segments to export")
temp_dir = tempfile.mkdtemp(prefix="aive_export_")
try:
segment_files = []
for i, seg in enumerate(keep_segments):
seg_file = os.path.join(temp_dir, f"seg_{i:04d}.ts")
cmd = [
ffmpeg, "-y",
"-ss", str(seg["start"]),
"-to", str(seg["end"]),
"-i", input_path,
"-c", "copy",
"-avoid_negative_ts", "make_zero",
"-f", "mpegts",
seg_file,
]
logger.info(f"Extracting segment {i}: {seg['start']:.2f}s - {seg['end']:.2f}s")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.warning(f"Stream copy segment {i} failed, will try re-encode: {result.stderr[-200:]}")
return export_reencode(input_path, output_path, keep_segments)
segment_files.append(seg_file)
concat_str = "|".join(segment_files)
cmd = [
ffmpeg, "-y",
"-i", f"concat:{concat_str}",
"-c", "copy",
"-movflags", "+faststart",
output_path,
]
logger.info(f"Concatenating {len(segment_files)} segments -> {output_path}")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
logger.warning(f"Concat failed, falling back to re-encode: {result.stderr[-200:]}")
return export_reencode(input_path, output_path, keep_segments)
return output_path
finally:
for f in os.listdir(temp_dir):
try:
os.remove(os.path.join(temp_dir, f))
except OSError:
pass
try:
os.rmdir(temp_dir)
except OSError:
pass
def export_reencode(
input_path: str,
output_path: str,
keep_segments: List[dict],
resolution: str = "1080p",
format_hint: str = "mp4",
) -> str:
"""
Export video with full re-encode. Slower but supports resolution changes,
format conversion, and avoids stream-copy edge cases.
"""
ffmpeg = _find_ffmpeg()
input_path = str(Path(input_path).resolve())
output_path = str(Path(output_path).resolve())
if not keep_segments:
raise ValueError("No segments to export")
scale_map = {
"720p": "scale=-2:720",
"1080p": "scale=-2:1080",
"4k": "scale=-2:2160",
}
filter_parts = []
for i, seg in enumerate(keep_segments):
filter_parts.append(
f"[0:v]trim=start={seg['start']}:end={seg['end']},setpts=PTS-STARTPTS[v{i}];"
f"[0:a]atrim=start={seg['start']}:end={seg['end']},asetpts=PTS-STARTPTS[a{i}];"
)
n = len(keep_segments)
concat_inputs = "".join(f"[v{i}][a{i}]" for i in range(n))
filter_parts.append(f"{concat_inputs}concat=n={n}:v=1:a=1[outv][outa]")
filter_complex = "".join(filter_parts)
scale = scale_map.get(resolution, "")
if scale:
filter_complex += f";[outv]{scale}[outv_scaled]"
video_map = "[outv_scaled]"
else:
video_map = "[outv]"
codec_args = ["-c:v", "libx264", "-preset", "medium", "-crf", "18", "-c:a", "aac", "-b:a", "192k"]
if format_hint == "webm":
codec_args = ["-c:v", "libvpx-vp9", "-crf", "30", "-b:v", "0", "-c:a", "libopus"]
cmd = [
ffmpeg, "-y",
"-i", input_path,
"-filter_complex", filter_complex,
"-map", video_map,
"-map", "[outa]",
*codec_args,
"-movflags", "+faststart",
output_path,
]
logger.info(f"Re-encoding {n} segments -> {output_path} ({resolution})")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg re-encode failed: {result.stderr[-500:]}")
return output_path
def export_reencode_with_subs(
input_path: str,
output_path: str,
keep_segments: List[dict],
subtitle_path: str,
resolution: str = "1080p",
format_hint: str = "mp4",
) -> str:
"""
Export video with re-encode and burn-in subtitles (ASS format).
Applies trim+concat first, then overlays the subtitle file.
"""
ffmpeg = _find_ffmpeg()
input_path = str(Path(input_path).resolve())
output_path = str(Path(output_path).resolve())
subtitle_path = str(Path(subtitle_path).resolve())
if not keep_segments:
raise ValueError("No segments to export")
scale_map = {
"720p": "scale=-2:720",
"1080p": "scale=-2:1080",
"4k": "scale=-2:2160",
}
filter_parts = []
for i, seg in enumerate(keep_segments):
filter_parts.append(
f"[0:v]trim=start={seg['start']}:end={seg['end']},setpts=PTS-STARTPTS[v{i}];"
f"[0:a]atrim=start={seg['start']}:end={seg['end']},asetpts=PTS-STARTPTS[a{i}];"
)
n = len(keep_segments)
concat_inputs = "".join(f"[v{i}][a{i}]" for i in range(n))
filter_parts.append(f"{concat_inputs}concat=n={n}:v=1:a=1[outv][outa]")
filter_complex = "".join(filter_parts)
# Escape path for FFmpeg subtitle filter (Windows backslashes need escaping)
escaped_sub = subtitle_path.replace("\\", "/").replace(":", "\\:")
scale = scale_map.get(resolution, "")
if scale:
filter_complex += f";[outv]{scale},ass='{escaped_sub}'[outv_final]"
else:
filter_complex += f";[outv]ass='{escaped_sub}'[outv_final]"
video_map = "[outv_final]"
codec_args = ["-c:v", "libx264", "-preset", "medium", "-crf", "18", "-c:a", "aac", "-b:a", "192k"]
if format_hint == "webm":
codec_args = ["-c:v", "libvpx-vp9", "-crf", "30", "-b:v", "0", "-c:a", "libopus"]
cmd = [
ffmpeg, "-y",
"-i", input_path,
"-filter_complex", filter_complex,
"-map", video_map,
"-map", "[outa]",
*codec_args,
"-movflags", "+faststart",
output_path,
]
logger.info(f"Re-encoding {n} segments with subtitles -> {output_path} ({resolution})")
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg re-encode with subs failed: {result.stderr[-500:]}")
return output_path
def get_video_info(input_path: str) -> dict:
"""Get basic video metadata using ffprobe."""
ffmpeg = _find_ffmpeg()
ffprobe = ffmpeg.replace("ffmpeg", "ffprobe")
cmd = [
ffprobe, "-v", "quiet",
"-print_format", "json",
"-show_format", "-show_streams",
str(input_path),
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
import json
data = json.loads(result.stdout)
fmt = data.get("format", {})
video_stream = next((s for s in data.get("streams", []) if s.get("codec_type") == "video"), {})
return {
"duration": float(fmt.get("duration", 0)),
"size": int(fmt.get("size", 0)),
"format": fmt.get("format_name", ""),
"width": int(video_stream.get("width", 0)),
"height": int(video_stream.get("height", 0)),
"codec": video_stream.get("codec_name", ""),
"fps": eval(video_stream.get("r_frame_rate", "0/1")) if "/" in video_stream.get("r_frame_rate", "") else 0,
}
except Exception as e:
logger.error(f"Failed to get video info: {e}")
return {}

View File

View File

@ -0,0 +1,59 @@
from pathlib import Path
import tempfile
import os
import logging
try:
from moviepy import AudioFileClip
except ImportError:
from moviepy.editor import AudioFileClip
logger = logging.getLogger(__name__)
_temp_audio_files = []
def extract_audio(video_path: Path):
"""Extract audio from a video file into a temp directory for automatic cleanup."""
try:
audio = AudioFileClip(str(video_path))
temp_dir = tempfile.mkdtemp(prefix="videotranscriber_")
audio_path = Path(temp_dir) / f"{video_path.stem}_audio.wav"
try:
audio.write_audiofile(str(audio_path), logger=None)
except TypeError:
# moviepy 1.x uses verbose parameter; moviepy 2.x removed it
audio.write_audiofile(str(audio_path), verbose=False, logger=None)
audio.close()
_temp_audio_files.append(str(audio_path))
return audio_path
except Exception as e:
raise RuntimeError(f"Audio extraction failed: {e}")
def cleanup_temp_audio():
"""Remove all temporary audio files created during processing."""
cleaned = 0
for fpath in _temp_audio_files:
try:
if os.path.exists(fpath):
os.remove(fpath)
parent = os.path.dirname(fpath)
if os.path.isdir(parent) and not os.listdir(parent):
os.rmdir(parent)
cleaned += 1
except Exception as e:
logger.warning(f"Could not remove temp file {fpath}: {e}")
_temp_audio_files.clear()
return cleaned
def get_video_duration(video_path: Path):
"""Get duration of a video/audio file in seconds."""
try:
clip = AudioFileClip(str(video_path))
duration = clip.duration
clip.close()
return duration
except Exception:
return None

205
backend/utils/cache.py Normal file
View File

@ -0,0 +1,205 @@
"""
Caching utilities for the OBS Recording Transcriber.
Provides functions to cache and retrieve transcription and summarization results.
"""
import json
import hashlib
import os
from pathlib import Path
import logging
import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default cache directory
CACHE_DIR = Path.home() / ".obs_transcriber_cache"
def get_file_hash(file_path):
"""
Generate a hash for a file based on its content and modification time.
Args:
file_path (Path): Path to the file
Returns:
str: Hash string representing the file
"""
file_path = Path(file_path)
if not file_path.exists():
return None
# Get file stats
stats = file_path.stat()
file_size = stats.st_size
mod_time = stats.st_mtime
# Create a hash based on path, size and modification time
# This is faster than hashing the entire file content
hash_input = f"{file_path.absolute()}|{file_size}|{mod_time}"
return hashlib.md5(hash_input.encode()).hexdigest()
def get_cache_path(file_path, model=None, operation=None):
"""
Get the cache file path for a given input file and operation.
Args:
file_path (Path): Path to the original file
model (str, optional): Model used for processing
operation (str, optional): Operation type (e.g., 'transcribe', 'summarize')
Returns:
Path: Path to the cache file
"""
file_path = Path(file_path)
file_hash = get_file_hash(file_path)
if not file_hash:
return None
# Create cache directory if it doesn't exist
cache_dir = CACHE_DIR
cache_dir.mkdir(parents=True, exist_ok=True)
# Create a cache filename based on the hash and optional parameters
cache_name = file_hash
if model:
cache_name += f"_{model}"
if operation:
cache_name += f"_{operation}"
return cache_dir / f"{cache_name}.json"
def save_to_cache(file_path, data, model=None, operation=None):
"""
Save data to cache.
Args:
file_path (Path): Path to the original file
data (dict): Data to cache
model (str, optional): Model used for processing
operation (str, optional): Operation type
Returns:
bool: True if successful, False otherwise
"""
cache_path = get_cache_path(file_path, model, operation)
if not cache_path:
return False
try:
# Add metadata to the cached data
cache_data = {
"original_file": str(Path(file_path).absolute()),
"timestamp": time.time(),
"model": model,
"operation": operation,
"data": data
}
with open(cache_path, 'w', encoding='utf-8') as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
logger.info(f"Cached data saved to {cache_path}")
return True
except Exception as e:
logger.error(f"Error saving cache: {e}")
return False
def load_from_cache(file_path, model=None, operation=None, max_age=None):
"""
Load data from cache if available and not expired.
Args:
file_path (Path): Path to the original file
model (str, optional): Model used for processing
operation (str, optional): Operation type
max_age (float, optional): Maximum age of cache in seconds
Returns:
dict or None: Cached data or None if not available
"""
cache_path = get_cache_path(file_path, model, operation)
if not cache_path or not cache_path.exists():
return None
try:
with open(cache_path, 'r', encoding='utf-8') as f:
cache_data = json.load(f)
# Check if cache is expired
if max_age is not None:
cache_time = cache_data.get("timestamp", 0)
if time.time() - cache_time > max_age:
logger.info(f"Cache expired for {file_path}")
return None
logger.info(f"Loaded data from cache: {cache_path}")
return cache_data.get("data")
except Exception as e:
logger.error(f"Error loading cache: {e}")
return None
def clear_cache(max_age=None):
"""
Clear all cache files or only expired ones.
Args:
max_age (float, optional): Maximum age of cache in seconds
Returns:
int: Number of files deleted
"""
if not CACHE_DIR.exists():
return 0
count = 0
for cache_file in CACHE_DIR.glob("*.json"):
try:
if max_age is not None:
# Check if file is expired
with open(cache_file, 'r', encoding='utf-8') as f:
cache_data = json.load(f)
cache_time = cache_data.get("timestamp", 0)
if time.time() - cache_time <= max_age:
continue # Skip non-expired files
# Delete the file
os.remove(cache_file)
count += 1
except Exception as e:
logger.error(f"Error deleting cache file {cache_file}: {e}")
logger.info(f"Cleared {count} cache files")
return count
def get_cache_size():
"""
Get the total size of the cache directory.
Returns:
tuple: (size_bytes, file_count)
"""
if not CACHE_DIR.exists():
return 0, 0
total_size = 0
file_count = 0
for cache_file in CACHE_DIR.glob("*.json"):
try:
total_size += cache_file.stat().st_size
file_count += 1
except Exception:
pass
return total_size, file_count

196
backend/utils/gpu_utils.py Normal file
View File

@ -0,0 +1,196 @@
"""
GPU utilities for the Video Transcriber.
Provides functions to detect and configure GPU acceleration.
"""
import logging
import torch
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_gpu_info():
"""
Get information about available GPUs.
Returns:
dict: Information about available GPUs
"""
gpu_info = {
"cuda_available": torch.cuda.is_available(),
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"cuda_devices": [],
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
}
# Get CUDA device information
if gpu_info["cuda_available"]:
for i in range(gpu_info["cuda_device_count"]):
device_props = torch.cuda.get_device_properties(i)
gpu_info["cuda_devices"].append({
"index": i,
"name": device_props.name,
"total_memory": device_props.total_memory,
"compute_capability": f"{device_props.major}.{device_props.minor}"
})
return gpu_info
def get_optimal_device():
"""
Get the optimal device for computation.
Returns:
torch.device: The optimal device (cuda, mps, or cpu)
"""
if torch.cuda.is_available():
# If multiple GPUs are available, select the one with the most memory
if torch.cuda.device_count() > 1:
max_memory = 0
best_device = 0
for i in range(torch.cuda.device_count()):
device_props = torch.cuda.get_device_properties(i)
if device_props.total_memory > max_memory:
max_memory = device_props.total_memory
best_device = i
return torch.device(f"cuda:{best_device}")
return torch.device("cuda:0")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def set_memory_limits(memory_fraction=0.8):
"""
Set memory limits for GPU usage.
Args:
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
Returns:
bool: True if successful, False otherwise
"""
if not torch.cuda.is_available():
return False
try:
# Set memory fraction for each device
for i in range(torch.cuda.device_count()):
torch.cuda.set_per_process_memory_fraction(memory_fraction, i)
return True
except Exception as e:
logger.error(f"Error setting memory limits: {e}")
return False
def optimize_for_inference():
"""
Apply optimizations for inference.
Returns:
bool: True if successful, False otherwise
"""
try:
# Set deterministic algorithms for reproducibility
torch.backends.cudnn.deterministic = True
# Enable cuDNN benchmark mode for optimized performance
torch.backends.cudnn.benchmark = True
# Disable gradient calculation for inference
torch.set_grad_enabled(False)
return True
except Exception as e:
logger.error(f"Error optimizing for inference: {e}")
return False
def get_recommended_batch_size(model_size="base"):
"""
Get recommended batch size based on available GPU memory.
Args:
model_size (str): Size of the model (tiny, base, small, medium, large)
Returns:
int: Recommended batch size
"""
# Default batch sizes for CPU
default_batch_sizes = {
"tiny": 16,
"base": 8,
"small": 4,
"medium": 2,
"large": 1
}
# If CUDA is not available, return default CPU batch size
if not torch.cuda.is_available():
return default_batch_sizes.get(model_size, 1)
# Approximate memory requirements in GB for different model sizes
memory_requirements = {
"tiny": 1,
"base": 2,
"small": 4,
"medium": 8,
"large": 16
}
# Get available GPU memory
device = get_optimal_device()
if device.type == "cuda":
device_idx = device.index
device_props = torch.cuda.get_device_properties(device_idx)
available_memory_gb = device_props.total_memory / (1024 ** 3)
# Calculate batch size based on available memory
model_memory = memory_requirements.get(model_size, 2)
max_batch_size = int(available_memory_gb / model_memory)
# Ensure batch size is at least 1
return max(1, max_batch_size)
# For MPS or other devices, return default
return default_batch_sizes.get(model_size, 1)
def configure_gpu(model_size="base", memory_fraction=0.8):
"""
Configure GPU settings for optimal performance.
Args:
model_size (str): Size of the model (tiny, base, small, medium, large)
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
Returns:
dict: Configuration information
"""
gpu_info = get_gpu_info()
device = get_optimal_device()
# Set memory limits if using CUDA
if device.type == "cuda":
set_memory_limits(memory_fraction)
# Apply inference optimizations
optimize_for_inference()
# Get recommended batch size
batch_size = get_recommended_batch_size(model_size)
config = {
"device": device,
"batch_size": batch_size,
"gpu_info": gpu_info,
"memory_fraction": memory_fraction if device.type == "cuda" else None
}
logger.info(f"GPU configuration: Using {device} with batch size {batch_size}")
return config