Files
TalkEdit/backend/services/diarization.py

99 lines
2.9 KiB
Python
Raw Normal View History

"""
Speaker diarization service using pyannote.audio.
Refactored from the original repo -- removed Streamlit dependency.
"""
import logging
import os
from pathlib import Path
from typing import Optional
import torch
from utils.gpu_utils import get_optimal_device
logger = logging.getLogger(__name__)
_pipeline_cache = {}
def _get_pipeline(hf_token: str, device: torch.device):
cache_key = str(device)
if cache_key in _pipeline_cache:
return _pipeline_cache[cache_key]
try:
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
use_auth_token=hf_token,
)
if device.type == "cuda":
pipeline = pipeline.to(device)
_pipeline_cache[cache_key] = pipeline
return pipeline
except Exception as e:
logger.error(f"Failed to load diarization pipeline: {e}")
return None
def diarize_and_label(
transcription_result: dict,
audio_path: str,
hf_token: Optional[str] = None,
num_speakers: Optional[int] = None,
use_gpu: bool = True,
) -> dict:
"""
Apply speaker diarization to an existing transcription result.
Adds 'speaker' field to each word and segment.
Returns the mutated transcription_result with speaker labels.
"""
hf_token = hf_token or os.environ.get("HF_TOKEN")
if not hf_token:
logger.warning("No HuggingFace token provided; skipping diarization")
return transcription_result
device = get_optimal_device() if use_gpu else torch.device("cpu")
pipeline = _get_pipeline(hf_token, device)
if pipeline is None:
return transcription_result
audio_path = Path(audio_path)
logger.info(f"Running diarization on {audio_path}")
try:
diarization = pipeline(str(audio_path), num_speakers=num_speakers)
except Exception as e:
logger.error(f"Diarization failed: {e}")
return transcription_result
speaker_map = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
speaker_map.append((turn.start, turn.end, speaker))
def _find_speaker(start: float, end: float) -> str:
best_overlap = 0
best_speaker = "UNKNOWN"
for s_start, s_end, speaker in speaker_map:
overlap_start = max(start, s_start)
overlap_end = min(end, s_end)
overlap = max(0, overlap_end - overlap_start)
if overlap > best_overlap:
best_overlap = overlap
best_speaker = speaker
return best_speaker
for word in transcription_result.get("words", []):
word["speaker"] = _find_speaker(word["start"], word["end"])
for segment in transcription_result.get("segments", []):
segment["speaker"] = _find_speaker(segment["start"], segment["end"])
for w in segment.get("words", []):
w["speaker"] = _find_speaker(w["start"], w["end"])
return transcription_result