Add installation scripts and update documentation for Phase 3 features
This commit is contained in:
205
utils/cache.py
Normal file
205
utils/cache.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""
|
||||
Caching utilities for the OBS Recording Transcriber.
|
||||
Provides functions to cache and retrieve transcription and summarization results.
|
||||
"""
|
||||
|
||||
import json
|
||||
import hashlib
|
||||
import os
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import time
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default cache directory
|
||||
CACHE_DIR = Path.home() / ".obs_transcriber_cache"
|
||||
|
||||
|
||||
def get_file_hash(file_path):
|
||||
"""
|
||||
Generate a hash for a file based on its content and modification time.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the file
|
||||
|
||||
Returns:
|
||||
str: Hash string representing the file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
# Get file stats
|
||||
stats = file_path.stat()
|
||||
file_size = stats.st_size
|
||||
mod_time = stats.st_mtime
|
||||
|
||||
# Create a hash based on path, size and modification time
|
||||
# This is faster than hashing the entire file content
|
||||
hash_input = f"{file_path.absolute()}|{file_size}|{mod_time}"
|
||||
return hashlib.md5(hash_input.encode()).hexdigest()
|
||||
|
||||
|
||||
def get_cache_path(file_path, model=None, operation=None):
|
||||
"""
|
||||
Get the cache file path for a given input file and operation.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the original file
|
||||
model (str, optional): Model used for processing
|
||||
operation (str, optional): Operation type (e.g., 'transcribe', 'summarize')
|
||||
|
||||
Returns:
|
||||
Path: Path to the cache file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
file_hash = get_file_hash(file_path)
|
||||
|
||||
if not file_hash:
|
||||
return None
|
||||
|
||||
# Create cache directory if it doesn't exist
|
||||
cache_dir = CACHE_DIR
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create a cache filename based on the hash and optional parameters
|
||||
cache_name = file_hash
|
||||
if model:
|
||||
cache_name += f"_{model}"
|
||||
if operation:
|
||||
cache_name += f"_{operation}"
|
||||
|
||||
return cache_dir / f"{cache_name}.json"
|
||||
|
||||
|
||||
def save_to_cache(file_path, data, model=None, operation=None):
|
||||
"""
|
||||
Save data to cache.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the original file
|
||||
data (dict): Data to cache
|
||||
model (str, optional): Model used for processing
|
||||
operation (str, optional): Operation type
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
cache_path = get_cache_path(file_path, model, operation)
|
||||
if not cache_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
# Add metadata to the cached data
|
||||
cache_data = {
|
||||
"original_file": str(Path(file_path).absolute()),
|
||||
"timestamp": time.time(),
|
||||
"model": model,
|
||||
"operation": operation,
|
||||
"data": data
|
||||
}
|
||||
|
||||
with open(cache_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(cache_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"Cached data saved to {cache_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving cache: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def load_from_cache(file_path, model=None, operation=None, max_age=None):
|
||||
"""
|
||||
Load data from cache if available and not expired.
|
||||
|
||||
Args:
|
||||
file_path (Path): Path to the original file
|
||||
model (str, optional): Model used for processing
|
||||
operation (str, optional): Operation type
|
||||
max_age (float, optional): Maximum age of cache in seconds
|
||||
|
||||
Returns:
|
||||
dict or None: Cached data or None if not available
|
||||
"""
|
||||
cache_path = get_cache_path(file_path, model, operation)
|
||||
if not cache_path or not cache_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(cache_path, 'r', encoding='utf-8') as f:
|
||||
cache_data = json.load(f)
|
||||
|
||||
# Check if cache is expired
|
||||
if max_age is not None:
|
||||
cache_time = cache_data.get("timestamp", 0)
|
||||
if time.time() - cache_time > max_age:
|
||||
logger.info(f"Cache expired for {file_path}")
|
||||
return None
|
||||
|
||||
logger.info(f"Loaded data from cache: {cache_path}")
|
||||
return cache_data.get("data")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading cache: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def clear_cache(max_age=None):
|
||||
"""
|
||||
Clear all cache files or only expired ones.
|
||||
|
||||
Args:
|
||||
max_age (float, optional): Maximum age of cache in seconds
|
||||
|
||||
Returns:
|
||||
int: Number of files deleted
|
||||
"""
|
||||
if not CACHE_DIR.exists():
|
||||
return 0
|
||||
|
||||
count = 0
|
||||
for cache_file in CACHE_DIR.glob("*.json"):
|
||||
try:
|
||||
if max_age is not None:
|
||||
# Check if file is expired
|
||||
with open(cache_file, 'r', encoding='utf-8') as f:
|
||||
cache_data = json.load(f)
|
||||
|
||||
cache_time = cache_data.get("timestamp", 0)
|
||||
if time.time() - cache_time <= max_age:
|
||||
continue # Skip non-expired files
|
||||
|
||||
# Delete the file
|
||||
os.remove(cache_file)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting cache file {cache_file}: {e}")
|
||||
|
||||
logger.info(f"Cleared {count} cache files")
|
||||
return count
|
||||
|
||||
|
||||
def get_cache_size():
|
||||
"""
|
||||
Get the total size of the cache directory.
|
||||
|
||||
Returns:
|
||||
tuple: (size_bytes, file_count)
|
||||
"""
|
||||
if not CACHE_DIR.exists():
|
||||
return 0, 0
|
||||
|
||||
total_size = 0
|
||||
file_count = 0
|
||||
|
||||
for cache_file in CACHE_DIR.glob("*.json"):
|
||||
try:
|
||||
total_size += cache_file.stat().st_size
|
||||
file_count += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return total_size, file_count
|
||||
226
utils/diarization.py
Normal file
226
utils/diarization.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""
|
||||
Speaker diarization utilities for the OBS Recording Transcriber.
|
||||
Provides functions to identify different speakers in audio recordings.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
from pyannote.core import Segment
|
||||
import whisper
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import GPU utilities, but don't fail if not available
|
||||
try:
|
||||
from utils.gpu_utils import get_optimal_device
|
||||
GPU_UTILS_AVAILABLE = True
|
||||
except ImportError:
|
||||
GPU_UTILS_AVAILABLE = False
|
||||
|
||||
# Default HuggingFace auth token environment variable
|
||||
HF_TOKEN_ENV = "HF_TOKEN"
|
||||
|
||||
|
||||
def get_diarization_pipeline(use_gpu=True, hf_token=None):
|
||||
"""
|
||||
Initialize the speaker diarization pipeline.
|
||||
|
||||
Args:
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
hf_token (str, optional): HuggingFace API token for accessing the model
|
||||
|
||||
Returns:
|
||||
Pipeline or None: Diarization pipeline if successful, None otherwise
|
||||
"""
|
||||
# Check if token is provided or in environment
|
||||
if hf_token is None:
|
||||
hf_token = os.environ.get(HF_TOKEN_ENV)
|
||||
if hf_token is None:
|
||||
logger.error(f"HuggingFace token not provided. Set {HF_TOKEN_ENV} environment variable or pass token directly.")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
logger.info(f"Using device: {device} for diarization")
|
||||
|
||||
# Initialize the pipeline
|
||||
pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.0",
|
||||
use_auth_token=hf_token
|
||||
)
|
||||
|
||||
# Move to appropriate device
|
||||
if device.type == "cuda":
|
||||
pipeline = pipeline.to(torch.device(device))
|
||||
|
||||
return pipeline
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing diarization pipeline: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def diarize_audio(audio_path, pipeline=None, num_speakers=None, use_gpu=True, hf_token=None):
|
||||
"""
|
||||
Perform speaker diarization on an audio file.
|
||||
|
||||
Args:
|
||||
audio_path (Path): Path to the audio file
|
||||
pipeline (Pipeline, optional): Pre-initialized diarization pipeline
|
||||
num_speakers (int, optional): Number of speakers (if known)
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
hf_token (str, optional): HuggingFace API token
|
||||
|
||||
Returns:
|
||||
dict: Dictionary mapping time segments to speaker IDs
|
||||
"""
|
||||
audio_path = Path(audio_path)
|
||||
|
||||
# Initialize pipeline if not provided
|
||||
if pipeline is None:
|
||||
pipeline = get_diarization_pipeline(use_gpu, hf_token)
|
||||
if pipeline is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Run diarization
|
||||
logger.info(f"Running speaker diarization on {audio_path}")
|
||||
diarization = pipeline(audio_path, num_speakers=num_speakers)
|
||||
|
||||
# Extract speaker segments
|
||||
speaker_segments = {}
|
||||
for turn, _, speaker in diarization.itertracks(yield_label=True):
|
||||
segment = (turn.start, turn.end)
|
||||
speaker_segments[segment] = speaker
|
||||
|
||||
return speaker_segments
|
||||
except Exception as e:
|
||||
logger.error(f"Error during diarization: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def apply_diarization_to_transcript(transcript_segments, speaker_segments):
|
||||
"""
|
||||
Apply speaker diarization results to transcript segments.
|
||||
|
||||
Args:
|
||||
transcript_segments (list): List of transcript segments with timing info
|
||||
speaker_segments (dict): Dictionary mapping time segments to speaker IDs
|
||||
|
||||
Returns:
|
||||
list: Updated transcript segments with speaker information
|
||||
"""
|
||||
if not speaker_segments:
|
||||
return transcript_segments
|
||||
|
||||
# Convert speaker segments to a more usable format
|
||||
speaker_ranges = [(Segment(start, end), speaker)
|
||||
for (start, end), speaker in speaker_segments.items()]
|
||||
|
||||
# Update transcript segments with speaker information
|
||||
for segment in transcript_segments:
|
||||
segment_start = segment['start']
|
||||
segment_end = segment['end']
|
||||
segment_range = Segment(segment_start, segment_end)
|
||||
|
||||
# Find overlapping speaker segments
|
||||
overlaps = []
|
||||
for (spk_range, speaker) in speaker_ranges:
|
||||
overlap = segment_range.intersect(spk_range)
|
||||
if overlap:
|
||||
overlaps.append((overlap.duration, speaker))
|
||||
|
||||
# Assign the speaker with the most overlap
|
||||
if overlaps:
|
||||
overlaps.sort(reverse=True) # Sort by duration (descending)
|
||||
segment['speaker'] = overlaps[0][1]
|
||||
else:
|
||||
segment['speaker'] = "UNKNOWN"
|
||||
|
||||
return transcript_segments
|
||||
|
||||
|
||||
def format_transcript_with_speakers(transcript_segments):
|
||||
"""
|
||||
Format transcript with speaker labels.
|
||||
|
||||
Args:
|
||||
transcript_segments (list): List of transcript segments with speaker info
|
||||
|
||||
Returns:
|
||||
str: Formatted transcript with speaker labels
|
||||
"""
|
||||
formatted_lines = []
|
||||
current_speaker = None
|
||||
|
||||
for segment in transcript_segments:
|
||||
speaker = segment.get('speaker', 'UNKNOWN')
|
||||
text = segment['text'].strip()
|
||||
|
||||
# Add speaker label when speaker changes
|
||||
if speaker != current_speaker:
|
||||
formatted_lines.append(f"\n[{speaker}]")
|
||||
current_speaker = speaker
|
||||
|
||||
formatted_lines.append(text)
|
||||
|
||||
return " ".join(formatted_lines)
|
||||
|
||||
|
||||
def transcribe_with_diarization(audio_path, whisper_model="base", num_speakers=None,
|
||||
use_gpu=True, hf_token=None):
|
||||
"""
|
||||
Transcribe audio with speaker diarization.
|
||||
|
||||
Args:
|
||||
audio_path (Path): Path to the audio file
|
||||
whisper_model (str): Whisper model size to use
|
||||
num_speakers (int, optional): Number of speakers (if known)
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
hf_token (str, optional): HuggingFace API token
|
||||
|
||||
Returns:
|
||||
tuple: (diarized_segments, formatted_transcript)
|
||||
"""
|
||||
audio_path = Path(audio_path)
|
||||
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
|
||||
try:
|
||||
# Step 1: Transcribe audio with Whisper
|
||||
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
|
||||
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
|
||||
result = model.transcribe(str(audio_path))
|
||||
transcript_segments = result["segments"]
|
||||
|
||||
# Step 2: Perform speaker diarization
|
||||
logger.info("Performing speaker diarization")
|
||||
pipeline = get_diarization_pipeline(use_gpu, hf_token)
|
||||
if pipeline is None:
|
||||
logger.warning("Diarization pipeline not available, returning transcript without speakers")
|
||||
return transcript_segments, result["text"]
|
||||
|
||||
speaker_segments = diarize_audio(audio_path, pipeline, num_speakers, use_gpu)
|
||||
|
||||
# Step 3: Apply diarization to transcript
|
||||
if speaker_segments:
|
||||
diarized_segments = apply_diarization_to_transcript(transcript_segments, speaker_segments)
|
||||
formatted_transcript = format_transcript_with_speakers(diarized_segments)
|
||||
return diarized_segments, formatted_transcript
|
||||
else:
|
||||
return transcript_segments, result["text"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in transcribe_with_diarization: {e}")
|
||||
return None, None
|
||||
284
utils/export.py
Normal file
284
utils/export.py
Normal file
@ -0,0 +1,284 @@
|
||||
"""
|
||||
Subtitle export utilities for the OBS Recording Transcriber.
|
||||
Supports exporting transcripts to SRT, ASS, and WebVTT subtitle formats.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import re
|
||||
from datetime import timedelta
|
||||
import gzip
|
||||
import zipfile
|
||||
import logging
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_timestamp_srt(timestamp_ms):
|
||||
"""
|
||||
Format a timestamp in milliseconds to SRT format (HH:MM:SS,mmm).
|
||||
|
||||
Args:
|
||||
timestamp_ms (int): Timestamp in milliseconds
|
||||
|
||||
Returns:
|
||||
str: Formatted timestamp string
|
||||
"""
|
||||
hours, remainder = divmod(timestamp_ms, 3600000)
|
||||
minutes, remainder = divmod(remainder, 60000)
|
||||
seconds, milliseconds = divmod(remainder, 1000)
|
||||
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
|
||||
|
||||
|
||||
def format_timestamp_ass(timestamp_ms):
|
||||
"""
|
||||
Format a timestamp in milliseconds to ASS format (H:MM:SS.cc).
|
||||
|
||||
Args:
|
||||
timestamp_ms (int): Timestamp in milliseconds
|
||||
|
||||
Returns:
|
||||
str: Formatted timestamp string
|
||||
"""
|
||||
hours, remainder = divmod(timestamp_ms, 3600000)
|
||||
minutes, remainder = divmod(remainder, 60000)
|
||||
seconds, remainder = divmod(remainder, 1000)
|
||||
centiseconds = remainder // 10
|
||||
return f"{int(hours)}:{int(minutes):02d}:{int(seconds):02d}.{int(centiseconds):02d}"
|
||||
|
||||
|
||||
def format_timestamp_vtt(timestamp_ms):
|
||||
"""
|
||||
Format a timestamp in milliseconds to WebVTT format (HH:MM:SS.mmm).
|
||||
|
||||
Args:
|
||||
timestamp_ms (int): Timestamp in milliseconds
|
||||
|
||||
Returns:
|
||||
str: Formatted timestamp string
|
||||
"""
|
||||
hours, remainder = divmod(timestamp_ms, 3600000)
|
||||
minutes, remainder = divmod(remainder, 60000)
|
||||
seconds, milliseconds = divmod(remainder, 1000)
|
||||
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
|
||||
|
||||
|
||||
def export_to_srt(segments, output_path):
|
||||
"""
|
||||
Export transcript segments to SRT format.
|
||||
|
||||
Args:
|
||||
segments (list): List of transcript segments with start, end, and text
|
||||
output_path (Path): Path to save the SRT file
|
||||
|
||||
Returns:
|
||||
Path: Path to the saved SRT file
|
||||
"""
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for i, segment in enumerate(segments, 1):
|
||||
start_time = format_timestamp_srt(int(segment['start'] * 1000))
|
||||
end_time = format_timestamp_srt(int(segment['end'] * 1000))
|
||||
|
||||
f.write(f"{i}\n")
|
||||
f.write(f"{start_time} --> {end_time}\n")
|
||||
f.write(f"{segment['text'].strip()}\n\n")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def export_to_ass(segments, output_path, video_width=1920, video_height=1080, style=None):
|
||||
"""
|
||||
Export transcript segments to ASS format with styling.
|
||||
|
||||
Args:
|
||||
segments (list): List of transcript segments with start, end, and text
|
||||
output_path (Path): Path to save the ASS file
|
||||
video_width (int): Width of the video in pixels
|
||||
video_height (int): Height of the video in pixels
|
||||
style (dict, optional): Custom style parameters
|
||||
|
||||
Returns:
|
||||
Path: Path to the saved ASS file
|
||||
"""
|
||||
# Default style
|
||||
default_style = {
|
||||
"fontname": "Arial",
|
||||
"fontsize": "48",
|
||||
"primary_color": "&H00FFFFFF", # White
|
||||
"secondary_color": "&H000000FF", # Blue
|
||||
"outline_color": "&H00000000", # Black
|
||||
"back_color": "&H80000000", # Semi-transparent black
|
||||
"bold": "-1", # True
|
||||
"italic": "0", # False
|
||||
"alignment": "2", # Bottom center
|
||||
}
|
||||
|
||||
# Apply custom style if provided
|
||||
if style:
|
||||
default_style.update(style)
|
||||
|
||||
# ASS header template
|
||||
ass_header = f"""[Script Info]
|
||||
Title: Transcription
|
||||
ScriptType: v4.00+
|
||||
WrapStyle: 0
|
||||
PlayResX: {video_width}
|
||||
PlayResY: {video_height}
|
||||
ScaledBorderAndShadow: yes
|
||||
|
||||
[V4+ Styles]
|
||||
Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
|
||||
Style: Default,{default_style['fontname']},{default_style['fontsize']},{default_style['primary_color']},{default_style['secondary_color']},{default_style['outline_color']},{default_style['back_color']},{default_style['bold']},{default_style['italic']},0,0,100,100,0,0,1,2,2,{default_style['alignment']},10,10,10,1
|
||||
|
||||
[Events]
|
||||
Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
|
||||
"""
|
||||
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
f.write(ass_header)
|
||||
|
||||
for segment in segments:
|
||||
start_time = format_timestamp_ass(int(segment['start'] * 1000))
|
||||
end_time = format_timestamp_ass(int(segment['end'] * 1000))
|
||||
text = segment['text'].strip().replace('\n', '\\N')
|
||||
|
||||
f.write(f"Dialogue: 0,{start_time},{end_time},Default,,0,0,0,,{text}\n")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def export_to_vtt(segments, output_path):
|
||||
"""
|
||||
Export transcript segments to WebVTT format.
|
||||
|
||||
Args:
|
||||
segments (list): List of transcript segments with start, end, and text
|
||||
output_path (Path): Path to save the WebVTT file
|
||||
|
||||
Returns:
|
||||
Path: Path to the saved WebVTT file
|
||||
"""
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
# WebVTT header
|
||||
f.write("WEBVTT\n\n")
|
||||
|
||||
for i, segment in enumerate(segments, 1):
|
||||
start_time = format_timestamp_vtt(int(segment['start'] * 1000))
|
||||
end_time = format_timestamp_vtt(int(segment['end'] * 1000))
|
||||
|
||||
# Optional cue identifier
|
||||
f.write(f"{i}\n")
|
||||
f.write(f"{start_time} --> {end_time}\n")
|
||||
f.write(f"{segment['text'].strip()}\n\n")
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def transcript_to_segments(transcript, segment_duration=5.0):
|
||||
"""
|
||||
Convert a plain transcript to timed segments for subtitle export.
|
||||
Used when the original segments are not available.
|
||||
|
||||
Args:
|
||||
transcript (str): Full transcript text
|
||||
segment_duration (float): Duration of each segment in seconds
|
||||
|
||||
Returns:
|
||||
list: List of segments with start, end, and text
|
||||
"""
|
||||
# Split transcript into sentences
|
||||
sentences = re.split(r'(?<=[.!?])\s+', transcript)
|
||||
segments = []
|
||||
|
||||
current_time = 0.0
|
||||
for sentence in sentences:
|
||||
if not sentence.strip():
|
||||
continue
|
||||
|
||||
# Estimate duration based on word count (approx. 2.5 words per second)
|
||||
word_count = len(sentence.split())
|
||||
duration = max(2.0, word_count / 2.5)
|
||||
|
||||
segments.append({
|
||||
'start': current_time,
|
||||
'end': current_time + duration,
|
||||
'text': sentence
|
||||
})
|
||||
|
||||
current_time += duration
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def compress_file(input_path, compression_type='gzip'):
|
||||
"""
|
||||
Compress a file using the specified compression method.
|
||||
|
||||
Args:
|
||||
input_path (Path): Path to the file to compress
|
||||
compression_type (str): Type of compression ('gzip' or 'zip')
|
||||
|
||||
Returns:
|
||||
Path: Path to the compressed file
|
||||
"""
|
||||
input_path = Path(input_path)
|
||||
|
||||
if compression_type == 'gzip':
|
||||
output_path = input_path.with_suffix(input_path.suffix + '.gz')
|
||||
with open(input_path, 'rb') as f_in:
|
||||
with gzip.open(output_path, 'wb') as f_out:
|
||||
f_out.write(f_in.read())
|
||||
return output_path
|
||||
|
||||
elif compression_type == 'zip':
|
||||
output_path = input_path.with_suffix('.zip')
|
||||
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
||||
zipf.write(input_path, arcname=input_path.name)
|
||||
return output_path
|
||||
|
||||
else:
|
||||
logger.warning(f"Unsupported compression type: {compression_type}")
|
||||
return input_path
|
||||
|
||||
|
||||
def export_transcript(transcript, output_path, format_type='srt', segments=None,
|
||||
compress=False, compression_type='gzip', style=None):
|
||||
"""
|
||||
Export transcript to the specified subtitle format.
|
||||
|
||||
Args:
|
||||
transcript (str): Full transcript text
|
||||
output_path (Path): Base path for the output file (without extension)
|
||||
format_type (str): 'srt', 'ass', or 'vtt'
|
||||
segments (list, optional): List of transcript segments with timing information
|
||||
compress (bool): Whether to compress the output file
|
||||
compression_type (str): Type of compression ('gzip' or 'zip')
|
||||
style (dict, optional): Custom style parameters for ASS format
|
||||
|
||||
Returns:
|
||||
Path: Path to the saved subtitle file
|
||||
"""
|
||||
output_path = Path(output_path)
|
||||
|
||||
# If segments are not provided, create them from the transcript
|
||||
if segments is None:
|
||||
segments = transcript_to_segments(transcript)
|
||||
|
||||
if format_type.lower() == 'srt':
|
||||
output_file = output_path.with_suffix('.srt')
|
||||
result_path = export_to_srt(segments, output_file)
|
||||
elif format_type.lower() == 'ass':
|
||||
output_file = output_path.with_suffix('.ass')
|
||||
result_path = export_to_ass(segments, output_file, style=style)
|
||||
elif format_type.lower() == 'vtt':
|
||||
output_file = output_path.with_suffix('.vtt')
|
||||
result_path = export_to_vtt(segments, output_file)
|
||||
else:
|
||||
raise ValueError(f"Unsupported format type: {format_type}. Use 'srt', 'ass', or 'vtt'.")
|
||||
|
||||
# Compress the file if requested
|
||||
if compress:
|
||||
result_path = compress_file(result_path, compression_type)
|
||||
|
||||
return result_path
|
||||
202
utils/gpu_utils.py
Normal file
202
utils/gpu_utils.py
Normal file
@ -0,0 +1,202 @@
|
||||
"""
|
||||
GPU utilities for the OBS Recording Transcriber.
|
||||
Provides functions to detect and configure GPU acceleration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import torch
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_gpu_info():
|
||||
"""
|
||||
Get information about available GPUs.
|
||||
|
||||
Returns:
|
||||
dict: Information about available GPUs
|
||||
"""
|
||||
gpu_info = {
|
||||
"cuda_available": torch.cuda.is_available(),
|
||||
"cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
|
||||
"cuda_devices": [],
|
||||
"mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
}
|
||||
|
||||
# Get CUDA device information
|
||||
if gpu_info["cuda_available"]:
|
||||
for i in range(gpu_info["cuda_device_count"]):
|
||||
device_props = torch.cuda.get_device_properties(i)
|
||||
gpu_info["cuda_devices"].append({
|
||||
"index": i,
|
||||
"name": device_props.name,
|
||||
"total_memory": device_props.total_memory,
|
||||
"compute_capability": f"{device_props.major}.{device_props.minor}"
|
||||
})
|
||||
|
||||
return gpu_info
|
||||
|
||||
|
||||
def get_optimal_device():
|
||||
"""
|
||||
Get the optimal device for computation.
|
||||
|
||||
Returns:
|
||||
torch.device: The optimal device (cuda, mps, or cpu)
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
# If multiple GPUs are available, select the one with the most memory
|
||||
if torch.cuda.device_count() > 1:
|
||||
max_memory = 0
|
||||
best_device = 0
|
||||
for i in range(torch.cuda.device_count()):
|
||||
device_props = torch.cuda.get_device_properties(i)
|
||||
if device_props.total_memory > max_memory:
|
||||
max_memory = device_props.total_memory
|
||||
best_device = i
|
||||
return torch.device(f"cuda:{best_device}")
|
||||
return torch.device("cuda:0")
|
||||
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
|
||||
return torch.device("mps")
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
|
||||
def set_memory_limits(memory_fraction=0.8):
|
||||
"""
|
||||
Set memory limits for GPU usage.
|
||||
|
||||
Args:
|
||||
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
return False
|
||||
|
||||
try:
|
||||
# Import only if CUDA is available
|
||||
import torch.cuda
|
||||
|
||||
# Set memory fraction for each device
|
||||
for i in range(torch.cuda.device_count()):
|
||||
torch.cuda.set_per_process_memory_fraction(memory_fraction, i)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting memory limits: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def optimize_for_inference():
|
||||
"""
|
||||
Apply optimizations for inference.
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Set deterministic algorithms for reproducibility
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
# Enable cuDNN benchmark mode for optimized performance
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
# Disable gradient calculation for inference
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error optimizing for inference: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_recommended_batch_size(model_size="base"):
|
||||
"""
|
||||
Get recommended batch size based on available GPU memory.
|
||||
|
||||
Args:
|
||||
model_size (str): Size of the model (tiny, base, small, medium, large)
|
||||
|
||||
Returns:
|
||||
int: Recommended batch size
|
||||
"""
|
||||
# Default batch sizes for CPU
|
||||
default_batch_sizes = {
|
||||
"tiny": 16,
|
||||
"base": 8,
|
||||
"small": 4,
|
||||
"medium": 2,
|
||||
"large": 1
|
||||
}
|
||||
|
||||
# If CUDA is not available, return default CPU batch size
|
||||
if not torch.cuda.is_available():
|
||||
return default_batch_sizes.get(model_size, 1)
|
||||
|
||||
# Approximate memory requirements in GB for different model sizes
|
||||
memory_requirements = {
|
||||
"tiny": 1,
|
||||
"base": 2,
|
||||
"small": 4,
|
||||
"medium": 8,
|
||||
"large": 16
|
||||
}
|
||||
|
||||
# Get available GPU memory
|
||||
device = get_optimal_device()
|
||||
if device.type == "cuda":
|
||||
device_idx = device.index
|
||||
device_props = torch.cuda.get_device_properties(device_idx)
|
||||
available_memory_gb = device_props.total_memory / (1024 ** 3)
|
||||
|
||||
# Calculate batch size based on available memory
|
||||
model_memory = memory_requirements.get(model_size, 2)
|
||||
max_batch_size = int(available_memory_gb / model_memory)
|
||||
|
||||
# Ensure batch size is at least 1
|
||||
return max(1, max_batch_size)
|
||||
|
||||
# For MPS or other devices, return default
|
||||
return default_batch_sizes.get(model_size, 1)
|
||||
|
||||
|
||||
def configure_gpu(model_size="base", memory_fraction=0.8):
|
||||
"""
|
||||
Configure GPU settings for optimal performance.
|
||||
|
||||
Args:
|
||||
model_size (str): Size of the model (tiny, base, small, medium, large)
|
||||
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
dict: Configuration information
|
||||
"""
|
||||
gpu_info = get_gpu_info()
|
||||
device = get_optimal_device()
|
||||
|
||||
# Set memory limits if using CUDA
|
||||
if device.type == "cuda":
|
||||
set_memory_limits(memory_fraction)
|
||||
|
||||
# Apply inference optimizations
|
||||
optimize_for_inference()
|
||||
|
||||
# Get recommended batch size
|
||||
batch_size = get_recommended_batch_size(model_size)
|
||||
|
||||
config = {
|
||||
"device": device,
|
||||
"batch_size": batch_size,
|
||||
"gpu_info": gpu_info,
|
||||
"memory_fraction": memory_fraction if device.type == "cuda" else None
|
||||
}
|
||||
|
||||
logger.info(f"GPU configuration: Using {device} with batch size {batch_size}")
|
||||
return config
|
||||
325
utils/keyword_extraction.py
Normal file
325
utils/keyword_extraction.py
Normal file
@ -0,0 +1,325 @@
|
||||
"""
|
||||
Keyword extraction utilities for the OBS Recording Transcriber.
|
||||
Provides functions to extract keywords and link them to timestamps.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from collections import Counter
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import GPU utilities, but don't fail if not available
|
||||
try:
|
||||
from utils.gpu_utils import get_optimal_device
|
||||
GPU_UTILS_AVAILABLE = True
|
||||
except ImportError:
|
||||
GPU_UTILS_AVAILABLE = False
|
||||
|
||||
# Default models
|
||||
NER_MODEL = "dslim/bert-base-NER"
|
||||
|
||||
|
||||
def extract_keywords_tfidf(text, max_keywords=10, ngram_range=(1, 2)):
|
||||
"""
|
||||
Extract keywords using TF-IDF.
|
||||
|
||||
Args:
|
||||
text (str): Text to extract keywords from
|
||||
max_keywords (int): Maximum number of keywords to extract
|
||||
ngram_range (tuple): Range of n-grams to consider
|
||||
|
||||
Returns:
|
||||
list: List of (keyword, score) tuples
|
||||
"""
|
||||
try:
|
||||
# Preprocess text
|
||||
text = text.lower()
|
||||
|
||||
# Remove common stopwords
|
||||
stopwords = {'a', 'an', 'the', 'and', 'or', 'but', 'if', 'because', 'as', 'what',
|
||||
'when', 'where', 'how', 'who', 'which', 'this', 'that', 'these', 'those',
|
||||
'then', 'just', 'so', 'than', 'such', 'both', 'through', 'about', 'for',
|
||||
'is', 'of', 'while', 'during', 'to', 'from', 'in', 'out', 'on', 'off', 'by'}
|
||||
|
||||
# Create sentences for better TF-IDF analysis
|
||||
sentences = re.split(r'[.!?]', text)
|
||||
sentences = [s.strip() for s in sentences if s.strip()]
|
||||
|
||||
if not sentences:
|
||||
return []
|
||||
|
||||
# Apply TF-IDF
|
||||
vectorizer = TfidfVectorizer(
|
||||
max_features=100,
|
||||
stop_words=stopwords,
|
||||
ngram_range=ngram_range
|
||||
)
|
||||
|
||||
try:
|
||||
tfidf_matrix = vectorizer.fit_transform(sentences)
|
||||
feature_names = vectorizer.get_feature_names_out()
|
||||
|
||||
# Calculate average TF-IDF score across all sentences
|
||||
avg_tfidf = np.mean(tfidf_matrix.toarray(), axis=0)
|
||||
|
||||
# Get top keywords
|
||||
keywords = [(feature_names[i], avg_tfidf[i]) for i in avg_tfidf.argsort()[::-1]]
|
||||
|
||||
# Filter out single-character keywords and limit to max_keywords
|
||||
keywords = [(k, s) for k, s in keywords if len(k) > 1][:max_keywords]
|
||||
|
||||
return keywords
|
||||
except ValueError as e:
|
||||
logger.warning(f"TF-IDF extraction failed: {e}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting keywords with TF-IDF: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def extract_named_entities(text, model=NER_MODEL, use_gpu=True):
|
||||
"""
|
||||
Extract named entities from text.
|
||||
|
||||
Args:
|
||||
text (str): Text to extract entities from
|
||||
model (str): Model to use for NER
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
|
||||
Returns:
|
||||
list: List of (entity, type) tuples
|
||||
"""
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
device_arg = 0 if device.type == "cuda" else -1
|
||||
else:
|
||||
device_arg = -1
|
||||
|
||||
try:
|
||||
# Initialize the pipeline
|
||||
ner_pipeline = pipeline("ner", model=model, device=device_arg, aggregation_strategy="simple")
|
||||
|
||||
# Split text into manageable chunks if too long
|
||||
max_length = 512
|
||||
if len(text) > max_length:
|
||||
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
|
||||
else:
|
||||
chunks = [text]
|
||||
|
||||
# Process each chunk
|
||||
all_entities = []
|
||||
for chunk in chunks:
|
||||
entities = ner_pipeline(chunk)
|
||||
all_entities.extend(entities)
|
||||
|
||||
# Extract entity text and type
|
||||
entity_info = [(entity["word"], entity["entity_group"]) for entity in all_entities]
|
||||
|
||||
return entity_info
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting named entities: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def find_keyword_timestamps(segments, keywords):
|
||||
"""
|
||||
Find timestamps for keywords in transcript segments.
|
||||
|
||||
Args:
|
||||
segments (list): List of transcript segments with timing info
|
||||
keywords (list): List of keywords to find
|
||||
|
||||
Returns:
|
||||
dict: Dictionary mapping keywords to lists of timestamps
|
||||
"""
|
||||
keyword_timestamps = {}
|
||||
|
||||
# Convert keywords to lowercase for case-insensitive matching
|
||||
if isinstance(keywords[0], tuple):
|
||||
# If keywords is a list of (keyword, score) tuples
|
||||
keywords_lower = [k.lower() for k, _ in keywords]
|
||||
else:
|
||||
# If keywords is just a list of keywords
|
||||
keywords_lower = [k.lower() for k in keywords]
|
||||
|
||||
# Process each segment
|
||||
for segment in segments:
|
||||
segment_text = segment["text"].lower()
|
||||
start_time = segment["start"]
|
||||
end_time = segment["end"]
|
||||
|
||||
# Check each keyword
|
||||
for i, keyword in enumerate(keywords_lower):
|
||||
if keyword in segment_text:
|
||||
# Get the original case of the keyword
|
||||
original_keyword = keywords[i][0] if isinstance(keywords[0], tuple) else keywords[i]
|
||||
|
||||
# Initialize the list if this is the first occurrence
|
||||
if original_keyword not in keyword_timestamps:
|
||||
keyword_timestamps[original_keyword] = []
|
||||
|
||||
# Add the timestamp
|
||||
keyword_timestamps[original_keyword].append({
|
||||
"start": start_time,
|
||||
"end": end_time,
|
||||
"context": segment["text"]
|
||||
})
|
||||
|
||||
return keyword_timestamps
|
||||
|
||||
|
||||
def extract_keywords_from_transcript(transcript, segments, max_keywords=15, use_gpu=True):
|
||||
"""
|
||||
Extract keywords from transcript and link them to timestamps.
|
||||
|
||||
Args:
|
||||
transcript (str): Full transcript text
|
||||
segments (list): List of transcript segments with timing info
|
||||
max_keywords (int): Maximum number of keywords to extract
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
|
||||
Returns:
|
||||
tuple: (keyword_timestamps, entities_with_timestamps)
|
||||
"""
|
||||
try:
|
||||
# Extract keywords using TF-IDF
|
||||
tfidf_keywords = extract_keywords_tfidf(transcript, max_keywords=max_keywords)
|
||||
|
||||
# Extract named entities
|
||||
entities = extract_named_entities(transcript, use_gpu=use_gpu)
|
||||
|
||||
# Count entity occurrences and get the most frequent ones
|
||||
entity_counter = Counter([entity for entity, _ in entities])
|
||||
top_entities = [(entity, count) for entity, count in entity_counter.most_common(max_keywords)]
|
||||
|
||||
# Find timestamps for keywords and entities
|
||||
keyword_timestamps = find_keyword_timestamps(segments, tfidf_keywords)
|
||||
entity_timestamps = find_keyword_timestamps(segments, top_entities)
|
||||
|
||||
return keyword_timestamps, entity_timestamps
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting keywords from transcript: {e}")
|
||||
return {}, {}
|
||||
|
||||
|
||||
def generate_keyword_index(keyword_timestamps, entity_timestamps=None):
|
||||
"""
|
||||
Generate a keyword index with timestamps.
|
||||
|
||||
Args:
|
||||
keyword_timestamps (dict): Dictionary mapping keywords to timestamp lists
|
||||
entity_timestamps (dict, optional): Dictionary mapping entities to timestamp lists
|
||||
|
||||
Returns:
|
||||
str: Formatted keyword index
|
||||
"""
|
||||
lines = ["# Keyword Index\n"]
|
||||
|
||||
# Add keywords section
|
||||
if keyword_timestamps:
|
||||
lines.append("## Keywords\n")
|
||||
for keyword, timestamps in sorted(keyword_timestamps.items()):
|
||||
if timestamps:
|
||||
times = [f"{int(ts['start'] // 60):02d}:{int(ts['start'] % 60):02d}" for ts in timestamps]
|
||||
lines.append(f"- **{keyword}**: {', '.join(times)}\n")
|
||||
|
||||
# Add entities section
|
||||
if entity_timestamps:
|
||||
lines.append("\n## Named Entities\n")
|
||||
for entity, timestamps in sorted(entity_timestamps.items()):
|
||||
if timestamps:
|
||||
times = [f"{int(ts['start'] // 60):02d}:{int(ts['start'] % 60):02d}" for ts in timestamps]
|
||||
lines.append(f"- **{entity}**: {', '.join(times)}\n")
|
||||
|
||||
return "".join(lines)
|
||||
|
||||
|
||||
def generate_interactive_transcript(segments, keyword_timestamps=None, entity_timestamps=None):
|
||||
"""
|
||||
Generate an interactive transcript with keyword highlighting.
|
||||
|
||||
Args:
|
||||
segments (list): List of transcript segments with timing info
|
||||
keyword_timestamps (dict, optional): Dictionary mapping keywords to timestamp lists
|
||||
entity_timestamps (dict, optional): Dictionary mapping entities to timestamp lists
|
||||
|
||||
Returns:
|
||||
str: HTML formatted interactive transcript
|
||||
"""
|
||||
# Combine keywords and entities
|
||||
all_keywords = {}
|
||||
if keyword_timestamps:
|
||||
all_keywords.update(keyword_timestamps)
|
||||
if entity_timestamps:
|
||||
all_keywords.update(entity_timestamps)
|
||||
|
||||
# Generate HTML
|
||||
html = ["<div class='interactive-transcript'>"]
|
||||
|
||||
for segment in segments:
|
||||
start_time = segment["start"]
|
||||
end_time = segment["end"]
|
||||
text = segment["text"]
|
||||
|
||||
# Format timestamp
|
||||
timestamp = f"{int(start_time // 60):02d}:{int(start_time % 60):02d}"
|
||||
|
||||
# Add speaker if available
|
||||
speaker = segment.get("speaker", "")
|
||||
speaker_html = f"<span class='speaker'>[{speaker}]</span> " if speaker else ""
|
||||
|
||||
# Highlight keywords in text
|
||||
highlighted_text = text
|
||||
for keyword in all_keywords:
|
||||
# Use regex to match whole words only
|
||||
pattern = r'\b' + re.escape(keyword) + r'\b'
|
||||
replacement = f"<span class='keyword' data-keyword='{keyword}'>{keyword}</span>"
|
||||
highlighted_text = re.sub(pattern, replacement, highlighted_text, flags=re.IGNORECASE)
|
||||
|
||||
# Add segment to HTML
|
||||
html.append(f"<p class='segment' data-start='{start_time}' data-end='{end_time}'>")
|
||||
html.append(f"<span class='timestamp'>{timestamp}</span> {speaker_html}{highlighted_text}")
|
||||
html.append("</p>")
|
||||
|
||||
html.append("</div>")
|
||||
|
||||
return "\n".join(html)
|
||||
|
||||
|
||||
def create_keyword_cloud_data(keyword_timestamps, entity_timestamps=None):
|
||||
"""
|
||||
Create data for a keyword cloud visualization.
|
||||
|
||||
Args:
|
||||
keyword_timestamps (dict): Dictionary mapping keywords to timestamp lists
|
||||
entity_timestamps (dict, optional): Dictionary mapping entities to timestamp lists
|
||||
|
||||
Returns:
|
||||
list: List of (keyword, weight) tuples for visualization
|
||||
"""
|
||||
cloud_data = []
|
||||
|
||||
# Process keywords
|
||||
for keyword, timestamps in keyword_timestamps.items():
|
||||
weight = len(timestamps) # Weight by occurrence count
|
||||
cloud_data.append((keyword, weight))
|
||||
|
||||
# Process entities if provided
|
||||
if entity_timestamps:
|
||||
for entity, timestamps in entity_timestamps.items():
|
||||
weight = len(timestamps) * 1.5 # Give entities slightly higher weight
|
||||
cloud_data.append((entity, weight))
|
||||
|
||||
return cloud_data
|
||||
155
utils/ollama_integration.py
Normal file
155
utils/ollama_integration.py
Normal file
@ -0,0 +1,155 @@
|
||||
"""
|
||||
Ollama integration for local AI model inference.
|
||||
Provides functions to use Ollama's API for text summarization.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default Ollama API endpoint
|
||||
OLLAMA_API_URL = "http://localhost:11434/api"
|
||||
|
||||
|
||||
def check_ollama_available():
|
||||
"""
|
||||
Check if Ollama service is available.
|
||||
|
||||
Returns:
|
||||
bool: True if Ollama is available, False otherwise
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"{OLLAMA_API_URL}/tags", timeout=2)
|
||||
return response.status_code == 200
|
||||
except requests.exceptions.RequestException:
|
||||
return False
|
||||
|
||||
|
||||
def list_available_models():
|
||||
"""
|
||||
List available models in Ollama.
|
||||
|
||||
Returns:
|
||||
list: List of available model names
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"{OLLAMA_API_URL}/tags")
|
||||
if response.status_code == 200:
|
||||
models = response.json().get('models', [])
|
||||
return [model['name'] for model in models]
|
||||
return []
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Error listing Ollama models: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def summarize_with_ollama(text, model="llama3", max_length=150):
|
||||
"""
|
||||
Summarize text using Ollama's local API.
|
||||
|
||||
Args:
|
||||
text (str): Text to summarize
|
||||
model (str): Ollama model to use
|
||||
max_length (int): Maximum length of the summary
|
||||
|
||||
Returns:
|
||||
str: Summarized text or None if failed
|
||||
"""
|
||||
if not check_ollama_available():
|
||||
logger.warning("Ollama service is not available")
|
||||
return None
|
||||
|
||||
# Check if the model is available
|
||||
available_models = list_available_models()
|
||||
if model not in available_models:
|
||||
logger.warning(f"Model {model} not available in Ollama. Available models: {available_models}")
|
||||
return None
|
||||
|
||||
# Prepare the prompt for summarization
|
||||
prompt = f"Summarize the following text in about {max_length} words:\n\n{text}"
|
||||
|
||||
try:
|
||||
# Make the API request
|
||||
response = requests.post(
|
||||
f"{OLLAMA_API_URL}/generate",
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": False,
|
||||
"options": {
|
||||
"temperature": 0.3,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": max_length * 2 # Approximate token count
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result.get('response', '').strip()
|
||||
else:
|
||||
logger.error(f"Ollama API error: {response.status_code} - {response.text}")
|
||||
return None
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Error communicating with Ollama: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
|
||||
"""
|
||||
Chunk long text and summarize each chunk, then combine the summaries.
|
||||
|
||||
Args:
|
||||
text (str): Text to summarize
|
||||
model (str): Ollama model to use
|
||||
chunk_size (int): Maximum size of each chunk in characters
|
||||
max_length (int): Maximum length of the final summary
|
||||
|
||||
Returns:
|
||||
str: Combined summary or None if failed
|
||||
"""
|
||||
if len(text) <= chunk_size:
|
||||
return summarize_with_ollama(text, model, max_length)
|
||||
|
||||
# Split text into chunks
|
||||
words = text.split()
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
for word in words:
|
||||
if current_length + len(word) + 1 <= chunk_size:
|
||||
current_chunk.append(word)
|
||||
current_length += len(word) + 1
|
||||
else:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
current_chunk = [word]
|
||||
current_length = len(word) + 1
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
|
||||
# Summarize each chunk
|
||||
chunk_summaries = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
|
||||
summary = summarize_with_ollama(chunk, model, max_length // len(chunks))
|
||||
if summary:
|
||||
chunk_summaries.append(summary)
|
||||
|
||||
if not chunk_summaries:
|
||||
return None
|
||||
|
||||
# If there's only one chunk summary, return it
|
||||
if len(chunk_summaries) == 1:
|
||||
return chunk_summaries[0]
|
||||
|
||||
# Otherwise, combine the summaries and summarize again
|
||||
combined_summary = " ".join(chunk_summaries)
|
||||
return summarize_with_ollama(combined_summary, model, max_length)
|
||||
@ -1,22 +1,114 @@
|
||||
import whisper
|
||||
from pathlib import Path
|
||||
from transformers import pipeline, AutoTokenizer
|
||||
from utils.audio_processing import extract_audio
|
||||
from utils.summarization import summarize_text
|
||||
import logging
|
||||
import torch
|
||||
|
||||
# Try to import GPU utilities, but don't fail if not available
|
||||
try:
|
||||
from utils.gpu_utils import configure_gpu, get_optimal_device
|
||||
GPU_UTILS_AVAILABLE = True
|
||||
except ImportError:
|
||||
GPU_UTILS_AVAILABLE = False
|
||||
|
||||
# Try to import caching utilities, but don't fail if not available
|
||||
try:
|
||||
from utils.cache import load_from_cache, save_to_cache
|
||||
CACHE_AVAILABLE = True
|
||||
except ImportError:
|
||||
CACHE_AVAILABLE = False
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WHISPER_MODEL = "base"
|
||||
SUMMARIZATION_MODEL = "t5-base"
|
||||
|
||||
def transcribe_audio(audio_path: Path):
|
||||
"""Transcribe audio using Whisper."""
|
||||
model = whisper.load_model(WHISPER_MODEL)
|
||||
result = model.transcribe(str(audio_path))
|
||||
def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cache_max_age=None,
|
||||
use_gpu=True, memory_fraction=0.8):
|
||||
"""
|
||||
Transcribe audio using Whisper and return both segments and full transcript.
|
||||
|
||||
Args:
|
||||
audio_path (Path): Path to the audio or video file
|
||||
model (str): Whisper model size to use (tiny, base, small, medium, large)
|
||||
use_cache (bool): Whether to use caching
|
||||
cache_max_age (float, optional): Maximum age of cache in seconds
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
tuple: (segments, transcript) where segments is a list of dicts with timing info
|
||||
"""
|
||||
audio_path = Path(audio_path)
|
||||
|
||||
# Check cache first if enabled
|
||||
if use_cache and CACHE_AVAILABLE:
|
||||
cached_data = load_from_cache(audio_path, model, "transcribe", cache_max_age)
|
||||
if cached_data:
|
||||
logger.info(f"Using cached transcription for {audio_path}")
|
||||
return cached_data.get("segments", []), cached_data.get("transcript", "")
|
||||
|
||||
# Extract audio if the input is a video file
|
||||
if audio_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.mkv']:
|
||||
audio_path = extract_audio(audio_path)
|
||||
|
||||
# Configure GPU if available and requested
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
gpu_config = configure_gpu(model, memory_fraction)
|
||||
device = gpu_config["device"]
|
||||
logger.info(f"Using device: {device} for transcription")
|
||||
|
||||
# Load the specified Whisper model
|
||||
logger.info(f"Loading Whisper model: {model}")
|
||||
whisper_model = whisper.load_model(model, device=device if device.type != "mps" else "cpu")
|
||||
|
||||
# Transcribe the audio
|
||||
logger.info(f"Transcribing audio: {audio_path}")
|
||||
result = whisper_model.transcribe(str(audio_path))
|
||||
|
||||
# Extract the full transcript and segments
|
||||
transcript = result["text"]
|
||||
summary = summarize_text(transcript)
|
||||
return transcript, summary
|
||||
segments = result["segments"]
|
||||
|
||||
# Cache the results if caching is enabled
|
||||
if use_cache and CACHE_AVAILABLE:
|
||||
cache_data = {
|
||||
"transcript": transcript,
|
||||
"segments": segments
|
||||
}
|
||||
save_to_cache(audio_path, cache_data, model, "transcribe")
|
||||
|
||||
return segments, transcript
|
||||
|
||||
def summarize_text(text):
|
||||
"""Summarize text using a pre-trained T5 transformer model with chunking."""
|
||||
summarization_pipeline = pipeline("summarization", model=SUMMARIZATION_MODEL)
|
||||
tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL)
|
||||
|
||||
def summarize_text(text, model=SUMMARIZATION_MODEL, use_gpu=True, memory_fraction=0.8):
|
||||
"""
|
||||
Summarize text using a pre-trained transformer model with chunking.
|
||||
|
||||
Args:
|
||||
text (str): Text to summarize
|
||||
model (str): Model to use for summarization
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
str: Summarized text
|
||||
"""
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
logger.info(f"Using device: {device} for summarization")
|
||||
|
||||
# Initialize the pipeline with the specified device
|
||||
device_arg = -1 if device.type == "cpu" else 0 # -1 for CPU, 0 for GPU
|
||||
summarization_pipeline = pipeline("summarization", model=model, device=device_arg)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
max_tokens = 512
|
||||
|
||||
@ -24,20 +116,57 @@ def summarize_text(text):
|
||||
num_tokens = len(tokens['input_ids'][0])
|
||||
|
||||
if num_tokens > max_tokens:
|
||||
chunks = chunk_text(text, max_tokens)
|
||||
chunks = chunk_text(text, max_tokens, tokenizer)
|
||||
summaries = []
|
||||
for chunk in chunks:
|
||||
summary_output = summarization_pipeline("summarize: " + chunk, max_length=150, min_length=30, do_sample=False)
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
|
||||
summary_output = summarization_pipeline(
|
||||
"summarize: " + chunk,
|
||||
max_length=150,
|
||||
min_length=30,
|
||||
do_sample=False
|
||||
)
|
||||
summaries.append(summary_output[0]['summary_text'])
|
||||
|
||||
overall_summary = " ".join(summaries)
|
||||
|
||||
# If the combined summary is still long, summarize it again
|
||||
if len(summaries) > 1:
|
||||
logger.info("Generating final summary from chunk summaries")
|
||||
combined_text = " ".join(summaries)
|
||||
overall_summary = summarization_pipeline(
|
||||
"summarize: " + combined_text,
|
||||
max_length=150,
|
||||
min_length=30,
|
||||
do_sample=False
|
||||
)[0]['summary_text']
|
||||
else:
|
||||
overall_summary = summarization_pipeline("summarize: " + text, max_length=150, min_length=30, do_sample=False)[0]['summary_text']
|
||||
overall_summary = summarization_pipeline(
|
||||
"summarize: " + text,
|
||||
max_length=150,
|
||||
min_length=30,
|
||||
do_sample=False
|
||||
)[0]['summary_text']
|
||||
|
||||
return overall_summary
|
||||
|
||||
def chunk_text(text, max_tokens):
|
||||
"""Splits the text into a list of chunks based on token limits."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL)
|
||||
|
||||
def chunk_text(text, max_tokens, tokenizer=None):
|
||||
"""
|
||||
Splits the text into a list of chunks based on token limits.
|
||||
|
||||
Args:
|
||||
text (str): Text to chunk
|
||||
max_tokens (int): Maximum tokens per chunk
|
||||
tokenizer (AutoTokenizer, optional): Tokenizer to use
|
||||
|
||||
Returns:
|
||||
list: List of text chunks
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL)
|
||||
|
||||
words = text.split()
|
||||
|
||||
chunks = []
|
||||
|
||||
283
utils/translation.py
Normal file
283
utils/translation.py
Normal file
@ -0,0 +1,283 @@
|
||||
"""
|
||||
Translation utilities for the OBS Recording Transcriber.
|
||||
Provides functions for language detection and translation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, M2M100ForConditionalGeneration
|
||||
import whisper
|
||||
import iso639
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import GPU utilities, but don't fail if not available
|
||||
try:
|
||||
from utils.gpu_utils import get_optimal_device
|
||||
GPU_UTILS_AVAILABLE = True
|
||||
except ImportError:
|
||||
GPU_UTILS_AVAILABLE = False
|
||||
|
||||
# Default models
|
||||
TRANSLATION_MODEL = "facebook/m2m100_418M"
|
||||
LANGUAGE_DETECTION_MODEL = "papluca/xlm-roberta-base-language-detection"
|
||||
|
||||
# ISO language code mapping
|
||||
def get_language_name(code):
|
||||
"""
|
||||
Get the language name from ISO code.
|
||||
|
||||
Args:
|
||||
code (str): ISO language code
|
||||
|
||||
Returns:
|
||||
str: Language name or original code if not found
|
||||
"""
|
||||
try:
|
||||
return iso639.languages.get(part1=code).name
|
||||
except (KeyError, AttributeError):
|
||||
try:
|
||||
return iso639.languages.get(part2b=code).name
|
||||
except (KeyError, AttributeError):
|
||||
return code
|
||||
|
||||
|
||||
def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True):
|
||||
"""
|
||||
Detect the language of a text.
|
||||
|
||||
Args:
|
||||
text (str): Text to detect language for
|
||||
model (str): Model to use for language detection
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
|
||||
Returns:
|
||||
tuple: (language_code, confidence)
|
||||
"""
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
device_arg = 0 if device.type == "cuda" else -1
|
||||
else:
|
||||
device_arg = -1
|
||||
|
||||
try:
|
||||
# Initialize the pipeline
|
||||
classifier = pipeline("text-classification", model=model, device=device_arg)
|
||||
|
||||
# Truncate text if too long
|
||||
max_length = 512
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length]
|
||||
|
||||
# Detect language
|
||||
result = classifier(text)[0]
|
||||
language_code = result["label"]
|
||||
confidence = result["score"]
|
||||
|
||||
return language_code, confidence
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting language: {e}")
|
||||
return None, 0.0
|
||||
|
||||
|
||||
def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_MODEL, use_gpu=True):
|
||||
"""
|
||||
Translate text from source language to target language.
|
||||
|
||||
Args:
|
||||
text (str): Text to translate
|
||||
source_lang (str, optional): Source language code (auto-detect if None)
|
||||
target_lang (str): Target language code
|
||||
model (str): Model to use for translation
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
|
||||
Returns:
|
||||
str: Translated text
|
||||
"""
|
||||
# Auto-detect source language if not provided
|
||||
if source_lang is None:
|
||||
detected_lang, confidence = detect_language(text, use_gpu=use_gpu)
|
||||
if detected_lang and confidence > 0.5:
|
||||
source_lang = detected_lang
|
||||
logger.info(f"Detected language: {get_language_name(source_lang)} ({source_lang}) with confidence {confidence:.2f}")
|
||||
else:
|
||||
logger.warning("Could not reliably detect language, defaulting to English")
|
||||
source_lang = "en"
|
||||
|
||||
# Skip translation if source and target are the same
|
||||
if source_lang == target_lang:
|
||||
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
|
||||
return text
|
||||
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
model = M2M100ForConditionalGeneration.from_pretrained(model)
|
||||
|
||||
# Move model to device
|
||||
model = model.to(device)
|
||||
|
||||
# Prepare for translation
|
||||
tokenizer.src_lang = source_lang
|
||||
|
||||
# Split text into manageable chunks if too long
|
||||
max_length = 512
|
||||
if len(text) > max_length:
|
||||
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
|
||||
else:
|
||||
chunks = [text]
|
||||
|
||||
# Translate each chunk
|
||||
translated_chunks = []
|
||||
for chunk in chunks:
|
||||
encoded = tokenizer(chunk, return_tensors="pt").to(device)
|
||||
generated_tokens = model.generate(
|
||||
**encoded,
|
||||
forced_bos_token_id=tokenizer.get_lang_id(target_lang),
|
||||
max_length=max_length
|
||||
)
|
||||
translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
||||
translated_chunks.append(translated_chunk)
|
||||
|
||||
# Combine translated chunks
|
||||
translated_text = " ".join(translated_chunks)
|
||||
|
||||
return translated_text
|
||||
except Exception as e:
|
||||
logger.error(f"Error translating text: {e}")
|
||||
return text
|
||||
|
||||
|
||||
def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=True):
|
||||
"""
|
||||
Translate transcript segments.
|
||||
|
||||
Args:
|
||||
segments (list): List of transcript segments
|
||||
source_lang (str, optional): Source language code (auto-detect if None)
|
||||
target_lang (str): Target language code
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
|
||||
Returns:
|
||||
list: Translated segments
|
||||
"""
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
# Auto-detect source language from combined text if not provided
|
||||
if source_lang is None:
|
||||
combined_text = " ".join([segment["text"] for segment in segments])
|
||||
detected_lang, _ = detect_language(combined_text, use_gpu=use_gpu)
|
||||
source_lang = detected_lang if detected_lang else "en"
|
||||
|
||||
# Skip translation if source and target are the same
|
||||
if source_lang == target_lang:
|
||||
return segments
|
||||
|
||||
try:
|
||||
# Initialize translation pipeline
|
||||
translated_segments = []
|
||||
|
||||
# Translate each segment
|
||||
for segment in segments:
|
||||
translated_text = translate_text(
|
||||
segment["text"],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
use_gpu=use_gpu
|
||||
)
|
||||
|
||||
# Create a new segment with translated text
|
||||
translated_segment = segment.copy()
|
||||
translated_segment["text"] = translated_text
|
||||
translated_segment["original_text"] = segment["text"]
|
||||
translated_segment["source_lang"] = source_lang
|
||||
translated_segment["target_lang"] = target_lang
|
||||
|
||||
translated_segments.append(translated_segment)
|
||||
|
||||
return translated_segments
|
||||
except Exception as e:
|
||||
logger.error(f"Error translating segments: {e}")
|
||||
return segments
|
||||
|
||||
|
||||
def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en",
|
||||
use_gpu=True, detect_source=True):
|
||||
"""
|
||||
Transcribe audio and translate to target language.
|
||||
|
||||
Args:
|
||||
audio_path (Path): Path to the audio file
|
||||
whisper_model (str): Whisper model size to use
|
||||
target_lang (str): Target language code
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
detect_source (bool): Whether to auto-detect source language
|
||||
|
||||
Returns:
|
||||
tuple: (original_segments, translated_segments, original_transcript, translated_transcript)
|
||||
"""
|
||||
audio_path = Path(audio_path)
|
||||
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
|
||||
try:
|
||||
# Step 1: Transcribe audio with Whisper
|
||||
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
|
||||
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
|
||||
|
||||
# Use Whisper's built-in language detection if requested
|
||||
if detect_source:
|
||||
# First, detect language with Whisper
|
||||
audio = whisper.load_audio(str(audio_path))
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
mel = whisper.log_mel_spectrogram(audio).to(device if device.type != "mps" else "cpu")
|
||||
_, probs = model.detect_language(mel)
|
||||
source_lang = max(probs, key=probs.get)
|
||||
logger.info(f"Whisper detected language: {get_language_name(source_lang)} ({source_lang})")
|
||||
|
||||
# Transcribe with detected language
|
||||
result = model.transcribe(str(audio_path), language=source_lang)
|
||||
else:
|
||||
# Transcribe without language specification
|
||||
result = model.transcribe(str(audio_path))
|
||||
source_lang = result.get("language", "en")
|
||||
|
||||
original_segments = result["segments"]
|
||||
original_transcript = result["text"]
|
||||
|
||||
# Step 2: Translate if needed
|
||||
if source_lang != target_lang:
|
||||
logger.info(f"Translating from {source_lang} to {target_lang}")
|
||||
translated_segments = translate_segments(
|
||||
original_segments,
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
use_gpu=use_gpu
|
||||
)
|
||||
|
||||
# Create full translated transcript
|
||||
translated_transcript = " ".join([segment["text"] for segment in translated_segments])
|
||||
else:
|
||||
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
|
||||
translated_segments = original_segments
|
||||
translated_transcript = original_transcript
|
||||
|
||||
return original_segments, translated_segments, original_transcript, translated_transcript
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in transcribe_and_translate: {e}")
|
||||
return None, None, None, None
|
||||
Reference in New Issue
Block a user