"""Transcription endpoint using WhisperX.""" import logging from typing import Optional from fastapi import APIRouter, HTTPException from pydantic import BaseModel from services.transcription import transcribe_audio from services.diarization import diarize_and_label logger = logging.getLogger(__name__) router = APIRouter() class TranscribeRequest(BaseModel): file_path: str model: str = "base" language: Optional[str] = None use_gpu: bool = True use_cache: bool = True diarize: bool = False hf_token: Optional[str] = None num_speakers: Optional[int] = None @router.post("/transcribe") async def transcribe(req: TranscribeRequest): try: result = transcribe_audio( file_path=req.file_path, model_name=req.model, use_gpu=req.use_gpu, use_cache=req.use_cache, language=req.language, ) if req.diarize and req.hf_token: result = diarize_and_label( transcription_result=result, audio_path=req.file_path, hf_token=req.hf_token, num_speakers=req.num_speakers, use_gpu=req.use_gpu, ) return result except FileNotFoundError: raise HTTPException(status_code=404, detail=f"File not found: {req.file_path}") except Exception as e: logger.error(f"Transcription failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) class ReTranscribeSegmentRequest(BaseModel): file_path: str start: float end: float model: str = "base" language: Optional[str] = None @router.post("/transcribe/segment") async def transcribe_segment(req: ReTranscribeSegmentRequest): """ Re-transcribe a specific segment of audio. Extracts the segment with FFmpeg, transcribes it, and returns words with timestamps adjusted to the original file timeline. """ import subprocess import tempfile import os try: # Extract the segment to a temp file tmp_dir = tempfile.mkdtemp(prefix="talkedit_segment_") segment_path = os.path.join(tmp_dir, "segment.wav") cmd = [ "ffmpeg", "-y", "-i", req.file_path, "-ss", str(req.start), "-to", str(req.end), "-vn", "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", segment_path, ] result = subprocess.run(cmd, capture_output=True, text=True) if result.returncode != 0: raise RuntimeError(f"Segment extraction failed: {result.stderr[-300:]}") # Transcribe the segment — try GPU first, fall back to CPU try: segment_result = transcribe_audio( file_path=segment_path, model_name=req.model, use_gpu=True, use_cache=False, language=req.language, ) except Exception as gpu_err: logger.warning(f"GPU transcription failed (%s), falling back to CPU", gpu_err) segment_result = transcribe_audio( file_path=segment_path, model_name=req.model, use_gpu=False, use_cache=False, language=req.language, ) # Adjust timestamps to be relative to the original file offset = req.start adjusted_words = [] for w in segment_result.get("words", []): w["start"] = round(w["start"] + offset, 3) w["end"] = round(w["end"] + offset, 3) adjusted_words.append(w) adjusted_segments = [] for seg in segment_result.get("segments", []): seg["start"] = round(seg["start"] + offset, 3) seg["end"] = round(seg["end"] + offset, 3) # Also adjust words within each segment for w in seg.get("words", []): w["start"] = round(w["start"] + offset, 3) w["end"] = round(w["end"] + offset, 3) adjusted_segments.append(seg) # Cleanup try: os.remove(segment_path) os.rmdir(tmp_dir) except OSError: pass return { "words": adjusted_words, "segments": adjusted_segments, "language": segment_result.get("language", "en"), } except FileNotFoundError: raise HTTPException(status_code=404, detail=f"File not found: {req.file_path}") except Exception as e: logger.error(f"Segment transcription failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=str(e))