feat: Add streaming Ollama support, model caching, and UI improvements

- Add streaming summarization via Ollama API (stream_summarize_with_ollama)

- Cache ML models with @st.cache_resource (diarization, NER, translation, Whisper)

- Add temp file cleanup for extracted audio

- Add system capabilities detection (FFmpeg, GPU info)

- Add get_video_duration utility

- Improve validation with FFmpeg check

- Rewrite app.py with streaming support and UI enhancements

- Clean up redundant comments and unused imports across all utils
This commit is contained in:
Your Name
2026-02-18 10:26:09 -05:00
parent ce398ae1d4
commit 70c5d32413
10 changed files with 998 additions and 707 deletions

View File

@ -1,19 +1,55 @@
from pathlib import Path
import tempfile
import os
import logging
# moviepy 2.x removed moviepy.editor; import directly from moviepy
try:
from moviepy import AudioFileClip
except ImportError:
# Fallback for moviepy 1.x
from moviepy.editor import AudioFileClip
logger = logging.getLogger(__name__)
_temp_audio_files = []
def extract_audio(video_path: Path):
"""Extract audio from a video file."""
"""Extract audio from a video file into a temp directory for automatic cleanup."""
try:
audio = AudioFileClip(str(video_path))
audio_path = video_path.parent / f"{video_path.stem}_audio.wav"
temp_dir = tempfile.mkdtemp(prefix="videotranscriber_")
audio_path = Path(temp_dir) / f"{video_path.stem}_audio.wav"
audio.write_audiofile(str(audio_path), verbose=False, logger=None)
audio.close()
_temp_audio_files.append(str(audio_path))
return audio_path
except Exception as e:
raise RuntimeError(f"Audio extraction failed: {e}")
def cleanup_temp_audio():
"""Remove all temporary audio files created during processing."""
cleaned = 0
for fpath in _temp_audio_files:
try:
if os.path.exists(fpath):
os.remove(fpath)
parent = os.path.dirname(fpath)
if os.path.isdir(parent) and not os.listdir(parent):
os.rmdir(parent)
cleaned += 1
except Exception as e:
logger.warning(f"Could not remove temp file {fpath}: {e}")
_temp_audio_files.clear()
return cleaned
def get_video_duration(video_path: Path):
"""Get duration of a video/audio file in seconds."""
try:
clip = AudioFileClip(str(video_path))
duration = clip.duration
clip.close()
return duration
except Exception:
return None

View File

@ -1,5 +1,5 @@
"""
Speaker diarization utilities for the OBS Recording Transcriber.
Speaker diarization utilities for the Video Transcriber.
Provides functions to identify different speakers in audio recordings.
"""
@ -11,22 +11,34 @@ import torch
from pyannote.audio import Pipeline
from pyannote.core import Segment
import whisper
import streamlit as st
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Try to import GPU utilities, but don't fail if not available
try:
from utils.gpu_utils import get_optimal_device
GPU_UTILS_AVAILABLE = True
except ImportError:
GPU_UTILS_AVAILABLE = False
# Default HuggingFace auth token environment variable
HF_TOKEN_ENV = "HF_TOKEN"
@st.cache_resource
def _load_diarization_pipeline(hf_token, device_str):
"""Load and cache the speaker diarization pipeline."""
logger.info(f"Loading diarization pipeline on {device_str}")
pipe = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
use_auth_token=hf_token
)
device = torch.device(device_str)
if device.type == "cuda":
pipe = pipe.to(device)
return pipe
def get_diarization_pipeline(use_gpu=True, hf_token=None):
"""
Initialize the speaker diarization pipeline.
@ -38,7 +50,6 @@ def get_diarization_pipeline(use_gpu=True, hf_token=None):
Returns:
Pipeline or None: Diarization pipeline if successful, None otherwise
"""
# Check if token is provided or in environment
if hf_token is None:
hf_token = os.environ.get(HF_TOKEN_ENV)
if hf_token is None:
@ -46,23 +57,12 @@ def get_diarization_pipeline(use_gpu=True, hf_token=None):
return None
try:
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
logger.info(f"Using device: {device} for diarization")
# Initialize the pipeline
pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
use_auth_token=hf_token
)
# Move to appropriate device
if device.type == "cuda":
pipeline = pipeline.to(torch.device(device))
return pipeline
return _load_diarization_pipeline(hf_token, str(device))
except Exception as e:
logger.error(f"Error initializing diarization pipeline: {e}")
return None
@ -198,9 +198,9 @@ def transcribe_with_diarization(audio_path, whisper_model="base", num_speakers=N
device = get_optimal_device()
try:
# Step 1: Transcribe audio with Whisper
from utils.transcription import _load_whisper_model
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
model = _load_whisper_model(whisper_model, str(device))
result = model.transcribe(str(audio_path))
transcript_segments = result["segments"]

View File

@ -1,12 +1,9 @@
"""
GPU utilities for the OBS Recording Transcriber.
GPU utilities for the Video Transcriber.
Provides functions to detect and configure GPU acceleration.
"""
import logging
import os
import platform
import subprocess
import torch
# Configure logging
@ -68,8 +65,6 @@ def get_optimal_device():
def set_memory_limits(memory_fraction=0.8):
global torch
import torch
"""
Set memory limits for GPU usage.

View File

@ -1,5 +1,5 @@
"""
Keyword extraction utilities for the OBS Recording Transcriber.
Keyword extraction utilities for the Video Transcriber.
Provides functions to extract keywords and link them to timestamps.
"""
@ -8,25 +8,30 @@ import re
import torch
import numpy as np
from pathlib import Path
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from collections import Counter
import streamlit as st
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Try to import GPU utilities, but don't fail if not available
try:
from utils.gpu_utils import get_optimal_device
GPU_UTILS_AVAILABLE = True
except ImportError:
GPU_UTILS_AVAILABLE = False
# Default models
NER_MODEL = "dslim/bert-base-NER"
@st.cache_resource
def _load_ner_pipeline(model_name, device_int):
"""Load and cache the NER pipeline."""
logger.info(f"Loading NER model: {model_name}")
return pipeline("ner", model=model_name, device=device_int, aggregation_strategy="simple")
def extract_keywords_tfidf(text, max_keywords=10, ngram_range=(1, 2)):
"""
Extract keywords using TF-IDF.
@ -107,8 +112,7 @@ def extract_named_entities(text, model=NER_MODEL, use_gpu=True):
device_arg = -1
try:
# Initialize the pipeline
ner_pipeline = pipeline("ner", model=model, device=device_arg, aggregation_strategy="simple")
ner_pipeline = _load_ner_pipeline(model, device_arg)
# Split text into manageable chunks if too long
max_length = 512

View File

@ -1,6 +1,6 @@
"""
Ollama integration for local AI model inference.
Provides functions to use Ollama's API for text summarization.
Provides functions to use Ollama's API for text summarization with streaming support.
"""
import requests
@ -9,21 +9,14 @@ import logging
from pathlib import Path
import os
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Default Ollama API endpoint - configurable via environment variable
OLLAMA_API_URL = os.environ.get("OLLAMA_API_URL", "http://localhost:11434/api")
def check_ollama_available():
"""
Check if Ollama service is available.
Returns:
bool: True if Ollama is available, False otherwise
"""
"""Check if Ollama service is available."""
try:
response = requests.get(f"{OLLAMA_API_URL}/tags", timeout=2)
return response.status_code == 200
@ -32,12 +25,7 @@ def check_ollama_available():
def list_available_models():
"""
List available models in Ollama.
Returns:
list: List of available model names
"""
"""List available models in Ollama."""
try:
response = requests.get(f"{OLLAMA_API_URL}/tags")
if response.status_code == 200:
@ -50,32 +38,14 @@ def list_available_models():
def summarize_with_ollama(text, model="llama3", max_length=150):
"""
Summarize text using Ollama's local API.
Args:
text (str): Text to summarize
model (str): Ollama model to use
max_length (int): Maximum length of the summary
Returns:
str: Summarized text or None if failed
"""
"""Summarize text using Ollama's local API (non-streaming)."""
if not check_ollama_available():
logger.warning("Ollama service is not available")
return None
# Check if the model is available
available_models = list_available_models()
if model not in available_models:
logger.warning(f"Model {model} not available in Ollama. Available models: {available_models}")
return None
# Prepare the prompt for summarization
prompt = f"Summarize the following text in about {max_length} words:\n\n{text}"
try:
# Make the API request
response = requests.post(
f"{OLLAMA_API_URL}/generate",
json={
@ -85,7 +55,7 @@ def summarize_with_ollama(text, model="llama3", max_length=150):
"options": {
"temperature": 0.3,
"top_p": 0.9,
"max_tokens": max_length * 2 # Approximate token count
"max_tokens": max_length * 2
}
}
)
@ -101,23 +71,55 @@ def summarize_with_ollama(text, model="llama3", max_length=150):
return None
def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
def stream_summarize_with_ollama(text, model="llama3", max_length=150):
"""
Chunk long text and summarize each chunk, then combine the summaries.
Summarize text using Ollama with streaming. Yields tokens as they arrive.
Args:
text (str): Text to summarize
model (str): Ollama model to use
chunk_size (int): Maximum size of each chunk in characters
max_length (int): Maximum length of the final summary
Returns:
str: Combined summary or None if failed
Yields:
str: Individual response tokens
"""
if not check_ollama_available():
logger.warning("Ollama service is not available")
return
prompt = f"Summarize the following text in about {max_length} words:\n\n{text}"
try:
response = requests.post(
f"{OLLAMA_API_URL}/generate",
json={
"model": model,
"prompt": prompt,
"stream": True,
"options": {
"temperature": 0.3,
"top_p": 0.9,
"max_tokens": max_length * 2
}
},
stream=True
)
if response.status_code == 200:
for line in response.iter_lines():
if line:
data = json.loads(line)
token = data.get('response', '')
if token:
yield token
if data.get('done', False):
break
else:
logger.error(f"Ollama API error: {response.status_code}")
except requests.exceptions.RequestException as e:
logger.error(f"Error communicating with Ollama: {e}")
def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
"""Chunk long text and summarize each chunk, then combine."""
if len(text) <= chunk_size:
return summarize_with_ollama(text, model, max_length)
# Split text into chunks
words = text.split()
chunks = []
current_chunk = []
@ -135,7 +137,6 @@ def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
if current_chunk:
chunks.append(' '.join(current_chunk))
# Summarize each chunk
chunk_summaries = []
for i, chunk in enumerate(chunks):
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
@ -146,10 +147,55 @@ def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
if not chunk_summaries:
return None
# If there's only one chunk summary, return it
if len(chunk_summaries) == 1:
return chunk_summaries[0]
# Otherwise, combine the summaries and summarize again
combined_summary = " ".join(chunk_summaries)
return summarize_with_ollama(combined_summary, model, max_length)
return summarize_with_ollama(combined_summary, model, max_length)
def stream_chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
"""
Chunk and summarize with streaming on the final summary.
Returns non-streaming chunk summaries, then streams the final combination.
Yields:
str: Tokens from the final summary
"""
if len(text) <= chunk_size:
yield from stream_summarize_with_ollama(text, model, max_length)
return
words = text.split()
chunks = []
current_chunk = []
current_length = 0
for word in words:
if current_length + len(word) + 1 <= chunk_size:
current_chunk.append(word)
current_length += len(word) + 1
else:
chunks.append(' '.join(current_chunk))
current_chunk = [word]
current_length = len(word) + 1
if current_chunk:
chunks.append(' '.join(current_chunk))
chunk_summaries = []
for i, chunk in enumerate(chunks):
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
summary = summarize_with_ollama(chunk, model, max_length // len(chunks))
if summary:
chunk_summaries.append(summary)
if not chunk_summaries:
return
if len(chunk_summaries) == 1:
yield chunk_summaries[0]
return
combined_summary = " ".join(chunk_summaries)
yield from stream_summarize_with_ollama(combined_summary, model, max_length)

View File

@ -1,45 +1,49 @@
from transformers import pipeline, AutoTokenizer
import torch
import logging
import streamlit as st
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SUMMARY_MODEL = "Falconsai/text_summarization"
@st.cache_resource
def _load_summarizer(device_int):
"""Load and cache the summarization pipeline."""
logger.info(f"Loading summarization model on device {device_int}")
return pipeline("summarization", model=SUMMARY_MODEL, device=device_int)
@st.cache_resource
def _load_summary_tokenizer():
"""Load and cache the summarization tokenizer."""
return AutoTokenizer.from_pretrained(SUMMARY_MODEL)
def chunk_text(text, max_tokens, tokenizer):
"""
Splits the text into a list of chunks based on token limits.
Args:
text (str): Text to chunk
max_tokens (int): Maximum tokens per chunk
tokenizer (AutoTokenizer): Tokenizer to use
Returns:
list: List of text chunks
Splits text into chunks by tokenizing once, then splitting by token windows.
Much faster than the per-word tokenization approach.
"""
words = text.split()
all_ids = tokenizer(text, return_tensors='pt', truncation=False)['input_ids'][0]
content_ids = all_ids[1:-1] # strip BOS/EOS
usable_max = max_tokens - 2 # leave room for special tokens
chunks = []
current_chunk = []
current_length = 0
for word in words:
hypothetical_length = current_length + len(tokenizer(word, return_tensors='pt')['input_ids'][0]) - 2
if hypothetical_length <= max_tokens:
current_chunk.append(word)
current_length = hypothetical_length
else:
chunks.append(' '.join(current_chunk))
current_chunk = [word]
current_length = len(tokenizer(word, return_tensors='pt')['input_ids'][0]) - 2
if current_chunk:
chunks.append(' '.join(current_chunk))
for i in range(0, len(content_ids), usable_max):
chunk_ids = content_ids[i : i + usable_max]
decoded = tokenizer.decode(chunk_ids, skip_special_tokens=True).strip()
if decoded:
chunks.append(decoded)
if not chunks:
chunks.append(text)
return chunks
def summarize_text(text, use_gpu=True, memory_fraction=0.8):
"""
Summarize text using a Hugging Face pipeline with chunking support.
@ -52,21 +56,17 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8):
Returns:
str: Summarized text
"""
# Determine device
device = -1 # Default to CPU
device = -1
if use_gpu and torch.cuda.is_available():
device = 0 # Use first GPU
if torch.cuda.is_available():
torch.cuda.set_per_process_memory_fraction(memory_fraction)
device = 0
torch.cuda.set_per_process_memory_fraction(memory_fraction)
logger.info(f"Using device {device} for summarization")
try:
# Initialize the pipeline and tokenizer
summarizer = pipeline("summarization", model=SUMMARY_MODEL, device=device)
tokenizer = AutoTokenizer.from_pretrained(SUMMARY_MODEL)
summarizer = _load_summarizer(device)
tokenizer = _load_summary_tokenizer()
# Check if text needs to be chunked
max_tokens = 512
tokens = tokenizer(text, return_tensors='pt')
num_tokens = len(tokens['input_ids'][0])
@ -85,7 +85,6 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8):
)
summaries.append(summary_output[0]['summary_text'])
# If multiple chunks, summarize the combined summaries
if len(summaries) > 1:
logger.info("Generating final summary from chunk summaries")
combined_text = " ".join(summaries)
@ -106,7 +105,6 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8):
except Exception as e:
logger.error(f"Error during summarization: {e}")
# Fallback to CPU if GPU fails
if device != -1:
logger.info("Falling back to CPU")
return summarize_text(text, use_gpu=False, memory_fraction=memory_fraction)

View File

@ -1,31 +1,36 @@
import whisper
from pathlib import Path
from transformers import pipeline, AutoTokenizer
from utils.audio_processing import extract_audio
from utils.summarization import summarize_text
import logging
import torch
import streamlit as st
# Try to import GPU utilities, but don't fail if not available
try:
from utils.gpu_utils import configure_gpu, get_optimal_device
GPU_UTILS_AVAILABLE = True
except ImportError:
GPU_UTILS_AVAILABLE = False
# Try to import caching utilities, but don't fail if not available
try:
from utils.cache import load_from_cache, save_to_cache
CACHE_AVAILABLE = True
except ImportError:
CACHE_AVAILABLE = False
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
WHISPER_MODEL = "base"
@st.cache_resource
def _load_whisper_model(model_name, device_str):
"""Load and cache a Whisper model. Cached across reruns."""
logger.info(f"Loading Whisper model: {model_name} on {device_str}")
device = torch.device(device_str)
return whisper.load_model(model_name, device=device if device.type != "mps" else "cpu")
def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cache_max_age=None,
use_gpu=True, memory_fraction=0.8):
"""
@ -44,38 +49,30 @@ def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cach
"""
audio_path = Path(audio_path)
# Check cache first if enabled
if use_cache and CACHE_AVAILABLE:
cached_data = load_from_cache(audio_path, model, "transcribe", cache_max_age)
if cached_data:
logger.info(f"Using cached transcription for {audio_path}")
return cached_data.get("segments", []), cached_data.get("transcript", "")
# Extract audio if the input is a video file (M4A is already audio)
video_extensions = ['.mp4', '.avi', '.mov', '.mkv']
if audio_path.suffix.lower() in video_extensions:
audio_path = extract_audio(audio_path)
# Configure GPU if available and requested
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
gpu_config = configure_gpu(model, memory_fraction)
device = gpu_config["device"]
logger.info(f"Using device: {device} for transcription")
# Load the specified Whisper model
logger.info(f"Loading Whisper model: {model}")
whisper_model = whisper.load_model(model, device=device if device.type != "mps" else "cpu")
whisper_model = _load_whisper_model(model, str(device))
# Transcribe the audio
logger.info(f"Transcribing audio: {audio_path}")
result = whisper_model.transcribe(str(audio_path))
# Extract the full transcript and segments
transcript = result["text"]
segments = result["segments"]
# Cache the results if caching is enabled
if use_cache and CACHE_AVAILABLE:
cache_data = {
"transcript": transcript,

View File

@ -1,41 +1,49 @@
"""
Translation utilities for the OBS Recording Transcriber.
Translation utilities for the Video Transcriber.
Provides functions for language detection and translation.
"""
import logging
import torch
from pathlib import Path
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, M2M100ForConditionalGeneration
from transformers import pipeline, AutoTokenizer, M2M100ForConditionalGeneration
import whisper
import iso639
import streamlit as st
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Try to import GPU utilities, but don't fail if not available
try:
from utils.gpu_utils import get_optimal_device
GPU_UTILS_AVAILABLE = True
except ImportError:
GPU_UTILS_AVAILABLE = False
# Default models
TRANSLATION_MODEL = "facebook/m2m100_418M"
LANGUAGE_DETECTION_MODEL = "papluca/xlm-roberta-base-language-detection"
# ISO language code mapping
@st.cache_resource
def _load_language_detector(model_name, device_int):
"""Load and cache the language detection pipeline."""
logger.info(f"Loading language detection model: {model_name}")
return pipeline("text-classification", model=model_name, device=device_int)
@st.cache_resource
def _load_translation_model(model_name, device_str):
"""Load and cache the M2M100 translation model and tokenizer."""
logger.info(f"Loading translation model: {model_name} on {device_str}")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = M2M100ForConditionalGeneration.from_pretrained(model_name)
device = torch.device(device_str)
model = model.to(device)
return model, tokenizer
def get_language_name(code):
"""
Get the language name from ISO code.
Args:
code (str): ISO language code
Returns:
str: Language name or original code if not found
"""
"""Get the language name from ISO code."""
try:
return iso639.languages.get(part1=code).name
except (KeyError, AttributeError):
@ -57,7 +65,6 @@ def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True):
Returns:
tuple: (language_code, confidence)
"""
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
@ -66,25 +73,43 @@ def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True):
device_arg = -1
try:
# Initialize the pipeline
classifier = pipeline("text-classification", model=model, device=device_arg)
classifier = _load_language_detector(model, device_arg)
# Truncate text if too long
max_length = 512
if len(text) > max_length:
text = text[:max_length]
# Detect language
result = classifier(text)[0]
language_code = result["label"]
confidence = result["score"]
return language_code, confidence
return result["label"], result["score"]
except Exception as e:
logger.error(f"Error detecting language: {e}")
return None, 0.0
def _translate_text_with_model(text, source_lang, target_lang, trans_model, tokenizer, device):
"""Translate text using a pre-loaded model and tokenizer."""
tokenizer.src_lang = source_lang
max_length = 512
if len(text) > max_length:
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
else:
chunks = [text]
translated_chunks = []
for chunk in chunks:
encoded = tokenizer(chunk, return_tensors="pt").to(device)
generated_tokens = trans_model.generate(
**encoded,
forced_bos_token_id=tokenizer.get_lang_id(target_lang),
max_length=max_length
)
translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
translated_chunks.append(translated_chunk)
return " ".join(translated_chunks)
def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_MODEL, use_gpu=True):
"""
Translate text from source language to target language.
@ -99,7 +124,6 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M
Returns:
str: Translated text
"""
# Auto-detect source language if not provided
if source_lang is None:
detected_lang, confidence = detect_language(text, use_gpu=use_gpu)
if detected_lang and confidence > 0.5:
@ -109,50 +133,17 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M
logger.warning("Could not reliably detect language, defaulting to English")
source_lang = "en"
# Skip translation if source and target are the same
if source_lang == target_lang:
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
return text
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
try:
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model)
model = M2M100ForConditionalGeneration.from_pretrained(model)
# Move model to device
model = model.to(device)
# Prepare for translation
tokenizer.src_lang = source_lang
# Split text into manageable chunks if too long
max_length = 512
if len(text) > max_length:
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
else:
chunks = [text]
# Translate each chunk
translated_chunks = []
for chunk in chunks:
encoded = tokenizer(chunk, return_tensors="pt").to(device)
generated_tokens = model.generate(
**encoded,
forced_bos_token_id=tokenizer.get_lang_id(target_lang),
max_length=max_length
)
translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
translated_chunks.append(translated_chunk)
# Combine translated chunks
translated_text = " ".join(translated_chunks)
return translated_text
trans_model, tokenizer = _load_translation_model(model, str(device))
return _translate_text_with_model(text, source_lang, target_lang, trans_model, tokenizer, device)
except Exception as e:
logger.error(f"Error translating text: {e}")
return text
@ -160,7 +151,7 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M
def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=True):
"""
Translate transcript segments.
Translate transcript segments. Loads the model once and reuses for all segments.
Args:
segments (list): List of transcript segments
@ -174,36 +165,32 @@ def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=Tru
if not segments:
return []
# Auto-detect source language from combined text if not provided
if source_lang is None:
combined_text = " ".join([segment["text"] for segment in segments])
detected_lang, _ = detect_language(combined_text, use_gpu=use_gpu)
source_lang = detected_lang if detected_lang else "en"
# Skip translation if source and target are the same
if source_lang == target_lang:
return segments
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
try:
# Initialize translation pipeline
translated_segments = []
trans_model, tokenizer = _load_translation_model(TRANSLATION_MODEL, str(device))
# Translate each segment
translated_segments = []
for segment in segments:
translated_text = translate_text(
segment["text"],
source_lang=source_lang,
target_lang=target_lang,
use_gpu=use_gpu
translated_text = _translate_text_with_model(
segment["text"], source_lang, target_lang, trans_model, tokenizer, device
)
# Create a new segment with translated text
translated_segment = segment.copy()
translated_segment["text"] = translated_text
translated_segment["original_text"] = segment["text"]
translated_segment["source_lang"] = source_lang
translated_segment["target_lang"] = target_lang
translated_segments.append(translated_segment)
return translated_segments
@ -227,39 +214,33 @@ def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en",
Returns:
tuple: (original_segments, translated_segments, original_transcript, translated_transcript)
"""
from utils.transcription import _load_whisper_model
audio_path = Path(audio_path)
# Configure device
device = torch.device("cpu")
if use_gpu and GPU_UTILS_AVAILABLE:
device = get_optimal_device()
try:
# Step 1: Transcribe audio with Whisper
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
model = _load_whisper_model(whisper_model, str(device))
# Use Whisper's built-in language detection if requested
if detect_source:
# First, detect language with Whisper
audio = whisper.load_audio(str(audio_path))
audio = whisper.pad_or_trim(audio)
mel = whisper.log_mel_spectrogram(audio).to(device if device.type != "mps" else "cpu")
_, probs = model.detect_language(mel)
source_lang = max(probs, key=probs.get)
logger.info(f"Whisper detected language: {get_language_name(source_lang)} ({source_lang})")
# Transcribe with detected language
result = model.transcribe(str(audio_path), language=source_lang)
else:
# Transcribe without language specification
result = model.transcribe(str(audio_path))
source_lang = result.get("language", "en")
original_segments = result["segments"]
original_transcript = result["text"]
# Step 2: Translate if needed
if source_lang != target_lang:
logger.info(f"Translating from {source_lang} to {target_lang}")
translated_segments = translate_segments(
@ -268,8 +249,6 @@ def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en",
target_lang=target_lang,
use_gpu=use_gpu
)
# Create full translated transcript
translated_transcript = " ".join([segment["text"] for segment in translated_segments])
else:
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
@ -280,4 +259,4 @@ def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en",
except Exception as e:
logger.error(f"Error in transcribe_and_translate: {e}")
return None, None, None, None
return None, None, None, None

View File

@ -1,8 +1,38 @@
from pathlib import Path
import shutil
import logging
def validate_environment(obs_path: Path):
logger = logging.getLogger(__name__)
def validate_environment(obs_path: Path = None):
"""Validate environment and prerequisites."""
errors = []
if not obs_path.exists():
errors.append(f"OBS directory not found: {obs_path}")
if obs_path and not obs_path.exists():
errors.append(f"Directory not found: {obs_path}")
if not shutil.which("ffmpeg"):
errors.append("FFmpeg is not installed or not in PATH. Install it from https://ffmpeg.org/download.html")
return errors
def get_system_capabilities():
"""Return a dict of detected system capabilities for display."""
import torch
caps = {
"ffmpeg": shutil.which("ffmpeg") is not None,
"cuda": torch.cuda.is_available(),
"mps": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
"gpu_name": None,
"gpu_memory": None,
}
if caps["cuda"] and torch.cuda.device_count() > 0:
props = torch.cuda.get_device_properties(0)
caps["gpu_name"] = props.name
caps["gpu_memory"] = props.total_memory
return caps