2026-03-03 06:31:04 -05:00
|
|
|
"""Transcription endpoint using WhisperX."""
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
|
|
from fastapi import APIRouter, HTTPException
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
|
|
|
|
from services.transcription import transcribe_audio
|
|
|
|
|
from services.diarization import diarize_and_label
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TranscribeRequest(BaseModel):
|
|
|
|
|
file_path: str
|
|
|
|
|
model: str = "base"
|
|
|
|
|
language: Optional[str] = None
|
|
|
|
|
use_gpu: bool = True
|
|
|
|
|
use_cache: bool = True
|
|
|
|
|
diarize: bool = False
|
|
|
|
|
hf_token: Optional[str] = None
|
|
|
|
|
num_speakers: Optional[int] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/transcribe")
|
|
|
|
|
async def transcribe(req: TranscribeRequest):
|
|
|
|
|
try:
|
|
|
|
|
result = transcribe_audio(
|
|
|
|
|
file_path=req.file_path,
|
|
|
|
|
model_name=req.model,
|
|
|
|
|
use_gpu=req.use_gpu,
|
|
|
|
|
use_cache=req.use_cache,
|
|
|
|
|
language=req.language,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if req.diarize and req.hf_token:
|
|
|
|
|
result = diarize_and_label(
|
|
|
|
|
transcription_result=result,
|
|
|
|
|
audio_path=req.file_path,
|
|
|
|
|
hf_token=req.hf_token,
|
|
|
|
|
num_speakers=req.num_speakers,
|
|
|
|
|
use_gpu=req.use_gpu,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
raise HTTPException(status_code=404, detail=f"File not found: {req.file_path}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Transcription failed: {e}", exc_info=True)
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
2026-05-04 23:54:14 -06:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReTranscribeSegmentRequest(BaseModel):
|
|
|
|
|
file_path: str
|
|
|
|
|
start: float
|
|
|
|
|
end: float
|
|
|
|
|
model: str = "base"
|
|
|
|
|
language: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@router.post("/transcribe/segment")
|
|
|
|
|
async def transcribe_segment(req: ReTranscribeSegmentRequest):
|
|
|
|
|
"""
|
|
|
|
|
Re-transcribe a specific segment of audio.
|
|
|
|
|
Extracts the segment with FFmpeg, transcribes it, and returns words
|
|
|
|
|
with timestamps adjusted to the original file timeline.
|
|
|
|
|
"""
|
|
|
|
|
import subprocess
|
|
|
|
|
import tempfile
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
# Extract the segment to a temp file
|
|
|
|
|
tmp_dir = tempfile.mkdtemp(prefix="talkedit_segment_")
|
|
|
|
|
segment_path = os.path.join(tmp_dir, "segment.wav")
|
|
|
|
|
|
|
|
|
|
cmd = [
|
|
|
|
|
"ffmpeg", "-y",
|
|
|
|
|
"-i", req.file_path,
|
|
|
|
|
"-ss", str(req.start),
|
|
|
|
|
"-to", str(req.end),
|
|
|
|
|
"-vn",
|
|
|
|
|
"-acodec", "pcm_s16le",
|
|
|
|
|
"-ar", "16000",
|
|
|
|
|
"-ac", "1",
|
|
|
|
|
segment_path,
|
|
|
|
|
]
|
|
|
|
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
|
|
|
|
if result.returncode != 0:
|
|
|
|
|
raise RuntimeError(f"Segment extraction failed: {result.stderr[-300:]}")
|
|
|
|
|
|
|
|
|
|
# Transcribe the segment — try GPU first, fall back to CPU
|
|
|
|
|
try:
|
|
|
|
|
segment_result = transcribe_audio(
|
|
|
|
|
file_path=segment_path,
|
|
|
|
|
model_name=req.model,
|
|
|
|
|
use_gpu=True,
|
|
|
|
|
use_cache=False,
|
|
|
|
|
language=req.language,
|
|
|
|
|
)
|
|
|
|
|
except Exception as gpu_err:
|
|
|
|
|
logger.warning(f"GPU transcription failed (%s), falling back to CPU", gpu_err)
|
|
|
|
|
segment_result = transcribe_audio(
|
|
|
|
|
file_path=segment_path,
|
|
|
|
|
model_name=req.model,
|
|
|
|
|
use_gpu=False,
|
|
|
|
|
use_cache=False,
|
|
|
|
|
language=req.language,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Adjust timestamps to be relative to the original file
|
|
|
|
|
offset = req.start
|
|
|
|
|
adjusted_words = []
|
|
|
|
|
for w in segment_result.get("words", []):
|
|
|
|
|
w["start"] = round(w["start"] + offset, 3)
|
|
|
|
|
w["end"] = round(w["end"] + offset, 3)
|
|
|
|
|
adjusted_words.append(w)
|
|
|
|
|
|
|
|
|
|
adjusted_segments = []
|
|
|
|
|
for seg in segment_result.get("segments", []):
|
|
|
|
|
seg["start"] = round(seg["start"] + offset, 3)
|
|
|
|
|
seg["end"] = round(seg["end"] + offset, 3)
|
|
|
|
|
# Also adjust words within each segment
|
|
|
|
|
for w in seg.get("words", []):
|
|
|
|
|
w["start"] = round(w["start"] + offset, 3)
|
|
|
|
|
w["end"] = round(w["end"] + offset, 3)
|
|
|
|
|
adjusted_segments.append(seg)
|
|
|
|
|
|
|
|
|
|
# Cleanup
|
|
|
|
|
try:
|
|
|
|
|
os.remove(segment_path)
|
|
|
|
|
os.rmdir(tmp_dir)
|
|
|
|
|
except OSError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"words": adjusted_words,
|
|
|
|
|
"segments": adjusted_segments,
|
|
|
|
|
"language": segment_result.get("language", "en"),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except FileNotFoundError:
|
|
|
|
|
raise HTTPException(status_code=404, detail=f"File not found: {req.file_path}")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Segment transcription failed: {e}", exc_info=True)
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|