feat: Add streaming Ollama support, model caching, and UI improvements
- Add streaming summarization via Ollama API (stream_summarize_with_ollama) - Cache ML models with @st.cache_resource (diarization, NER, translation, Whisper) - Add temp file cleanup for extracted audio - Add system capabilities detection (FFmpeg, GPU info) - Add get_video_duration utility - Improve validation with FFmpeg check - Rewrite app.py with streaming support and UI enhancements - Clean up redundant comments and unused imports across all utils
This commit is contained in:
@ -1,19 +1,55 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
|
||||||
# moviepy 2.x removed moviepy.editor; import directly from moviepy
|
|
||||||
try:
|
try:
|
||||||
from moviepy import AudioFileClip
|
from moviepy import AudioFileClip
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Fallback for moviepy 1.x
|
|
||||||
from moviepy.editor import AudioFileClip
|
from moviepy.editor import AudioFileClip
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_temp_audio_files = []
|
||||||
|
|
||||||
|
|
||||||
def extract_audio(video_path: Path):
|
def extract_audio(video_path: Path):
|
||||||
"""Extract audio from a video file."""
|
"""Extract audio from a video file into a temp directory for automatic cleanup."""
|
||||||
try:
|
try:
|
||||||
audio = AudioFileClip(str(video_path))
|
audio = AudioFileClip(str(video_path))
|
||||||
audio_path = video_path.parent / f"{video_path.stem}_audio.wav"
|
temp_dir = tempfile.mkdtemp(prefix="videotranscriber_")
|
||||||
|
audio_path = Path(temp_dir) / f"{video_path.stem}_audio.wav"
|
||||||
audio.write_audiofile(str(audio_path), verbose=False, logger=None)
|
audio.write_audiofile(str(audio_path), verbose=False, logger=None)
|
||||||
audio.close()
|
audio.close()
|
||||||
|
_temp_audio_files.append(str(audio_path))
|
||||||
return audio_path
|
return audio_path
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Audio extraction failed: {e}")
|
raise RuntimeError(f"Audio extraction failed: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_temp_audio():
|
||||||
|
"""Remove all temporary audio files created during processing."""
|
||||||
|
cleaned = 0
|
||||||
|
for fpath in _temp_audio_files:
|
||||||
|
try:
|
||||||
|
if os.path.exists(fpath):
|
||||||
|
os.remove(fpath)
|
||||||
|
parent = os.path.dirname(fpath)
|
||||||
|
if os.path.isdir(parent) and not os.listdir(parent):
|
||||||
|
os.rmdir(parent)
|
||||||
|
cleaned += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not remove temp file {fpath}: {e}")
|
||||||
|
_temp_audio_files.clear()
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
def get_video_duration(video_path: Path):
|
||||||
|
"""Get duration of a video/audio file in seconds."""
|
||||||
|
try:
|
||||||
|
clip = AudioFileClip(str(video_path))
|
||||||
|
duration = clip.duration
|
||||||
|
clip.close()
|
||||||
|
return duration
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Speaker diarization utilities for the OBS Recording Transcriber.
|
Speaker diarization utilities for the Video Transcriber.
|
||||||
Provides functions to identify different speakers in audio recordings.
|
Provides functions to identify different speakers in audio recordings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -11,22 +11,34 @@ import torch
|
|||||||
from pyannote.audio import Pipeline
|
from pyannote.audio import Pipeline
|
||||||
from pyannote.core import Segment
|
from pyannote.core import Segment
|
||||||
import whisper
|
import whisper
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Try to import GPU utilities, but don't fail if not available
|
|
||||||
try:
|
try:
|
||||||
from utils.gpu_utils import get_optimal_device
|
from utils.gpu_utils import get_optimal_device
|
||||||
GPU_UTILS_AVAILABLE = True
|
GPU_UTILS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
GPU_UTILS_AVAILABLE = False
|
GPU_UTILS_AVAILABLE = False
|
||||||
|
|
||||||
# Default HuggingFace auth token environment variable
|
|
||||||
HF_TOKEN_ENV = "HF_TOKEN"
|
HF_TOKEN_ENV = "HF_TOKEN"
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def _load_diarization_pipeline(hf_token, device_str):
|
||||||
|
"""Load and cache the speaker diarization pipeline."""
|
||||||
|
logger.info(f"Loading diarization pipeline on {device_str}")
|
||||||
|
pipe = Pipeline.from_pretrained(
|
||||||
|
"pyannote/speaker-diarization-3.0",
|
||||||
|
use_auth_token=hf_token
|
||||||
|
)
|
||||||
|
device = torch.device(device_str)
|
||||||
|
if device.type == "cuda":
|
||||||
|
pipe = pipe.to(device)
|
||||||
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
def get_diarization_pipeline(use_gpu=True, hf_token=None):
|
def get_diarization_pipeline(use_gpu=True, hf_token=None):
|
||||||
"""
|
"""
|
||||||
Initialize the speaker diarization pipeline.
|
Initialize the speaker diarization pipeline.
|
||||||
@ -38,7 +50,6 @@ def get_diarization_pipeline(use_gpu=True, hf_token=None):
|
|||||||
Returns:
|
Returns:
|
||||||
Pipeline or None: Diarization pipeline if successful, None otherwise
|
Pipeline or None: Diarization pipeline if successful, None otherwise
|
||||||
"""
|
"""
|
||||||
# Check if token is provided or in environment
|
|
||||||
if hf_token is None:
|
if hf_token is None:
|
||||||
hf_token = os.environ.get(HF_TOKEN_ENV)
|
hf_token = os.environ.get(HF_TOKEN_ENV)
|
||||||
if hf_token is None:
|
if hf_token is None:
|
||||||
@ -46,23 +57,12 @@ def get_diarization_pipeline(use_gpu=True, hf_token=None):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Configure device
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
logger.info(f"Using device: {device} for diarization")
|
logger.info(f"Using device: {device} for diarization")
|
||||||
|
|
||||||
# Initialize the pipeline
|
return _load_diarization_pipeline(hf_token, str(device))
|
||||||
pipeline = Pipeline.from_pretrained(
|
|
||||||
"pyannote/speaker-diarization-3.0",
|
|
||||||
use_auth_token=hf_token
|
|
||||||
)
|
|
||||||
|
|
||||||
# Move to appropriate device
|
|
||||||
if device.type == "cuda":
|
|
||||||
pipeline = pipeline.to(torch.device(device))
|
|
||||||
|
|
||||||
return pipeline
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error initializing diarization pipeline: {e}")
|
logger.error(f"Error initializing diarization pipeline: {e}")
|
||||||
return None
|
return None
|
||||||
@ -198,9 +198,9 @@ def transcribe_with_diarization(audio_path, whisper_model="base", num_speakers=N
|
|||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Step 1: Transcribe audio with Whisper
|
from utils.transcription import _load_whisper_model
|
||||||
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
|
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
|
||||||
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
|
model = _load_whisper_model(whisper_model, str(device))
|
||||||
result = model.transcribe(str(audio_path))
|
result = model.transcribe(str(audio_path))
|
||||||
transcript_segments = result["segments"]
|
transcript_segments = result["segments"]
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
GPU utilities for the OBS Recording Transcriber.
|
GPU utilities for the Video Transcriber.
|
||||||
Provides functions to detect and configure GPU acceleration.
|
Provides functions to detect and configure GPU acceleration.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import subprocess
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
@ -68,8 +65,6 @@ def get_optimal_device():
|
|||||||
|
|
||||||
|
|
||||||
def set_memory_limits(memory_fraction=0.8):
|
def set_memory_limits(memory_fraction=0.8):
|
||||||
global torch
|
|
||||||
import torch
|
|
||||||
"""
|
"""
|
||||||
Set memory limits for GPU usage.
|
Set memory limits for GPU usage.
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
"""
|
"""
|
||||||
Keyword extraction utilities for the OBS Recording Transcriber.
|
Keyword extraction utilities for the Video Transcriber.
|
||||||
Provides functions to extract keywords and link them to timestamps.
|
Provides functions to extract keywords and link them to timestamps.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -8,25 +8,30 @@ import re
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
|
from transformers import pipeline
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Try to import GPU utilities, but don't fail if not available
|
|
||||||
try:
|
try:
|
||||||
from utils.gpu_utils import get_optimal_device
|
from utils.gpu_utils import get_optimal_device
|
||||||
GPU_UTILS_AVAILABLE = True
|
GPU_UTILS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
GPU_UTILS_AVAILABLE = False
|
GPU_UTILS_AVAILABLE = False
|
||||||
|
|
||||||
# Default models
|
|
||||||
NER_MODEL = "dslim/bert-base-NER"
|
NER_MODEL = "dslim/bert-base-NER"
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def _load_ner_pipeline(model_name, device_int):
|
||||||
|
"""Load and cache the NER pipeline."""
|
||||||
|
logger.info(f"Loading NER model: {model_name}")
|
||||||
|
return pipeline("ner", model=model_name, device=device_int, aggregation_strategy="simple")
|
||||||
|
|
||||||
|
|
||||||
def extract_keywords_tfidf(text, max_keywords=10, ngram_range=(1, 2)):
|
def extract_keywords_tfidf(text, max_keywords=10, ngram_range=(1, 2)):
|
||||||
"""
|
"""
|
||||||
Extract keywords using TF-IDF.
|
Extract keywords using TF-IDF.
|
||||||
@ -107,8 +112,7 @@ def extract_named_entities(text, model=NER_MODEL, use_gpu=True):
|
|||||||
device_arg = -1
|
device_arg = -1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize the pipeline
|
ner_pipeline = _load_ner_pipeline(model, device_arg)
|
||||||
ner_pipeline = pipeline("ner", model=model, device=device_arg, aggregation_strategy="simple")
|
|
||||||
|
|
||||||
# Split text into manageable chunks if too long
|
# Split text into manageable chunks if too long
|
||||||
max_length = 512
|
max_length = 512
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
"""
|
"""
|
||||||
Ollama integration for local AI model inference.
|
Ollama integration for local AI model inference.
|
||||||
Provides functions to use Ollama's API for text summarization.
|
Provides functions to use Ollama's API for text summarization with streaming support.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@ -9,21 +9,14 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Default Ollama API endpoint - configurable via environment variable
|
|
||||||
OLLAMA_API_URL = os.environ.get("OLLAMA_API_URL", "http://localhost:11434/api")
|
OLLAMA_API_URL = os.environ.get("OLLAMA_API_URL", "http://localhost:11434/api")
|
||||||
|
|
||||||
|
|
||||||
def check_ollama_available():
|
def check_ollama_available():
|
||||||
"""
|
"""Check if Ollama service is available."""
|
||||||
Check if Ollama service is available.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True if Ollama is available, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{OLLAMA_API_URL}/tags", timeout=2)
|
response = requests.get(f"{OLLAMA_API_URL}/tags", timeout=2)
|
||||||
return response.status_code == 200
|
return response.status_code == 200
|
||||||
@ -32,12 +25,7 @@ def check_ollama_available():
|
|||||||
|
|
||||||
|
|
||||||
def list_available_models():
|
def list_available_models():
|
||||||
"""
|
"""List available models in Ollama."""
|
||||||
List available models in Ollama.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of available model names
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
response = requests.get(f"{OLLAMA_API_URL}/tags")
|
response = requests.get(f"{OLLAMA_API_URL}/tags")
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
@ -50,32 +38,14 @@ def list_available_models():
|
|||||||
|
|
||||||
|
|
||||||
def summarize_with_ollama(text, model="llama3", max_length=150):
|
def summarize_with_ollama(text, model="llama3", max_length=150):
|
||||||
"""
|
"""Summarize text using Ollama's local API (non-streaming)."""
|
||||||
Summarize text using Ollama's local API.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): Text to summarize
|
|
||||||
model (str): Ollama model to use
|
|
||||||
max_length (int): Maximum length of the summary
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Summarized text or None if failed
|
|
||||||
"""
|
|
||||||
if not check_ollama_available():
|
if not check_ollama_available():
|
||||||
logger.warning("Ollama service is not available")
|
logger.warning("Ollama service is not available")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Check if the model is available
|
|
||||||
available_models = list_available_models()
|
|
||||||
if model not in available_models:
|
|
||||||
logger.warning(f"Model {model} not available in Ollama. Available models: {available_models}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Prepare the prompt for summarization
|
|
||||||
prompt = f"Summarize the following text in about {max_length} words:\n\n{text}"
|
prompt = f"Summarize the following text in about {max_length} words:\n\n{text}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Make the API request
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{OLLAMA_API_URL}/generate",
|
f"{OLLAMA_API_URL}/generate",
|
||||||
json={
|
json={
|
||||||
@ -85,7 +55,7 @@ def summarize_with_ollama(text, model="llama3", max_length=150):
|
|||||||
"options": {
|
"options": {
|
||||||
"temperature": 0.3,
|
"temperature": 0.3,
|
||||||
"top_p": 0.9,
|
"top_p": 0.9,
|
||||||
"max_tokens": max_length * 2 # Approximate token count
|
"max_tokens": max_length * 2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -101,23 +71,55 @@ def summarize_with_ollama(text, model="llama3", max_length=150):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def stream_summarize_with_ollama(text, model="llama3", max_length=150):
|
||||||
|
"""
|
||||||
|
Summarize text using Ollama with streaming. Yields tokens as they arrive.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: Individual response tokens
|
||||||
|
"""
|
||||||
|
if not check_ollama_available():
|
||||||
|
logger.warning("Ollama service is not available")
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt = f"Summarize the following text in about {max_length} words:\n\n{text}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{OLLAMA_API_URL}/generate",
|
||||||
|
json={
|
||||||
|
"model": model,
|
||||||
|
"prompt": prompt,
|
||||||
|
"stream": True,
|
||||||
|
"options": {
|
||||||
|
"temperature": 0.3,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"max_tokens": max_length * 2
|
||||||
|
}
|
||||||
|
},
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
data = json.loads(line)
|
||||||
|
token = data.get('response', '')
|
||||||
|
if token:
|
||||||
|
yield token
|
||||||
|
if data.get('done', False):
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.error(f"Ollama API error: {response.status_code}")
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"Error communicating with Ollama: {e}")
|
||||||
|
|
||||||
|
|
||||||
def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
|
def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
|
||||||
"""
|
"""Chunk long text and summarize each chunk, then combine."""
|
||||||
Chunk long text and summarize each chunk, then combine the summaries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text (str): Text to summarize
|
|
||||||
model (str): Ollama model to use
|
|
||||||
chunk_size (int): Maximum size of each chunk in characters
|
|
||||||
max_length (int): Maximum length of the final summary
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Combined summary or None if failed
|
|
||||||
"""
|
|
||||||
if len(text) <= chunk_size:
|
if len(text) <= chunk_size:
|
||||||
return summarize_with_ollama(text, model, max_length)
|
return summarize_with_ollama(text, model, max_length)
|
||||||
|
|
||||||
# Split text into chunks
|
|
||||||
words = text.split()
|
words = text.split()
|
||||||
chunks = []
|
chunks = []
|
||||||
current_chunk = []
|
current_chunk = []
|
||||||
@ -135,7 +137,6 @@ def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
|
|||||||
if current_chunk:
|
if current_chunk:
|
||||||
chunks.append(' '.join(current_chunk))
|
chunks.append(' '.join(current_chunk))
|
||||||
|
|
||||||
# Summarize each chunk
|
|
||||||
chunk_summaries = []
|
chunk_summaries = []
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
|
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
|
||||||
@ -146,10 +147,55 @@ def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
|
|||||||
if not chunk_summaries:
|
if not chunk_summaries:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# If there's only one chunk summary, return it
|
|
||||||
if len(chunk_summaries) == 1:
|
if len(chunk_summaries) == 1:
|
||||||
return chunk_summaries[0]
|
return chunk_summaries[0]
|
||||||
|
|
||||||
# Otherwise, combine the summaries and summarize again
|
|
||||||
combined_summary = " ".join(chunk_summaries)
|
combined_summary = " ".join(chunk_summaries)
|
||||||
return summarize_with_ollama(combined_summary, model, max_length)
|
return summarize_with_ollama(combined_summary, model, max_length)
|
||||||
|
|
||||||
|
|
||||||
|
def stream_chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150):
|
||||||
|
"""
|
||||||
|
Chunk and summarize with streaming on the final summary.
|
||||||
|
Returns non-streaming chunk summaries, then streams the final combination.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
str: Tokens from the final summary
|
||||||
|
"""
|
||||||
|
if len(text) <= chunk_size:
|
||||||
|
yield from stream_summarize_with_ollama(text, model, max_length)
|
||||||
|
return
|
||||||
|
|
||||||
|
words = text.split()
|
||||||
|
chunks = []
|
||||||
|
current_chunk = []
|
||||||
|
current_length = 0
|
||||||
|
|
||||||
|
for word in words:
|
||||||
|
if current_length + len(word) + 1 <= chunk_size:
|
||||||
|
current_chunk.append(word)
|
||||||
|
current_length += len(word) + 1
|
||||||
|
else:
|
||||||
|
chunks.append(' '.join(current_chunk))
|
||||||
|
current_chunk = [word]
|
||||||
|
current_length = len(word) + 1
|
||||||
|
|
||||||
|
if current_chunk:
|
||||||
|
chunks.append(' '.join(current_chunk))
|
||||||
|
|
||||||
|
chunk_summaries = []
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
|
||||||
|
summary = summarize_with_ollama(chunk, model, max_length // len(chunks))
|
||||||
|
if summary:
|
||||||
|
chunk_summaries.append(summary)
|
||||||
|
|
||||||
|
if not chunk_summaries:
|
||||||
|
return
|
||||||
|
|
||||||
|
if len(chunk_summaries) == 1:
|
||||||
|
yield chunk_summaries[0]
|
||||||
|
return
|
||||||
|
|
||||||
|
combined_summary = " ".join(chunk_summaries)
|
||||||
|
yield from stream_summarize_with_ollama(combined_summary, model, max_length)
|
||||||
@ -1,45 +1,49 @@
|
|||||||
from transformers import pipeline, AutoTokenizer
|
from transformers import pipeline, AutoTokenizer
|
||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
SUMMARY_MODEL = "Falconsai/text_summarization"
|
SUMMARY_MODEL = "Falconsai/text_summarization"
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def _load_summarizer(device_int):
|
||||||
|
"""Load and cache the summarization pipeline."""
|
||||||
|
logger.info(f"Loading summarization model on device {device_int}")
|
||||||
|
return pipeline("summarization", model=SUMMARY_MODEL, device=device_int)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def _load_summary_tokenizer():
|
||||||
|
"""Load and cache the summarization tokenizer."""
|
||||||
|
return AutoTokenizer.from_pretrained(SUMMARY_MODEL)
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(text, max_tokens, tokenizer):
|
def chunk_text(text, max_tokens, tokenizer):
|
||||||
"""
|
"""
|
||||||
Splits the text into a list of chunks based on token limits.
|
Splits text into chunks by tokenizing once, then splitting by token windows.
|
||||||
|
Much faster than the per-word tokenization approach.
|
||||||
Args:
|
|
||||||
text (str): Text to chunk
|
|
||||||
max_tokens (int): Maximum tokens per chunk
|
|
||||||
tokenizer (AutoTokenizer): Tokenizer to use
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: List of text chunks
|
|
||||||
"""
|
"""
|
||||||
words = text.split()
|
all_ids = tokenizer(text, return_tensors='pt', truncation=False)['input_ids'][0]
|
||||||
|
content_ids = all_ids[1:-1] # strip BOS/EOS
|
||||||
|
usable_max = max_tokens - 2 # leave room for special tokens
|
||||||
|
|
||||||
chunks = []
|
chunks = []
|
||||||
current_chunk = []
|
for i in range(0, len(content_ids), usable_max):
|
||||||
current_length = 0
|
chunk_ids = content_ids[i : i + usable_max]
|
||||||
|
decoded = tokenizer.decode(chunk_ids, skip_special_tokens=True).strip()
|
||||||
|
if decoded:
|
||||||
|
chunks.append(decoded)
|
||||||
|
|
||||||
for word in words:
|
if not chunks:
|
||||||
hypothetical_length = current_length + len(tokenizer(word, return_tensors='pt')['input_ids'][0]) - 2
|
chunks.append(text)
|
||||||
if hypothetical_length <= max_tokens:
|
|
||||||
current_chunk.append(word)
|
|
||||||
current_length = hypothetical_length
|
|
||||||
else:
|
|
||||||
chunks.append(' '.join(current_chunk))
|
|
||||||
current_chunk = [word]
|
|
||||||
current_length = len(tokenizer(word, return_tensors='pt')['input_ids'][0]) - 2
|
|
||||||
|
|
||||||
if current_chunk:
|
|
||||||
chunks.append(' '.join(current_chunk))
|
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def summarize_text(text, use_gpu=True, memory_fraction=0.8):
|
def summarize_text(text, use_gpu=True, memory_fraction=0.8):
|
||||||
"""
|
"""
|
||||||
Summarize text using a Hugging Face pipeline with chunking support.
|
Summarize text using a Hugging Face pipeline with chunking support.
|
||||||
@ -52,21 +56,17 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8):
|
|||||||
Returns:
|
Returns:
|
||||||
str: Summarized text
|
str: Summarized text
|
||||||
"""
|
"""
|
||||||
# Determine device
|
device = -1
|
||||||
device = -1 # Default to CPU
|
|
||||||
if use_gpu and torch.cuda.is_available():
|
if use_gpu and torch.cuda.is_available():
|
||||||
device = 0 # Use first GPU
|
device = 0
|
||||||
if torch.cuda.is_available():
|
torch.cuda.set_per_process_memory_fraction(memory_fraction)
|
||||||
torch.cuda.set_per_process_memory_fraction(memory_fraction)
|
|
||||||
|
|
||||||
logger.info(f"Using device {device} for summarization")
|
logger.info(f"Using device {device} for summarization")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize the pipeline and tokenizer
|
summarizer = _load_summarizer(device)
|
||||||
summarizer = pipeline("summarization", model=SUMMARY_MODEL, device=device)
|
tokenizer = _load_summary_tokenizer()
|
||||||
tokenizer = AutoTokenizer.from_pretrained(SUMMARY_MODEL)
|
|
||||||
|
|
||||||
# Check if text needs to be chunked
|
|
||||||
max_tokens = 512
|
max_tokens = 512
|
||||||
tokens = tokenizer(text, return_tensors='pt')
|
tokens = tokenizer(text, return_tensors='pt')
|
||||||
num_tokens = len(tokens['input_ids'][0])
|
num_tokens = len(tokens['input_ids'][0])
|
||||||
@ -85,7 +85,6 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8):
|
|||||||
)
|
)
|
||||||
summaries.append(summary_output[0]['summary_text'])
|
summaries.append(summary_output[0]['summary_text'])
|
||||||
|
|
||||||
# If multiple chunks, summarize the combined summaries
|
|
||||||
if len(summaries) > 1:
|
if len(summaries) > 1:
|
||||||
logger.info("Generating final summary from chunk summaries")
|
logger.info("Generating final summary from chunk summaries")
|
||||||
combined_text = " ".join(summaries)
|
combined_text = " ".join(summaries)
|
||||||
@ -106,7 +105,6 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error during summarization: {e}")
|
logger.error(f"Error during summarization: {e}")
|
||||||
# Fallback to CPU if GPU fails
|
|
||||||
if device != -1:
|
if device != -1:
|
||||||
logger.info("Falling back to CPU")
|
logger.info("Falling back to CPU")
|
||||||
return summarize_text(text, use_gpu=False, memory_fraction=memory_fraction)
|
return summarize_text(text, use_gpu=False, memory_fraction=memory_fraction)
|
||||||
|
|||||||
@ -1,31 +1,36 @@
|
|||||||
import whisper
|
import whisper
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from transformers import pipeline, AutoTokenizer
|
|
||||||
from utils.audio_processing import extract_audio
|
from utils.audio_processing import extract_audio
|
||||||
from utils.summarization import summarize_text
|
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
# Try to import GPU utilities, but don't fail if not available
|
|
||||||
try:
|
try:
|
||||||
from utils.gpu_utils import configure_gpu, get_optimal_device
|
from utils.gpu_utils import configure_gpu, get_optimal_device
|
||||||
GPU_UTILS_AVAILABLE = True
|
GPU_UTILS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
GPU_UTILS_AVAILABLE = False
|
GPU_UTILS_AVAILABLE = False
|
||||||
|
|
||||||
# Try to import caching utilities, but don't fail if not available
|
|
||||||
try:
|
try:
|
||||||
from utils.cache import load_from_cache, save_to_cache
|
from utils.cache import load_from_cache, save_to_cache
|
||||||
CACHE_AVAILABLE = True
|
CACHE_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
CACHE_AVAILABLE = False
|
CACHE_AVAILABLE = False
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
WHISPER_MODEL = "base"
|
WHISPER_MODEL = "base"
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def _load_whisper_model(model_name, device_str):
|
||||||
|
"""Load and cache a Whisper model. Cached across reruns."""
|
||||||
|
logger.info(f"Loading Whisper model: {model_name} on {device_str}")
|
||||||
|
device = torch.device(device_str)
|
||||||
|
return whisper.load_model(model_name, device=device if device.type != "mps" else "cpu")
|
||||||
|
|
||||||
|
|
||||||
def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cache_max_age=None,
|
def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cache_max_age=None,
|
||||||
use_gpu=True, memory_fraction=0.8):
|
use_gpu=True, memory_fraction=0.8):
|
||||||
"""
|
"""
|
||||||
@ -44,38 +49,30 @@ def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cach
|
|||||||
"""
|
"""
|
||||||
audio_path = Path(audio_path)
|
audio_path = Path(audio_path)
|
||||||
|
|
||||||
# Check cache first if enabled
|
|
||||||
if use_cache and CACHE_AVAILABLE:
|
if use_cache and CACHE_AVAILABLE:
|
||||||
cached_data = load_from_cache(audio_path, model, "transcribe", cache_max_age)
|
cached_data = load_from_cache(audio_path, model, "transcribe", cache_max_age)
|
||||||
if cached_data:
|
if cached_data:
|
||||||
logger.info(f"Using cached transcription for {audio_path}")
|
logger.info(f"Using cached transcription for {audio_path}")
|
||||||
return cached_data.get("segments", []), cached_data.get("transcript", "")
|
return cached_data.get("segments", []), cached_data.get("transcript", "")
|
||||||
|
|
||||||
# Extract audio if the input is a video file (M4A is already audio)
|
|
||||||
video_extensions = ['.mp4', '.avi', '.mov', '.mkv']
|
video_extensions = ['.mp4', '.avi', '.mov', '.mkv']
|
||||||
if audio_path.suffix.lower() in video_extensions:
|
if audio_path.suffix.lower() in video_extensions:
|
||||||
audio_path = extract_audio(audio_path)
|
audio_path = extract_audio(audio_path)
|
||||||
|
|
||||||
# Configure GPU if available and requested
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||||
gpu_config = configure_gpu(model, memory_fraction)
|
gpu_config = configure_gpu(model, memory_fraction)
|
||||||
device = gpu_config["device"]
|
device = gpu_config["device"]
|
||||||
logger.info(f"Using device: {device} for transcription")
|
logger.info(f"Using device: {device} for transcription")
|
||||||
|
|
||||||
# Load the specified Whisper model
|
whisper_model = _load_whisper_model(model, str(device))
|
||||||
logger.info(f"Loading Whisper model: {model}")
|
|
||||||
whisper_model = whisper.load_model(model, device=device if device.type != "mps" else "cpu")
|
|
||||||
|
|
||||||
# Transcribe the audio
|
|
||||||
logger.info(f"Transcribing audio: {audio_path}")
|
logger.info(f"Transcribing audio: {audio_path}")
|
||||||
result = whisper_model.transcribe(str(audio_path))
|
result = whisper_model.transcribe(str(audio_path))
|
||||||
|
|
||||||
# Extract the full transcript and segments
|
|
||||||
transcript = result["text"]
|
transcript = result["text"]
|
||||||
segments = result["segments"]
|
segments = result["segments"]
|
||||||
|
|
||||||
# Cache the results if caching is enabled
|
|
||||||
if use_cache and CACHE_AVAILABLE:
|
if use_cache and CACHE_AVAILABLE:
|
||||||
cache_data = {
|
cache_data = {
|
||||||
"transcript": transcript,
|
"transcript": transcript,
|
||||||
|
|||||||
@ -1,41 +1,49 @@
|
|||||||
"""
|
"""
|
||||||
Translation utilities for the OBS Recording Transcriber.
|
Translation utilities for the Video Transcriber.
|
||||||
Provides functions for language detection and translation.
|
Provides functions for language detection and translation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, M2M100ForConditionalGeneration
|
from transformers import pipeline, AutoTokenizer, M2M100ForConditionalGeneration
|
||||||
import whisper
|
import whisper
|
||||||
import iso639
|
import iso639
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Try to import GPU utilities, but don't fail if not available
|
|
||||||
try:
|
try:
|
||||||
from utils.gpu_utils import get_optimal_device
|
from utils.gpu_utils import get_optimal_device
|
||||||
GPU_UTILS_AVAILABLE = True
|
GPU_UTILS_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
GPU_UTILS_AVAILABLE = False
|
GPU_UTILS_AVAILABLE = False
|
||||||
|
|
||||||
# Default models
|
|
||||||
TRANSLATION_MODEL = "facebook/m2m100_418M"
|
TRANSLATION_MODEL = "facebook/m2m100_418M"
|
||||||
LANGUAGE_DETECTION_MODEL = "papluca/xlm-roberta-base-language-detection"
|
LANGUAGE_DETECTION_MODEL = "papluca/xlm-roberta-base-language-detection"
|
||||||
|
|
||||||
# ISO language code mapping
|
|
||||||
|
@st.cache_resource
|
||||||
|
def _load_language_detector(model_name, device_int):
|
||||||
|
"""Load and cache the language detection pipeline."""
|
||||||
|
logger.info(f"Loading language detection model: {model_name}")
|
||||||
|
return pipeline("text-classification", model=model_name, device=device_int)
|
||||||
|
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def _load_translation_model(model_name, device_str):
|
||||||
|
"""Load and cache the M2M100 translation model and tokenizer."""
|
||||||
|
logger.info(f"Loading translation model: {model_name} on {device_str}")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
model = M2M100ForConditionalGeneration.from_pretrained(model_name)
|
||||||
|
device = torch.device(device_str)
|
||||||
|
model = model.to(device)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def get_language_name(code):
|
def get_language_name(code):
|
||||||
"""
|
"""Get the language name from ISO code."""
|
||||||
Get the language name from ISO code.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
code (str): ISO language code
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: Language name or original code if not found
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
return iso639.languages.get(part1=code).name
|
return iso639.languages.get(part1=code).name
|
||||||
except (KeyError, AttributeError):
|
except (KeyError, AttributeError):
|
||||||
@ -57,7 +65,6 @@ def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple: (language_code, confidence)
|
tuple: (language_code, confidence)
|
||||||
"""
|
"""
|
||||||
# Configure device
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
@ -66,25 +73,43 @@ def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True):
|
|||||||
device_arg = -1
|
device_arg = -1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Initialize the pipeline
|
classifier = _load_language_detector(model, device_arg)
|
||||||
classifier = pipeline("text-classification", model=model, device=device_arg)
|
|
||||||
|
|
||||||
# Truncate text if too long
|
|
||||||
max_length = 512
|
max_length = 512
|
||||||
if len(text) > max_length:
|
if len(text) > max_length:
|
||||||
text = text[:max_length]
|
text = text[:max_length]
|
||||||
|
|
||||||
# Detect language
|
|
||||||
result = classifier(text)[0]
|
result = classifier(text)[0]
|
||||||
language_code = result["label"]
|
return result["label"], result["score"]
|
||||||
confidence = result["score"]
|
|
||||||
|
|
||||||
return language_code, confidence
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error detecting language: {e}")
|
logger.error(f"Error detecting language: {e}")
|
||||||
return None, 0.0
|
return None, 0.0
|
||||||
|
|
||||||
|
|
||||||
|
def _translate_text_with_model(text, source_lang, target_lang, trans_model, tokenizer, device):
|
||||||
|
"""Translate text using a pre-loaded model and tokenizer."""
|
||||||
|
tokenizer.src_lang = source_lang
|
||||||
|
|
||||||
|
max_length = 512
|
||||||
|
if len(text) > max_length:
|
||||||
|
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
|
||||||
|
else:
|
||||||
|
chunks = [text]
|
||||||
|
|
||||||
|
translated_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
encoded = tokenizer(chunk, return_tensors="pt").to(device)
|
||||||
|
generated_tokens = trans_model.generate(
|
||||||
|
**encoded,
|
||||||
|
forced_bos_token_id=tokenizer.get_lang_id(target_lang),
|
||||||
|
max_length=max_length
|
||||||
|
)
|
||||||
|
translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
||||||
|
translated_chunks.append(translated_chunk)
|
||||||
|
|
||||||
|
return " ".join(translated_chunks)
|
||||||
|
|
||||||
|
|
||||||
def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_MODEL, use_gpu=True):
|
def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_MODEL, use_gpu=True):
|
||||||
"""
|
"""
|
||||||
Translate text from source language to target language.
|
Translate text from source language to target language.
|
||||||
@ -99,7 +124,6 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M
|
|||||||
Returns:
|
Returns:
|
||||||
str: Translated text
|
str: Translated text
|
||||||
"""
|
"""
|
||||||
# Auto-detect source language if not provided
|
|
||||||
if source_lang is None:
|
if source_lang is None:
|
||||||
detected_lang, confidence = detect_language(text, use_gpu=use_gpu)
|
detected_lang, confidence = detect_language(text, use_gpu=use_gpu)
|
||||||
if detected_lang and confidence > 0.5:
|
if detected_lang and confidence > 0.5:
|
||||||
@ -109,50 +133,17 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M
|
|||||||
logger.warning("Could not reliably detect language, defaulting to English")
|
logger.warning("Could not reliably detect language, defaulting to English")
|
||||||
source_lang = "en"
|
source_lang = "en"
|
||||||
|
|
||||||
# Skip translation if source and target are the same
|
|
||||||
if source_lang == target_lang:
|
if source_lang == target_lang:
|
||||||
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
|
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
|
||||||
return text
|
return text
|
||||||
|
|
||||||
# Configure device
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Load model and tokenizer
|
trans_model, tokenizer = _load_translation_model(model, str(device))
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
return _translate_text_with_model(text, source_lang, target_lang, trans_model, tokenizer, device)
|
||||||
model = M2M100ForConditionalGeneration.from_pretrained(model)
|
|
||||||
|
|
||||||
# Move model to device
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
# Prepare for translation
|
|
||||||
tokenizer.src_lang = source_lang
|
|
||||||
|
|
||||||
# Split text into manageable chunks if too long
|
|
||||||
max_length = 512
|
|
||||||
if len(text) > max_length:
|
|
||||||
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
|
|
||||||
else:
|
|
||||||
chunks = [text]
|
|
||||||
|
|
||||||
# Translate each chunk
|
|
||||||
translated_chunks = []
|
|
||||||
for chunk in chunks:
|
|
||||||
encoded = tokenizer(chunk, return_tensors="pt").to(device)
|
|
||||||
generated_tokens = model.generate(
|
|
||||||
**encoded,
|
|
||||||
forced_bos_token_id=tokenizer.get_lang_id(target_lang),
|
|
||||||
max_length=max_length
|
|
||||||
)
|
|
||||||
translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
|
||||||
translated_chunks.append(translated_chunk)
|
|
||||||
|
|
||||||
# Combine translated chunks
|
|
||||||
translated_text = " ".join(translated_chunks)
|
|
||||||
|
|
||||||
return translated_text
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error translating text: {e}")
|
logger.error(f"Error translating text: {e}")
|
||||||
return text
|
return text
|
||||||
@ -160,7 +151,7 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M
|
|||||||
|
|
||||||
def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=True):
|
def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=True):
|
||||||
"""
|
"""
|
||||||
Translate transcript segments.
|
Translate transcript segments. Loads the model once and reuses for all segments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
segments (list): List of transcript segments
|
segments (list): List of transcript segments
|
||||||
@ -174,36 +165,32 @@ def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=Tru
|
|||||||
if not segments:
|
if not segments:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# Auto-detect source language from combined text if not provided
|
|
||||||
if source_lang is None:
|
if source_lang is None:
|
||||||
combined_text = " ".join([segment["text"] for segment in segments])
|
combined_text = " ".join([segment["text"] for segment in segments])
|
||||||
detected_lang, _ = detect_language(combined_text, use_gpu=use_gpu)
|
detected_lang, _ = detect_language(combined_text, use_gpu=use_gpu)
|
||||||
source_lang = detected_lang if detected_lang else "en"
|
source_lang = detected_lang if detected_lang else "en"
|
||||||
|
|
||||||
# Skip translation if source and target are the same
|
|
||||||
if source_lang == target_lang:
|
if source_lang == target_lang:
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
try:
|
device = torch.device("cpu")
|
||||||
# Initialize translation pipeline
|
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||||
translated_segments = []
|
device = get_optimal_device()
|
||||||
|
|
||||||
# Translate each segment
|
try:
|
||||||
|
trans_model, tokenizer = _load_translation_model(TRANSLATION_MODEL, str(device))
|
||||||
|
|
||||||
|
translated_segments = []
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
translated_text = translate_text(
|
translated_text = _translate_text_with_model(
|
||||||
segment["text"],
|
segment["text"], source_lang, target_lang, trans_model, tokenizer, device
|
||||||
source_lang=source_lang,
|
|
||||||
target_lang=target_lang,
|
|
||||||
use_gpu=use_gpu
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a new segment with translated text
|
|
||||||
translated_segment = segment.copy()
|
translated_segment = segment.copy()
|
||||||
translated_segment["text"] = translated_text
|
translated_segment["text"] = translated_text
|
||||||
translated_segment["original_text"] = segment["text"]
|
translated_segment["original_text"] = segment["text"]
|
||||||
translated_segment["source_lang"] = source_lang
|
translated_segment["source_lang"] = source_lang
|
||||||
translated_segment["target_lang"] = target_lang
|
translated_segment["target_lang"] = target_lang
|
||||||
|
|
||||||
translated_segments.append(translated_segment)
|
translated_segments.append(translated_segment)
|
||||||
|
|
||||||
return translated_segments
|
return translated_segments
|
||||||
@ -227,39 +214,33 @@ def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en",
|
|||||||
Returns:
|
Returns:
|
||||||
tuple: (original_segments, translated_segments, original_transcript, translated_transcript)
|
tuple: (original_segments, translated_segments, original_transcript, translated_transcript)
|
||||||
"""
|
"""
|
||||||
|
from utils.transcription import _load_whisper_model
|
||||||
|
|
||||||
audio_path = Path(audio_path)
|
audio_path = Path(audio_path)
|
||||||
|
|
||||||
# Configure device
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||||
device = get_optimal_device()
|
device = get_optimal_device()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Step 1: Transcribe audio with Whisper
|
|
||||||
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
|
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
|
||||||
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
|
model = _load_whisper_model(whisper_model, str(device))
|
||||||
|
|
||||||
# Use Whisper's built-in language detection if requested
|
|
||||||
if detect_source:
|
if detect_source:
|
||||||
# First, detect language with Whisper
|
|
||||||
audio = whisper.load_audio(str(audio_path))
|
audio = whisper.load_audio(str(audio_path))
|
||||||
audio = whisper.pad_or_trim(audio)
|
audio = whisper.pad_or_trim(audio)
|
||||||
mel = whisper.log_mel_spectrogram(audio).to(device if device.type != "mps" else "cpu")
|
mel = whisper.log_mel_spectrogram(audio).to(device if device.type != "mps" else "cpu")
|
||||||
_, probs = model.detect_language(mel)
|
_, probs = model.detect_language(mel)
|
||||||
source_lang = max(probs, key=probs.get)
|
source_lang = max(probs, key=probs.get)
|
||||||
logger.info(f"Whisper detected language: {get_language_name(source_lang)} ({source_lang})")
|
logger.info(f"Whisper detected language: {get_language_name(source_lang)} ({source_lang})")
|
||||||
|
|
||||||
# Transcribe with detected language
|
|
||||||
result = model.transcribe(str(audio_path), language=source_lang)
|
result = model.transcribe(str(audio_path), language=source_lang)
|
||||||
else:
|
else:
|
||||||
# Transcribe without language specification
|
|
||||||
result = model.transcribe(str(audio_path))
|
result = model.transcribe(str(audio_path))
|
||||||
source_lang = result.get("language", "en")
|
source_lang = result.get("language", "en")
|
||||||
|
|
||||||
original_segments = result["segments"]
|
original_segments = result["segments"]
|
||||||
original_transcript = result["text"]
|
original_transcript = result["text"]
|
||||||
|
|
||||||
# Step 2: Translate if needed
|
|
||||||
if source_lang != target_lang:
|
if source_lang != target_lang:
|
||||||
logger.info(f"Translating from {source_lang} to {target_lang}")
|
logger.info(f"Translating from {source_lang} to {target_lang}")
|
||||||
translated_segments = translate_segments(
|
translated_segments = translate_segments(
|
||||||
@ -268,8 +249,6 @@ def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en",
|
|||||||
target_lang=target_lang,
|
target_lang=target_lang,
|
||||||
use_gpu=use_gpu
|
use_gpu=use_gpu
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create full translated transcript
|
|
||||||
translated_transcript = " ".join([segment["text"] for segment in translated_segments])
|
translated_transcript = " ".join([segment["text"] for segment in translated_segments])
|
||||||
else:
|
else:
|
||||||
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
|
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
|
||||||
|
|||||||
@ -1,8 +1,38 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
import logging
|
||||||
|
|
||||||
def validate_environment(obs_path: Path):
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_environment(obs_path: Path = None):
|
||||||
"""Validate environment and prerequisites."""
|
"""Validate environment and prerequisites."""
|
||||||
errors = []
|
errors = []
|
||||||
if not obs_path.exists():
|
|
||||||
errors.append(f"OBS directory not found: {obs_path}")
|
if obs_path and not obs_path.exists():
|
||||||
|
errors.append(f"Directory not found: {obs_path}")
|
||||||
|
|
||||||
|
if not shutil.which("ffmpeg"):
|
||||||
|
errors.append("FFmpeg is not installed or not in PATH. Install it from https://ffmpeg.org/download.html")
|
||||||
|
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_capabilities():
|
||||||
|
"""Return a dict of detected system capabilities for display."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
caps = {
|
||||||
|
"ffmpeg": shutil.which("ffmpeg") is not None,
|
||||||
|
"cuda": torch.cuda.is_available(),
|
||||||
|
"mps": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(),
|
||||||
|
"gpu_name": None,
|
||||||
|
"gpu_memory": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if caps["cuda"] and torch.cuda.device_count() > 0:
|
||||||
|
props = torch.cuda.get_device_properties(0)
|
||||||
|
caps["gpu_name"] = props.name
|
||||||
|
caps["gpu_memory"] = props.total_memory
|
||||||
|
|
||||||
|
return caps
|
||||||
|
|||||||
Reference in New Issue
Block a user