Files
TalkEdit/backend/routers/transcribe.py
2026-05-04 23:54:14 -06:00

150 lines
4.6 KiB
Python

"""Transcription endpoint using WhisperX."""
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from services.transcription import transcribe_audio
from services.diarization import diarize_and_label
logger = logging.getLogger(__name__)
router = APIRouter()
class TranscribeRequest(BaseModel):
file_path: str
model: str = "base"
language: Optional[str] = None
use_gpu: bool = True
use_cache: bool = True
diarize: bool = False
hf_token: Optional[str] = None
num_speakers: Optional[int] = None
@router.post("/transcribe")
async def transcribe(req: TranscribeRequest):
try:
result = transcribe_audio(
file_path=req.file_path,
model_name=req.model,
use_gpu=req.use_gpu,
use_cache=req.use_cache,
language=req.language,
)
if req.diarize and req.hf_token:
result = diarize_and_label(
transcription_result=result,
audio_path=req.file_path,
hf_token=req.hf_token,
num_speakers=req.num_speakers,
use_gpu=req.use_gpu,
)
return result
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"File not found: {req.file_path}")
except Exception as e:
logger.error(f"Transcription failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
class ReTranscribeSegmentRequest(BaseModel):
file_path: str
start: float
end: float
model: str = "base"
language: Optional[str] = None
@router.post("/transcribe/segment")
async def transcribe_segment(req: ReTranscribeSegmentRequest):
"""
Re-transcribe a specific segment of audio.
Extracts the segment with FFmpeg, transcribes it, and returns words
with timestamps adjusted to the original file timeline.
"""
import subprocess
import tempfile
import os
try:
# Extract the segment to a temp file
tmp_dir = tempfile.mkdtemp(prefix="talkedit_segment_")
segment_path = os.path.join(tmp_dir, "segment.wav")
cmd = [
"ffmpeg", "-y",
"-i", req.file_path,
"-ss", str(req.start),
"-to", str(req.end),
"-vn",
"-acodec", "pcm_s16le",
"-ar", "16000",
"-ac", "1",
segment_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Segment extraction failed: {result.stderr[-300:]}")
# Transcribe the segment — try GPU first, fall back to CPU
try:
segment_result = transcribe_audio(
file_path=segment_path,
model_name=req.model,
use_gpu=True,
use_cache=False,
language=req.language,
)
except Exception as gpu_err:
logger.warning(f"GPU transcription failed (%s), falling back to CPU", gpu_err)
segment_result = transcribe_audio(
file_path=segment_path,
model_name=req.model,
use_gpu=False,
use_cache=False,
language=req.language,
)
# Adjust timestamps to be relative to the original file
offset = req.start
adjusted_words = []
for w in segment_result.get("words", []):
w["start"] = round(w["start"] + offset, 3)
w["end"] = round(w["end"] + offset, 3)
adjusted_words.append(w)
adjusted_segments = []
for seg in segment_result.get("segments", []):
seg["start"] = round(seg["start"] + offset, 3)
seg["end"] = round(seg["end"] + offset, 3)
# Also adjust words within each segment
for w in seg.get("words", []):
w["start"] = round(w["start"] + offset, 3)
w["end"] = round(w["end"] + offset, 3)
adjusted_segments.append(seg)
# Cleanup
try:
os.remove(segment_path)
os.rmdir(tmp_dir)
except OSError:
pass
return {
"words": adjusted_words,
"segments": adjusted_segments,
"language": segment_result.get("language", "en"),
}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"File not found: {req.file_path}")
except Exception as e:
logger.error(f"Segment transcription failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))