Files
TalkEdit/backend/routers/transcribe.py

150 lines
4.6 KiB
Python
Raw Normal View History

"""Transcription endpoint using WhisperX."""
import logging
from typing import Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from services.transcription import transcribe_audio
from services.diarization import diarize_and_label
logger = logging.getLogger(__name__)
router = APIRouter()
class TranscribeRequest(BaseModel):
file_path: str
model: str = "base"
language: Optional[str] = None
use_gpu: bool = True
use_cache: bool = True
diarize: bool = False
hf_token: Optional[str] = None
num_speakers: Optional[int] = None
@router.post("/transcribe")
async def transcribe(req: TranscribeRequest):
try:
result = transcribe_audio(
file_path=req.file_path,
model_name=req.model,
use_gpu=req.use_gpu,
use_cache=req.use_cache,
language=req.language,
)
if req.diarize and req.hf_token:
result = diarize_and_label(
transcription_result=result,
audio_path=req.file_path,
hf_token=req.hf_token,
num_speakers=req.num_speakers,
use_gpu=req.use_gpu,
)
return result
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"File not found: {req.file_path}")
except Exception as e:
logger.error(f"Transcription failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
2026-05-04 23:54:14 -06:00
class ReTranscribeSegmentRequest(BaseModel):
file_path: str
start: float
end: float
model: str = "base"
language: Optional[str] = None
@router.post("/transcribe/segment")
async def transcribe_segment(req: ReTranscribeSegmentRequest):
"""
Re-transcribe a specific segment of audio.
Extracts the segment with FFmpeg, transcribes it, and returns words
with timestamps adjusted to the original file timeline.
"""
import subprocess
import tempfile
import os
try:
# Extract the segment to a temp file
tmp_dir = tempfile.mkdtemp(prefix="talkedit_segment_")
segment_path = os.path.join(tmp_dir, "segment.wav")
cmd = [
"ffmpeg", "-y",
"-i", req.file_path,
"-ss", str(req.start),
"-to", str(req.end),
"-vn",
"-acodec", "pcm_s16le",
"-ar", "16000",
"-ac", "1",
segment_path,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"Segment extraction failed: {result.stderr[-300:]}")
# Transcribe the segment — try GPU first, fall back to CPU
try:
segment_result = transcribe_audio(
file_path=segment_path,
model_name=req.model,
use_gpu=True,
use_cache=False,
language=req.language,
)
except Exception as gpu_err:
logger.warning(f"GPU transcription failed (%s), falling back to CPU", gpu_err)
segment_result = transcribe_audio(
file_path=segment_path,
model_name=req.model,
use_gpu=False,
use_cache=False,
language=req.language,
)
# Adjust timestamps to be relative to the original file
offset = req.start
adjusted_words = []
for w in segment_result.get("words", []):
w["start"] = round(w["start"] + offset, 3)
w["end"] = round(w["end"] + offset, 3)
adjusted_words.append(w)
adjusted_segments = []
for seg in segment_result.get("segments", []):
seg["start"] = round(seg["start"] + offset, 3)
seg["end"] = round(seg["end"] + offset, 3)
# Also adjust words within each segment
for w in seg.get("words", []):
w["start"] = round(w["start"] + offset, 3)
w["end"] = round(w["end"] + offset, 3)
adjusted_segments.append(seg)
# Cleanup
try:
os.remove(segment_path)
os.rmdir(tmp_dir)
except OSError:
pass
return {
"words": adjusted_words,
"segments": adjusted_segments,
"language": segment_result.get("language", "en"),
}
except FileNotFoundError:
raise HTTPException(status_code=404, detail=f"File not found: {req.file_path}")
except Exception as e:
logger.error(f"Segment transcription failed: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))