feat: Add streaming Ollama support, model caching, and UI improvements
- Add streaming summarization via Ollama API (stream_summarize_with_ollama) - Cache ML models with @st.cache_resource (diarization, NER, translation, Whisper) - Add temp file cleanup for extracted audio - Add system capabilities detection (FFmpeg, GPU info) - Add get_video_duration utility - Improve validation with FFmpeg check - Rewrite app.py with streaming support and UI enhancements - Clean up redundant comments and unused imports across all utils
This commit is contained in:
@ -1,19 +1,55 @@
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import os
|
||||
import logging
|
||||
|
||||
# moviepy 2.x removed moviepy.editor; import directly from moviepy
|
||||
try:
|
||||
from moviepy import AudioFileClip
|
||||
except ImportError:
|
||||
# Fallback for moviepy 1.x
|
||||
from moviepy.editor import AudioFileClip
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_temp_audio_files = []
|
||||
|
||||
|
||||
def extract_audio(video_path: Path):
|
||||
"""Extract audio from a video file."""
|
||||
"""Extract audio from a video file into a temp directory for automatic cleanup."""
|
||||
try:
|
||||
audio = AudioFileClip(str(video_path))
|
||||
audio_path = video_path.parent / f"{video_path.stem}_audio.wav"
|
||||
temp_dir = tempfile.mkdtemp(prefix="videotranscriber_")
|
||||
audio_path = Path(temp_dir) / f"{video_path.stem}_audio.wav"
|
||||
audio.write_audiofile(str(audio_path), verbose=False, logger=None)
|
||||
audio.close()
|
||||
_temp_audio_files.append(str(audio_path))
|
||||
return audio_path
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Audio extraction failed: {e}")
|
||||
|
||||
|
||||
def cleanup_temp_audio():
|
||||
"""Remove all temporary audio files created during processing."""
|
||||
cleaned = 0
|
||||
for fpath in _temp_audio_files:
|
||||
try:
|
||||
if os.path.exists(fpath):
|
||||
os.remove(fpath)
|
||||
parent = os.path.dirname(fpath)
|
||||
if os.path.isdir(parent) and not os.listdir(parent):
|
||||
os.rmdir(parent)
|
||||
cleaned += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove temp file {fpath}: {e}")
|
||||
_temp_audio_files.clear()
|
||||
return cleaned
|
||||
|
||||
|
||||
def get_video_duration(video_path: Path):
|
||||
"""Get duration of a video/audio file in seconds."""
|
||||
try:
|
||||
clip = AudioFileClip(str(video_path))
|
||||
duration = clip.duration
|
||||
clip.close()
|
||||
return duration
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Speaker diarization utilities for the OBS Recording Transcriber.
|
||||
Speaker diarization utilities for the Video Transcriber.
|
||||
Provides functions to identify different speakers in audio recordings.
|
||||
"""
|
||||
|
||||
@ -11,22 +11,34 @@ import torch
|
||||
from pyannote.audio import Pipeline
|
||||
from pyannote.core import Segment
|
||||
import whisper
|
||||
import streamlit as st
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import GPU utilities, but don't fail if not available
|
||||
try:
|
||||
from utils.gpu_utils import get_optimal_device
|
||||
GPU_UTILS_AVAILABLE = True
|
||||
except ImportError:
|
||||
GPU_UTILS_AVAILABLE = False
|
||||
|
||||
# Default HuggingFace auth token environment variable
|
||||
HF_TOKEN_ENV = "HF_TOKEN"
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def _load_diarization_pipeline(hf_token, device_str):
|
||||
"""Load and cache the speaker diarization pipeline."""
|
||||
logger.info(f"Loading diarization pipeline on {device_str}")
|
||||
pipe = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.0",
|
||||
use_auth_token=hf_token
|
||||
)
|
||||
device = torch.device(device_str)
|
||||
if device.type == "cuda":
|
||||
pipe = pipe.to(device)
|
||||
return pipe
|
||||
|
||||
|
||||
def get_diarization_pipeline(use_gpu=True, hf_token=None):
|
||||
"""
|
||||
Initialize the speaker diarization pipeline.
|
||||
@ -38,7 +50,6 @@ def get_diarization_pipeline(use_gpu=True, hf_token=None):
|
||||
Returns:
|
||||
Pipeline or None: Diarization pipeline if successful, None otherwise
|
||||
"""
|
||||
# Check if token is provided or in environment
|
||||
if hf_token is None:
|
||||
hf_token = os.environ.get(HF_TOKEN_ENV)
|
||||
if hf_token is None:
|
||||
@ -46,23 +57,12 @@ def get_diarization_pipeline(use_gpu=True, hf_token=None):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
logger.info(f"Using device: {device} for diarization")
|
||||
|
||||
# Initialize the pipeline
|
||||
pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.0",
|
||||
use_auth_token=hf_token
|
||||
)
|
||||
|
||||
# Move to appropriate device
|
||||
if device.type == "cuda":
|
||||
pipeline = pipeline.to(torch.device(device))
|
||||
|
||||
return pipeline
|
||||
return _load_diarization_pipeline(hf_token, str(device))
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing diarization pipeline: {e}")
|
||||
return None
|
||||
@ -198,9 +198,9 @@ def transcribe_with_diarization(audio_path, whisper_model="base", num_speakers=N
|
||||
device = get_optimal_device()
|
||||
|
||||
try:
|
||||
# Step 1: Transcribe audio with Whisper
|
||||
from utils.transcription import _load_whisper_model
|
||||
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
|
||||
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
|
||||
model = _load_whisper_model(whisper_model, str(device))
|
||||
result = model.transcribe(str(audio_path))
|
||||
transcript_segments = result["segments"]
|
||||
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
"""
|
||||
GPU utilities for the OBS Recording Transcriber.
|
||||
GPU utilities for the Video Transcriber.
|
||||
Provides functions to detect and configure GPU acceleration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import torch
|
||||
|
||||
# Configure logging
|
||||
@ -68,8 +65,6 @@ def get_optimal_device():
|
||||
|
||||
|
||||
def set_memory_limits(memory_fraction=0.8):
|
||||
global torch
|
||||
import torch
|
||||
"""
|
||||
Set memory limits for GPU usage.
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
"""
|
||||
Keyword extraction utilities for the OBS Recording Transcriber.
|
||||
Keyword extraction utilities for the Video Transcriber.
|
||||
Provides functions to extract keywords and link them to timestamps.
|
||||
"""
|
||||
|
||||
@ -8,25 +8,30 @@ import re
|
||||
import torch
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
|
||||
from transformers import pipeline
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from collections import Counter
|
||||
import streamlit as st
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import GPU utilities, but don't fail if not available
|
||||
try:
|
||||
from utils.gpu_utils import get_optimal_device
|
||||
GPU_UTILS_AVAILABLE = True
|
||||
except ImportError:
|
||||
GPU_UTILS_AVAILABLE = False
|
||||
|
||||
# Default models
|
||||
NER_MODEL = "dslim/bert-base-NER"
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def _load_ner_pipeline(model_name, device_int):
|
||||
"""Load and cache the NER pipeline."""
|
||||
logger.info(f"Loading NER model: {model_name}")
|
||||
return pipeline("ner", model=model_name, device=device_int, aggregation_strategy="simple")
|
||||
|
||||
|
||||
def extract_keywords_tfidf(text, max_keywords=10, ngram_range=(1, 2)):
|
||||
"""
|
||||
Extract keywords using TF-IDF.
|
||||
@ -107,8 +112,7 @@ def extract_named_entities(text, model=NER_MODEL, use_gpu=True):
|
||||
device_arg = -1
|
||||
|
||||
try:
|
||||
# Initialize the pipeline
|
||||
ner_pipeline = pipeline("ner", model=model, device=device_arg, aggregation_strategy="simple")
|
||||
ner_pipeline = _load_ner_pipeline(model, device_arg)
|
||||
|
||||
# Split text into manageable chunks if too long
|
||||
max_length = 512
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
"""
|
||||
Ollama integration for local AI model inference.
|
||||
Provides functions to use Ollama's API for text summarization.
|
||||
Provides functions to use Ollama's API for text summarization with streaming support.
|
||||
"""
|
||||
|
||||
import requests
|
||||
@ -9,21 +9,14 @@ import logging
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default Ollama API endpoint - configurable via environment variable
|
||||
OLLAMA_API_URL = os.environ.get("OLLAMA_API_URL", "http://localhost:11434/api")
|
||||
|
||||
|
||||
def check_ollama_available():
|
||||
"""
|
||||
Check if Ollama service is available.
|
||||
|
||||
Returns:
|
||||
bool: True if Ollama is available, False otherwise
|
||||
"""
|
||||
"""Check if Ollama service is available."""
|
||||
try:
|
||||
response = requests.get(f"{OLLAMA_API_URL}/tags", timeout=2)
|
||||
return response.status_code == 200
|
||||
@ -32,12 +25,7 @@ def check_ollama_available():
|
||||
|
||||
|
||||
def list_available_models():
|
||||
"""
|
||||
List available models in Ollama.
|
||||
|
||||
Returns:
|
||||
list: List of available model names
|
||||
"""
|
||||
"""List available models in Ollama."""
|
||||
try:
|
||||
response = requests.get(f"{OLLAMA_API_URL}/tags")
|
||||
if response.status_code == 200:
|
||||
@ -50,32 +38,14 @@ def list_available_models():
|
||||
|
||||
|
||||
def summarize_with_ollama(text, model="llama3", max_length=150):
|
||||
"""
|
||||
Summarize text using Ollama's local API.
|
||||
|
||||
Args:
|
||||
text (str): Text to summarize
|
||||
model (str): Ollama model to use
|
||||
max_length (int): Maximum length of the summary
|
||||
|
||||
Returns:
|
||||
str: Summarized text or None if failed
|
||||
"""
|
||||
"""Summarize text using Ollama's local API (non-streaming)."""
|
||||
if not check_ollama_available():
|
||||
logger.warning("Ollama service is not available")
|
||||
return None
|
||||
|
||||
# Check if the model is available
|
||||
available_models = list_available_models()
|
||||
if model not in available_models:
|
||||
logger.warning(f"Model {model} not available in Ollama. Available models: {available_models}")
|
||||
return None
|
||||
|
||||
# Prepare the prompt for summarization
|
||||
prompt = f"Summarize the following text in about {max_length} words:\n\n{text}"
|
||||
|
||||
try:
|
||||
# Make the API request
|
||||
response = requests.post(
|
||||
f"{OLLAMA_API_URL}/generate",
|
||||
json={
|
||||
@ -85,7 +55,7 @@ def summarize_with_ollama(text, model="llama3", max_length=150):
|
||||
"options": {
|
||||
"temperature": 0.3,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": max_length * 2 # Approximate token count
|
||||
"max_tokens": max_length * 2
|
||||
}
|
||||
}
|
||||
)
|
||||
@ -101,23 +71,55 @@ def summarize_with_ollama(text, model="llama3", max_length=150):
|
||||
return None
|
||||
|
||||
|
||||
def stream_summarize_with_ollama(text, model="llama3", max_length=150):
|
||||
"""
|
||||
Summarize text using Ollama with streaming. Yields tokens as they arrive.
|
||||
|
||||
Yields:
|
||||
str: Individual response tokens
|
||||
"""
|
||||
if not check_ollama_available():
|
||||
logger.warning("Ollama service is not available")
|
||||
return
|
||||
|
||||
prompt = f"Summarize the following text in about {max_length} words:\n\n{text}"
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{OLLAMA_API_URL}/generate",
|
||||
json={
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
"options": {
|
||||
"temperature": 0.3,
|
||||
"top_p": 0.9,
|
||||
"max_tokens": max_length * 2
|
||||
}
|
||||
},
|
||||
stream=True
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
data = json.loads(line)
|
||||
token = data.get('response', '')
|
||||
if token:
|
||||
yield token
|
||||
if data.get('done', False):
|
||||
break
|
||||
else:
|
||||
logger.error(f"Ollama API error: {response.status_code}")
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Error communicating with Ollama: {e}")
|
||||
|
||||
|
||||
def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
|
||||
"""
|
||||
Chunk long text and summarize each chunk, then combine the summaries.
|
||||
|
||||
Args:
|
||||
text (str): Text to summarize
|
||||
model (str): Ollama model to use
|
||||
chunk_size (int): Maximum size of each chunk in characters
|
||||
max_length (int): Maximum length of the final summary
|
||||
|
||||
Returns:
|
||||
str: Combined summary or None if failed
|
||||
"""
|
||||
"""Chunk long text and summarize each chunk, then combine."""
|
||||
if len(text) <= chunk_size:
|
||||
return summarize_with_ollama(text, model, max_length)
|
||||
|
||||
# Split text into chunks
|
||||
words = text.split()
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
@ -135,7 +137,6 @@ def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
|
||||
if current_chunk:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
|
||||
# Summarize each chunk
|
||||
chunk_summaries = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
|
||||
@ -146,10 +147,55 @@ def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
|
||||
if not chunk_summaries:
|
||||
return None
|
||||
|
||||
# If there's only one chunk summary, return it
|
||||
if len(chunk_summaries) == 1:
|
||||
return chunk_summaries[0]
|
||||
|
||||
# Otherwise, combine the summaries and summarize again
|
||||
combined_summary = " ".join(chunk_summaries)
|
||||
return summarize_with_ollama(combined_summary, model, max_length)
|
||||
|
||||
|
||||
def stream_chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
|
||||
"""
|
||||
Chunk and summarize with streaming on the final summary.
|
||||
Returns non-streaming chunk summaries, then streams the final combination.
|
||||
|
||||
Yields:
|
||||
str: Tokens from the final summary
|
||||
"""
|
||||
if len(text) <= chunk_size:
|
||||
yield from stream_summarize_with_ollama(text, model, max_length)
|
||||
return
|
||||
|
||||
words = text.split()
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
for word in words:
|
||||
if current_length + len(word) + 1 <= chunk_size:
|
||||
current_chunk.append(word)
|
||||
current_length += len(word) + 1
|
||||
else:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
current_chunk = [word]
|
||||
current_length = len(word) + 1
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
|
||||
chunk_summaries = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
|
||||
summary = summarize_with_ollama(chunk, model, max_length // len(chunks))
|
||||
if summary:
|
||||
chunk_summaries.append(summary)
|
||||
|
||||
if not chunk_summaries:
|
||||
return
|
||||
|
||||
if len(chunk_summaries) == 1:
|
||||
yield chunk_summaries[0]
|
||||
return
|
||||
|
||||
combined_summary = " ".join(chunk_summaries)
|
||||
yield from stream_summarize_with_ollama(combined_summary, model, max_length)
|
||||
@ -1,45 +1,49 @@
|
||||
from transformers import pipeline, AutoTokenizer
|
||||
import torch
|
||||
import logging
|
||||
import streamlit as st
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SUMMARY_MODEL = "Falconsai/text_summarization"
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def _load_summarizer(device_int):
|
||||
"""Load and cache the summarization pipeline."""
|
||||
logger.info(f"Loading summarization model on device {device_int}")
|
||||
return pipeline("summarization", model=SUMMARY_MODEL, device=device_int)
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def _load_summary_tokenizer():
|
||||
"""Load and cache the summarization tokenizer."""
|
||||
return AutoTokenizer.from_pretrained(SUMMARY_MODEL)
|
||||
|
||||
|
||||
def chunk_text(text, max_tokens, tokenizer):
|
||||
"""
|
||||
Splits the text into a list of chunks based on token limits.
|
||||
|
||||
Args:
|
||||
text (str): Text to chunk
|
||||
max_tokens (int): Maximum tokens per chunk
|
||||
tokenizer (AutoTokenizer): Tokenizer to use
|
||||
|
||||
Returns:
|
||||
list: List of text chunks
|
||||
Splits text into chunks by tokenizing once, then splitting by token windows.
|
||||
Much faster than the per-word tokenization approach.
|
||||
"""
|
||||
words = text.split()
|
||||
all_ids = tokenizer(text, return_tensors='pt', truncation=False)['input_ids'][0]
|
||||
content_ids = all_ids[1:-1] # strip BOS/EOS
|
||||
usable_max = max_tokens - 2 # leave room for special tokens
|
||||
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
for i in range(0, len(content_ids), usable_max):
|
||||
chunk_ids = content_ids[i : i + usable_max]
|
||||
decoded = tokenizer.decode(chunk_ids, skip_special_tokens=True).strip()
|
||||
if decoded:
|
||||
chunks.append(decoded)
|
||||
|
||||
for word in words:
|
||||
hypothetical_length = current_length + len(tokenizer(word, return_tensors='pt')['input_ids'][0]) - 2
|
||||
if hypothetical_length <= max_tokens:
|
||||
current_chunk.append(word)
|
||||
current_length = hypothetical_length
|
||||
else:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
current_chunk = [word]
|
||||
current_length = len(tokenizer(word, return_tensors='pt')['input_ids'][0]) - 2
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
if not chunks:
|
||||
chunks.append(text)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def summarize_text(text, use_gpu=True, memory_fraction=0.8):
|
||||
"""
|
||||
Summarize text using a Hugging Face pipeline with chunking support.
|
||||
@ -52,21 +56,17 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8):
|
||||
Returns:
|
||||
str: Summarized text
|
||||
"""
|
||||
# Determine device
|
||||
device = -1 # Default to CPU
|
||||
device = -1
|
||||
if use_gpu and torch.cuda.is_available():
|
||||
device = 0 # Use first GPU
|
||||
if torch.cuda.is_available():
|
||||
device = 0
|
||||
torch.cuda.set_per_process_memory_fraction(memory_fraction)
|
||||
|
||||
logger.info(f"Using device {device} for summarization")
|
||||
|
||||
try:
|
||||
# Initialize the pipeline and tokenizer
|
||||
summarizer = pipeline("summarization", model=SUMMARY_MODEL, device=device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(SUMMARY_MODEL)
|
||||
summarizer = _load_summarizer(device)
|
||||
tokenizer = _load_summary_tokenizer()
|
||||
|
||||
# Check if text needs to be chunked
|
||||
max_tokens = 512
|
||||
tokens = tokenizer(text, return_tensors='pt')
|
||||
num_tokens = len(tokens['input_ids'][0])
|
||||
@ -85,7 +85,6 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8):
|
||||
)
|
||||
summaries.append(summary_output[0]['summary_text'])
|
||||
|
||||
# If multiple chunks, summarize the combined summaries
|
||||
if len(summaries) > 1:
|
||||
logger.info("Generating final summary from chunk summaries")
|
||||
combined_text = " ".join(summaries)
|
||||
@ -106,7 +105,6 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during summarization: {e}")
|
||||
# Fallback to CPU if GPU fails
|
||||
if device != -1:
|
||||
logger.info("Falling back to CPU")
|
||||
return summarize_text(text, use_gpu=False, memory_fraction=memory_fraction)
|
||||
|
||||
@ -1,31 +1,36 @@
|
||||
import whisper
|
||||
from pathlib import Path
|
||||
from transformers import pipeline, AutoTokenizer
|
||||
from utils.audio_processing import extract_audio
|
||||
from utils.summarization import summarize_text
|
||||
import logging
|
||||
import torch
|
||||
import streamlit as st
|
||||
|
||||
# Try to import GPU utilities, but don't fail if not available
|
||||
try:
|
||||
from utils.gpu_utils import configure_gpu, get_optimal_device
|
||||
GPU_UTILS_AVAILABLE = True
|
||||
except ImportError:
|
||||
GPU_UTILS_AVAILABLE = False
|
||||
|
||||
# Try to import caching utilities, but don't fail if not available
|
||||
try:
|
||||
from utils.cache import load_from_cache, save_to_cache
|
||||
CACHE_AVAILABLE = True
|
||||
except ImportError:
|
||||
CACHE_AVAILABLE = False
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WHISPER_MODEL = "base"
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def _load_whisper_model(model_name, device_str):
|
||||
"""Load and cache a Whisper model. Cached across reruns."""
|
||||
logger.info(f"Loading Whisper model: {model_name} on {device_str}")
|
||||
device = torch.device(device_str)
|
||||
return whisper.load_model(model_name, device=device if device.type != "mps" else "cpu")
|
||||
|
||||
|
||||
def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cache_max_age=None,
|
||||
use_gpu=True, memory_fraction=0.8):
|
||||
"""
|
||||
@ -44,38 +49,30 @@ def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cach
|
||||
"""
|
||||
audio_path = Path(audio_path)
|
||||
|
||||
# Check cache first if enabled
|
||||
if use_cache and CACHE_AVAILABLE:
|
||||
cached_data = load_from_cache(audio_path, model, "transcribe", cache_max_age)
|
||||
if cached_data:
|
||||
logger.info(f"Using cached transcription for {audio_path}")
|
||||
return cached_data.get("segments", []), cached_data.get("transcript", "")
|
||||
|
||||
# Extract audio if the input is a video file (M4A is already audio)
|
||||
video_extensions = ['.mp4', '.avi', '.mov', '.mkv']
|
||||
if audio_path.suffix.lower() in video_extensions:
|
||||
audio_path = extract_audio(audio_path)
|
||||
|
||||
# Configure GPU if available and requested
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
gpu_config = configure_gpu(model, memory_fraction)
|
||||
device = gpu_config["device"]
|
||||
logger.info(f"Using device: {device} for transcription")
|
||||
|
||||
# Load the specified Whisper model
|
||||
logger.info(f"Loading Whisper model: {model}")
|
||||
whisper_model = whisper.load_model(model, device=device if device.type != "mps" else "cpu")
|
||||
whisper_model = _load_whisper_model(model, str(device))
|
||||
|
||||
# Transcribe the audio
|
||||
logger.info(f"Transcribing audio: {audio_path}")
|
||||
result = whisper_model.transcribe(str(audio_path))
|
||||
|
||||
# Extract the full transcript and segments
|
||||
transcript = result["text"]
|
||||
segments = result["segments"]
|
||||
|
||||
# Cache the results if caching is enabled
|
||||
if use_cache and CACHE_AVAILABLE:
|
||||
cache_data = {
|
||||
"transcript": transcript,
|
||||
|
||||
@ -1,41 +1,49 @@
|
||||
"""
|
||||
Translation utilities for the OBS Recording Transcriber.
|
||||
Translation utilities for the Video Transcriber.
|
||||
Provides functions for language detection and translation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, M2M100ForConditionalGeneration
|
||||
from transformers import pipeline, AutoTokenizer, M2M100ForConditionalGeneration
|
||||
import whisper
|
||||
import iso639
|
||||
import streamlit as st
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Try to import GPU utilities, but don't fail if not available
|
||||
try:
|
||||
from utils.gpu_utils import get_optimal_device
|
||||
GPU_UTILS_AVAILABLE = True
|
||||
except ImportError:
|
||||
GPU_UTILS_AVAILABLE = False
|
||||
|
||||
# Default models
|
||||
TRANSLATION_MODEL = "facebook/m2m100_418M"
|
||||
LANGUAGE_DETECTION_MODEL = "papluca/xlm-roberta-base-language-detection"
|
||||
|
||||
# ISO language code mapping
|
||||
|
||||
@st.cache_resource
|
||||
def _load_language_detector(model_name, device_int):
|
||||
"""Load and cache the language detection pipeline."""
|
||||
logger.info(f"Loading language detection model: {model_name}")
|
||||
return pipeline("text-classification", model=model_name, device=device_int)
|
||||
|
||||
|
||||
@st.cache_resource
|
||||
def _load_translation_model(model_name, device_str):
|
||||
"""Load and cache the M2M100 translation model and tokenizer."""
|
||||
logger.info(f"Loading translation model: {model_name} on {device_str}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
model = M2M100ForConditionalGeneration.from_pretrained(model_name)
|
||||
device = torch.device(device_str)
|
||||
model = model.to(device)
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_language_name(code):
|
||||
"""
|
||||
Get the language name from ISO code.
|
||||
|
||||
Args:
|
||||
code (str): ISO language code
|
||||
|
||||
Returns:
|
||||
str: Language name or original code if not found
|
||||
"""
|
||||
"""Get the language name from ISO code."""
|
||||
try:
|
||||
return iso639.languages.get(part1=code).name
|
||||
except (KeyError, AttributeError):
|
||||
@ -57,7 +65,6 @@ def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True):
|
||||
Returns:
|
||||
tuple: (language_code, confidence)
|
||||
"""
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
@ -66,25 +73,43 @@ def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True):
|
||||
device_arg = -1
|
||||
|
||||
try:
|
||||
# Initialize the pipeline
|
||||
classifier = pipeline("text-classification", model=model, device=device_arg)
|
||||
classifier = _load_language_detector(model, device_arg)
|
||||
|
||||
# Truncate text if too long
|
||||
max_length = 512
|
||||
if len(text) > max_length:
|
||||
text = text[:max_length]
|
||||
|
||||
# Detect language
|
||||
result = classifier(text)[0]
|
||||
language_code = result["label"]
|
||||
confidence = result["score"]
|
||||
|
||||
return language_code, confidence
|
||||
return result["label"], result["score"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error detecting language: {e}")
|
||||
return None, 0.0
|
||||
|
||||
|
||||
def _translate_text_with_model(text, source_lang, target_lang, trans_model, tokenizer, device):
|
||||
"""Translate text using a pre-loaded model and tokenizer."""
|
||||
tokenizer.src_lang = source_lang
|
||||
|
||||
max_length = 512
|
||||
if len(text) > max_length:
|
||||
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
|
||||
else:
|
||||
chunks = [text]
|
||||
|
||||
translated_chunks = []
|
||||
for chunk in chunks:
|
||||
encoded = tokenizer(chunk, return_tensors="pt").to(device)
|
||||
generated_tokens = trans_model.generate(
|
||||
**encoded,
|
||||
forced_bos_token_id=tokenizer.get_lang_id(target_lang),
|
||||
max_length=max_length
|
||||
)
|
||||
translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
||||
translated_chunks.append(translated_chunk)
|
||||
|
||||
return " ".join(translated_chunks)
|
||||
|
||||
|
||||
def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_MODEL, use_gpu=True):
|
||||
"""
|
||||
Translate text from source language to target language.
|
||||
@ -99,7 +124,6 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M
|
||||
Returns:
|
||||
str: Translated text
|
||||
"""
|
||||
# Auto-detect source language if not provided
|
||||
if source_lang is None:
|
||||
detected_lang, confidence = detect_language(text, use_gpu=use_gpu)
|
||||
if detected_lang and confidence > 0.5:
|
||||
@ -109,50 +133,17 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M
|
||||
logger.warning("Could not reliably detect language, defaulting to English")
|
||||
source_lang = "en"
|
||||
|
||||
# Skip translation if source and target are the same
|
||||
if source_lang == target_lang:
|
||||
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
|
||||
return text
|
||||
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
|
||||
try:
|
||||
# Load model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
model = M2M100ForConditionalGeneration.from_pretrained(model)
|
||||
|
||||
# Move model to device
|
||||
model = model.to(device)
|
||||
|
||||
# Prepare for translation
|
||||
tokenizer.src_lang = source_lang
|
||||
|
||||
# Split text into manageable chunks if too long
|
||||
max_length = 512
|
||||
if len(text) > max_length:
|
||||
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
|
||||
else:
|
||||
chunks = [text]
|
||||
|
||||
# Translate each chunk
|
||||
translated_chunks = []
|
||||
for chunk in chunks:
|
||||
encoded = tokenizer(chunk, return_tensors="pt").to(device)
|
||||
generated_tokens = model.generate(
|
||||
**encoded,
|
||||
forced_bos_token_id=tokenizer.get_lang_id(target_lang),
|
||||
max_length=max_length
|
||||
)
|
||||
translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
||||
translated_chunks.append(translated_chunk)
|
||||
|
||||
# Combine translated chunks
|
||||
translated_text = " ".join(translated_chunks)
|
||||
|
||||
return translated_text
|
||||
trans_model, tokenizer = _load_translation_model(model, str(device))
|
||||
return _translate_text_with_model(text, source_lang, target_lang, trans_model, tokenizer, device)
|
||||
except Exception as e:
|
||||
logger.error(f"Error translating text: {e}")
|
||||
return text
|
||||
@ -160,7 +151,7 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M
|
||||
|
||||
def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=True):
|
||||
"""
|
||||
Translate transcript segments.
|
||||
Translate transcript segments. Loads the model once and reuses for all segments.
|
||||
|
||||
Args:
|
||||
segments (list): List of transcript segments
|
||||
@ -174,36 +165,32 @@ def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=Tru
|
||||
if not segments:
|
||||
return []
|
||||
|
||||
# Auto-detect source language from combined text if not provided
|
||||
if source_lang is None:
|
||||
combined_text = " ".join([segment["text"] for segment in segments])
|
||||
detected_lang, _ = detect_language(combined_text, use_gpu=use_gpu)
|
||||
source_lang = detected_lang if detected_lang else "en"
|
||||
|
||||
# Skip translation if source and target are the same
|
||||
if source_lang == target_lang:
|
||||
return segments
|
||||
|
||||
try:
|
||||
# Initialize translation pipeline
|
||||
translated_segments = []
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
|
||||
# Translate each segment
|
||||
try:
|
||||
trans_model, tokenizer = _load_translation_model(TRANSLATION_MODEL, str(device))
|
||||
|
||||
translated_segments = []
|
||||
for segment in segments:
|
||||
translated_text = translate_text(
|
||||
segment["text"],
|
||||
source_lang=source_lang,
|
||||
target_lang=target_lang,
|
||||
use_gpu=use_gpu
|
||||
translated_text = _translate_text_with_model(
|
||||
segment["text"], source_lang, target_lang, trans_model, tokenizer, device
|
||||
)
|
||||
|
||||
# Create a new segment with translated text
|
||||
translated_segment = segment.copy()
|
||||
translated_segment["text"] = translated_text
|
||||
translated_segment["original_text"] = segment["text"]
|
||||
translated_segment["source_lang"] = source_lang
|
||||
translated_segment["target_lang"] = target_lang
|
||||
|
||||
translated_segments.append(translated_segment)
|
||||
|
||||
return translated_segments
|
||||
@ -227,39 +214,33 @@ def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en",
|
||||
Returns:
|
||||
tuple: (original_segments, translated_segments, original_transcript, translated_transcript)
|
||||
"""
|
||||
from utils.transcription import _load_whisper_model
|
||||
|
||||
audio_path = Path(audio_path)
|
||||
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
|
||||
try:
|
||||
# Step 1: Transcribe audio with Whisper
|
||||
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
|
||||
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
|
||||
model = _load_whisper_model(whisper_model, str(device))
|
||||
|
||||
# Use Whisper's built-in language detection if requested
|
||||
if detect_source:
|
||||
# First, detect language with Whisper
|
||||
audio = whisper.load_audio(str(audio_path))
|
||||
audio = whisper.pad_or_trim(audio)
|
||||
mel = whisper.log_mel_spectrogram(audio).to(device if device.type != "mps" else "cpu")
|
||||
_, probs = model.detect_language(mel)
|
||||
source_lang = max(probs, key=probs.get)
|
||||
logger.info(f"Whisper detected language: {get_language_name(source_lang)} ({source_lang})")
|
||||
|
||||
# Transcribe with detected language
|
||||
result = model.transcribe(str(audio_path), language=source_lang)
|
||||
else:
|
||||
# Transcribe without language specification
|
||||
result = model.transcribe(str(audio_path))
|
||||
source_lang = result.get("language", "en")
|
||||
|
||||
original_segments = result["segments"]
|
||||
original_transcript = result["text"]
|
||||
|
||||
# Step 2: Translate if needed
|
||||
if source_lang != target_lang:
|
||||
logger.info(f"Translating from {source_lang} to {target_lang}")
|
||||
translated_segments = translate_segments(
|
||||
@ -268,8 +249,6 @@ def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en",
|
||||
target_lang=target_lang,
|
||||
use_gpu=use_gpu
|
||||
)
|
||||
|
||||
# Create full translated transcript
|
||||
translated_transcript = " ".join([segment["text"] for segment in translated_segments])
|
||||
else:
|
||||
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
|
||||
|
||||
@ -1,8 +1,38 @@
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import logging
|
||||
|
||||
def validate_environment(obs_path: Path):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_environment(obs_path: Path = None):
|
||||
"""Validate environment and prerequisites."""
|
||||
errors = []
|
||||
if not obs_path.exists():
|
||||
errors.append(f"OBS directory not found: {obs_path}")
|
||||
|
||||
if obs_path and not obs_path.exists():
|
||||
errors.append(f"Directory not found: {obs_path}")
|
||||
|
||||
if not shutil.which("ffmpeg"):
|
||||
errors.append("FFmpeg is not installed or not in PATH. Install it from https://ffmpeg.org/download.html")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def get_system_capabilities():
|
||||
"""Return a dict of detected system capabilities for display."""
|
||||
import torch
|
||||
|
||||
caps = {
|
||||
"ffmpeg": shutil.which("ffmpeg") is not None,
|
||||
"cuda": torch.cuda.is_available(),
|
||||
"mps": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
|
||||
"gpu_name": None,
|
||||
"gpu_memory": None,
|
||||
}
|
||||
|
||||
if caps["cuda"] and torch.cuda.device_count() > 0:
|
||||
props = torch.cuda.get_device_properties(0)
|
||||
caps["gpu_name"] = props.name
|
||||
caps["gpu_memory"] = props.total_memory
|
||||
|
||||
return caps
|
||||
|
||||
Reference in New Issue
Block a user