From 70c5d3241349f5af9eebdbab58df8a83155d38b7 Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 18 Feb 2026 10:26:09 -0500 Subject: [PATCH] feat: Add streaming Ollama support, model caching, and UI improvements - Add streaming summarization via Ollama API (stream_summarize_with_ollama) - Cache ML models with @st.cache_resource (diarization, NER, translation, Whisper) - Add temp file cleanup for extracted audio - Add system capabilities detection (FFmpeg, GPU info) - Add get_video_duration utility - Improve validation with FFmpeg check - Rewrite app.py with streaming support and UI enhancements - Clean up redundant comments and unused imports across all utils --- app.py | 1168 ++++++++++++++++++++--------------- utils/audio_processing.py | 44 +- utils/diarization.py | 38 +- utils/gpu_utils.py | 7 +- utils/keyword_extraction.py | 18 +- utils/ollama_integration.py | 146 +++-- utils/summarization.py | 74 ++- utils/transcription.py | 25 +- utils/translation.py | 149 ++--- utils/validation.py | 36 +- 10 files changed, 998 insertions(+), 707 deletions(-) diff --git a/app.py b/app.py index b53e39f..4c28d7f 100644 --- a/app.py +++ b/app.py @@ -1,544 +1,750 @@ import streamlit as st -from utils.audio_processing import extract_audio +from utils.audio_processing import extract_audio, cleanup_temp_audio, get_video_duration from utils.transcription import transcribe_audio from utils.summarization import summarize_text -from utils.validation import validate_environment +from utils.validation import validate_environment, get_system_capabilities from utils.export import export_transcript from pathlib import Path import os import logging import humanize -from datetime import timedelta +import time +import tempfile -# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Try to import Ollama integration, but don't fail if it's not available try: - from utils.ollama_integration import check_ollama_available, list_available_models, chunk_and_summarize + from utils.ollama_integration import ( + check_ollama_available, list_available_models, + chunk_and_summarize, stream_chunk_and_summarize + ) OLLAMA_AVAILABLE = check_ollama_available() except ImportError: OLLAMA_AVAILABLE = False -# Try to import GPU utilities, but don't fail if not available try: - from utils.gpu_utils import get_gpu_info, configure_gpu + from utils.gpu_utils import get_gpu_info, configure_gpu, optimize_for_inference GPU_UTILS_AVAILABLE = True + optimize_for_inference() except ImportError: GPU_UTILS_AVAILABLE = False -# Try to import caching utilities, but don't fail if not available try: from utils.cache import get_cache_size, clear_cache CACHE_AVAILABLE = True except ImportError: CACHE_AVAILABLE = False -# Try to import diarization utilities, but don't fail if not available try: from utils.diarization import transcribe_with_diarization DIARIZATION_AVAILABLE = True except ImportError: DIARIZATION_AVAILABLE = False -# Try to import translation utilities, but don't fail if not available try: from utils.translation import transcribe_and_translate, get_language_name TRANSLATION_AVAILABLE = True except ImportError: TRANSLATION_AVAILABLE = False -# Try to import keyword extraction utilities, but don't fail if not available try: - from utils.keyword_extraction import extract_keywords_from_transcript, generate_keyword_index, generate_interactive_transcript + from utils.keyword_extraction import ( + extract_keywords_from_transcript, generate_keyword_index, + generate_interactive_transcript + ) KEYWORD_EXTRACTION_AVAILABLE = True except ImportError: KEYWORD_EXTRACTION_AVAILABLE = False + +def init_session_state(): + """Initialize session state with defaults for persistence across reruns.""" + defaults = { + "transcription_model": "base", + "summarization_method": "Hugging Face (Online)", + "use_diarization": False, + "use_translation": False, + "use_keywords": False, + "use_gpu": GPU_UTILS_AVAILABLE, + "use_cache": CACHE_AVAILABLE, + "memory_fraction": 0.8, + "export_formats": ["TXT"], + "compress_exports": False, + "base_folder": str(Path.home()), + "recursive_search": False, + "results": None, + "processing": False, + } + for key, val in defaults.items(): + if key not in st.session_state: + st.session_state[key] = val + + +def format_duration(seconds): + """Format seconds into MM:SS or HH:MM:SS.""" + if seconds is None: + return "Unknown" + m, s = divmod(int(seconds), 60) + h, m = divmod(m, 60) + if h > 0: + return f"{h}:{m:02d}:{s:02d}" + return f"{m}:{s:02d}" + + +def save_uploaded_file(uploaded_file): + """Save an uploaded file to a temp directory and return its path.""" + temp_dir = tempfile.mkdtemp(prefix="vt_upload_") + file_path = Path(temp_dir) / uploaded_file.name + with open(file_path, "wb") as f: + f.write(uploaded_file.getbuffer()) + return file_path + + +def render_sidebar(): + """Render the sidebar with collapsible settings groups.""" + st.sidebar.markdown("### Settings") + + # -- Model Settings (expanded by default) -- + with st.sidebar.expander("Model Settings", expanded=True): + st.session_state.transcription_model = st.selectbox( + "Whisper Model", + ["tiny", "base", "small", "medium", "large"], + index=["tiny", "base", "small", "medium", "large"].index( + st.session_state.transcription_model + ), + help="Larger models are more accurate but slower.", + key="sb_whisper_model", + ) + + summarization_options = ( + ["Hugging Face (Online)", "Ollama (Local)"] + if OLLAMA_AVAILABLE + else ["Hugging Face (Online)"] + ) + st.session_state.summarization_method = st.selectbox( + "Summarization", + summarization_options, + index=0, + help="Ollama runs locally but requires installation.", + key="sb_summarization", + ) + + ollama_model = None + if OLLAMA_AVAILABLE and st.session_state.summarization_method == "Ollama (Local)": + available_models = list_available_models() + if available_models: + ollama_model = st.selectbox( + "Ollama Model", + available_models, + index=0, + key="sb_ollama_model", + ) + else: + st.warning("No Ollama models found. Run `ollama pull `.") + + # -- Advanced Features (collapsed) -- + with st.sidebar.expander("Advanced Features"): + st.session_state.use_diarization = st.checkbox( + "Speaker Diarization", + value=st.session_state.use_diarization, + disabled=not DIARIZATION_AVAILABLE, + help="Identify different speakers in the recording.", + key="sb_diarization", + ) + + hf_token = None + num_speakers = 2 + if st.session_state.use_diarization and DIARIZATION_AVAILABLE: + hf_token = st.text_input( + "HuggingFace Token", + type="password", + help="Required for diarization. Get token at huggingface.co/settings/tokens", + key="sb_hf_token", + ) + num_speakers = st.number_input( + "Number of Speakers", min_value=1, max_value=10, value=2, + key="sb_num_speakers", + ) + + st.session_state.use_translation = st.checkbox( + "Translation", + value=st.session_state.use_translation, + disabled=not TRANSLATION_AVAILABLE, + help="Translate the transcript to another language.", + key="sb_translation", + ) + + target_lang = None + if st.session_state.use_translation and TRANSLATION_AVAILABLE: + target_lang = st.selectbox( + "Target Language", + ["en", "es", "fr", "de", "it", "pt", "nl", "ru", "zh", "ja", "ko", "ar"], + format_func=lambda x: f"{get_language_name(x)} ({x})", + key="sb_target_lang", + ) + + st.session_state.use_keywords = st.checkbox( + "Keyword Extraction", + value=st.session_state.use_keywords, + disabled=not KEYWORD_EXTRACTION_AVAILABLE, + help="Extract keywords and link them to timestamps.", + key="sb_keywords", + ) + + max_keywords = 15 + if st.session_state.use_keywords and KEYWORD_EXTRACTION_AVAILABLE: + max_keywords = st.slider( + "Max Keywords", min_value=5, max_value=30, value=15, + key="sb_max_keywords", + ) + + # -- Performance (collapsed) -- + with st.sidebar.expander("Performance"): + st.session_state.use_gpu = st.checkbox( + "GPU Acceleration", + value=st.session_state.use_gpu, + disabled=not GPU_UTILS_AVAILABLE, + help="Use GPU for faster processing if available.", + key="sb_gpu", + ) + + if GPU_UTILS_AVAILABLE and st.session_state.use_gpu: + gpu_info = get_gpu_info() + if gpu_info["cuda_available"]: + gpu_devices = [ + f"{d['name']} ({humanize.naturalsize(d['total_memory'])})" + for d in gpu_info["cuda_devices"] + ] + st.info(f"GPU: {', '.join(gpu_devices)}") + elif gpu_info["mps_available"]: + st.info("Apple Silicon GPU (MPS)") + else: + st.warning("No GPU detected. Using CPU.") + + st.session_state.memory_fraction = st.slider( + "GPU Memory %", + min_value=0.1, max_value=1.0, + value=st.session_state.memory_fraction, step=0.1, + disabled=not (GPU_UTILS_AVAILABLE and st.session_state.use_gpu), + key="sb_memory", + ) + + st.session_state.use_cache = st.checkbox( + "Cache Results", + value=st.session_state.use_cache, + disabled=not CACHE_AVAILABLE, + help="Cache transcriptions to avoid reprocessing.", + key="sb_cache", + ) + + if CACHE_AVAILABLE and st.session_state.use_cache: + cache_size, cache_files = get_cache_size() + if cache_size > 0: + st.caption(f"Cache: {humanize.naturalsize(cache_size)} ({cache_files} files)") + if st.button("Clear Cache", key="sb_clear_cache"): + cleared = clear_cache() + st.success(f"Cleared {cleared} files") + + # -- Export (collapsed) -- + with st.sidebar.expander("Export Options"): + st.session_state.export_formats = st.multiselect( + "Formats", + ["TXT", "SRT", "VTT", "ASS"], + default=st.session_state.export_formats, + key="sb_export_formats", + ) + + st.session_state.compress_exports = st.checkbox( + "Compress Exports", + value=st.session_state.compress_exports, + key="sb_compress", + ) + + compression_type = None + if st.session_state.compress_exports: + compression_type = st.radio( + "Compression", ["gzip", "zip"], index=0, + key="sb_compression_type", + ) + + ass_style = None + if "ASS" in st.session_state.export_formats: + if st.checkbox("Customize ASS Style", value=False, key="sb_ass_custom"): + ass_style = { + "fontname": st.selectbox( + "Font", + ["Arial", "Helvetica", "Times New Roman", "Courier New"], + key="sb_ass_font", + ), + "fontsize": str(st.slider("Font Size", 12, 72, 48, key="sb_ass_size")), + "alignment": st.selectbox( + "Alignment", + ["2 (Bottom Center)", "1 (Bottom Left)", "3 (Bottom Right)", "8 (Top Center)"], + key="sb_ass_align", + ).split()[0], + "bold": "-1" if st.checkbox("Bold", value=True, key="sb_ass_bold") else "0", + "italic": "-1" if st.checkbox("Italic", value=False, key="sb_ass_italic") else "0", + } + + # -- System Info (collapsed) -- + with st.sidebar.expander("System Info"): + caps = get_system_capabilities() + st.markdown(f"- **FFmpeg:** {'Installed' if caps['ffmpeg'] else 'Not found'}") + st.markdown(f"- **CUDA:** {'Available' if caps['cuda'] else 'Not available'}") + st.markdown(f"- **MPS:** {'Available' if caps['mps'] else 'Not available'}") + if caps["gpu_name"]: + st.markdown(f"- **GPU:** {caps['gpu_name']} ({humanize.naturalsize(caps['gpu_memory'])})") + st.markdown(f"- **Ollama:** {'Connected' if OLLAMA_AVAILABLE else 'Not available'}") + st.markdown(f"- **Diarization:** {'Ready' if DIARIZATION_AVAILABLE else 'Not available'}") + + return { + "ollama_model": ollama_model, + "hf_token": hf_token, + "num_speakers": num_speakers, + "target_lang": target_lang, + "max_keywords": max_keywords, + "compression_type": compression_type, + "ass_style": ass_style, + } + + +def render_file_input(): + """Render the file input section with upload + folder browse tabs.""" + upload_tab, browse_tab = st.tabs(["Upload Files", "Browse Folder"]) + + selected_file = None + + with upload_tab: + uploaded_files = st.file_uploader( + "Drag and drop your recordings here", + type=["mp4", "avi", "mov", "mkv", "m4a"], + accept_multiple_files=True, + key="file_uploader", + ) + if uploaded_files: + if len(uploaded_files) == 1: + selected_file = ("upload", uploaded_files[0]) + else: + file_names = [f.name for f in uploaded_files] + chosen = st.selectbox("Choose a recording", file_names, key="upload_select") + idx = file_names.index(chosen) + selected_file = ("upload", uploaded_files[idx]) + + with browse_tab: + col1, col2 = st.columns([4, 1]) + with col1: + st.session_state.base_folder = st.text_input( + "Folder path", + value=st.session_state.base_folder, + key="folder_input", + ) + with col2: + st.session_state.recursive_search = st.checkbox( + "Recursive", value=st.session_state.recursive_search, + key="recursive_check", + ) + + base_path = Path(st.session_state.base_folder) + env_errors = validate_environment(base_path) + if env_errors: + for error in env_errors: + st.warning(error) + else: + extensions = ["*.mp4", "*.avi", "*.mov", "*.mkv", "*.m4a"] + recordings = [] + glob_fn = base_path.rglob if st.session_state.recursive_search else base_path.glob + for ext in extensions: + recordings.extend(glob_fn(ext)) + + if recordings: + chosen = st.selectbox( + "Choose a recording", + recordings, + format_func=lambda p: str(p.relative_to(base_path)) if str(p).startswith(str(base_path)) else str(p), + key="folder_select", + ) + selected_file = ("path", chosen) + else: + st.info("No recordings found. Supported formats: MP4, AVI, MOV, MKV, M4A") + + return selected_file + + +def render_file_preview(selected_file): + """Show file metadata before processing.""" + if selected_file is None: + return + + source_type, file_ref = selected_file + + if source_type == "upload": + file_size = file_ref.size + file_name = file_ref.name + duration = None + else: + file_size = file_ref.stat().st_size + file_name = file_ref.name + duration = get_video_duration(file_ref) + + cols = st.columns(4) + cols[0].metric("File", file_name) + cols[1].metric("Size", humanize.naturalsize(file_size)) + cols[2].metric("Format", Path(file_name).suffix.upper().lstrip(".")) + cols[3].metric("Duration", format_duration(duration)) + + +def resolve_file_path(selected_file): + """Convert the selected file reference to an actual file path.""" + source_type, file_ref = selected_file + if source_type == "upload": + return save_uploaded_file(file_ref) + return file_ref + + +def process_recording(file_path, sidebar_opts): + """Run the full processing pipeline with granular status updates.""" + results = {} + start_time = time.time() + + with st.status("Processing recording...", expanded=True) as status: + + # Step 1: Transcription + st.write(f"Transcribing with Whisper ({st.session_state.transcription_model} model)...") + t0 = time.time() + + if st.session_state.use_diarization and DIARIZATION_AVAILABLE and sidebar_opts["hf_token"]: + num_spk = int(sidebar_opts["num_speakers"]) if sidebar_opts["num_speakers"] > 0 else None + segments, transcript = transcribe_with_diarization( + file_path, + whisper_model=st.session_state.transcription_model, + num_speakers=num_spk, + use_gpu=st.session_state.use_gpu, + hf_token=sidebar_opts["hf_token"], + ) + results["diarized"] = True + elif st.session_state.use_translation and TRANSLATION_AVAILABLE: + st.write("Transcribing and translating...") + orig_seg, trans_seg, orig_text, trans_text = transcribe_and_translate( + file_path, + whisper_model=st.session_state.transcription_model, + target_lang=sidebar_opts["target_lang"], + use_gpu=st.session_state.use_gpu, + ) + segments = trans_seg + transcript = trans_text + results["original_text"] = orig_text + results["original_segments"] = orig_seg + results["translated"] = True + else: + segments, transcript = transcribe_audio( + file_path, + model=st.session_state.transcription_model, + use_cache=st.session_state.use_cache, + use_gpu=st.session_state.use_gpu, + memory_fraction=st.session_state.memory_fraction, + ) + + transcription_time = time.time() - t0 + st.write(f"Transcription complete ({transcription_time:.1f}s)") + + if not transcript: + status.update(label="Processing failed", state="error") + return None + + results["segments"] = segments + results["transcript"] = transcript + + # Step 2: Keyword extraction + if st.session_state.use_keywords and KEYWORD_EXTRACTION_AVAILABLE: + st.write("Extracting keywords...") + t0 = time.time() + kw_ts, ent_ts = extract_keywords_from_transcript( + transcript, segments, + max_keywords=sidebar_opts["max_keywords"], + use_gpu=st.session_state.use_gpu, + ) + results["keyword_timestamps"] = kw_ts + results["entity_timestamps"] = ent_ts + results["keyword_index"] = generate_keyword_index(kw_ts, ent_ts) + results["interactive_transcript"] = generate_interactive_transcript(segments, kw_ts, ent_ts) + st.write(f"Keywords extracted ({time.time() - t0:.1f}s)") + + # Step 3: Summarization + st.write("Generating summary...") + t0 = time.time() + + use_ollama = ( + OLLAMA_AVAILABLE + and st.session_state.summarization_method == "Ollama (Local)" + and sidebar_opts["ollama_model"] + ) + + if use_ollama: + summary = chunk_and_summarize(transcript, model=sidebar_opts["ollama_model"]) + if not summary: + st.write("Ollama failed, falling back to Hugging Face...") + summary = summarize_text( + transcript, + use_gpu=st.session_state.use_gpu, + memory_fraction=st.session_state.memory_fraction, + ) + results["ollama_streaming"] = True + else: + summary = summarize_text( + transcript, + use_gpu=st.session_state.use_gpu, + memory_fraction=st.session_state.memory_fraction, + ) + + results["summary"] = summary + st.write(f"Summary generated ({time.time() - t0:.1f}s)") + + # Cleanup temp audio files + cleanup_temp_audio() + + total_time = time.time() - start_time + results["processing_time"] = total_time + results["word_count"] = len(transcript.split()) + + status.update(label=f"Complete in {total_time:.1f}s", state="complete") + + return results + + +def render_results(results, sidebar_opts): + """Display processing results with metrics, tabs, and export options.""" + if results is None: + st.error("Processing failed. Check logs for details.") + return + + # Metric cards + st.markdown("---") + metric_cols = st.columns(4) + metric_cols[0].metric("Words", f"{results['word_count']:,}") + metric_cols[1].metric("Segments", str(len(results.get("segments", [])))) + metric_cols[2].metric("Processing Time", f"{results['processing_time']:.1f}s") + + if results.get("diarized"): + speakers = set(seg.get("speaker", "UNKNOWN") for seg in results["segments"]) + metric_cols[3].metric("Speakers", str(len(speakers))) + elif results.get("translated"): + metric_cols[3].metric("Translated", "Yes") + else: + metric_cols[3].metric("Model", st.session_state.transcription_model.capitalize()) + + # Results tabs + tab_names = ["Summary", "Transcript", "Advanced"] + tab1, tab2, tab3 = st.tabs(tab_names) + + with tab1: + st.subheader("Summary") + if results.get("ollama_streaming") and OLLAMA_AVAILABLE and sidebar_opts["ollama_model"]: + st.write(results["summary"]) + with st.expander("Re-generate with streaming"): + if st.button("Stream Summary", key="stream_btn"): + st.write_stream( + stream_chunk_and_summarize( + results["transcript"], + model=sidebar_opts["ollama_model"], + ) + ) + else: + st.write(results["summary"]) + + if results.get("original_text"): + with st.expander("Original Language Summary"): + original_summary = summarize_text( + results["original_text"], + use_gpu=st.session_state.use_gpu, + memory_fraction=st.session_state.memory_fraction, + ) + st.write(original_summary) + + with tab2: + st.subheader("Full Transcript") + + if results.get("interactive_transcript"): + st.markdown(results["interactive_transcript"], unsafe_allow_html=True) + else: + st.markdown( + f"
{_format_segments_html(results['segments'])}
", + unsafe_allow_html=True, + ) + + st.download_button( + "Copy Transcript (Download TXT)", + data=results["transcript"], + file_name="transcript.txt", + mime="text/plain", + key="copy_transcript", + ) + + if results.get("original_text"): + with st.expander("Original Language Transcript"): + st.text(results["original_text"]) + + with tab3: + if results.get("keyword_index"): + st.subheader("Keyword Index") + st.markdown(results["keyword_index"]) + + if results.get("diarized"): + st.subheader("Speaker Information") + speakers = set(seg.get("speaker", "UNKNOWN") for seg in results["segments"]) + st.write(f"Detected {len(speakers)} speakers: {', '.join(speakers)}") + + speaker_words = {} + for seg in results["segments"]: + spk = seg.get("speaker", "UNKNOWN") + speaker_words[spk] = speaker_words.get(spk, 0) + len(seg["text"].split()) + + for spk, words in speaker_words.items(): + st.write(f"- **{spk}**: {words} words") + + # Export section + export_formats = st.session_state.export_formats + if export_formats: + st.markdown("---") + st.subheader("Export") + export_cols = st.columns(len(export_formats)) + + output_base = Path(results.get("file_name", "transcript")).stem + + for i, fmt in enumerate(export_formats): + with export_cols[i]: + if fmt == "TXT": + st.download_button( + label=f"Download {fmt}", + data=results["transcript"], + file_name=f"{output_base}_transcript.txt", + mime="text/plain", + key=f"dl_{fmt}", + ) + elif fmt in ["SRT", "VTT", "ASS"]: + output_path = export_transcript( + results["transcript"], + output_base, + fmt.lower(), + segments=results["segments"], + compress=st.session_state.compress_exports, + compression_type=sidebar_opts["compression_type"], + style=sidebar_opts["ass_style"] if fmt == "ASS" else None, + ) + + with open(output_path, "rb") as f: + content = f.read() + + file_ext = f".{fmt.lower()}" + if st.session_state.compress_exports: + file_ext += ".gz" if sidebar_opts["compression_type"] == "gzip" else ".zip" + + st.download_button( + label=f"Download {fmt}", + data=content, + file_name=f"{output_base}{file_ext}", + mime="application/octet-stream", + key=f"dl_{fmt}", + ) + + try: + os.remove(output_path) + except OSError: + pass + + +def _format_segments_html(segments): + """Format transcript segments as HTML with timestamps.""" + if not segments: + return "

No segments available.

" + + lines = [] + for seg in segments: + start = seg.get("start", 0) + ts = f"{int(start // 60):02d}:{int(start % 60):02d}" + speaker = seg.get("speaker", "") + speaker_html = f"[{speaker}] " if speaker else "" + text = seg.get("text", "").strip() + lines.append( + f"

" + f"{ts}" + f"{speaker_html}{text}

" + ) + return "\n".join(lines) + + def main(): - # Set page configuration st.set_page_config( - page_title="OBS Recording Transcriber", - page_icon="🎥", + page_title="Video Transcriber", + page_icon="🎬", layout="wide", - initial_sidebar_state="expanded" + initial_sidebar_state="expanded", ) - - # Custom CSS for better UI + st.markdown(""" """, unsafe_allow_html=True) - - st.title("🎥 OBS Recording Transcriber") - st.caption("Process your OBS recordings with AI transcription and summarization") - # Sidebar configuration - st.sidebar.header("Settings") - - # Allow the user to select a base folder - base_folder = st.sidebar.text_input( - "Enter the base folder path:", - value=str(Path.home()) - ) - - base_path = Path(base_folder) + init_session_state() - # Model selection - st.sidebar.subheader("Model Settings") - - # Transcription model selection - transcription_model = st.sidebar.selectbox( - "Transcription Model", - ["tiny", "base", "small", "medium", "large"], - index=1, - help="Select the Whisper model size. Larger models are more accurate but slower." - ) - - # Summarization model selection - summarization_options = ["Hugging Face (Online)", "Ollama (Local)"] if OLLAMA_AVAILABLE else ["Hugging Face (Online)"] - summarization_method = st.sidebar.selectbox( - "Summarization Method", - summarization_options, - index=0, - help="Select the summarization method. Ollama runs locally but requires installation." - ) - - # If Ollama is selected, show model selection - ollama_model = None - if OLLAMA_AVAILABLE and summarization_method == "Ollama (Local)": - available_models = list_available_models() - if available_models: - ollama_model = st.sidebar.selectbox( - "Ollama Model", - available_models, - index=0 if "llama3" in available_models else 0, - help="Select the Ollama model to use for summarization." - ) - else: - st.sidebar.warning("No Ollama models found. Please install models using 'ollama pull model_name'.") - - # Advanced features - st.sidebar.subheader("Advanced Features") - - # Speaker diarization - use_diarization = st.sidebar.checkbox( - "Speaker Diarization", - value=False, - disabled=not DIARIZATION_AVAILABLE, - help="Identify different speakers in the recording." - ) - - # Show HF token input if diarization is enabled - hf_token = None - if use_diarization and DIARIZATION_AVAILABLE: - hf_token = st.sidebar.text_input( - "HuggingFace Token", - type="password", - help="Required for speaker diarization. Get your token at huggingface.co/settings/tokens" - ) - - num_speakers = st.sidebar.number_input( - "Number of Speakers", - min_value=1, - max_value=10, - value=2, - help="Specify the number of speakers if known, or leave at default for auto-detection." - ) - - # Translation - use_translation = st.sidebar.checkbox( - "Translation", - value=False, - disabled=not TRANSLATION_AVAILABLE, - help="Translate the transcript to another language." - ) - - # Target language selection if translation is enabled - target_lang = None - if use_translation and TRANSLATION_AVAILABLE: - target_lang = st.sidebar.selectbox( - "Target Language", - ["en", "es", "fr", "de", "it", "pt", "nl", "ru", "zh", "ja", "ko", "ar"], - format_func=lambda x: f"{get_language_name(x)} ({x})", - help="Select the language to translate to." - ) - - # Keyword extraction - use_keywords = st.sidebar.checkbox( - "Keyword Extraction", - value=False, - disabled=not KEYWORD_EXTRACTION_AVAILABLE, - help="Extract keywords and link them to timestamps." - ) - - if use_keywords and KEYWORD_EXTRACTION_AVAILABLE: - max_keywords = st.sidebar.slider( - "Max Keywords", - min_value=5, - max_value=30, - value=15, - help="Maximum number of keywords to extract." - ) - - # Performance settings - st.sidebar.subheader("Performance Settings") - - # GPU acceleration - use_gpu = st.sidebar.checkbox( - "Use GPU Acceleration", - value=True if GPU_UTILS_AVAILABLE else False, - disabled=not GPU_UTILS_AVAILABLE, - help="Use GPU for faster processing if available." - ) - - # Show GPU info if available - if GPU_UTILS_AVAILABLE and use_gpu: - gpu_info = get_gpu_info() - if gpu_info["cuda_available"]: - gpu_devices = [f"{d['name']} ({humanize.naturalsize(d['total_memory'])})" for d in gpu_info["cuda_devices"]] - st.sidebar.info(f"GPU(s) available: {', '.join(gpu_devices)}") - elif gpu_info["mps_available"]: - st.sidebar.info("Apple Silicon GPU (MPS) available") - else: - st.sidebar.warning("No GPU detected. Using CPU.") - - # Memory usage - memory_fraction = st.sidebar.slider( - "GPU Memory Usage", - min_value=0.1, - max_value=1.0, - value=0.8, - step=0.1, - disabled=not (GPU_UTILS_AVAILABLE and use_gpu), - help="Fraction of GPU memory to use. Lower if you encounter out-of-memory errors." - ) - - # Caching options - use_cache = st.sidebar.checkbox( - "Use Caching", - value=True if CACHE_AVAILABLE else False, - disabled=not CACHE_AVAILABLE, - help="Cache transcription results to avoid reprocessing the same files." - ) - - # Cache management - if CACHE_AVAILABLE and use_cache: - cache_size, cache_files = get_cache_size() - if cache_size > 0: - st.sidebar.info(f"Cache: {humanize.naturalsize(cache_size)} ({cache_files} files)") - if st.sidebar.button("Clear Cache"): - cleared = clear_cache() - st.sidebar.success(f"Cleared {cleared} cache files") - - # Export options - st.sidebar.subheader("Export Options") - export_format = st.sidebar.multiselect( - "Export Formats", - ["TXT", "SRT", "VTT", "ASS"], - default=["TXT"], - help="Select the formats to export the transcript." - ) - - # Compression options - compress_exports = st.sidebar.checkbox( - "Compress Exports", - value=False, - help="Compress exported files to save space." - ) - - if compress_exports: - compression_type = st.sidebar.radio( - "Compression Format", - ["gzip", "zip"], - index=0, - help="Select the compression format for exported files." - ) - else: - compression_type = None - - # ASS subtitle styling - if "ASS" in export_format: - st.sidebar.subheader("ASS Subtitle Styling") - show_style_options = st.sidebar.checkbox("Customize ASS Style", value=False) - - if show_style_options: - ass_style = {} - ass_style["fontname"] = st.sidebar.selectbox( - "Font", - ["Arial", "Helvetica", "Times New Roman", "Courier New", "Comic Sans MS"], - index=0 - ) - ass_style["fontsize"] = st.sidebar.slider("Font Size", 12, 72, 48) - ass_style["alignment"] = st.sidebar.selectbox( - "Alignment", - ["2 (Bottom Center)", "1 (Bottom Left)", "3 (Bottom Right)", "8 (Top Center)"], - index=0 - ).split()[0] # Extract just the number - ass_style["bold"] = "-1" if st.sidebar.checkbox("Bold", value=True) else "0" - ass_style["italic"] = "-1" if st.sidebar.checkbox("Italic", value=False) else "0" - else: - ass_style = None + st.title("Video Transcriber") + st.caption("AI-powered transcription, summarization, and analysis for video and audio recordings") - # Validate environment - env_errors = validate_environment(base_path) - if env_errors: - st.error("## Environment Issues") - for error in env_errors: - st.markdown(f"- {error}") - return + sidebar_opts = render_sidebar() - # File selection - support multiple video and audio formats - supported_extensions = ["*.mp4", "*.avi", "*.mov", "*.mkv", "*.m4a"] - recordings = [] - for extension in supported_extensions: - recordings.extend(base_path.glob(extension)) - - if not recordings: - st.warning(f"📂 No recordings found in the folder: {base_folder}!") - st.info("💡 Supported formats: MP4, AVI, MOV, MKV, M4A") - return + # FFmpeg check + ffmpeg_errors = validate_environment() + if ffmpeg_errors: + for err in ffmpeg_errors: + st.warning(err) - selected_file = st.selectbox("Choose a recording", recordings) + selected_file = render_file_input() + + if selected_file: + render_file_preview(selected_file) + + st.markdown("") + if st.button("Start Processing", type="primary", use_container_width=True): + file_path = resolve_file_path(selected_file) + + results = process_recording(file_path, sidebar_opts) + + if results: + source_type, file_ref = selected_file + results["file_name"] = file_ref.name if source_type == "upload" else file_ref.name + st.session_state.results = results + st.toast("Processing complete!", icon="✅") + + # Clean up uploaded temp files + if selected_file[0] == "upload": + try: + os.remove(file_path) + os.rmdir(file_path.parent) + except OSError: + pass + + # Show persisted results from session state + if st.session_state.results: + render_results(st.session_state.results, sidebar_opts) - # Process button with spinner - if st.button("🚀 Start Processing"): - # Create a progress bar - progress_bar = st.progress(0) - status_text = st.empty() - - try: - # Update progress - status_text.text("Extracting audio...") - progress_bar.progress(10) - - # Process based on selected features - if use_diarization and DIARIZATION_AVAILABLE and hf_token: - # Transcribe with speaker diarization - status_text.text("Transcribing with speaker diarization...") - num_speakers_arg = int(num_speakers) if num_speakers > 0 else None - diarized_segments, diarized_transcript = transcribe_with_diarization( - selected_file, - whisper_model=transcription_model, - num_speakers=num_speakers_arg, - use_gpu=use_gpu, - hf_token=hf_token - ) - segments = diarized_segments - transcript = diarized_transcript - elif use_translation and TRANSLATION_AVAILABLE: - # Transcribe and translate - status_text.text("Transcribing and translating...") - original_segments, translated_segments, original_transcript, translated_transcript = transcribe_and_translate( - selected_file, - whisper_model=transcription_model, - target_lang=target_lang, - use_gpu=use_gpu - ) - segments = translated_segments - transcript = translated_transcript - # Store original for display - original_text = original_transcript - else: - # Standard transcription - status_text.text("Transcribing audio...") - segments, transcript = transcribe_audio( - selected_file, - model=transcription_model, - use_cache=use_cache, - use_gpu=use_gpu, - memory_fraction=memory_fraction - ) - - progress_bar.progress(50) - - if transcript: - # Extract keywords if requested - keyword_timestamps = None - entity_timestamps = None - if use_keywords and KEYWORD_EXTRACTION_AVAILABLE: - status_text.text("Extracting keywords...") - keyword_timestamps, entity_timestamps = extract_keywords_from_transcript( - transcript, - segments, - max_keywords=max_keywords, - use_gpu=use_gpu - ) - - # Generate keyword index - keyword_index = generate_keyword_index(keyword_timestamps, entity_timestamps) - - # Generate interactive transcript - interactive_transcript = generate_interactive_transcript( - segments, - keyword_timestamps, - entity_timestamps - ) - - # Generate summary based on selected method - status_text.text("Generating summary...") - if OLLAMA_AVAILABLE and summarization_method == "Ollama (Local)" and ollama_model: - summary = chunk_and_summarize(transcript, model=ollama_model) - if not summary: - st.warning("Ollama summarization failed. Falling back to Hugging Face.") - summary = summarize_text( - transcript, - use_gpu=use_gpu, - memory_fraction=memory_fraction - ) - else: - summary = summarize_text( - transcript, - use_gpu=use_gpu, - memory_fraction=memory_fraction - ) - - progress_bar.progress(80) - status_text.text("Preparing results...") - - # Display results in tabs - tab1, tab2, tab3 = st.tabs(["Summary", "Transcript", "Advanced"]) - - with tab1: - st.subheader("🖍 Summary") - st.write(summary) - - # If translation was used, show original language - if use_translation and TRANSLATION_AVAILABLE and 'original_text' in locals(): - with st.expander("Original Language Summary"): - original_summary = summarize_text( - original_text, - use_gpu=use_gpu, - memory_fraction=memory_fraction - ) - st.write(original_summary) - - with tab2: - st.subheader("📜 Full Transcript") - - # Show interactive transcript if keywords were extracted - if use_keywords and KEYWORD_EXTRACTION_AVAILABLE and 'interactive_transcript' in locals(): - st.markdown(interactive_transcript, unsafe_allow_html=True) - else: - st.text(transcript) - - # If translation was used, show original language - if use_translation and TRANSLATION_AVAILABLE and 'original_text' in locals(): - with st.expander("Original Language Transcript"): - st.text(original_text) - - with tab3: - # Show keyword index if available - if use_keywords and KEYWORD_EXTRACTION_AVAILABLE and 'keyword_index' in locals(): - st.subheader("🔑 Keyword Index") - st.markdown(keyword_index) - - # Show speaker information if available - if use_diarization and DIARIZATION_AVAILABLE: - st.subheader("🎙️ Speaker Information") - speakers = set(segment.get('speaker', 'UNKNOWN') for segment in segments) - st.write(f"Detected {len(speakers)} speakers: {', '.join(speakers)}") - - # Count words per speaker - speaker_words = {} - for segment in segments: - speaker = segment.get('speaker', 'UNKNOWN') - words = len(segment['text'].split()) - if speaker in speaker_words: - speaker_words[speaker] += words - else: - speaker_words[speaker] = words - - # Display speaker statistics - st.write("### Speaker Statistics") - for speaker, words in speaker_words.items(): - st.write(f"- **{speaker}**: {words} words") - - # Export options - st.subheader("💾 Export Options") - export_cols = st.columns(len(export_format)) - - output_base = Path(selected_file).stem - - for i, format_type in enumerate(export_format): - with export_cols[i]: - if format_type == "TXT": - st.download_button( - label=f"Download {format_type}", - data=transcript, - file_name=f"{output_base}_transcript.txt", - mime="text/plain" - ) - elif format_type in ["SRT", "VTT", "ASS"]: - # Export to subtitle format - output_path = export_transcript( - transcript, - output_base, - format_type.lower(), - segments=segments, - compress=compress_exports, - compression_type=compression_type, - style=ass_style if format_type == "ASS" and ass_style else None - ) - - # Read the exported file for download - with open(output_path, 'rb') as f: - subtitle_content = f.read() - - # Determine file extension - file_ext = f".{format_type.lower()}" - if compress_exports: - file_ext += ".gz" if compression_type == "gzip" else ".zip" - - st.download_button( - label=f"Download {format_type}", - data=subtitle_content, - file_name=f"{output_base}{file_ext}", - mime="application/octet-stream" - ) - - # Clean up the temporary file - os.remove(output_path) - - # Complete progress - progress_bar.progress(100) - status_text.text("Processing complete!") - else: - st.error("❌ Failed to process recording") - except Exception as e: - st.error(f"An error occurred: {e}") - st.write(e) # This will show the traceback in the Streamlit app if __name__ == "__main__": main() diff --git a/utils/audio_processing.py b/utils/audio_processing.py index 307ab9f..e9530f3 100644 --- a/utils/audio_processing.py +++ b/utils/audio_processing.py @@ -1,19 +1,55 @@ from pathlib import Path +import tempfile +import os +import logging -# moviepy 2.x removed moviepy.editor; import directly from moviepy try: from moviepy import AudioFileClip except ImportError: - # Fallback for moviepy 1.x from moviepy.editor import AudioFileClip +logger = logging.getLogger(__name__) + +_temp_audio_files = [] + + def extract_audio(video_path: Path): - """Extract audio from a video file.""" + """Extract audio from a video file into a temp directory for automatic cleanup.""" try: audio = AudioFileClip(str(video_path)) - audio_path = video_path.parent / f"{video_path.stem}_audio.wav" + temp_dir = tempfile.mkdtemp(prefix="videotranscriber_") + audio_path = Path(temp_dir) / f"{video_path.stem}_audio.wav" audio.write_audiofile(str(audio_path), verbose=False, logger=None) audio.close() + _temp_audio_files.append(str(audio_path)) return audio_path except Exception as e: raise RuntimeError(f"Audio extraction failed: {e}") + + +def cleanup_temp_audio(): + """Remove all temporary audio files created during processing.""" + cleaned = 0 + for fpath in _temp_audio_files: + try: + if os.path.exists(fpath): + os.remove(fpath) + parent = os.path.dirname(fpath) + if os.path.isdir(parent) and not os.listdir(parent): + os.rmdir(parent) + cleaned += 1 + except Exception as e: + logger.warning(f"Could not remove temp file {fpath}: {e}") + _temp_audio_files.clear() + return cleaned + + +def get_video_duration(video_path: Path): + """Get duration of a video/audio file in seconds.""" + try: + clip = AudioFileClip(str(video_path)) + duration = clip.duration + clip.close() + return duration + except Exception: + return None diff --git a/utils/diarization.py b/utils/diarization.py index bc807ff..2326dc0 100644 --- a/utils/diarization.py +++ b/utils/diarization.py @@ -1,5 +1,5 @@ """ -Speaker diarization utilities for the OBS Recording Transcriber. +Speaker diarization utilities for the Video Transcriber. Provides functions to identify different speakers in audio recordings. """ @@ -11,22 +11,34 @@ import torch from pyannote.audio import Pipeline from pyannote.core import Segment import whisper +import streamlit as st -# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Try to import GPU utilities, but don't fail if not available try: from utils.gpu_utils import get_optimal_device GPU_UTILS_AVAILABLE = True except ImportError: GPU_UTILS_AVAILABLE = False -# Default HuggingFace auth token environment variable HF_TOKEN_ENV = "HF_TOKEN" +@st.cache_resource +def _load_diarization_pipeline(hf_token, device_str): + """Load and cache the speaker diarization pipeline.""" + logger.info(f"Loading diarization pipeline on {device_str}") + pipe = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.0", + use_auth_token=hf_token + ) + device = torch.device(device_str) + if device.type == "cuda": + pipe = pipe.to(device) + return pipe + + def get_diarization_pipeline(use_gpu=True, hf_token=None): """ Initialize the speaker diarization pipeline. @@ -38,7 +50,6 @@ def get_diarization_pipeline(use_gpu=True, hf_token=None): Returns: Pipeline or None: Diarization pipeline if successful, None otherwise """ - # Check if token is provided or in environment if hf_token is None: hf_token = os.environ.get(HF_TOKEN_ENV) if hf_token is None: @@ -46,23 +57,12 @@ def get_diarization_pipeline(use_gpu=True, hf_token=None): return None try: - # Configure device device = torch.device("cpu") if use_gpu and GPU_UTILS_AVAILABLE: device = get_optimal_device() logger.info(f"Using device: {device} for diarization") - # Initialize the pipeline - pipeline = Pipeline.from_pretrained( - "pyannote/speaker-diarization-3.0", - use_auth_token=hf_token - ) - - # Move to appropriate device - if device.type == "cuda": - pipeline = pipeline.to(torch.device(device)) - - return pipeline + return _load_diarization_pipeline(hf_token, str(device)) except Exception as e: logger.error(f"Error initializing diarization pipeline: {e}") return None @@ -198,9 +198,9 @@ def transcribe_with_diarization(audio_path, whisper_model="base", num_speakers=N device = get_optimal_device() try: - # Step 1: Transcribe audio with Whisper + from utils.transcription import _load_whisper_model logger.info(f"Transcribing audio with Whisper model: {whisper_model}") - model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu") + model = _load_whisper_model(whisper_model, str(device)) result = model.transcribe(str(audio_path)) transcript_segments = result["segments"] diff --git a/utils/gpu_utils.py b/utils/gpu_utils.py index 1c08ec3..8c5dcd2 100644 --- a/utils/gpu_utils.py +++ b/utils/gpu_utils.py @@ -1,12 +1,9 @@ """ -GPU utilities for the OBS Recording Transcriber. +GPU utilities for the Video Transcriber. Provides functions to detect and configure GPU acceleration. """ import logging -import os -import platform -import subprocess import torch # Configure logging @@ -68,8 +65,6 @@ def get_optimal_device(): def set_memory_limits(memory_fraction=0.8): - global torch - import torch """ Set memory limits for GPU usage. diff --git a/utils/keyword_extraction.py b/utils/keyword_extraction.py index 498283c..733ec4e 100644 --- a/utils/keyword_extraction.py +++ b/utils/keyword_extraction.py @@ -1,5 +1,5 @@ """ -Keyword extraction utilities for the OBS Recording Transcriber. +Keyword extraction utilities for the Video Transcriber. Provides functions to extract keywords and link them to timestamps. """ @@ -8,25 +8,30 @@ import re import torch import numpy as np from pathlib import Path -from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification +from transformers import pipeline from sklearn.feature_extraction.text import TfidfVectorizer from collections import Counter +import streamlit as st -# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Try to import GPU utilities, but don't fail if not available try: from utils.gpu_utils import get_optimal_device GPU_UTILS_AVAILABLE = True except ImportError: GPU_UTILS_AVAILABLE = False -# Default models NER_MODEL = "dslim/bert-base-NER" +@st.cache_resource +def _load_ner_pipeline(model_name, device_int): + """Load and cache the NER pipeline.""" + logger.info(f"Loading NER model: {model_name}") + return pipeline("ner", model=model_name, device=device_int, aggregation_strategy="simple") + + def extract_keywords_tfidf(text, max_keywords=10, ngram_range=(1, 2)): """ Extract keywords using TF-IDF. @@ -107,8 +112,7 @@ def extract_named_entities(text, model=NER_MODEL, use_gpu=True): device_arg = -1 try: - # Initialize the pipeline - ner_pipeline = pipeline("ner", model=model, device=device_arg, aggregation_strategy="simple") + ner_pipeline = _load_ner_pipeline(model, device_arg) # Split text into manageable chunks if too long max_length = 512 diff --git a/utils/ollama_integration.py b/utils/ollama_integration.py index a63bac8..d71caa4 100644 --- a/utils/ollama_integration.py +++ b/utils/ollama_integration.py @@ -1,6 +1,6 @@ """ Ollama integration for local AI model inference. -Provides functions to use Ollama's API for text summarization. +Provides functions to use Ollama's API for text summarization with streaming support. """ import requests @@ -9,21 +9,14 @@ import logging from pathlib import Path import os -# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Default Ollama API endpoint - configurable via environment variable OLLAMA_API_URL = os.environ.get("OLLAMA_API_URL", "http://localhost:11434/api") def check_ollama_available(): - """ - Check if Ollama service is available. - - Returns: - bool: True if Ollama is available, False otherwise - """ + """Check if Ollama service is available.""" try: response = requests.get(f"{OLLAMA_API_URL}/tags", timeout=2) return response.status_code == 200 @@ -32,12 +25,7 @@ def check_ollama_available(): def list_available_models(): - """ - List available models in Ollama. - - Returns: - list: List of available model names - """ + """List available models in Ollama.""" try: response = requests.get(f"{OLLAMA_API_URL}/tags") if response.status_code == 200: @@ -50,32 +38,14 @@ def list_available_models(): def summarize_with_ollama(text, model="llama3", max_length=150): - """ - Summarize text using Ollama's local API. - - Args: - text (str): Text to summarize - model (str): Ollama model to use - max_length (int): Maximum length of the summary - - Returns: - str: Summarized text or None if failed - """ + """Summarize text using Ollama's local API (non-streaming).""" if not check_ollama_available(): logger.warning("Ollama service is not available") return None - # Check if the model is available - available_models = list_available_models() - if model not in available_models: - logger.warning(f"Model {model} not available in Ollama. Available models: {available_models}") - return None - - # Prepare the prompt for summarization prompt = f"Summarize the following text in about {max_length} words:\n\n{text}" try: - # Make the API request response = requests.post( f"{OLLAMA_API_URL}/generate", json={ @@ -85,7 +55,7 @@ def summarize_with_ollama(text, model="llama3", max_length=150): "options": { "temperature": 0.3, "top_p": 0.9, - "max_tokens": max_length * 2 # Approximate token count + "max_tokens": max_length * 2 } } ) @@ -101,23 +71,55 @@ def summarize_with_ollama(text, model="llama3", max_length=150): return None -def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150): +def stream_summarize_with_ollama(text, model="llama3", max_length=150): """ - Chunk long text and summarize each chunk, then combine the summaries. + Summarize text using Ollama with streaming. Yields tokens as they arrive. - Args: - text (str): Text to summarize - model (str): Ollama model to use - chunk_size (int): Maximum size of each chunk in characters - max_length (int): Maximum length of the final summary - - Returns: - str: Combined summary or None if failed + Yields: + str: Individual response tokens """ + if not check_ollama_available(): + logger.warning("Ollama service is not available") + return + + prompt = f"Summarize the following text in about {max_length} words:\n\n{text}" + + try: + response = requests.post( + f"{OLLAMA_API_URL}/generate", + json={ + "model": model, + "prompt": prompt, + "stream": True, + "options": { + "temperature": 0.3, + "top_p": 0.9, + "max_tokens": max_length * 2 + } + }, + stream=True + ) + + if response.status_code == 200: + for line in response.iter_lines(): + if line: + data = json.loads(line) + token = data.get('response', '') + if token: + yield token + if data.get('done', False): + break + else: + logger.error(f"Ollama API error: {response.status_code}") + except requests.exceptions.RequestException as e: + logger.error(f"Error communicating with Ollama: {e}") + + +def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150): + """Chunk long text and summarize each chunk, then combine.""" if len(text) <= chunk_size: return summarize_with_ollama(text, model, max_length) - # Split text into chunks words = text.split() chunks = [] current_chunk = [] @@ -135,7 +137,6 @@ def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150): if current_chunk: chunks.append(' '.join(current_chunk)) - # Summarize each chunk chunk_summaries = [] for i, chunk in enumerate(chunks): logger.info(f"Summarizing chunk {i+1}/{len(chunks)}") @@ -146,10 +147,55 @@ def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150): if not chunk_summaries: return None - # If there's only one chunk summary, return it if len(chunk_summaries) == 1: return chunk_summaries[0] - # Otherwise, combine the summaries and summarize again combined_summary = " ".join(chunk_summaries) - return summarize_with_ollama(combined_summary, model, max_length) \ No newline at end of file + return summarize_with_ollama(combined_summary, model, max_length) + + +def stream_chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150): + """ + Chunk and summarize with streaming on the final summary. + Returns non-streaming chunk summaries, then streams the final combination. + + Yields: + str: Tokens from the final summary + """ + if len(text) <= chunk_size: + yield from stream_summarize_with_ollama(text, model, max_length) + return + + words = text.split() + chunks = [] + current_chunk = [] + current_length = 0 + + for word in words: + if current_length + len(word) + 1 <= chunk_size: + current_chunk.append(word) + current_length += len(word) + 1 + else: + chunks.append(' '.join(current_chunk)) + current_chunk = [word] + current_length = len(word) + 1 + + if current_chunk: + chunks.append(' '.join(current_chunk)) + + chunk_summaries = [] + for i, chunk in enumerate(chunks): + logger.info(f"Summarizing chunk {i+1}/{len(chunks)}") + summary = summarize_with_ollama(chunk, model, max_length // len(chunks)) + if summary: + chunk_summaries.append(summary) + + if not chunk_summaries: + return + + if len(chunk_summaries) == 1: + yield chunk_summaries[0] + return + + combined_summary = " ".join(chunk_summaries) + yield from stream_summarize_with_ollama(combined_summary, model, max_length) \ No newline at end of file diff --git a/utils/summarization.py b/utils/summarization.py index a5c5ddc..28012c4 100644 --- a/utils/summarization.py +++ b/utils/summarization.py @@ -1,45 +1,49 @@ from transformers import pipeline, AutoTokenizer import torch import logging +import streamlit as st -# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) SUMMARY_MODEL = "Falconsai/text_summarization" + +@st.cache_resource +def _load_summarizer(device_int): + """Load and cache the summarization pipeline.""" + logger.info(f"Loading summarization model on device {device_int}") + return pipeline("summarization", model=SUMMARY_MODEL, device=device_int) + + +@st.cache_resource +def _load_summary_tokenizer(): + """Load and cache the summarization tokenizer.""" + return AutoTokenizer.from_pretrained(SUMMARY_MODEL) + + def chunk_text(text, max_tokens, tokenizer): """ - Splits the text into a list of chunks based on token limits. - - Args: - text (str): Text to chunk - max_tokens (int): Maximum tokens per chunk - tokenizer (AutoTokenizer): Tokenizer to use - - Returns: - list: List of text chunks + Splits text into chunks by tokenizing once, then splitting by token windows. + Much faster than the per-word tokenization approach. """ - words = text.split() + all_ids = tokenizer(text, return_tensors='pt', truncation=False)['input_ids'][0] + content_ids = all_ids[1:-1] # strip BOS/EOS + usable_max = max_tokens - 2 # leave room for special tokens + chunks = [] - current_chunk = [] - current_length = 0 - - for word in words: - hypothetical_length = current_length + len(tokenizer(word, return_tensors='pt')['input_ids'][0]) - 2 - if hypothetical_length <= max_tokens: - current_chunk.append(word) - current_length = hypothetical_length - else: - chunks.append(' '.join(current_chunk)) - current_chunk = [word] - current_length = len(tokenizer(word, return_tensors='pt')['input_ids'][0]) - 2 - - if current_chunk: - chunks.append(' '.join(current_chunk)) - + for i in range(0, len(content_ids), usable_max): + chunk_ids = content_ids[i : i + usable_max] + decoded = tokenizer.decode(chunk_ids, skip_special_tokens=True).strip() + if decoded: + chunks.append(decoded) + + if not chunks: + chunks.append(text) + return chunks + def summarize_text(text, use_gpu=True, memory_fraction=0.8): """ Summarize text using a Hugging Face pipeline with chunking support. @@ -52,21 +56,17 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8): Returns: str: Summarized text """ - # Determine device - device = -1 # Default to CPU + device = -1 if use_gpu and torch.cuda.is_available(): - device = 0 # Use first GPU - if torch.cuda.is_available(): - torch.cuda.set_per_process_memory_fraction(memory_fraction) + device = 0 + torch.cuda.set_per_process_memory_fraction(memory_fraction) logger.info(f"Using device {device} for summarization") try: - # Initialize the pipeline and tokenizer - summarizer = pipeline("summarization", model=SUMMARY_MODEL, device=device) - tokenizer = AutoTokenizer.from_pretrained(SUMMARY_MODEL) + summarizer = _load_summarizer(device) + tokenizer = _load_summary_tokenizer() - # Check if text needs to be chunked max_tokens = 512 tokens = tokenizer(text, return_tensors='pt') num_tokens = len(tokens['input_ids'][0]) @@ -85,7 +85,6 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8): ) summaries.append(summary_output[0]['summary_text']) - # If multiple chunks, summarize the combined summaries if len(summaries) > 1: logger.info("Generating final summary from chunk summaries") combined_text = " ".join(summaries) @@ -106,7 +105,6 @@ def summarize_text(text, use_gpu=True, memory_fraction=0.8): except Exception as e: logger.error(f"Error during summarization: {e}") - # Fallback to CPU if GPU fails if device != -1: logger.info("Falling back to CPU") return summarize_text(text, use_gpu=False, memory_fraction=memory_fraction) diff --git a/utils/transcription.py b/utils/transcription.py index 68626b1..0c53b5f 100644 --- a/utils/transcription.py +++ b/utils/transcription.py @@ -1,31 +1,36 @@ import whisper from pathlib import Path -from transformers import pipeline, AutoTokenizer from utils.audio_processing import extract_audio -from utils.summarization import summarize_text import logging import torch +import streamlit as st -# Try to import GPU utilities, but don't fail if not available try: from utils.gpu_utils import configure_gpu, get_optimal_device GPU_UTILS_AVAILABLE = True except ImportError: GPU_UTILS_AVAILABLE = False -# Try to import caching utilities, but don't fail if not available try: from utils.cache import load_from_cache, save_to_cache CACHE_AVAILABLE = True except ImportError: CACHE_AVAILABLE = False -# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) WHISPER_MODEL = "base" + +@st.cache_resource +def _load_whisper_model(model_name, device_str): + """Load and cache a Whisper model. Cached across reruns.""" + logger.info(f"Loading Whisper model: {model_name} on {device_str}") + device = torch.device(device_str) + return whisper.load_model(model_name, device=device if device.type != "mps" else "cpu") + + def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cache_max_age=None, use_gpu=True, memory_fraction=0.8): """ @@ -44,38 +49,30 @@ def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cach """ audio_path = Path(audio_path) - # Check cache first if enabled if use_cache and CACHE_AVAILABLE: cached_data = load_from_cache(audio_path, model, "transcribe", cache_max_age) if cached_data: logger.info(f"Using cached transcription for {audio_path}") return cached_data.get("segments", []), cached_data.get("transcript", "") - # Extract audio if the input is a video file (M4A is already audio) video_extensions = ['.mp4', '.avi', '.mov', '.mkv'] if audio_path.suffix.lower() in video_extensions: audio_path = extract_audio(audio_path) - # Configure GPU if available and requested device = torch.device("cpu") if use_gpu and GPU_UTILS_AVAILABLE: gpu_config = configure_gpu(model, memory_fraction) device = gpu_config["device"] logger.info(f"Using device: {device} for transcription") - # Load the specified Whisper model - logger.info(f"Loading Whisper model: {model}") - whisper_model = whisper.load_model(model, device=device if device.type != "mps" else "cpu") + whisper_model = _load_whisper_model(model, str(device)) - # Transcribe the audio logger.info(f"Transcribing audio: {audio_path}") result = whisper_model.transcribe(str(audio_path)) - # Extract the full transcript and segments transcript = result["text"] segments = result["segments"] - # Cache the results if caching is enabled if use_cache and CACHE_AVAILABLE: cache_data = { "transcript": transcript, diff --git a/utils/translation.py b/utils/translation.py index 84ae541..5d8094b 100644 --- a/utils/translation.py +++ b/utils/translation.py @@ -1,41 +1,49 @@ """ -Translation utilities for the OBS Recording Transcriber. +Translation utilities for the Video Transcriber. Provides functions for language detection and translation. """ import logging import torch from pathlib import Path -from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, M2M100ForConditionalGeneration +from transformers import pipeline, AutoTokenizer, M2M100ForConditionalGeneration import whisper import iso639 +import streamlit as st -# Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Try to import GPU utilities, but don't fail if not available try: from utils.gpu_utils import get_optimal_device GPU_UTILS_AVAILABLE = True except ImportError: GPU_UTILS_AVAILABLE = False -# Default models TRANSLATION_MODEL = "facebook/m2m100_418M" LANGUAGE_DETECTION_MODEL = "papluca/xlm-roberta-base-language-detection" -# ISO language code mapping + +@st.cache_resource +def _load_language_detector(model_name, device_int): + """Load and cache the language detection pipeline.""" + logger.info(f"Loading language detection model: {model_name}") + return pipeline("text-classification", model=model_name, device=device_int) + + +@st.cache_resource +def _load_translation_model(model_name, device_str): + """Load and cache the M2M100 translation model and tokenizer.""" + logger.info(f"Loading translation model: {model_name} on {device_str}") + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = M2M100ForConditionalGeneration.from_pretrained(model_name) + device = torch.device(device_str) + model = model.to(device) + return model, tokenizer + + def get_language_name(code): - """ - Get the language name from ISO code. - - Args: - code (str): ISO language code - - Returns: - str: Language name or original code if not found - """ + """Get the language name from ISO code.""" try: return iso639.languages.get(part1=code).name except (KeyError, AttributeError): @@ -57,7 +65,6 @@ def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True): Returns: tuple: (language_code, confidence) """ - # Configure device device = torch.device("cpu") if use_gpu and GPU_UTILS_AVAILABLE: device = get_optimal_device() @@ -66,25 +73,43 @@ def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True): device_arg = -1 try: - # Initialize the pipeline - classifier = pipeline("text-classification", model=model, device=device_arg) + classifier = _load_language_detector(model, device_arg) - # Truncate text if too long max_length = 512 if len(text) > max_length: text = text[:max_length] - # Detect language result = classifier(text)[0] - language_code = result["label"] - confidence = result["score"] - - return language_code, confidence + return result["label"], result["score"] except Exception as e: logger.error(f"Error detecting language: {e}") return None, 0.0 +def _translate_text_with_model(text, source_lang, target_lang, trans_model, tokenizer, device): + """Translate text using a pre-loaded model and tokenizer.""" + tokenizer.src_lang = source_lang + + max_length = 512 + if len(text) > max_length: + chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)] + else: + chunks = [text] + + translated_chunks = [] + for chunk in chunks: + encoded = tokenizer(chunk, return_tensors="pt").to(device) + generated_tokens = trans_model.generate( + **encoded, + forced_bos_token_id=tokenizer.get_lang_id(target_lang), + max_length=max_length + ) + translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] + translated_chunks.append(translated_chunk) + + return " ".join(translated_chunks) + + def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_MODEL, use_gpu=True): """ Translate text from source language to target language. @@ -99,7 +124,6 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M Returns: str: Translated text """ - # Auto-detect source language if not provided if source_lang is None: detected_lang, confidence = detect_language(text, use_gpu=use_gpu) if detected_lang and confidence > 0.5: @@ -109,50 +133,17 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M logger.warning("Could not reliably detect language, defaulting to English") source_lang = "en" - # Skip translation if source and target are the same if source_lang == target_lang: logger.info(f"Source and target languages are the same ({source_lang}), skipping translation") return text - # Configure device device = torch.device("cpu") if use_gpu and GPU_UTILS_AVAILABLE: device = get_optimal_device() try: - # Load model and tokenizer - tokenizer = AutoTokenizer.from_pretrained(model) - model = M2M100ForConditionalGeneration.from_pretrained(model) - - # Move model to device - model = model.to(device) - - # Prepare for translation - tokenizer.src_lang = source_lang - - # Split text into manageable chunks if too long - max_length = 512 - if len(text) > max_length: - chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)] - else: - chunks = [text] - - # Translate each chunk - translated_chunks = [] - for chunk in chunks: - encoded = tokenizer(chunk, return_tensors="pt").to(device) - generated_tokens = model.generate( - **encoded, - forced_bos_token_id=tokenizer.get_lang_id(target_lang), - max_length=max_length - ) - translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] - translated_chunks.append(translated_chunk) - - # Combine translated chunks - translated_text = " ".join(translated_chunks) - - return translated_text + trans_model, tokenizer = _load_translation_model(model, str(device)) + return _translate_text_with_model(text, source_lang, target_lang, trans_model, tokenizer, device) except Exception as e: logger.error(f"Error translating text: {e}") return text @@ -160,7 +151,7 @@ def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_M def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=True): """ - Translate transcript segments. + Translate transcript segments. Loads the model once and reuses for all segments. Args: segments (list): List of transcript segments @@ -174,36 +165,32 @@ def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=Tru if not segments: return [] - # Auto-detect source language from combined text if not provided if source_lang is None: combined_text = " ".join([segment["text"] for segment in segments]) detected_lang, _ = detect_language(combined_text, use_gpu=use_gpu) source_lang = detected_lang if detected_lang else "en" - # Skip translation if source and target are the same if source_lang == target_lang: return segments + device = torch.device("cpu") + if use_gpu and GPU_UTILS_AVAILABLE: + device = get_optimal_device() + try: - # Initialize translation pipeline - translated_segments = [] + trans_model, tokenizer = _load_translation_model(TRANSLATION_MODEL, str(device)) - # Translate each segment + translated_segments = [] for segment in segments: - translated_text = translate_text( - segment["text"], - source_lang=source_lang, - target_lang=target_lang, - use_gpu=use_gpu + translated_text = _translate_text_with_model( + segment["text"], source_lang, target_lang, trans_model, tokenizer, device ) - # Create a new segment with translated text translated_segment = segment.copy() translated_segment["text"] = translated_text translated_segment["original_text"] = segment["text"] translated_segment["source_lang"] = source_lang translated_segment["target_lang"] = target_lang - translated_segments.append(translated_segment) return translated_segments @@ -227,39 +214,33 @@ def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en", Returns: tuple: (original_segments, translated_segments, original_transcript, translated_transcript) """ + from utils.transcription import _load_whisper_model + audio_path = Path(audio_path) - # Configure device device = torch.device("cpu") if use_gpu and GPU_UTILS_AVAILABLE: device = get_optimal_device() try: - # Step 1: Transcribe audio with Whisper logger.info(f"Transcribing audio with Whisper model: {whisper_model}") - model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu") + model = _load_whisper_model(whisper_model, str(device)) - # Use Whisper's built-in language detection if requested if detect_source: - # First, detect language with Whisper audio = whisper.load_audio(str(audio_path)) audio = whisper.pad_or_trim(audio) mel = whisper.log_mel_spectrogram(audio).to(device if device.type != "mps" else "cpu") _, probs = model.detect_language(mel) source_lang = max(probs, key=probs.get) logger.info(f"Whisper detected language: {get_language_name(source_lang)} ({source_lang})") - - # Transcribe with detected language result = model.transcribe(str(audio_path), language=source_lang) else: - # Transcribe without language specification result = model.transcribe(str(audio_path)) source_lang = result.get("language", "en") original_segments = result["segments"] original_transcript = result["text"] - # Step 2: Translate if needed if source_lang != target_lang: logger.info(f"Translating from {source_lang} to {target_lang}") translated_segments = translate_segments( @@ -268,8 +249,6 @@ def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en", target_lang=target_lang, use_gpu=use_gpu ) - - # Create full translated transcript translated_transcript = " ".join([segment["text"] for segment in translated_segments]) else: logger.info(f"Source and target languages are the same ({source_lang}), skipping translation") @@ -280,4 +259,4 @@ def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en", except Exception as e: logger.error(f"Error in transcribe_and_translate: {e}") - return None, None, None, None \ No newline at end of file + return None, None, None, None \ No newline at end of file diff --git a/utils/validation.py b/utils/validation.py index 5079fd1..8b579af 100644 --- a/utils/validation.py +++ b/utils/validation.py @@ -1,8 +1,38 @@ from pathlib import Path +import shutil +import logging -def validate_environment(obs_path: Path): +logger = logging.getLogger(__name__) + + +def validate_environment(obs_path: Path = None): """Validate environment and prerequisites.""" errors = [] - if not obs_path.exists(): - errors.append(f"OBS directory not found: {obs_path}") + + if obs_path and not obs_path.exists(): + errors.append(f"Directory not found: {obs_path}") + + if not shutil.which("ffmpeg"): + errors.append("FFmpeg is not installed or not in PATH. Install it from https://ffmpeg.org/download.html") + return errors + + +def get_system_capabilities(): + """Return a dict of detected system capabilities for display.""" + import torch + + caps = { + "ffmpeg": shutil.which("ffmpeg") is not None, + "cuda": torch.cuda.is_available(), + "mps": hasattr(torch.backends, "mps") and torch.backends.mps.is_available(), + "gpu_name": None, + "gpu_memory": None, + } + + if caps["cuda"] and torch.cuda.device_count() > 0: + props = torch.cuda.get_device_properties(0) + caps["gpu_name"] = props.name + caps["gpu_memory"] = props.total_memory + + return caps