Add installation scripts and update documentation for Phase 3 features

This commit is contained in:
Your Name
2025-03-01 20:37:52 -05:00
parent a653ac7f28
commit 7ea098bd05
16 changed files with 3023 additions and 43 deletions

205
utils/cache.py Normal file
View File

@ -0,0 +1,205 @@
"""
Caching utilities for the OBS Recording Transcriber.
Provides functions to cache and retrieve transcription and summarization results.
"""
import json
import hashlib
import os
from pathlib import Path
import logging
import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default cache directory
CACHE_DIR = Path.home() / ".obs_transcriber_cache"
def get_file_hash(file_path):
"""
Generate a hash for a file based on its content and modification time.
Args:
file_path (Path): Path to the file
Returns:
str: Hash string representing the file
"""
file_path = Path(file_path)
if not file_path.exists():
return None
# Get file stats
stats = file_path.stat()
file_size = stats.st_size
mod_time = stats.st_mtime
# Create a hash based on path, size and modification time
# This is faster than hashing the entire file content
hash_input = f"{file_path.absolute()}|{file_size}|{mod_time}"
return hashlib.md5(hash_input.encode()).hexdigest()
def get_cache_path(file_path, model=None, operation=None):
"""
Get the cache file path for a given input file and operation.
Args:
file_path (Path): Path to the original file
model (str, optional): Model used for processing
operation (str, optional): Operation type (e.g., 'transcribe', 'summarize')
Returns:
Path: Path to the cache file
"""
file_path = Path(file_path)
file_hash = get_file_hash(file_path)
if not file_hash:
return None
# Create cache directory if it doesn't exist
cache_dir = CACHE_DIR
cache_dir.mkdir(parents=True, exist_ok=True)
# Create a cache filename based on the hash and optional parameters
cache_name = file_hash
if model:
cache_name += f"_{model}"
if operation:
cache_name += f"_{operation}"
return cache_dir / f"{cache_name}.json"
def save_to_cache(file_path, data, model=None, operation=None):
"""
Save data to cache.
Args:
file_path (Path): Path to the original file
data (dict): Data to cache
model (str, optional): Model used for processing
operation (str, optional): Operation type
Returns:
bool: True if successful, False otherwise
"""
cache_path = get_cache_path(file_path, model, operation)
if not cache_path:
return False
try:
# Add metadata to the cached data
cache_data = {
"original_file": str(Path(file_path).absolute()),
"timestamp": time.time(),
"model": model,
"operation": operation,
"data": data
}
with open(cache_path, 'w', encoding='utf-8') as f:
json.dump(cache_data, f, ensure_ascii=False, indent=2)
logger.info(f"Cached data saved to {cache_path}")
return True
except Exception as e:
logger.error(f"Error saving cache: {e}")
return False
def load_from_cache(file_path, model=None, operation=None, max_age=None):
"""
Load data from cache if available and not expired.
Args:
file_path (Path): Path to the original file
model (str, optional): Model used for processing
operation (str, optional): Operation type
max_age (float, optional): Maximum age of cache in seconds
Returns:
dict or None: Cached data or None if not available
"""
cache_path = get_cache_path(file_path, model, operation)
if not cache_path or not cache_path.exists():
return None
try:
with open(cache_path, 'r', encoding='utf-8') as f:
cache_data = json.load(f)
# Check if cache is expired
if max_age is not None:
cache_time = cache_data.get("timestamp", 0)
if time.time() - cache_time > max_age:
logger.info(f"Cache expired for {file_path}")
return None
logger.info(f"Loaded data from cache: {cache_path}")
return cache_data.get("data")
except Exception as e:
logger.error(f"Error loading cache: {e}")
return None
def clear_cache(max_age=None):
"""
Clear all cache files or only expired ones.
Args:
max_age (float, optional): Maximum age of cache in seconds
Returns:
int: Number of files deleted
"""
if not CACHE_DIR.exists():
return 0
count = 0
for cache_file in CACHE_DIR.glob("*.json"):
try:
if max_age is not None:
# Check if file is expired
with open(cache_file, 'r', encoding='utf-8') as f:
cache_data = json.load(f)
cache_time = cache_data.get("timestamp", 0)
if time.time() - cache_time <= max_age:
continue # Skip non-expired files
# Delete the file
os.remove(cache_file)
count += 1
except Exception as e:
logger.error(f"Error deleting cache file {cache_file}: {e}")
logger.info(f"Cleared {count} cache files")
return count
def get_cache_size():
"""
Get the total size of the cache directory.
Returns:
tuple: (size_bytes, file_count)
"""
if not CACHE_DIR.exists():
return 0, 0
total_size = 0
file_count = 0
for cache_file in CACHE_DIR.glob("*.json"):
try:
total_size += cache_file.stat().st_size
file_count += 1
except Exception:
pass
return total_size, file_count

226
utils/diarization.py Normal file
View File

@ -0,0 +1,226 @@
"""
Speaker diarization utilities for the OBS Recording Transcriber.
Provides functions to identify different speakers in audio recordings.
"""
import logging
import os
import numpy as np
from pathlib import Path
import torch
from pyannote.audio import Pipeline
from pyannote.core import Segment
import whisper
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Try to import GPU utilities, but don't fail if not available
try:
from utils.gpu_utils import get_optimal_device
GPU_UTILS_AVAILABLE = True
except ImportError:
GPU_UTILS_AVAILABLE = False
# Default HuggingFace auth token environment variable
HF_TOKEN_ENV = "HF_TOKEN"
def get_diarization_pipeline(use_gpu=True, hf_token=None):
"""
Initialize the speaker diarization pipeline.
Args:
use_gpu (bool): Whether to use GPU acceleration if available
hf_token (str, optional): HuggingFace API token for accessing the model
Returns:
Pipeline or None: Diarization pipeline if successful, None otherwise
"""
# Check if token is provided or in environment
if hf_token is None:
hf_token = os.environ.get(HF_TOKEN_ENV)
if hf_token is None:
logger.error(f"HuggingFace token not provided. Set {HF_TOKEN_ENV} environment variable or pass token directly.")
return None
try:
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
logger.info(f"Using device: {device} for diarization")
# Initialize the pipeline
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
use_auth_token=hf_token
)
# Move to appropriate device
if device.type == "cuda":
pipeline = pipeline.to(torch.device(device))
return pipeline
except Exception as e:
logger.error(f"Error initializing diarization pipeline: {e}")
return None
def diarize_audio(audio_path, pipeline=None, num_speakers=None, use_gpu=True, hf_token=None):
"""
Perform speaker diarization on an audio file.
Args:
audio_path (Path): Path to the audio file
pipeline (Pipeline, optional): Pre-initialized diarization pipeline
num_speakers (int, optional): Number of speakers (if known)
use_gpu (bool): Whether to use GPU acceleration if available
hf_token (str, optional): HuggingFace API token
Returns:
dict: Dictionary mapping time segments to speaker IDs
"""
audio_path = Path(audio_path)
# Initialize pipeline if not provided
if pipeline is None:
pipeline = get_diarization_pipeline(use_gpu, hf_token)
if pipeline is None:
return None
try:
# Run diarization
logger.info(f"Running speaker diarization on {audio_path}")
diarization = pipeline(audio_path, num_speakers=num_speakers)
# Extract speaker segments
speaker_segments = {}
for turn, _, speaker in diarization.itertracks(yield_label=True):
segment = (turn.start, turn.end)
speaker_segments[segment] = speaker
return speaker_segments
except Exception as e:
logger.error(f"Error during diarization: {e}")
return None
def apply_diarization_to_transcript(transcript_segments, speaker_segments):
"""
Apply speaker diarization results to transcript segments.
Args:
transcript_segments (list): List of transcript segments with timing info
speaker_segments (dict): Dictionary mapping time segments to speaker IDs
Returns:
list: Updated transcript segments with speaker information
"""
if not speaker_segments:
return transcript_segments
# Convert speaker segments to a more usable format
speaker_ranges = [(Segment(start, end), speaker)
for (start, end), speaker in speaker_segments.items()]
# Update transcript segments with speaker information
for segment in transcript_segments:
segment_start = segment['start']
segment_end = segment['end']
segment_range = Segment(segment_start, segment_end)
# Find overlapping speaker segments
overlaps = []
for (spk_range, speaker) in speaker_ranges:
overlap = segment_range.intersect(spk_range)
if overlap:
overlaps.append((overlap.duration, speaker))
# Assign the speaker with the most overlap
if overlaps:
overlaps.sort(reverse=True) # Sort by duration (descending)
segment['speaker'] = overlaps[0][1]
else:
segment['speaker'] = "UNKNOWN"
return transcript_segments
def format_transcript_with_speakers(transcript_segments):
"""
Format transcript with speaker labels.
Args:
transcript_segments (list): List of transcript segments with speaker info
Returns:
str: Formatted transcript with speaker labels
"""
formatted_lines = []
current_speaker = None
for segment in transcript_segments:
speaker = segment.get('speaker', 'UNKNOWN')
text = segment['text'].strip()
# Add speaker label when speaker changes
if speaker != current_speaker:
formatted_lines.append(f"\n[{speaker}]")
current_speaker = speaker
formatted_lines.append(text)
return " ".join(formatted_lines)
def transcribe_with_diarization(audio_path, whisper_model="base", num_speakers=None,
use_gpu=True, hf_token=None):
"""
Transcribe audio with speaker diarization.
Args:
audio_path (Path): Path to the audio file
whisper_model (str): Whisper model size to use
num_speakers (int, optional): Number of speakers (if known)
use_gpu (bool): Whether to use GPU acceleration if available
hf_token (str, optional): HuggingFace API token
Returns:
tuple: (diarized_segments, formatted_transcript)
"""
audio_path = Path(audio_path)
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
try:
# Step 1: Transcribe audio with Whisper
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
result = model.transcribe(str(audio_path))
transcript_segments = result["segments"]
# Step 2: Perform speaker diarization
logger.info("Performing speaker diarization")
pipeline = get_diarization_pipeline(use_gpu, hf_token)
if pipeline is None:
logger.warning("Diarization pipeline not available, returning transcript without speakers")
return transcript_segments, result["text"]
speaker_segments = diarize_audio(audio_path, pipeline, num_speakers, use_gpu)
# Step 3: Apply diarization to transcript
if speaker_segments:
diarized_segments = apply_diarization_to_transcript(transcript_segments, speaker_segments)
formatted_transcript = format_transcript_with_speakers(diarized_segments)
return diarized_segments, formatted_transcript
else:
return transcript_segments, result["text"]
except Exception as e:
logger.error(f"Error in transcribe_with_diarization: {e}")
return None, None

284
utils/export.py Normal file
View File

@ -0,0 +1,284 @@
"""
Subtitle export utilities for the OBS Recording Transcriber.
Supports exporting transcripts to SRT, ASS, and WebVTT subtitle formats.
"""
from pathlib import Path
import re
from datetime import timedelta
import gzip
import zipfile
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def format_timestamp_srt(timestamp_ms):
"""
Format a timestamp in milliseconds to SRT format (HH:MM:SS,mmm).
Args:
timestamp_ms (int): Timestamp in milliseconds
Returns:
str: Formatted timestamp string
"""
hours, remainder = divmod(timestamp_ms, 3600000)
minutes, remainder = divmod(remainder, 60000)
seconds, milliseconds = divmod(remainder, 1000)
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
def format_timestamp_ass(timestamp_ms):
"""
Format a timestamp in milliseconds to ASS format (H:MM:SS.cc).
Args:
timestamp_ms (int): Timestamp in milliseconds
Returns:
str: Formatted timestamp string
"""
hours, remainder = divmod(timestamp_ms, 3600000)
minutes, remainder = divmod(remainder, 60000)
seconds, remainder = divmod(remainder, 1000)
centiseconds = remainder // 10
return f"{int(hours)}:{int(minutes):02d}:{int(seconds):02d}.{int(centiseconds):02d}"
def format_timestamp_vtt(timestamp_ms):
"""
Format a timestamp in milliseconds to WebVTT format (HH:MM:SS.mmm).
Args:
timestamp_ms (int): Timestamp in milliseconds
Returns:
str: Formatted timestamp string
"""
hours, remainder = divmod(timestamp_ms, 3600000)
minutes, remainder = divmod(remainder, 60000)
seconds, milliseconds = divmod(remainder, 1000)
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
def export_to_srt(segments, output_path):
"""
Export transcript segments to SRT format.
Args:
segments (list): List of transcript segments with start, end, and text
output_path (Path): Path to save the SRT file
Returns:
Path: Path to the saved SRT file
"""
with open(output_path, 'w', encoding='utf-8') as f:
for i, segment in enumerate(segments, 1):
start_time = format_timestamp_srt(int(segment['start'] * 1000))
end_time = format_timestamp_srt(int(segment['end'] * 1000))
f.write(f"{i}\n")
f.write(f"{start_time} --> {end_time}\n")
f.write(f"{segment['text'].strip()}\n\n")
return output_path
def export_to_ass(segments, output_path, video_width=1920, video_height=1080, style=None):
"""
Export transcript segments to ASS format with styling.
Args:
segments (list): List of transcript segments with start, end, and text
output_path (Path): Path to save the ASS file
video_width (int): Width of the video in pixels
video_height (int): Height of the video in pixels
style (dict, optional): Custom style parameters
Returns:
Path: Path to the saved ASS file
"""
# Default style
default_style = {
"fontname": "Arial",
"fontsize": "48",
"primary_color": "&H00FFFFFF", # White
"secondary_color": "&H000000FF", # Blue
"outline_color": "&H00000000", # Black
"back_color": "&H80000000", # Semi-transparent black
"bold": "-1", # True
"italic": "0", # False
"alignment": "2", # Bottom center
}
# Apply custom style if provided
if style:
default_style.update(style)
# ASS header template
ass_header = f"""[Script Info]
Title: Transcription
ScriptType: v4.00+
WrapStyle: 0
PlayResX: {video_width}
PlayResY: {video_height}
ScaledBorderAndShadow: yes
[V4+ Styles]
Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
Style: Default,{default_style['fontname']},{default_style['fontsize']},{default_style['primary_color']},{default_style['secondary_color']},{default_style['outline_color']},{default_style['back_color']},{default_style['bold']},{default_style['italic']},0,0,100,100,0,0,1,2,2,{default_style['alignment']},10,10,10,1
[Events]
Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
"""
with open(output_path, 'w', encoding='utf-8') as f:
f.write(ass_header)
for segment in segments:
start_time = format_timestamp_ass(int(segment['start'] * 1000))
end_time = format_timestamp_ass(int(segment['end'] * 1000))
text = segment['text'].strip().replace('\n', '\\N')
f.write(f"Dialogue: 0,{start_time},{end_time},Default,,0,0,0,,{text}\n")
return output_path
def export_to_vtt(segments, output_path):
"""
Export transcript segments to WebVTT format.
Args:
segments (list): List of transcript segments with start, end, and text
output_path (Path): Path to save the WebVTT file
Returns:
Path: Path to the saved WebVTT file
"""
with open(output_path, 'w', encoding='utf-8') as f:
# WebVTT header
f.write("WEBVTT\n\n")
for i, segment in enumerate(segments, 1):
start_time = format_timestamp_vtt(int(segment['start'] * 1000))
end_time = format_timestamp_vtt(int(segment['end'] * 1000))
# Optional cue identifier
f.write(f"{i}\n")
f.write(f"{start_time} --> {end_time}\n")
f.write(f"{segment['text'].strip()}\n\n")
return output_path
def transcript_to_segments(transcript, segment_duration=5.0):
"""
Convert a plain transcript to timed segments for subtitle export.
Used when the original segments are not available.
Args:
transcript (str): Full transcript text
segment_duration (float): Duration of each segment in seconds
Returns:
list: List of segments with start, end, and text
"""
# Split transcript into sentences
sentences = re.split(r'(?<=[.!?])\s+', transcript)
segments = []
current_time = 0.0
for sentence in sentences:
if not sentence.strip():
continue
# Estimate duration based on word count (approx. 2.5 words per second)
word_count = len(sentence.split())
duration = max(2.0, word_count / 2.5)
segments.append({
'start': current_time,
'end': current_time + duration,
'text': sentence
})
current_time += duration
return segments
def compress_file(input_path, compression_type='gzip'):
"""
Compress a file using the specified compression method.
Args:
input_path (Path): Path to the file to compress
compression_type (str): Type of compression ('gzip' or 'zip')
Returns:
Path: Path to the compressed file
"""
input_path = Path(input_path)
if compression_type == 'gzip':
output_path = input_path.with_suffix(input_path.suffix + '.gz')
with open(input_path, 'rb') as f_in:
with gzip.open(output_path, 'wb') as f_out:
f_out.write(f_in.read())
return output_path
elif compression_type == 'zip':
output_path = input_path.with_suffix('.zip')
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
zipf.write(input_path, arcname=input_path.name)
return output_path
else:
logger.warning(f"Unsupported compression type: {compression_type}")
return input_path
def export_transcript(transcript, output_path, format_type='srt', segments=None,
compress=False, compression_type='gzip', style=None):
"""
Export transcript to the specified subtitle format.
Args:
transcript (str): Full transcript text
output_path (Path): Base path for the output file (without extension)
format_type (str): 'srt', 'ass', or 'vtt'
segments (list, optional): List of transcript segments with timing information
compress (bool): Whether to compress the output file
compression_type (str): Type of compression ('gzip' or 'zip')
style (dict, optional): Custom style parameters for ASS format
Returns:
Path: Path to the saved subtitle file
"""
output_path = Path(output_path)
# If segments are not provided, create them from the transcript
if segments is None:
segments = transcript_to_segments(transcript)
if format_type.lower() == 'srt':
output_file = output_path.with_suffix('.srt')
result_path = export_to_srt(segments, output_file)
elif format_type.lower() == 'ass':
output_file = output_path.with_suffix('.ass')
result_path = export_to_ass(segments, output_file, style=style)
elif format_type.lower() == 'vtt':
output_file = output_path.with_suffix('.vtt')
result_path = export_to_vtt(segments, output_file)
else:
raise ValueError(f"Unsupported format type: {format_type}. Use 'srt', 'ass', or 'vtt'.")
# Compress the file if requested
if compress:
result_path = compress_file(result_path, compression_type)
return result_path

202
utils/gpu_utils.py Normal file
View File

@ -0,0 +1,202 @@
"""
GPU utilities for the OBS Recording Transcriber.
Provides functions to detect and configure GPU acceleration.
"""
import logging
import os
import platform
import subprocess
import torch
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def get_gpu_info():
"""
Get information about available GPUs.
Returns:
dict: Information about available GPUs
"""
gpu_info = {
"cuda_available": torch.cuda.is_available(),
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
"cuda_devices": [],
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
}
# Get CUDA device information
if gpu_info["cuda_available"]:
for i in range(gpu_info["cuda_device_count"]):
device_props = torch.cuda.get_device_properties(i)
gpu_info["cuda_devices"].append({
"index": i,
"name": device_props.name,
"total_memory": device_props.total_memory,
"compute_capability": f"{device_props.major}.{device_props.minor}"
})
return gpu_info
def get_optimal_device():
"""
Get the optimal device for computation.
Returns:
torch.device: The optimal device (cuda, mps, or cpu)
"""
if torch.cuda.is_available():
# If multiple GPUs are available, select the one with the most memory
if torch.cuda.device_count() > 1:
max_memory = 0
best_device = 0
for i in range(torch.cuda.device_count()):
device_props = torch.cuda.get_device_properties(i)
if device_props.total_memory > max_memory:
max_memory = device_props.total_memory
best_device = i
return torch.device(f"cuda:{best_device}")
return torch.device("cuda:0")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def set_memory_limits(memory_fraction=0.8):
"""
Set memory limits for GPU usage.
Args:
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
Returns:
bool: True if successful, False otherwise
"""
if not torch.cuda.is_available():
return False
try:
# Import only if CUDA is available
import torch.cuda
# Set memory fraction for each device
for i in range(torch.cuda.device_count()):
torch.cuda.set_per_process_memory_fraction(memory_fraction, i)
return True
except Exception as e:
logger.error(f"Error setting memory limits: {e}")
return False
def optimize_for_inference():
"""
Apply optimizations for inference.
Returns:
bool: True if successful, False otherwise
"""
try:
# Set deterministic algorithms for reproducibility
torch.backends.cudnn.deterministic = True
# Enable cuDNN benchmark mode for optimized performance
torch.backends.cudnn.benchmark = True
# Disable gradient calculation for inference
torch.set_grad_enabled(False)
return True
except Exception as e:
logger.error(f"Error optimizing for inference: {e}")
return False
def get_recommended_batch_size(model_size="base"):
"""
Get recommended batch size based on available GPU memory.
Args:
model_size (str): Size of the model (tiny, base, small, medium, large)
Returns:
int: Recommended batch size
"""
# Default batch sizes for CPU
default_batch_sizes = {
"tiny": 16,
"base": 8,
"small": 4,
"medium": 2,
"large": 1
}
# If CUDA is not available, return default CPU batch size
if not torch.cuda.is_available():
return default_batch_sizes.get(model_size, 1)
# Approximate memory requirements in GB for different model sizes
memory_requirements = {
"tiny": 1,
"base": 2,
"small": 4,
"medium": 8,
"large": 16
}
# Get available GPU memory
device = get_optimal_device()
if device.type == "cuda":
device_idx = device.index
device_props = torch.cuda.get_device_properties(device_idx)
available_memory_gb = device_props.total_memory / (1024 ** 3)
# Calculate batch size based on available memory
model_memory = memory_requirements.get(model_size, 2)
max_batch_size = int(available_memory_gb / model_memory)
# Ensure batch size is at least 1
return max(1, max_batch_size)
# For MPS or other devices, return default
return default_batch_sizes.get(model_size, 1)
def configure_gpu(model_size="base", memory_fraction=0.8):
"""
Configure GPU settings for optimal performance.
Args:
model_size (str): Size of the model (tiny, base, small, medium, large)
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
Returns:
dict: Configuration information
"""
gpu_info = get_gpu_info()
device = get_optimal_device()
# Set memory limits if using CUDA
if device.type == "cuda":
set_memory_limits(memory_fraction)
# Apply inference optimizations
optimize_for_inference()
# Get recommended batch size
batch_size = get_recommended_batch_size(model_size)
config = {
"device": device,
"batch_size": batch_size,
"gpu_info": gpu_info,
"memory_fraction": memory_fraction if device.type == "cuda" else None
}
logger.info(f"GPU configuration: Using {device} with batch size {batch_size}")
return config

325
utils/keyword_extraction.py Normal file
View File

@ -0,0 +1,325 @@
"""
Keyword extraction utilities for the OBS Recording Transcriber.
Provides functions to extract keywords and link them to timestamps.
"""
import logging
import re
import torch
import numpy as np
from pathlib import Path
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import Counter
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Try to import GPU utilities, but don't fail if not available
try:
from utils.gpu_utils import get_optimal_device
GPU_UTILS_AVAILABLE = True
except ImportError:
GPU_UTILS_AVAILABLE = False
# Default models
NER_MODEL = "dslim/bert-base-NER"
def extract_keywords_tfidf(text, max_keywords=10, ngram_range=(1, 2)):
"""
Extract keywords using TF-IDF.
Args:
text (str): Text to extract keywords from
max_keywords (int): Maximum number of keywords to extract
ngram_range (tuple): Range of n-grams to consider
Returns:
list: List of (keyword, score) tuples
"""
try:
# Preprocess text
text = text.lower()
# Remove common stopwords
stopwords = {'a', 'an', 'the', 'and', 'or', 'but', 'if', 'because', 'as', 'what',
'when', 'where', 'how', 'who', 'which', 'this', 'that', 'these', 'those',
'then', 'just', 'so', 'than', 'such', 'both', 'through', 'about', 'for',
'is', 'of', 'while', 'during', 'to', 'from', 'in', 'out', 'on', 'off', 'by'}
# Create sentences for better TF-IDF analysis
sentences = re.split(r'[.!?]', text)
sentences = [s.strip() for s in sentences if s.strip()]
if not sentences:
return []
# Apply TF-IDF
vectorizer = TfidfVectorizer(
max_features=100,
stop_words=stopwords,
ngram_range=ngram_range
)
try:
tfidf_matrix = vectorizer.fit_transform(sentences)
feature_names = vectorizer.get_feature_names_out()
# Calculate average TF-IDF score across all sentences
avg_tfidf = np.mean(tfidf_matrix.toarray(), axis=0)
# Get top keywords
keywords = [(feature_names[i], avg_tfidf[i]) for i in avg_tfidf.argsort()[::-1]]
# Filter out single-character keywords and limit to max_keywords
keywords = [(k, s) for k, s in keywords if len(k) > 1][:max_keywords]
return keywords
except ValueError as e:
logger.warning(f"TF-IDF extraction failed: {e}")
return []
except Exception as e:
logger.error(f"Error extracting keywords with TF-IDF: {e}")
return []
def extract_named_entities(text, model=NER_MODEL, use_gpu=True):
"""
Extract named entities from text.
Args:
text (str): Text to extract entities from
model (str): Model to use for NER
use_gpu (bool): Whether to use GPU acceleration if available
Returns:
list: List of (entity, type) tuples
"""
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
device_arg = 0 if device.type == "cuda" else -1
else:
device_arg = -1
try:
# Initialize the pipeline
ner_pipeline = pipeline("ner", model=model, device=device_arg, aggregation_strategy="simple")
# Split text into manageable chunks if too long
max_length = 512
if len(text) > max_length:
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
else:
chunks = [text]
# Process each chunk
all_entities = []
for chunk in chunks:
entities = ner_pipeline(chunk)
all_entities.extend(entities)
# Extract entity text and type
entity_info = [(entity["word"], entity["entity_group"]) for entity in all_entities]
return entity_info
except Exception as e:
logger.error(f"Error extracting named entities: {e}")
return []
def find_keyword_timestamps(segments, keywords):
"""
Find timestamps for keywords in transcript segments.
Args:
segments (list): List of transcript segments with timing info
keywords (list): List of keywords to find
Returns:
dict: Dictionary mapping keywords to lists of timestamps
"""
keyword_timestamps = {}
# Convert keywords to lowercase for case-insensitive matching
if isinstance(keywords[0], tuple):
# If keywords is a list of (keyword, score) tuples
keywords_lower = [k.lower() for k, _ in keywords]
else:
# If keywords is just a list of keywords
keywords_lower = [k.lower() for k in keywords]
# Process each segment
for segment in segments:
segment_text = segment["text"].lower()
start_time = segment["start"]
end_time = segment["end"]
# Check each keyword
for i, keyword in enumerate(keywords_lower):
if keyword in segment_text:
# Get the original case of the keyword
original_keyword = keywords[i][0] if isinstance(keywords[0], tuple) else keywords[i]
# Initialize the list if this is the first occurrence
if original_keyword not in keyword_timestamps:
keyword_timestamps[original_keyword] = []
# Add the timestamp
keyword_timestamps[original_keyword].append({
"start": start_time,
"end": end_time,
"context": segment["text"]
})
return keyword_timestamps
def extract_keywords_from_transcript(transcript, segments, max_keywords=15, use_gpu=True):
"""
Extract keywords from transcript and link them to timestamps.
Args:
transcript (str): Full transcript text
segments (list): List of transcript segments with timing info
max_keywords (int): Maximum number of keywords to extract
use_gpu (bool): Whether to use GPU acceleration if available
Returns:
tuple: (keyword_timestamps, entities_with_timestamps)
"""
try:
# Extract keywords using TF-IDF
tfidf_keywords = extract_keywords_tfidf(transcript, max_keywords=max_keywords)
# Extract named entities
entities = extract_named_entities(transcript, use_gpu=use_gpu)
# Count entity occurrences and get the most frequent ones
entity_counter = Counter([entity for entity, _ in entities])
top_entities = [(entity, count) for entity, count in entity_counter.most_common(max_keywords)]
# Find timestamps for keywords and entities
keyword_timestamps = find_keyword_timestamps(segments, tfidf_keywords)
entity_timestamps = find_keyword_timestamps(segments, top_entities)
return keyword_timestamps, entity_timestamps
except Exception as e:
logger.error(f"Error extracting keywords from transcript: {e}")
return {}, {}
def generate_keyword_index(keyword_timestamps, entity_timestamps=None):
"""
Generate a keyword index with timestamps.
Args:
keyword_timestamps (dict): Dictionary mapping keywords to timestamp lists
entity_timestamps (dict, optional): Dictionary mapping entities to timestamp lists
Returns:
str: Formatted keyword index
"""
lines = ["# Keyword Index\n"]
# Add keywords section
if keyword_timestamps:
lines.append("## Keywords\n")
for keyword, timestamps in sorted(keyword_timestamps.items()):
if timestamps:
times = [f"{int(ts['start'] // 60):02d}:{int(ts['start'] % 60):02d}" for ts in timestamps]
lines.append(f"- **{keyword}**: {', '.join(times)}\n")
# Add entities section
if entity_timestamps:
lines.append("\n## Named Entities\n")
for entity, timestamps in sorted(entity_timestamps.items()):
if timestamps:
times = [f"{int(ts['start'] // 60):02d}:{int(ts['start'] % 60):02d}" for ts in timestamps]
lines.append(f"- **{entity}**: {', '.join(times)}\n")
return "".join(lines)
def generate_interactive_transcript(segments, keyword_timestamps=None, entity_timestamps=None):
"""
Generate an interactive transcript with keyword highlighting.
Args:
segments (list): List of transcript segments with timing info
keyword_timestamps (dict, optional): Dictionary mapping keywords to timestamp lists
entity_timestamps (dict, optional): Dictionary mapping entities to timestamp lists
Returns:
str: HTML formatted interactive transcript
"""
# Combine keywords and entities
all_keywords = {}
if keyword_timestamps:
all_keywords.update(keyword_timestamps)
if entity_timestamps:
all_keywords.update(entity_timestamps)
# Generate HTML
html = ["<div class='interactive-transcript'>"]
for segment in segments:
start_time = segment["start"]
end_time = segment["end"]
text = segment["text"]
# Format timestamp
timestamp = f"{int(start_time // 60):02d}:{int(start_time % 60):02d}"
# Add speaker if available
speaker = segment.get("speaker", "")
speaker_html = f"<span class='speaker'>[{speaker}]</span> " if speaker else ""
# Highlight keywords in text
highlighted_text = text
for keyword in all_keywords:
# Use regex to match whole words only
pattern = r'\b' + re.escape(keyword) + r'\b'
replacement = f"<span class='keyword' data-keyword='{keyword}'>{keyword}</span>"
highlighted_text = re.sub(pattern, replacement, highlighted_text, flags=re.IGNORECASE)
# Add segment to HTML
html.append(f"<p class='segment' data-start='{start_time}' data-end='{end_time}'>")
html.append(f"<span class='timestamp'>{timestamp}</span> {speaker_html}{highlighted_text}")
html.append("</p>")
html.append("</div>")
return "\n".join(html)
def create_keyword_cloud_data(keyword_timestamps, entity_timestamps=None):
"""
Create data for a keyword cloud visualization.
Args:
keyword_timestamps (dict): Dictionary mapping keywords to timestamp lists
entity_timestamps (dict, optional): Dictionary mapping entities to timestamp lists
Returns:
list: List of (keyword, weight) tuples for visualization
"""
cloud_data = []
# Process keywords
for keyword, timestamps in keyword_timestamps.items():
weight = len(timestamps) # Weight by occurrence count
cloud_data.append((keyword, weight))
# Process entities if provided
if entity_timestamps:
for entity, timestamps in entity_timestamps.items():
weight = len(timestamps) * 1.5 # Give entities slightly higher weight
cloud_data.append((entity, weight))
return cloud_data

155
utils/ollama_integration.py Normal file
View File

@ -0,0 +1,155 @@
"""
Ollama integration for local AI model inference.
Provides functions to use Ollama's API for text summarization.
"""
import requests
import json
import logging
from pathlib import Path
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default Ollama API endpoint
OLLAMA_API_URL = "http://localhost:11434/api"
def check_ollama_available():
"""
Check if Ollama service is available.
Returns:
bool: True if Ollama is available, False otherwise
"""
try:
response = requests.get(f"{OLLAMA_API_URL}/tags", timeout=2)
return response.status_code == 200
except requests.exceptions.RequestException:
return False
def list_available_models():
"""
List available models in Ollama.
Returns:
list: List of available model names
"""
try:
response = requests.get(f"{OLLAMA_API_URL}/tags")
if response.status_code == 200:
models = response.json().get('models', [])
return [model['name'] for model in models]
return []
except requests.exceptions.RequestException as e:
logger.error(f"Error listing Ollama models: {e}")
return []
def summarize_with_ollama(text, model="llama3", max_length=150):
"""
Summarize text using Ollama's local API.
Args:
text (str): Text to summarize
model (str): Ollama model to use
max_length (int): Maximum length of the summary
Returns:
str: Summarized text or None if failed
"""
if not check_ollama_available():
logger.warning("Ollama service is not available")
return None
# Check if the model is available
available_models = list_available_models()
if model not in available_models:
logger.warning(f"Model {model} not available in Ollama. Available models: {available_models}")
return None
# Prepare the prompt for summarization
prompt = f"Summarize the following text in about {max_length} words:\n\n{text}"
try:
# Make the API request
response = requests.post(
f"{OLLAMA_API_URL}/generate",
json={
"model": model,
"prompt": prompt,
"stream": False,
"options": {
"temperature": 0.3,
"top_p": 0.9,
"max_tokens": max_length * 2 # Approximate token count
}
}
)
if response.status_code == 200:
result = response.json()
return result.get('response', '').strip()
else:
logger.error(f"Ollama API error: {response.status_code} - {response.text}")
return None
except requests.exceptions.RequestException as e:
logger.error(f"Error communicating with Ollama: {e}")
return None
def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
"""
Chunk long text and summarize each chunk, then combine the summaries.
Args:
text (str): Text to summarize
model (str): Ollama model to use
chunk_size (int): Maximum size of each chunk in characters
max_length (int): Maximum length of the final summary
Returns:
str: Combined summary or None if failed
"""
if len(text) <= chunk_size:
return summarize_with_ollama(text, model, max_length)
# Split text into chunks
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
if current_length + len(word) + 1 <= chunk_size:
current_chunk.append(word)
current_length += len(word) + 1
else:
chunks.append(' '.join(current_chunk))
current_chunk = [word]
current_length = len(word) + 1
if current_chunk:
chunks.append(' '.join(current_chunk))
# Summarize each chunk
chunk_summaries = []
for i, chunk in enumerate(chunks):
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
summary = summarize_with_ollama(chunk, model, max_length // len(chunks))
if summary:
chunk_summaries.append(summary)
if not chunk_summaries:
return None
# If there's only one chunk summary, return it
if len(chunk_summaries) == 1:
return chunk_summaries[0]
# Otherwise, combine the summaries and summarize again
combined_summary = " ".join(chunk_summaries)
return summarize_with_ollama(combined_summary, model, max_length)

View File

@ -1,22 +1,114 @@
import whisper
from pathlib import Path
from transformers import pipeline, AutoTokenizer
from utils.audio_processing import extract_audio
from utils.summarization import summarize_text
import logging
import torch
# Try to import GPU utilities, but don't fail if not available
try:
from utils.gpu_utils import configure_gpu, get_optimal_device
GPU_UTILS_AVAILABLE = True
except ImportError:
GPU_UTILS_AVAILABLE = False
# Try to import caching utilities, but don't fail if not available
try:
from utils.cache import load_from_cache, save_to_cache
CACHE_AVAILABLE = True
except ImportError:
CACHE_AVAILABLE = False
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
WHISPER_MODEL = "base"
SUMMARIZATION_MODEL = "t5-base"
def transcribe_audio(audio_path: Path):
"""Transcribe audio using Whisper."""
model = whisper.load_model(WHISPER_MODEL)
result = model.transcribe(str(audio_path))
def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cache_max_age=None,
use_gpu=True, memory_fraction=0.8):
"""
Transcribe audio using Whisper and return both segments and full transcript.
Args:
audio_path (Path): Path to the audio or video file
model (str): Whisper model size to use (tiny, base, small, medium, large)
use_cache (bool): Whether to use caching
cache_max_age (float, optional): Maximum age of cache in seconds
use_gpu (bool): Whether to use GPU acceleration if available
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
Returns:
tuple: (segments, transcript) where segments is a list of dicts with timing info
"""
audio_path = Path(audio_path)
# Check cache first if enabled
if use_cache and CACHE_AVAILABLE:
cached_data = load_from_cache(audio_path, model, "transcribe", cache_max_age)
if cached_data:
logger.info(f"Using cached transcription for {audio_path}")
return cached_data.get("segments", []), cached_data.get("transcript", "")
# Extract audio if the input is a video file
if audio_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.mkv']:
audio_path = extract_audio(audio_path)
# Configure GPU if available and requested
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
gpu_config = configure_gpu(model, memory_fraction)
device = gpu_config["device"]
logger.info(f"Using device: {device} for transcription")
# Load the specified Whisper model
logger.info(f"Loading Whisper model: {model}")
whisper_model = whisper.load_model(model, device=device if device.type != "mps" else "cpu")
# Transcribe the audio
logger.info(f"Transcribing audio: {audio_path}")
result = whisper_model.transcribe(str(audio_path))
# Extract the full transcript and segments
transcript = result["text"]
summary = summarize_text(transcript)
return transcript, summary
segments = result["segments"]
# Cache the results if caching is enabled
if use_cache and CACHE_AVAILABLE:
cache_data = {
"transcript": transcript,
"segments": segments
}
save_to_cache(audio_path, cache_data, model, "transcribe")
return segments, transcript
def summarize_text(text):
"""Summarize text using a pre-trained T5 transformer model with chunking."""
summarization_pipeline = pipeline("summarization", model=SUMMARIZATION_MODEL)
tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL)
def summarize_text(text, model=SUMMARIZATION_MODEL, use_gpu=True, memory_fraction=0.8):
"""
Summarize text using a pre-trained transformer model with chunking.
Args:
text (str): Text to summarize
model (str): Model to use for summarization
use_gpu (bool): Whether to use GPU acceleration if available
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
Returns:
str: Summarized text
"""
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
logger.info(f"Using device: {device} for summarization")
# Initialize the pipeline with the specified device
device_arg = -1 if device.type == "cpu" else 0 # -1 for CPU, 0 for GPU
summarization_pipeline = pipeline("summarization", model=model, device=device_arg)
tokenizer = AutoTokenizer.from_pretrained(model)
max_tokens = 512
@ -24,20 +116,57 @@ def summarize_text(text):
num_tokens = len(tokens['input_ids'][0])
if num_tokens > max_tokens:
chunks = chunk_text(text, max_tokens)
chunks = chunk_text(text, max_tokens, tokenizer)
summaries = []
for chunk in chunks:
summary_output = summarization_pipeline("summarize: " + chunk, max_length=150, min_length=30, do_sample=False)
for i, chunk in enumerate(chunks):
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
summary_output = summarization_pipeline(
"summarize: " + chunk,
max_length=150,
min_length=30,
do_sample=False
)
summaries.append(summary_output[0]['summary_text'])
overall_summary = " ".join(summaries)
# If the combined summary is still long, summarize it again
if len(summaries) > 1:
logger.info("Generating final summary from chunk summaries")
combined_text = " ".join(summaries)
overall_summary = summarization_pipeline(
"summarize: " + combined_text,
max_length=150,
min_length=30,
do_sample=False
)[0]['summary_text']
else:
overall_summary = summarization_pipeline("summarize: " + text, max_length=150, min_length=30, do_sample=False)[0]['summary_text']
overall_summary = summarization_pipeline(
"summarize: " + text,
max_length=150,
min_length=30,
do_sample=False
)[0]['summary_text']
return overall_summary
def chunk_text(text, max_tokens):
"""Splits the text into a list of chunks based on token limits."""
tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL)
def chunk_text(text, max_tokens, tokenizer=None):
"""
Splits the text into a list of chunks based on token limits.
Args:
text (str): Text to chunk
max_tokens (int): Maximum tokens per chunk
tokenizer (AutoTokenizer, optional): Tokenizer to use
Returns:
list: List of text chunks
"""
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL)
words = text.split()
chunks = []

283
utils/translation.py Normal file
View File

@ -0,0 +1,283 @@
"""
Translation utilities for the OBS Recording Transcriber.
Provides functions for language detection and translation.
"""
import logging
import torch
from pathlib import Path
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, M2M100ForConditionalGeneration
import whisper
import iso639
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Try to import GPU utilities, but don't fail if not available
try:
from utils.gpu_utils import get_optimal_device
GPU_UTILS_AVAILABLE = True
except ImportError:
GPU_UTILS_AVAILABLE = False
# Default models
TRANSLATION_MODEL = "facebook/m2m100_418M"
LANGUAGE_DETECTION_MODEL = "papluca/xlm-roberta-base-language-detection"
# ISO language code mapping
def get_language_name(code):
"""
Get the language name from ISO code.
Args:
code (str): ISO language code
Returns:
str: Language name or original code if not found
"""
try:
return iso639.languages.get(part1=code).name
except (KeyError, AttributeError):
try:
return iso639.languages.get(part2b=code).name
except (KeyError, AttributeError):
return code
def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True):
"""
Detect the language of a text.
Args:
text (str): Text to detect language for
model (str): Model to use for language detection
use_gpu (bool): Whether to use GPU acceleration if available
Returns:
tuple: (language_code, confidence)
"""
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
device_arg = 0 if device.type == "cuda" else -1
else:
device_arg = -1
try:
# Initialize the pipeline
classifier = pipeline("text-classification", model=model, device=device_arg)
# Truncate text if too long
max_length = 512
if len(text) > max_length:
text = text[:max_length]
# Detect language
result = classifier(text)[0]
language_code = result["label"]
confidence = result["score"]
return language_code, confidence
except Exception as e:
logger.error(f"Error detecting language: {e}")
return None, 0.0
def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_MODEL, use_gpu=True):
"""
Translate text from source language to target language.
Args:
text (str): Text to translate
source_lang (str, optional): Source language code (auto-detect if None)
target_lang (str): Target language code
model (str): Model to use for translation
use_gpu (bool): Whether to use GPU acceleration if available
Returns:
str: Translated text
"""
# Auto-detect source language if not provided
if source_lang is None:
detected_lang, confidence = detect_language(text, use_gpu=use_gpu)
if detected_lang and confidence > 0.5:
source_lang = detected_lang
logger.info(f"Detected language: {get_language_name(source_lang)} ({source_lang}) with confidence {confidence:.2f}")
else:
logger.warning("Could not reliably detect language, defaulting to English")
source_lang = "en"
# Skip translation if source and target are the same
if source_lang == target_lang:
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
return text
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
try:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model)
model = M2M100ForConditionalGeneration.from_pretrained(model)
# Move model to device
model = model.to(device)
# Prepare for translation
tokenizer.src_lang = source_lang
# Split text into manageable chunks if too long
max_length = 512
if len(text) > max_length:
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
else:
chunks = [text]
# Translate each chunk
translated_chunks = []
for chunk in chunks:
encoded = tokenizer(chunk, return_tensors="pt").to(device)
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=tokenizer.get_lang_id(target_lang),
max_length=max_length
)
translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
translated_chunks.append(translated_chunk)
# Combine translated chunks
translated_text = " ".join(translated_chunks)
return translated_text
except Exception as e:
logger.error(f"Error translating text: {e}")
return text
def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=True):
"""
Translate transcript segments.
Args:
segments (list): List of transcript segments
source_lang (str, optional): Source language code (auto-detect if None)
target_lang (str): Target language code
use_gpu (bool): Whether to use GPU acceleration if available
Returns:
list: Translated segments
"""
if not segments:
return []
# Auto-detect source language from combined text if not provided
if source_lang is None:
combined_text = " ".join([segment["text"] for segment in segments])
detected_lang, _ = detect_language(combined_text, use_gpu=use_gpu)
source_lang = detected_lang if detected_lang else "en"
# Skip translation if source and target are the same
if source_lang == target_lang:
return segments
try:
# Initialize translation pipeline
translated_segments = []
# Translate each segment
for segment in segments:
translated_text = translate_text(
segment["text"],
source_lang=source_lang,
target_lang=target_lang,
use_gpu=use_gpu
)
# Create a new segment with translated text
translated_segment = segment.copy()
translated_segment["text"] = translated_text
translated_segment["original_text"] = segment["text"]
translated_segment["source_lang"] = source_lang
translated_segment["target_lang"] = target_lang
translated_segments.append(translated_segment)
return translated_segments
except Exception as e:
logger.error(f"Error translating segments: {e}")
return segments
def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en",
use_gpu=True, detect_source=True):
"""
Transcribe audio and translate to target language.
Args:
audio_path (Path): Path to the audio file
whisper_model (str): Whisper model size to use
target_lang (str): Target language code
use_gpu (bool): Whether to use GPU acceleration if available
detect_source (bool): Whether to auto-detect source language
Returns:
tuple: (original_segments, translated_segments, original_transcript, translated_transcript)
"""
audio_path = Path(audio_path)
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
try:
# Step 1: Transcribe audio with Whisper
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
# Use Whisper's built-in language detection if requested
if detect_source:
# First, detect language with Whisper
audio = whisper.load_audio(str(audio_path))
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(device if device.type != "mps" else "cpu")
_, probs = model.detect_language(mel)
source_lang = max(probs, key=probs.get)
logger.info(f"Whisper detected language: {get_language_name(source_lang)} ({source_lang})")
# Transcribe with detected language
result = model.transcribe(str(audio_path), language=source_lang)
else:
# Transcribe without language specification
result = model.transcribe(str(audio_path))
source_lang = result.get("language", "en")
original_segments = result["segments"]
original_transcript = result["text"]
# Step 2: Translate if needed
if source_lang != target_lang:
logger.info(f"Translating from {source_lang} to {target_lang}")
translated_segments = translate_segments(
original_segments,
source_lang=source_lang,
target_lang=target_lang,
use_gpu=use_gpu
)
# Create full translated transcript
translated_transcript = " ".join([segment["text"] for segment in translated_segments])
else:
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
translated_segments = original_segments
translated_transcript = original_transcript
return original_segments, translated_segments, original_transcript, translated_transcript
except Exception as e:
logger.error(f"Error in transcribe_and_translate: {e}")
return None, None, None, None