Files
TalkEdit/backend/services/transcription.py

206 lines
5.8 KiB
Python
Raw Normal View History

"""
WhisperX-based transcription service with word-level alignment.
Falls back to standard Whisper if WhisperX is not available.
"""
import logging
from pathlib import Path
from typing import Optional
import torch
from utils.gpu_utils import get_optimal_device, configure_gpu
from utils.audio_processing import extract_audio
from utils.cache import load_from_cache, save_to_cache
logger = logging.getLogger(__name__)
_model_cache: dict = {}
try:
import whisperx
WHISPERX_AVAILABLE = True
except ImportError:
WHISPERX_AVAILABLE = False
import whisper
try:
HF_TOKEN = None
import os
HF_TOKEN = os.environ.get("HF_TOKEN")
except Exception:
pass
def _get_device(use_gpu: bool = True) -> torch.device:
if use_gpu:
return get_optimal_device()
return torch.device("cpu")
def _load_model(model_name: str, device: torch.device):
cache_key = f"{model_name}_{device}"
if cache_key in _model_cache:
return _model_cache[cache_key]
logger.info(f"Loading model: {model_name} on {device}")
if WHISPERX_AVAILABLE:
compute_type = "float16" if device.type == "cuda" else "int8"
model = whisperx.load_model(
model_name,
2026-05-04 23:54:14 -06:00
device=device.type, # use "cuda" not "cuda:0" — some WhisperX versions don't support device ordinal
compute_type=compute_type,
)
else:
2026-05-04 23:54:14 -06:00
model = whisper.load_model(model_name, device=str(device))
_model_cache[cache_key] = model
return model
def transcribe_audio(
file_path: str,
model_name: str = "base",
use_gpu: bool = True,
use_cache: bool = True,
language: Optional[str] = None,
) -> dict:
"""
Transcribe audio/video file and return word-level timestamps.
Returns:
dict with keys: words, segments, language
"""
file_path = Path(file_path)
if use_cache:
cached = load_from_cache(file_path, model_name, "transcribe_wx")
if cached:
logger.info("Using cached transcription")
return cached
video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm"}
if file_path.suffix.lower() in video_extensions:
audio_path = extract_audio(file_path)
else:
audio_path = file_path
device = _get_device(use_gpu)
model = _load_model(model_name, device)
logger.info(f"Transcribing: {file_path}")
if WHISPERX_AVAILABLE:
result = _transcribe_whisperx(model, str(audio_path), device, language)
else:
result = _transcribe_standard(model, str(audio_path), language)
if use_cache:
save_to_cache(file_path, result, model_name, "transcribe_wx")
return result
def _transcribe_whisperx(model, audio_path: str, device: torch.device, language: Optional[str]) -> dict:
audio = whisperx.load_audio(audio_path)
transcribe_opts = {}
if language:
transcribe_opts["language"] = language
result = model.transcribe(audio, batch_size=16, **transcribe_opts)
detected_language = result.get("language", "en")
align_model, align_metadata = whisperx.load_align_model(
language_code=detected_language,
2026-05-04 23:54:14 -06:00
device=device.type,
)
aligned = whisperx.align(
result["segments"],
align_model,
align_metadata,
audio,
str(device),
return_char_alignments=False,
)
words = []
for seg in aligned.get("segments", []):
for w in seg.get("words", []):
words.append({
"word": w.get("word", ""),
"start": round(w.get("start", 0), 3),
"end": round(w.get("end", 0), 3),
"confidence": round(w.get("score", 0), 3),
})
segments = []
for i, seg in enumerate(aligned.get("segments", [])):
seg_words = []
for w in seg.get("words", []):
seg_words.append({
"word": w.get("word", ""),
"start": round(w.get("start", 0), 3),
"end": round(w.get("end", 0), 3),
"confidence": round(w.get("score", 0), 3),
})
segments.append({
"id": i,
"start": round(seg.get("start", 0), 3),
"end": round(seg.get("end", 0), 3),
"text": seg.get("text", "").strip(),
"words": seg_words,
})
return {
"words": words,
"segments": segments,
"language": detected_language,
}
def _transcribe_standard(model, audio_path: str, language: Optional[str]) -> dict:
"""Fallback: standard Whisper (segment-level only, synthesized word timestamps)."""
opts = {}
if language:
opts["language"] = language
result = model.transcribe(audio_path, **opts)
detected_language = result.get("language", "en")
words = []
segments = []
for i, seg in enumerate(result.get("segments", [])):
text = seg.get("text", "").strip()
seg_start = seg.get("start", 0)
seg_end = seg.get("end", 0)
seg_words_text = text.split()
duration = seg_end - seg_start
seg_words = []
for j, w_text in enumerate(seg_words_text):
w_start = seg_start + (j / max(len(seg_words_text), 1)) * duration
w_end = seg_start + ((j + 1) / max(len(seg_words_text), 1)) * duration
word_obj = {
"word": w_text,
"start": round(w_start, 3),
"end": round(w_end, 3),
"confidence": 0.5,
}
words.append(word_obj)
seg_words.append(word_obj)
segments.append({
"id": i,
"start": round(seg_start, 3),
"end": round(seg_end, 3),
"text": text,
"words": seg_words,
})
return {
"words": words,
"segments": segments,
"language": detected_language,
}