""" Speaker diarization utilities for the Video Transcriber. Provides functions to identify different speakers in audio recordings. """ import logging import os import numpy as np from pathlib import Path import torch from pyannote.audio import Pipeline from pyannote.core import Segment import whisper import streamlit as st logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) try: from utils.gpu_utils import get_optimal_device GPU_UTILS_AVAILABLE = True except ImportError: GPU_UTILS_AVAILABLE = False HF_TOKEN_ENV = "HF_TOKEN" @st.cache_resource def _load_diarization_pipeline(hf_token, device_str): """Load and cache the speaker diarization pipeline.""" logger.info(f"Loading diarization pipeline on {device_str}") pipe = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.0", use_auth_token=hf_token ) device = torch.device(device_str) if device.type == "cuda": pipe = pipe.to(device) return pipe def get_diarization_pipeline(use_gpu=True, hf_token=None): """ Initialize the speaker diarization pipeline. Args: use_gpu (bool): Whether to use GPU acceleration if available hf_token (str, optional): HuggingFace API token for accessing the model Returns: Pipeline or None: Diarization pipeline if successful, None otherwise """ if hf_token is None: hf_token = os.environ.get(HF_TOKEN_ENV) if hf_token is None: logger.error(f"HuggingFace token not provided. Set {HF_TOKEN_ENV} environment variable or pass token directly.") return None try: device = torch.device("cpu") if use_gpu and GPU_UTILS_AVAILABLE: device = get_optimal_device() logger.info(f"Using device: {device} for diarization") return _load_diarization_pipeline(hf_token, str(device)) except Exception as e: logger.error(f"Error initializing diarization pipeline: {e}") return None def diarize_audio(audio_path, pipeline=None, num_speakers=None, use_gpu=True, hf_token=None): """ Perform speaker diarization on an audio file. Args: audio_path (Path): Path to the audio file pipeline (Pipeline, optional): Pre-initialized diarization pipeline num_speakers (int, optional): Number of speakers (if known) use_gpu (bool): Whether to use GPU acceleration if available hf_token (str, optional): HuggingFace API token Returns: dict: Dictionary mapping time segments to speaker IDs """ audio_path = Path(audio_path) # Initialize pipeline if not provided if pipeline is None: pipeline = get_diarization_pipeline(use_gpu, hf_token) if pipeline is None: return None try: # Run diarization logger.info(f"Running speaker diarization on {audio_path}") diarization = pipeline(audio_path, num_speakers=num_speakers) # Extract speaker segments speaker_segments = {} for turn, _, speaker in diarization.itertracks(yield_label=True): segment = (turn.start, turn.end) speaker_segments[segment] = speaker return speaker_segments except Exception as e: logger.error(f"Error during diarization: {e}") return None def apply_diarization_to_transcript(transcript_segments, speaker_segments): """ Apply speaker diarization results to transcript segments. Args: transcript_segments (list): List of transcript segments with timing info speaker_segments (dict): Dictionary mapping time segments to speaker IDs Returns: list: Updated transcript segments with speaker information """ if not speaker_segments: return transcript_segments # Convert speaker segments to a more usable format speaker_ranges = [(Segment(start, end), speaker) for (start, end), speaker in speaker_segments.items()] # Update transcript segments with speaker information for segment in transcript_segments: segment_start = segment['start'] segment_end = segment['end'] segment_range = Segment(segment_start, segment_end) # Find overlapping speaker segments overlaps = [] for (spk_range, speaker) in speaker_ranges: overlap = segment_range.intersect(spk_range) if overlap: overlaps.append((overlap.duration, speaker)) # Assign the speaker with the most overlap if overlaps: overlaps.sort(reverse=True) # Sort by duration (descending) segment['speaker'] = overlaps[0][1] else: segment['speaker'] = "UNKNOWN" return transcript_segments def format_transcript_with_speakers(transcript_segments): """ Format transcript with speaker labels. Args: transcript_segments (list): List of transcript segments with speaker info Returns: str: Formatted transcript with speaker labels """ formatted_lines = [] current_speaker = None for segment in transcript_segments: speaker = segment.get('speaker', 'UNKNOWN') text = segment['text'].strip() # Add speaker label when speaker changes if speaker != current_speaker: formatted_lines.append(f"\n[{speaker}]") current_speaker = speaker formatted_lines.append(text) return " ".join(formatted_lines) def transcribe_with_diarization(audio_path, whisper_model="base", num_speakers=None, use_gpu=True, hf_token=None): """ Transcribe audio with speaker diarization. Args: audio_path (Path): Path to the audio file whisper_model (str): Whisper model size to use num_speakers (int, optional): Number of speakers (if known) use_gpu (bool): Whether to use GPU acceleration if available hf_token (str, optional): HuggingFace API token Returns: tuple: (diarized_segments, formatted_transcript) """ audio_path = Path(audio_path) # Configure device device = torch.device("cpu") if use_gpu and GPU_UTILS_AVAILABLE: device = get_optimal_device() try: from utils.transcription import _load_whisper_model logger.info(f"Transcribing audio with Whisper model: {whisper_model}") model = _load_whisper_model(whisper_model, str(device)) result = model.transcribe(str(audio_path)) transcript_segments = result["segments"] # Step 2: Perform speaker diarization logger.info("Performing speaker diarization") pipeline = get_diarization_pipeline(use_gpu, hf_token) if pipeline is None: logger.warning("Diarization pipeline not available, returning transcript without speakers") return transcript_segments, result["text"] speaker_segments = diarize_audio(audio_path, pipeline, num_speakers, use_gpu) # Step 3: Apply diarization to transcript if speaker_segments: diarized_segments = apply_diarization_to_transcript(transcript_segments, speaker_segments) formatted_transcript = format_transcript_with_speakers(diarized_segments) return diarized_segments, formatted_transcript else: return transcript_segments, result["text"] except Exception as e: logger.error(f"Error in transcribe_with_diarization: {e}") return None, None