""" WhisperX-based transcription service with word-level alignment. Falls back to standard Whisper if WhisperX is not available. """ import logging from pathlib import Path from typing import Optional import torch from utils.gpu_utils import get_optimal_device, configure_gpu from utils.audio_processing import extract_audio from utils.cache import load_from_cache, save_to_cache logger = logging.getLogger(__name__) _model_cache: dict = {} try: import whisperx WHISPERX_AVAILABLE = True except ImportError: WHISPERX_AVAILABLE = False import whisper try: HF_TOKEN = None import os HF_TOKEN = os.environ.get("HF_TOKEN") except Exception: pass def _get_device(use_gpu: bool = True) -> torch.device: if use_gpu: return get_optimal_device() return torch.device("cpu") def _load_model(model_name: str, device: torch.device): cache_key = f"{model_name}_{device}" if cache_key in _model_cache: return _model_cache[cache_key] logger.info(f"Loading model: {model_name} on {device}") if WHISPERX_AVAILABLE: compute_type = "float16" if device.type == "cuda" else "int8" model = whisperx.load_model( model_name, device=device.type, # use "cuda" not "cuda:0" — some WhisperX versions don't support device ordinal compute_type=compute_type, ) else: model = whisper.load_model(model_name, device=str(device)) _model_cache[cache_key] = model return model def transcribe_audio( file_path: str, model_name: str = "base", use_gpu: bool = True, use_cache: bool = True, language: Optional[str] = None, ) -> dict: """ Transcribe audio/video file and return word-level timestamps. Returns: dict with keys: words, segments, language """ file_path = Path(file_path) if use_cache: cached = load_from_cache(file_path, model_name, "transcribe_wx") if cached: logger.info("Using cached transcription") return cached video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm"} if file_path.suffix.lower() in video_extensions: audio_path = extract_audio(file_path) else: audio_path = file_path device = _get_device(use_gpu) model = _load_model(model_name, device) logger.info(f"Transcribing: {file_path}") if WHISPERX_AVAILABLE: result = _transcribe_whisperx(model, str(audio_path), device, language) else: result = _transcribe_standard(model, str(audio_path), language) if use_cache: save_to_cache(file_path, result, model_name, "transcribe_wx") return result def _transcribe_whisperx(model, audio_path: str, device: torch.device, language: Optional[str]) -> dict: audio = whisperx.load_audio(audio_path) transcribe_opts = {} if language: transcribe_opts["language"] = language result = model.transcribe(audio, batch_size=16, **transcribe_opts) detected_language = result.get("language", "en") align_model, align_metadata = whisperx.load_align_model( language_code=detected_language, device=device.type, ) aligned = whisperx.align( result["segments"], align_model, align_metadata, audio, str(device), return_char_alignments=False, ) words = [] for seg in aligned.get("segments", []): for w in seg.get("words", []): words.append({ "word": w.get("word", ""), "start": round(w.get("start", 0), 3), "end": round(w.get("end", 0), 3), "confidence": round(w.get("score", 0), 3), }) segments = [] for i, seg in enumerate(aligned.get("segments", [])): seg_words = [] for w in seg.get("words", []): seg_words.append({ "word": w.get("word", ""), "start": round(w.get("start", 0), 3), "end": round(w.get("end", 0), 3), "confidence": round(w.get("score", 0), 3), }) segments.append({ "id": i, "start": round(seg.get("start", 0), 3), "end": round(seg.get("end", 0), 3), "text": seg.get("text", "").strip(), "words": seg_words, }) return { "words": words, "segments": segments, "language": detected_language, } def _transcribe_standard(model, audio_path: str, language: Optional[str]) -> dict: """Fallback: standard Whisper (segment-level only, synthesized word timestamps).""" opts = {} if language: opts["language"] = language result = model.transcribe(audio_path, **opts) detected_language = result.get("language", "en") words = [] segments = [] for i, seg in enumerate(result.get("segments", [])): text = seg.get("text", "").strip() seg_start = seg.get("start", 0) seg_end = seg.get("end", 0) seg_words_text = text.split() duration = seg_end - seg_start seg_words = [] for j, w_text in enumerate(seg_words_text): w_start = seg_start + (j / max(len(seg_words_text), 1)) * duration w_end = seg_start + ((j + 1) / max(len(seg_words_text), 1)) * duration word_obj = { "word": w_text, "start": round(w_start, 3), "end": round(w_end, 3), "confidence": 0.5, } words.append(word_obj) seg_words.append(word_obj) segments.append({ "id": i, "start": round(seg_start, 3), "end": round(seg_end, 3), "text": text, "words": seg_words, }) return { "words": words, "segments": segments, "language": detected_language, }