diff --git a/GEMINI_INSIGHTS.md b/GEMINI_INSIGHTS.md new file mode 100644 index 0000000..0f16264 --- /dev/null +++ b/GEMINI_INSIGHTS.md @@ -0,0 +1,105 @@ +# Gemini Insights: OBS Recording Transcriber + +## Project Overview +The OBS Recording Transcriber is a Python application built with Streamlit that processes video recordings (particularly from OBS Studio) to generate transcripts and summaries using AI models. The application uses Whisper for transcription and Hugging Face Transformers for summarization. + +## Key Improvement Areas + +### 1. UI Enhancements +- **Implemented:** + - Responsive layout with columns for better organization + - Expanded sidebar with categorized settings + - Custom CSS for improved button styling + - Spinner for long-running operations + - Expanded transcript view by default + +- **Additional Recommendations:** + - Add a dark mode toggle + - Implement progress bars for each processing step + - Add tooltips for complex options + - Create a dashboard view for batch processing results + - Add visualization of transcript segments with timestamps + +### 2. Ollama Local API Integration +- **Implemented:** + - Local API integration for offline summarization + - Model selection from available Ollama models + - Chunking for long texts + - Fallback to online models when Ollama fails + +- **Additional Recommendations:** + - Add temperature and other generation parameters as advanced options + - Implement streaming responses for real-time feedback + - Cache results to avoid reprocessing + - Add support for custom Ollama model creation with specific instructions + - Implement parallel processing for multiple chunks + +### 3. Subtitle Export Formats +- **Implemented:** + - SRT export with proper formatting + - ASS export with basic styling + - Multi-format export options + - Automatic segment creation from plain text + +- **Additional Recommendations:** + - Add customizable styling options for ASS subtitles + - Implement subtitle editing before export + - Add support for VTT format for web videos + - Implement subtitle timing adjustment + - Add batch export for multiple files + +### 4. Architecture and Code Quality +- **Recommendations:** + - Implement proper error handling and logging throughout + - Add unit tests for critical components + - Create a configuration file for default settings + - Implement caching for processed files + - Add type hints throughout the codebase + - Document API endpoints for potential future web service + +### 5. Performance Optimizations +- **Recommendations:** + - Implement parallel processing for batch operations + - Add GPU acceleration configuration options + - Optimize memory usage for large files + - Implement incremental processing for very long recordings + - Add compression options for exported files + +### 6. Additional Features +- **Recommendations:** + - Speaker diarization (identifying different speakers) + - Language detection and translation + - Keyword extraction and timestamp linking + - Integration with video editing software + - Batch processing queue with email notifications + - Custom vocabulary for domain-specific terminology + +## Implementation Roadmap +1. **Phase 1 (Completed):** Basic UI improvements, Ollama integration, and subtitle export +2. **Phase 2 (Completed):** Performance optimizations and additional export formats + - Added WebVTT export format for web videos + - Implemented GPU acceleration with automatic device selection + - Added caching system for faster processing of previously transcribed files + - Optimized memory usage with configurable memory limits + - Added compression options for exported files + - Enhanced ASS subtitle styling options + - Added progress indicators for better user feedback +3. **Phase 3 (Completed):** Advanced features like speaker diarization and translation + - Implemented speaker diarization to identify different speakers in recordings + - Added language detection and translation capabilities + - Integrated keyword extraction with timestamp linking + - Created interactive transcript with keyword highlighting + - Added named entity recognition for better content analysis + - Generated keyword index with timestamp references + - Provided speaker statistics and word count analysis +4. **Phase 4:** Integration with other tools and services + +## Technical Considerations +- Ensure compatibility with different Whisper model sizes +- Handle large files efficiently to prevent memory issues +- Provide graceful degradation when optional dependencies are missing +- Maintain backward compatibility with existing workflows +- Consider containerization for easier deployment + +## Conclusion +The OBS Recording Transcriber has a solid foundation but can be significantly enhanced with the suggested improvements. The focus should be on improving user experience, adding offline processing capabilities, and expanding export options to make the tool more versatile for different use cases. \ No newline at end of file diff --git a/INSTALLATION.md b/INSTALLATION.md new file mode 100644 index 0000000..5d90c44 --- /dev/null +++ b/INSTALLATION.md @@ -0,0 +1,141 @@ +# Installation Guide for OBS Recording Transcriber + +This guide will help you install all the necessary dependencies for the OBS Recording Transcriber application, including the advanced features from Phase 3. + +## Prerequisites + +Before installing the Python packages, you need to set up some prerequisites: + +### 1. Python 3.8 or higher + +Make sure you have Python 3.8 or higher installed. You can download it from [python.org](https://www.python.org/downloads/). + +### 2. FFmpeg + +FFmpeg is required for audio processing: + +- **Windows**: + - Download from [gyan.dev/ffmpeg/builds](https://www.gyan.dev/ffmpeg/builds/) + - Extract the ZIP file + - Add the `bin` folder to your system PATH + +- **macOS**: + ```bash + brew install ffmpeg + ``` + +- **Linux**: + ```bash + sudo apt update + sudo apt install ffmpeg + ``` + +### 3. Visual C++ Build Tools (Windows only) + +Some packages like `tokenizers` require C++ build tools: + +1. Download and install [Visual C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/) +2. During installation, select "Desktop development with C++" + +## Installation Steps + +### 1. Create a Virtual Environment (Recommended) + +```bash +# Create a virtual environment +python -m venv venv + +# Activate the virtual environment +# Windows +venv\Scripts\activate +# macOS/Linux +source venv/bin/activate +``` + +### 2. Install PyTorch + +For better performance, install PyTorch with CUDA support if you have an NVIDIA GPU: + +```bash +# Windows/Linux with CUDA +pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 + +# macOS or CPU-only +pip install torch torchvision torchaudio +``` + +### 3. Install Dependencies + +```bash +# Install all dependencies from requirements.txt +pip install -r requirements.txt +``` + +### 4. Troubleshooting Common Issues + +#### Tokenizers Installation Issues + +If you encounter issues with `tokenizers` installation: + +1. Make sure you have Visual C++ Build Tools installed (Windows) +2. Try installing Rust: [rustup.rs](https://rustup.rs/) +3. Install tokenizers separately: + ```bash + pip install tokenizers --no-binary tokenizers + ``` + +#### PyAnnote.Audio Access + +To use speaker diarization, you need a HuggingFace token with access to the pyannote models: + +1. Create an account on [HuggingFace](https://huggingface.co/) +2. Generate an access token at [huggingface.co/settings/tokens](https://huggingface.co/settings/tokens) +3. Request access to [pyannote/speaker-diarization-3.0](https://huggingface.co/pyannote/speaker-diarization-3.0) +4. Set the token in the application when prompted or as an environment variable: + ```bash + # Windows + set HF_TOKEN=your_token_here + # macOS/Linux + export HF_TOKEN=your_token_here + ``` + +#### Memory Issues with Large Files + +If you encounter memory issues with large files: + +1. Use a smaller Whisper model (e.g., "base" instead of "large") +2. Reduce the GPU memory fraction in the application settings +3. Increase your system's swap space/virtual memory + +## Running the Application + +After installation, run the application with: + +```bash +streamlit run app.py +``` + +## Optional: Ollama Setup for Local Summarization + +To use Ollama for local summarization: + +1. Install Ollama from [ollama.ai](https://ollama.ai/) +2. Pull a model: + ```bash + ollama pull llama3 + ``` +3. Uncomment the Ollama line in requirements.txt and install: + ```bash + pip install ollama + ``` + +## Verifying Installation + +To verify that all components are working correctly: + +1. Run the application +2. Check that GPU acceleration is available (if applicable) +3. Test a small video file with basic transcription +4. Gradually enable advanced features like diarization and translation + +If you encounter any issues, check the application logs for specific error messages. \ No newline at end of file diff --git a/README.md b/README.md index 7093fdd..ee4ec57 100644 --- a/README.md +++ b/README.md @@ -7,19 +7,68 @@ Process OBS recordings or any video/audio files with AI-based transcription and - AI transcription using Whisper. - Summarization using Hugging Face Transformers. - File selection, resource validation, and error handling. +- Speaker diarization to identify different speakers in recordings. +- Language detection and translation capabilities. +- Keyword extraction with timestamp linking. +- Interactive transcript with keyword highlighting. +- Export to TXT, SRT, VTT, and ASS subtitle formats with compression options. +- GPU acceleration for faster processing. +- Caching system for previously processed files. ## Installation -1. Clone the repo. -git clone [https://github.com/DataAnts-AI/VideoTranscriber.git -cd VideoTranscriber -2. Install dependencies: - pip install -r requirements.txt +### Easy Installation (Recommended) + +#### Windows +1. Download or clone the repository +2. Run `install.bat` by double-clicking it +3. Follow the on-screen instructions + +#### Linux/macOS +1. Download or clone the repository +2. Open a terminal in the project directory +3. Make the install script executable: `chmod +x install.sh` +4. Run the script: `./install.sh` +5. Follow the on-screen instructions + +### Manual Installation +1. Clone the repo. +``` +git clone https://github.com/DataAnts-AI/VideoTranscriber.git +cd VideoTranscriber +``` + +2. Install dependencies: +``` +pip install -r requirements.txt +``` Notes: -Ensure that the versions align with the features you use and your system compatibility. -torch version should match the capabilities of your hardware (e.g., CUDA support for GPUs). -whisper might need to be installed from source or a GitHub repository if it's not available on PyPI. -If you encounter any issues regarding compatibility, versions may need adjustments. +- Ensure that the versions align with the features you use and your system compatibility. +- torch version should match the capabilities of your hardware (e.g., CUDA support for GPUs). +- For advanced features like speaker diarization, you'll need a HuggingFace token. +- See `INSTALLATION.md` for detailed instructions and troubleshooting. -3. streamlit run app.py +3. Run the application: +``` +streamlit run app.py +``` + +## Usage +1. Set your base folder where OBS recordings are stored +2. Select a recording from the dropdown +3. Choose transcription and summarization models +4. Configure performance settings (GPU acceleration, caching) +5. Select export formats and compression options +6. Click "Process Recording" to start + +## Advanced Features +- **Speaker Diarization**: Identify and label different speakers in your recordings +- **Translation**: Automatically detect language and translate to multiple languages +- **Keyword Extraction**: Extract important keywords with timestamp links +- **Interactive Transcript**: Navigate through the transcript with keyword highlighting +- **GPU Acceleration**: Utilize your GPU for faster processing +- **Caching**: Save processing time by caching results + +## Contributing +Contributions are welcome! Please feel free to submit a Pull Request. diff --git a/app.py b/app.py index c138984..2526013 100644 --- a/app.py +++ b/app.py @@ -3,21 +3,314 @@ from utils.audio_processing import extract_audio from utils.transcription import transcribe_audio from utils.summarization import summarize_text from utils.validation import validate_environment +from utils.export import export_transcript from pathlib import Path +import os +import logging +import humanize +from datetime import timedelta + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Try to import Ollama integration, but don't fail if it's not available +try: + from utils.ollama_integration import check_ollama_available, list_available_models, chunk_and_summarize + OLLAMA_AVAILABLE = check_ollama_available() +except ImportError: + OLLAMA_AVAILABLE = False + +# Try to import GPU utilities, but don't fail if not available +try: + from utils.gpu_utils import get_gpu_info, configure_gpu + GPU_UTILS_AVAILABLE = True +except ImportError: + GPU_UTILS_AVAILABLE = False + +# Try to import caching utilities, but don't fail if not available +try: + from utils.cache import get_cache_size, clear_cache + CACHE_AVAILABLE = True +except ImportError: + CACHE_AVAILABLE = False + +# Try to import diarization utilities, but don't fail if not available +try: + from utils.diarization import transcribe_with_diarization + DIARIZATION_AVAILABLE = True +except ImportError: + DIARIZATION_AVAILABLE = False + +# Try to import translation utilities, but don't fail if not available +try: + from utils.translation import transcribe_and_translate, get_language_name + TRANSLATION_AVAILABLE = True +except ImportError: + TRANSLATION_AVAILABLE = False + +# Try to import keyword extraction utilities, but don't fail if not available +try: + from utils.keyword_extraction import extract_keywords_from_transcript, generate_keyword_index, generate_interactive_transcript + KEYWORD_EXTRACTION_AVAILABLE = True +except ImportError: + KEYWORD_EXTRACTION_AVAILABLE = False def main(): + # Set page configuration + st.set_page_config( + page_title="OBS Recording Transcriber", + page_icon="🎥", + layout="wide", + initial_sidebar_state="expanded" + ) + + # Custom CSS for better UI + st.markdown(""" + + """, unsafe_allow_html=True) + st.title("🎥 OBS Recording Transcriber") st.caption("Process your OBS recordings with AI transcription and summarization") + # Sidebar configuration + st.sidebar.header("Settings") + # Allow the user to select a base folder - st.sidebar.header("Folder Selection") base_folder = st.sidebar.text_input( "Enter the base folder path:", value=str(Path.home()) ) - + base_path = Path(base_folder) + # Model selection + st.sidebar.subheader("Model Settings") + + # Transcription model selection + transcription_model = st.sidebar.selectbox( + "Transcription Model", + ["tiny", "base", "small", "medium", "large"], + index=1, + help="Select the Whisper model size. Larger models are more accurate but slower." + ) + + # Summarization model selection + summarization_options = ["Hugging Face (Online)", "Ollama (Local)"] if OLLAMA_AVAILABLE else ["Hugging Face (Online)"] + summarization_method = st.sidebar.selectbox( + "Summarization Method", + summarization_options, + index=0, + help="Select the summarization method. Ollama runs locally but requires installation." + ) + + # If Ollama is selected, show model selection + ollama_model = None + if OLLAMA_AVAILABLE and summarization_method == "Ollama (Local)": + available_models = list_available_models() + if available_models: + ollama_model = st.sidebar.selectbox( + "Ollama Model", + available_models, + index=0 if "llama3" in available_models else 0, + help="Select the Ollama model to use for summarization." + ) + else: + st.sidebar.warning("No Ollama models found. Please install models using 'ollama pull model_name'.") + + # Advanced features + st.sidebar.subheader("Advanced Features") + + # Speaker diarization + use_diarization = st.sidebar.checkbox( + "Speaker Diarization", + value=False, + disabled=not DIARIZATION_AVAILABLE, + help="Identify different speakers in the recording." + ) + + # Show HF token input if diarization is enabled + hf_token = None + if use_diarization and DIARIZATION_AVAILABLE: + hf_token = st.sidebar.text_input( + "HuggingFace Token", + type="password", + help="Required for speaker diarization. Get your token at huggingface.co/settings/tokens" + ) + + num_speakers = st.sidebar.number_input( + "Number of Speakers", + min_value=1, + max_value=10, + value=2, + help="Specify the number of speakers if known, or leave at default for auto-detection." + ) + + # Translation + use_translation = st.sidebar.checkbox( + "Translation", + value=False, + disabled=not TRANSLATION_AVAILABLE, + help="Translate the transcript to another language." + ) + + # Target language selection if translation is enabled + target_lang = None + if use_translation and TRANSLATION_AVAILABLE: + target_lang = st.sidebar.selectbox( + "Target Language", + ["en", "es", "fr", "de", "it", "pt", "nl", "ru", "zh", "ja", "ko", "ar"], + format_func=lambda x: f"{get_language_name(x)} ({x})", + help="Select the language to translate to." + ) + + # Keyword extraction + use_keywords = st.sidebar.checkbox( + "Keyword Extraction", + value=False, + disabled=not KEYWORD_EXTRACTION_AVAILABLE, + help="Extract keywords and link them to timestamps." + ) + + if use_keywords and KEYWORD_EXTRACTION_AVAILABLE: + max_keywords = st.sidebar.slider( + "Max Keywords", + min_value=5, + max_value=30, + value=15, + help="Maximum number of keywords to extract." + ) + + # Performance settings + st.sidebar.subheader("Performance Settings") + + # GPU acceleration + use_gpu = st.sidebar.checkbox( + "Use GPU Acceleration", + value=True if GPU_UTILS_AVAILABLE else False, + disabled=not GPU_UTILS_AVAILABLE, + help="Use GPU for faster processing if available." + ) + + # Show GPU info if available + if GPU_UTILS_AVAILABLE and use_gpu: + gpu_info = get_gpu_info() + if gpu_info["cuda_available"]: + gpu_devices = [f"{d['name']} ({humanize.naturalsize(d['total_memory'])})" for d in gpu_info["cuda_devices"]] + st.sidebar.info(f"GPU(s) available: {', '.join(gpu_devices)}") + elif gpu_info["mps_available"]: + st.sidebar.info("Apple Silicon GPU (MPS) available") + else: + st.sidebar.warning("No GPU detected. Using CPU.") + + # Memory usage + memory_fraction = st.sidebar.slider( + "GPU Memory Usage", + min_value=0.1, + max_value=1.0, + value=0.8, + step=0.1, + disabled=not (GPU_UTILS_AVAILABLE and use_gpu), + help="Fraction of GPU memory to use. Lower if you encounter out-of-memory errors." + ) + + # Caching options + use_cache = st.sidebar.checkbox( + "Use Caching", + value=True if CACHE_AVAILABLE else False, + disabled=not CACHE_AVAILABLE, + help="Cache transcription results to avoid reprocessing the same files." + ) + + # Cache management + if CACHE_AVAILABLE and use_cache: + cache_size, cache_files = get_cache_size() + if cache_size > 0: + st.sidebar.info(f"Cache: {humanize.naturalsize(cache_size)} ({cache_files} files)") + if st.sidebar.button("Clear Cache"): + cleared = clear_cache() + st.sidebar.success(f"Cleared {cleared} cache files") + + # Export options + st.sidebar.subheader("Export Options") + export_format = st.sidebar.multiselect( + "Export Formats", + ["TXT", "SRT", "VTT", "ASS"], + default=["TXT"], + help="Select the formats to export the transcript." + ) + + # Compression options + compress_exports = st.sidebar.checkbox( + "Compress Exports", + value=False, + help="Compress exported files to save space." + ) + + if compress_exports: + compression_type = st.sidebar.radio( + "Compression Format", + ["gzip", "zip"], + index=0, + help="Select the compression format for exported files." + ) + else: + compression_type = None + + # ASS subtitle styling + if "ASS" in export_format: + st.sidebar.subheader("ASS Subtitle Styling") + show_style_options = st.sidebar.checkbox("Customize ASS Style", value=False) + + if show_style_options: + ass_style = {} + ass_style["fontname"] = st.sidebar.selectbox( + "Font", + ["Arial", "Helvetica", "Times New Roman", "Courier New", "Comic Sans MS"], + index=0 + ) + ass_style["fontsize"] = st.sidebar.slider("Font Size", 12, 72, 48) + ass_style["alignment"] = st.sidebar.selectbox( + "Alignment", + ["2 (Bottom Center)", "1 (Bottom Left)", "3 (Bottom Right)", "8 (Top Center)"], + index=0 + ).split()[0] # Extract just the number + ass_style["bold"] = "-1" if st.sidebar.checkbox("Bold", value=True) else "0" + ass_style["italic"] = "-1" if st.sidebar.checkbox("Italic", value=False) else "0" + else: + ass_style = None + # Validate environment env_errors = validate_environment(base_path) if env_errors: @@ -34,25 +327,213 @@ def main(): selected_file = st.selectbox("Choose a recording", recordings) + # Process button with spinner if st.button("🚀 Start Processing"): + # Create a progress bar + progress_bar = st.progress(0) + status_text = st.empty() + try: - transcript, summary = transcribe_audio(selected_file) - if transcript: - st.subheader("🖍 Summary") - st.write(summary) - st.subheader("📜 Full Transcript") - with st.expander("View transcript content"): - st.text(transcript) - st.download_button( - label="💾 Download Transcript", - data=transcript, - file_name=f"{Path(selected_file).stem}_transcript.txt", - mime="text/plain" + # Update progress + status_text.text("Extracting audio...") + progress_bar.progress(10) + + # Process based on selected features + if use_diarization and DIARIZATION_AVAILABLE and hf_token: + # Transcribe with speaker diarization + status_text.text("Transcribing with speaker diarization...") + num_speakers_arg = int(num_speakers) if num_speakers > 0 else None + diarized_segments, diarized_transcript = transcribe_with_diarization( + selected_file, + whisper_model=transcription_model, + num_speakers=num_speakers_arg, + use_gpu=use_gpu, + hf_token=hf_token ) + segments = diarized_segments + transcript = diarized_transcript + elif use_translation and TRANSLATION_AVAILABLE: + # Transcribe and translate + status_text.text("Transcribing and translating...") + original_segments, translated_segments, original_transcript, translated_transcript = transcribe_and_translate( + selected_file, + whisper_model=transcription_model, + target_lang=target_lang, + use_gpu=use_gpu + ) + segments = translated_segments + transcript = translated_transcript + # Store original for display + original_text = original_transcript + else: + # Standard transcription + status_text.text("Transcribing audio...") + segments, transcript = transcribe_audio( + selected_file, + model=transcription_model, + use_cache=use_cache, + use_gpu=use_gpu, + memory_fraction=memory_fraction + ) + + progress_bar.progress(50) + + if transcript: + # Extract keywords if requested + keyword_timestamps = None + entity_timestamps = None + if use_keywords and KEYWORD_EXTRACTION_AVAILABLE: + status_text.text("Extracting keywords...") + keyword_timestamps, entity_timestamps = extract_keywords_from_transcript( + transcript, + segments, + max_keywords=max_keywords, + use_gpu=use_gpu + ) + + # Generate keyword index + keyword_index = generate_keyword_index(keyword_timestamps, entity_timestamps) + + # Generate interactive transcript + interactive_transcript = generate_interactive_transcript( + segments, + keyword_timestamps, + entity_timestamps + ) + + # Generate summary based on selected method + status_text.text("Generating summary...") + if OLLAMA_AVAILABLE and summarization_method == "Ollama (Local)" and ollama_model: + summary = chunk_and_summarize(transcript, model=ollama_model) + if not summary: + st.warning("Ollama summarization failed. Falling back to Hugging Face.") + summary = summarize_text( + transcript, + use_gpu=use_gpu, + memory_fraction=memory_fraction + ) + else: + summary = summarize_text( + transcript, + use_gpu=use_gpu, + memory_fraction=memory_fraction + ) + + progress_bar.progress(80) + status_text.text("Preparing results...") + + # Display results in tabs + tab1, tab2, tab3 = st.tabs(["Summary", "Transcript", "Advanced"]) + + with tab1: + st.subheader("🖍 Summary") + st.write(summary) + + # If translation was used, show original language + if use_translation and TRANSLATION_AVAILABLE and 'original_text' in locals(): + with st.expander("Original Language Summary"): + original_summary = summarize_text( + original_text, + use_gpu=use_gpu, + memory_fraction=memory_fraction + ) + st.write(original_summary) + + with tab2: + st.subheader("📜 Full Transcript") + + # Show interactive transcript if keywords were extracted + if use_keywords and KEYWORD_EXTRACTION_AVAILABLE and 'interactive_transcript' in locals(): + st.markdown(interactive_transcript, unsafe_allow_html=True) + else: + st.text(transcript) + + # If translation was used, show original language + if use_translation and TRANSLATION_AVAILABLE and 'original_text' in locals(): + with st.expander("Original Language Transcript"): + st.text(original_text) + + with tab3: + # Show keyword index if available + if use_keywords and KEYWORD_EXTRACTION_AVAILABLE and 'keyword_index' in locals(): + st.subheader("🔑 Keyword Index") + st.markdown(keyword_index) + + # Show speaker information if available + if use_diarization and DIARIZATION_AVAILABLE: + st.subheader("🎙️ Speaker Information") + speakers = set(segment.get('speaker', 'UNKNOWN') for segment in segments) + st.write(f"Detected {len(speakers)} speakers: {', '.join(speakers)}") + + # Count words per speaker + speaker_words = {} + for segment in segments: + speaker = segment.get('speaker', 'UNKNOWN') + words = len(segment['text'].split()) + if speaker in speaker_words: + speaker_words[speaker] += words + else: + speaker_words[speaker] = words + + # Display speaker statistics + st.write("### Speaker Statistics") + for speaker, words in speaker_words.items(): + st.write(f"- **{speaker}**: {words} words") + + # Export options + st.subheader("💾 Export Options") + export_cols = st.columns(len(export_format)) + + output_base = Path(selected_file).stem + + for i, format_type in enumerate(export_format): + with export_cols[i]: + if format_type == "TXT": + st.download_button( + label=f"Download {format_type}", + data=transcript, + file_name=f"{output_base}_transcript.txt", + mime="text/plain" + ) + elif format_type in ["SRT", "VTT", "ASS"]: + # Export to subtitle format + output_path = export_transcript( + transcript, + output_base, + format_type.lower(), + segments=segments, + compress=compress_exports, + compression_type=compression_type, + style=ass_style if format_type == "ASS" and ass_style else None + ) + + # Read the exported file for download + with open(output_path, 'rb') as f: + subtitle_content = f.read() + + # Determine file extension + file_ext = f".{format_type.lower()}" + if compress_exports: + file_ext += ".gz" if compression_type == "gzip" else ".zip" + + st.download_button( + label=f"Download {format_type}", + data=subtitle_content, + file_name=f"{output_base}{file_ext}", + mime="application/octet-stream" + ) + + # Clean up the temporary file + os.remove(output_path) + + # Complete progress + progress_bar.progress(100) + status_text.text("Processing complete!") else: st.error("❌ Failed to process recording") except Exception as e: st.error(f"An error occurred: {e}") st.write(e) # This will show the traceback in the Streamlit app + if __name__ == "__main__": main() diff --git a/install.bat b/install.bat new file mode 100644 index 0000000..6917ecb --- /dev/null +++ b/install.bat @@ -0,0 +1,25 @@ +@echo off +echo =================================================== +echo OBS Recording Transcriber - Windows Installation +echo =================================================== +echo. + +:: Check for Python +python --version > nul 2>&1 +if %errorlevel% neq 0 ( + echo Python not found! Please install Python 3.8 or higher. + echo Download from: https://www.python.org/downloads/ + echo Make sure to check "Add Python to PATH" during installation. + pause + exit /b 1 +) + +:: Run the installation script +echo Running installation script... +python install.py + +echo. +echo If the installation was successful, you can run the application with: +echo streamlit run app.py +echo. +pause \ No newline at end of file diff --git a/install.py b/install.py new file mode 100644 index 0000000..dc5f9b4 --- /dev/null +++ b/install.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python +""" +Installation script for OBS Recording Transcriber. +This script helps install all required dependencies and checks for common issues. +""" + +import os +import sys +import platform +import subprocess +import shutil +from pathlib import Path + +def print_header(text): + """Print a formatted header.""" + print("\n" + "=" * 80) + print(f" {text}") + print("=" * 80) + +def print_step(text): + """Print a step in the installation process.""" + print(f"\n>> {text}") + +def run_command(command, check=True): + """Run a shell command and return the result.""" + try: + result = subprocess.run( + command, + shell=True, + check=check, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + return result + except subprocess.CalledProcessError as e: + print(f"Error executing command: {command}") + print(f"Error message: {e.stderr}") + return None + +def check_python_version(): + """Check if Python version is 3.8 or higher.""" + print_step("Checking Python version") + version = sys.version_info + if version.major < 3 or (version.major == 3 and version.minor < 8): + print(f"Python 3.8 or higher is required. You have {sys.version}") + print("Please upgrade your Python installation.") + return False + print(f"Python version: {sys.version}") + return True + +def check_ffmpeg(): + """Check if FFmpeg is installed.""" + print_step("Checking FFmpeg installation") + result = shutil.which("ffmpeg") + if result is None: + print("FFmpeg not found in PATH.") + print("Please install FFmpeg:") + if platform.system() == "Windows": + print(" - Download from: https://www.gyan.dev/ffmpeg/builds/") + print(" - Extract and add the bin folder to your PATH") + elif platform.system() == "Darwin": # macOS + print(" - Install with Homebrew: brew install ffmpeg") + else: # Linux + print(" - Install with apt: sudo apt update && sudo apt install ffmpeg") + return False + + # Check FFmpeg version + version_result = run_command("ffmpeg -version") + if version_result: + print(f"FFmpeg is installed: {version_result.stdout.splitlines()[0]}") + return True + return False + +def check_gpu(): + """Check for GPU availability.""" + print_step("Checking GPU availability") + + # Check for NVIDIA GPU + if platform.system() == "Windows": + nvidia_smi = shutil.which("nvidia-smi") + if nvidia_smi: + result = run_command("nvidia-smi", check=False) + if result and result.returncode == 0: + print("NVIDIA GPU detected:") + for line in result.stdout.splitlines()[:10]: + print(f" {line}") + return "nvidia" + + # Check for Apple Silicon + if platform.system() == "Darwin" and platform.machine() == "arm64": + print("Apple Silicon (M1/M2) detected") + return "apple" + + print("No GPU detected or GPU drivers not installed. CPU will be used for processing.") + return "cpu" + +def setup_virtual_env(): + """Set up a virtual environment.""" + print_step("Setting up virtual environment") + + # Check if venv module is available + try: + import venv + print("Python venv module is available") + except ImportError: + print("Python venv module is not available. Please install it.") + return False + + # Create virtual environment if it doesn't exist + venv_path = Path("venv") + if venv_path.exists(): + print(f"Virtual environment already exists at {venv_path}") + activate_venv() + return True + + print(f"Creating virtual environment at {venv_path}") + try: + subprocess.run([sys.executable, "-m", "venv", "venv"], check=True) + print("Virtual environment created successfully") + activate_venv() + return True + except subprocess.CalledProcessError as e: + print(f"Error creating virtual environment: {e}") + return False + +def activate_venv(): + """Activate the virtual environment.""" + print_step("Activating virtual environment") + + venv_path = Path("venv") + if not venv_path.exists(): + print("Virtual environment not found") + return False + + # Get the path to the activate script + if platform.system() == "Windows": + activate_script = venv_path / "Scripts" / "activate.bat" + activate_cmd = f"call {activate_script}" + else: + activate_script = venv_path / "bin" / "activate" + activate_cmd = f"source {activate_script}" + + print(f"To activate the virtual environment, run:") + print(f" {activate_cmd}") + + # We can't actually activate the venv in this script because it would only + # affect the subprocess, not the parent process. We just provide instructions. + return True + +def install_pytorch(gpu_type): + """Install PyTorch with appropriate GPU support.""" + print_step("Installing PyTorch") + + if gpu_type == "nvidia": + print("Installing PyTorch with CUDA support") + cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118" + elif gpu_type == "apple": + print("Installing PyTorch with MPS support") + cmd = "pip install torch torchvision torchaudio" + else: + print("Installing PyTorch (CPU version)") + cmd = "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu" + + result = run_command(cmd) + if result and result.returncode == 0: + print("PyTorch installed successfully") + return True + else: + print("Failed to install PyTorch") + return False + +def install_dependencies(): + """Install dependencies from requirements.txt.""" + print_step("Installing dependencies from requirements.txt") + + requirements_path = Path("requirements.txt") + if not requirements_path.exists(): + print("requirements.txt not found") + return False + + result = run_command("pip install -r requirements.txt") + if result and result.returncode == 0: + print("Dependencies installed successfully") + return True + else: + print("Some dependencies failed to install. See error messages above.") + return False + +def install_tokenizers(): + """Install tokenizers package separately.""" + print_step("Installing tokenizers package") + + # First try the normal installation + result = run_command("pip install tokenizers", check=False) + if result and result.returncode == 0: + print("Tokenizers installed successfully") + return True + + # If that fails, try the no-binary option + print("Standard installation failed, trying alternative method...") + result = run_command("pip install tokenizers --no-binary tokenizers", check=False) + if result and result.returncode == 0: + print("Tokenizers installed successfully with alternative method") + return True + + print("Failed to install tokenizers. You may need to install Rust or Visual C++ Build Tools.") + if platform.system() == "Windows": + print("Download Visual C++ Build Tools: https://visualstudio.microsoft.com/visual-cpp-build-tools/") + print("Install Rust: https://rustup.rs/") + return False + +def check_installation(): + """Verify the installation by importing key packages.""" + print_step("Verifying installation") + + packages_to_check = [ + "streamlit", + "torch", + "transformers", + "whisper", + "numpy", + "sklearn" + ] + + all_successful = True + for package in packages_to_check: + try: + __import__(package) + print(f"✓ {package} imported successfully") + except ImportError: + print(f"✗ Failed to import {package}") + all_successful = False + + # Check optional packages + optional_packages = [ + "pyannote.audio", + "iso639" + ] + + print("\nChecking optional packages:") + for package in optional_packages: + try: + if package == "pyannote.audio": + # Just try to import pyannote + __import__("pyannote") + else: + __import__(package) + print(f"✓ {package} imported successfully") + except ImportError: + print(f"⚠ {package} not available (required for some advanced features)") + + return all_successful + +def main(): + """Main installation function.""" + print_header("OBS Recording Transcriber - Installation Script") + + # Check prerequisites + if not check_python_version(): + return + + ffmpeg_available = check_ffmpeg() + gpu_type = check_gpu() + + # Setup environment + if not setup_virtual_env(): + print("Failed to set up virtual environment. Continuing with system Python...") + + # Install packages + print("\nReady to install packages. Make sure your virtual environment is activated.") + input("Press Enter to continue...") + + install_pytorch(gpu_type) + install_dependencies() + install_tokenizers() + + # Verify installation + success = check_installation() + + print_header("Installation Summary") + print(f"Python: {'✓ OK' if check_python_version() else '✗ Needs upgrade'}") + print(f"FFmpeg: {'✓ Installed' if ffmpeg_available else '✗ Not found'}") + print(f"GPU Support: {gpu_type.upper()}") + print(f"Dependencies: {'✓ Installed' if success else '⚠ Some issues'}") + + print("\nNext steps:") + if not ffmpeg_available: + print("1. Install FFmpeg (required for audio processing)") + + print("1. Activate your virtual environment:") + if platform.system() == "Windows": + print(" venv\\Scripts\\activate") + else: + print(" source venv/bin/activate") + + print("2. Run the application:") + print(" streamlit run app.py") + + print("\nFor advanced features like speaker diarization:") + print("1. Get a HuggingFace token: https://huggingface.co/settings/tokens") + print("2. Request access to pyannote models: https://huggingface.co/pyannote/speaker-diarization-3.0") + + print("\nSee INSTALLATION.md for more details and troubleshooting.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/install.sh b/install.sh new file mode 100644 index 0000000..7298af9 --- /dev/null +++ b/install.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +echo "===================================================" +echo " OBS Recording Transcriber - Unix Installation" +echo "===================================================" +echo + +# Check for Python +if ! command -v python3 &> /dev/null; then + echo "Python 3 not found! Please install Python 3.8 or higher." + echo "For Ubuntu/Debian: sudo apt update && sudo apt install python3 python3-pip python3-venv" + echo "For macOS: brew install python3" + exit 1 +fi + +# Make the script executable +chmod +x install.py + +# Run the installation script +echo "Running installation script..." +python3 ./install.py + +echo +echo "If the installation was successful, you can run the application with:" +echo "streamlit run app.py" +echo \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 29da1a6..9b22b85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,42 @@ +# OBS Recording Transcriber Dependencies +# Core dependencies streamlit==1.26.0 moviepy==1.0.3 -whisper -transformers==4.21.1 +openai-whisper>=20230314 +transformers>=4.21.1 torch>=1.7.0 +torchaudio>=0.7.0 +requests>=2.28.0 +humanize>=4.6.0 + +# Phase 2 dependencies +scikit-learn>=1.0.0 +numpy>=1.20.0 + +# Phase 3 dependencies +pyannote.audio>=2.1.1 +iso639>=0.1.4 +protobuf>=3.20.0,<4.0.0 +tokenizers>=0.13.2 +scipy>=1.7.0 +matplotlib>=3.5.0 +soundfile>=0.10.3 +ffmpeg-python>=0.2.0 + +# Optional: Ollama Python client (uncomment to install) +# ollama + +# Installation notes: +# 1. For Windows users, you may need to install PyTorch separately: +# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 +# +# 2. For tokenizers issues, try installing Visual C++ Build Tools: +# https://visualstudio.microsoft.com/visual-cpp-build-tools/ +# +# 3. For pyannote.audio, you'll need a HuggingFace token with access to: +# https://huggingface.co/pyannote/speaker-diarization-3.0 +# +# 4. FFmpeg is required for audio processing: +# Windows: https://www.gyan.dev/ffmpeg/builds/ +# Mac: brew install ffmpeg +# Linux: apt-get install ffmpeg diff --git a/utils/cache.py b/utils/cache.py new file mode 100644 index 0000000..b5a4d37 --- /dev/null +++ b/utils/cache.py @@ -0,0 +1,205 @@ +""" +Caching utilities for the OBS Recording Transcriber. +Provides functions to cache and retrieve transcription and summarization results. +""" + +import json +import hashlib +import os +from pathlib import Path +import logging +import time + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Default cache directory +CACHE_DIR = Path.home() / ".obs_transcriber_cache" + + +def get_file_hash(file_path): + """ + Generate a hash for a file based on its content and modification time. + + Args: + file_path (Path): Path to the file + + Returns: + str: Hash string representing the file + """ + file_path = Path(file_path) + if not file_path.exists(): + return None + + # Get file stats + stats = file_path.stat() + file_size = stats.st_size + mod_time = stats.st_mtime + + # Create a hash based on path, size and modification time + # This is faster than hashing the entire file content + hash_input = f"{file_path.absolute()}|{file_size}|{mod_time}" + return hashlib.md5(hash_input.encode()).hexdigest() + + +def get_cache_path(file_path, model=None, operation=None): + """ + Get the cache file path for a given input file and operation. + + Args: + file_path (Path): Path to the original file + model (str, optional): Model used for processing + operation (str, optional): Operation type (e.g., 'transcribe', 'summarize') + + Returns: + Path: Path to the cache file + """ + file_path = Path(file_path) + file_hash = get_file_hash(file_path) + + if not file_hash: + return None + + # Create cache directory if it doesn't exist + cache_dir = CACHE_DIR + cache_dir.mkdir(parents=True, exist_ok=True) + + # Create a cache filename based on the hash and optional parameters + cache_name = file_hash + if model: + cache_name += f"_{model}" + if operation: + cache_name += f"_{operation}" + + return cache_dir / f"{cache_name}.json" + + +def save_to_cache(file_path, data, model=None, operation=None): + """ + Save data to cache. + + Args: + file_path (Path): Path to the original file + data (dict): Data to cache + model (str, optional): Model used for processing + operation (str, optional): Operation type + + Returns: + bool: True if successful, False otherwise + """ + cache_path = get_cache_path(file_path, model, operation) + if not cache_path: + return False + + try: + # Add metadata to the cached data + cache_data = { + "original_file": str(Path(file_path).absolute()), + "timestamp": time.time(), + "model": model, + "operation": operation, + "data": data + } + + with open(cache_path, 'w', encoding='utf-8') as f: + json.dump(cache_data, f, ensure_ascii=False, indent=2) + + logger.info(f"Cached data saved to {cache_path}") + return True + except Exception as e: + logger.error(f"Error saving cache: {e}") + return False + + +def load_from_cache(file_path, model=None, operation=None, max_age=None): + """ + Load data from cache if available and not expired. + + Args: + file_path (Path): Path to the original file + model (str, optional): Model used for processing + operation (str, optional): Operation type + max_age (float, optional): Maximum age of cache in seconds + + Returns: + dict or None: Cached data or None if not available + """ + cache_path = get_cache_path(file_path, model, operation) + if not cache_path or not cache_path.exists(): + return None + + try: + with open(cache_path, 'r', encoding='utf-8') as f: + cache_data = json.load(f) + + # Check if cache is expired + if max_age is not None: + cache_time = cache_data.get("timestamp", 0) + if time.time() - cache_time > max_age: + logger.info(f"Cache expired for {file_path}") + return None + + logger.info(f"Loaded data from cache: {cache_path}") + return cache_data.get("data") + except Exception as e: + logger.error(f"Error loading cache: {e}") + return None + + +def clear_cache(max_age=None): + """ + Clear all cache files or only expired ones. + + Args: + max_age (float, optional): Maximum age of cache in seconds + + Returns: + int: Number of files deleted + """ + if not CACHE_DIR.exists(): + return 0 + + count = 0 + for cache_file in CACHE_DIR.glob("*.json"): + try: + if max_age is not None: + # Check if file is expired + with open(cache_file, 'r', encoding='utf-8') as f: + cache_data = json.load(f) + + cache_time = cache_data.get("timestamp", 0) + if time.time() - cache_time <= max_age: + continue # Skip non-expired files + + # Delete the file + os.remove(cache_file) + count += 1 + except Exception as e: + logger.error(f"Error deleting cache file {cache_file}: {e}") + + logger.info(f"Cleared {count} cache files") + return count + + +def get_cache_size(): + """ + Get the total size of the cache directory. + + Returns: + tuple: (size_bytes, file_count) + """ + if not CACHE_DIR.exists(): + return 0, 0 + + total_size = 0 + file_count = 0 + + for cache_file in CACHE_DIR.glob("*.json"): + try: + total_size += cache_file.stat().st_size + file_count += 1 + except Exception: + pass + + return total_size, file_count \ No newline at end of file diff --git a/utils/diarization.py b/utils/diarization.py new file mode 100644 index 0000000..bc807ff --- /dev/null +++ b/utils/diarization.py @@ -0,0 +1,226 @@ +""" +Speaker diarization utilities for the OBS Recording Transcriber. +Provides functions to identify different speakers in audio recordings. +""" + +import logging +import os +import numpy as np +from pathlib import Path +import torch +from pyannote.audio import Pipeline +from pyannote.core import Segment +import whisper + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Try to import GPU utilities, but don't fail if not available +try: + from utils.gpu_utils import get_optimal_device + GPU_UTILS_AVAILABLE = True +except ImportError: + GPU_UTILS_AVAILABLE = False + +# Default HuggingFace auth token environment variable +HF_TOKEN_ENV = "HF_TOKEN" + + +def get_diarization_pipeline(use_gpu=True, hf_token=None): + """ + Initialize the speaker diarization pipeline. + + Args: + use_gpu (bool): Whether to use GPU acceleration if available + hf_token (str, optional): HuggingFace API token for accessing the model + + Returns: + Pipeline or None: Diarization pipeline if successful, None otherwise + """ + # Check if token is provided or in environment + if hf_token is None: + hf_token = os.environ.get(HF_TOKEN_ENV) + if hf_token is None: + logger.error(f"HuggingFace token not provided. Set {HF_TOKEN_ENV} environment variable or pass token directly.") + return None + + try: + # Configure device + device = torch.device("cpu") + if use_gpu and GPU_UTILS_AVAILABLE: + device = get_optimal_device() + logger.info(f"Using device: {device} for diarization") + + # Initialize the pipeline + pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.0", + use_auth_token=hf_token + ) + + # Move to appropriate device + if device.type == "cuda": + pipeline = pipeline.to(torch.device(device)) + + return pipeline + except Exception as e: + logger.error(f"Error initializing diarization pipeline: {e}") + return None + + +def diarize_audio(audio_path, pipeline=None, num_speakers=None, use_gpu=True, hf_token=None): + """ + Perform speaker diarization on an audio file. + + Args: + audio_path (Path): Path to the audio file + pipeline (Pipeline, optional): Pre-initialized diarization pipeline + num_speakers (int, optional): Number of speakers (if known) + use_gpu (bool): Whether to use GPU acceleration if available + hf_token (str, optional): HuggingFace API token + + Returns: + dict: Dictionary mapping time segments to speaker IDs + """ + audio_path = Path(audio_path) + + # Initialize pipeline if not provided + if pipeline is None: + pipeline = get_diarization_pipeline(use_gpu, hf_token) + if pipeline is None: + return None + + try: + # Run diarization + logger.info(f"Running speaker diarization on {audio_path}") + diarization = pipeline(audio_path, num_speakers=num_speakers) + + # Extract speaker segments + speaker_segments = {} + for turn, _, speaker in diarization.itertracks(yield_label=True): + segment = (turn.start, turn.end) + speaker_segments[segment] = speaker + + return speaker_segments + except Exception as e: + logger.error(f"Error during diarization: {e}") + return None + + +def apply_diarization_to_transcript(transcript_segments, speaker_segments): + """ + Apply speaker diarization results to transcript segments. + + Args: + transcript_segments (list): List of transcript segments with timing info + speaker_segments (dict): Dictionary mapping time segments to speaker IDs + + Returns: + list: Updated transcript segments with speaker information + """ + if not speaker_segments: + return transcript_segments + + # Convert speaker segments to a more usable format + speaker_ranges = [(Segment(start, end), speaker) + for (start, end), speaker in speaker_segments.items()] + + # Update transcript segments with speaker information + for segment in transcript_segments: + segment_start = segment['start'] + segment_end = segment['end'] + segment_range = Segment(segment_start, segment_end) + + # Find overlapping speaker segments + overlaps = [] + for (spk_range, speaker) in speaker_ranges: + overlap = segment_range.intersect(spk_range) + if overlap: + overlaps.append((overlap.duration, speaker)) + + # Assign the speaker with the most overlap + if overlaps: + overlaps.sort(reverse=True) # Sort by duration (descending) + segment['speaker'] = overlaps[0][1] + else: + segment['speaker'] = "UNKNOWN" + + return transcript_segments + + +def format_transcript_with_speakers(transcript_segments): + """ + Format transcript with speaker labels. + + Args: + transcript_segments (list): List of transcript segments with speaker info + + Returns: + str: Formatted transcript with speaker labels + """ + formatted_lines = [] + current_speaker = None + + for segment in transcript_segments: + speaker = segment.get('speaker', 'UNKNOWN') + text = segment['text'].strip() + + # Add speaker label when speaker changes + if speaker != current_speaker: + formatted_lines.append(f"\n[{speaker}]") + current_speaker = speaker + + formatted_lines.append(text) + + return " ".join(formatted_lines) + + +def transcribe_with_diarization(audio_path, whisper_model="base", num_speakers=None, + use_gpu=True, hf_token=None): + """ + Transcribe audio with speaker diarization. + + Args: + audio_path (Path): Path to the audio file + whisper_model (str): Whisper model size to use + num_speakers (int, optional): Number of speakers (if known) + use_gpu (bool): Whether to use GPU acceleration if available + hf_token (str, optional): HuggingFace API token + + Returns: + tuple: (diarized_segments, formatted_transcript) + """ + audio_path = Path(audio_path) + + # Configure device + device = torch.device("cpu") + if use_gpu and GPU_UTILS_AVAILABLE: + device = get_optimal_device() + + try: + # Step 1: Transcribe audio with Whisper + logger.info(f"Transcribing audio with Whisper model: {whisper_model}") + model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu") + result = model.transcribe(str(audio_path)) + transcript_segments = result["segments"] + + # Step 2: Perform speaker diarization + logger.info("Performing speaker diarization") + pipeline = get_diarization_pipeline(use_gpu, hf_token) + if pipeline is None: + logger.warning("Diarization pipeline not available, returning transcript without speakers") + return transcript_segments, result["text"] + + speaker_segments = diarize_audio(audio_path, pipeline, num_speakers, use_gpu) + + # Step 3: Apply diarization to transcript + if speaker_segments: + diarized_segments = apply_diarization_to_transcript(transcript_segments, speaker_segments) + formatted_transcript = format_transcript_with_speakers(diarized_segments) + return diarized_segments, formatted_transcript + else: + return transcript_segments, result["text"] + + except Exception as e: + logger.error(f"Error in transcribe_with_diarization: {e}") + return None, None \ No newline at end of file diff --git a/utils/export.py b/utils/export.py new file mode 100644 index 0000000..e540c76 --- /dev/null +++ b/utils/export.py @@ -0,0 +1,284 @@ +""" +Subtitle export utilities for the OBS Recording Transcriber. +Supports exporting transcripts to SRT, ASS, and WebVTT subtitle formats. +""" + +from pathlib import Path +import re +from datetime import timedelta +import gzip +import zipfile +import logging + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def format_timestamp_srt(timestamp_ms): + """ + Format a timestamp in milliseconds to SRT format (HH:MM:SS,mmm). + + Args: + timestamp_ms (int): Timestamp in milliseconds + + Returns: + str: Formatted timestamp string + """ + hours, remainder = divmod(timestamp_ms, 3600000) + minutes, remainder = divmod(remainder, 60000) + seconds, milliseconds = divmod(remainder, 1000) + return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}" + + +def format_timestamp_ass(timestamp_ms): + """ + Format a timestamp in milliseconds to ASS format (H:MM:SS.cc). + + Args: + timestamp_ms (int): Timestamp in milliseconds + + Returns: + str: Formatted timestamp string + """ + hours, remainder = divmod(timestamp_ms, 3600000) + minutes, remainder = divmod(remainder, 60000) + seconds, remainder = divmod(remainder, 1000) + centiseconds = remainder // 10 + return f"{int(hours)}:{int(minutes):02d}:{int(seconds):02d}.{int(centiseconds):02d}" + + +def format_timestamp_vtt(timestamp_ms): + """ + Format a timestamp in milliseconds to WebVTT format (HH:MM:SS.mmm). + + Args: + timestamp_ms (int): Timestamp in milliseconds + + Returns: + str: Formatted timestamp string + """ + hours, remainder = divmod(timestamp_ms, 3600000) + minutes, remainder = divmod(remainder, 60000) + seconds, milliseconds = divmod(remainder, 1000) + return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}" + + +def export_to_srt(segments, output_path): + """ + Export transcript segments to SRT format. + + Args: + segments (list): List of transcript segments with start, end, and text + output_path (Path): Path to save the SRT file + + Returns: + Path: Path to the saved SRT file + """ + with open(output_path, 'w', encoding='utf-8') as f: + for i, segment in enumerate(segments, 1): + start_time = format_timestamp_srt(int(segment['start'] * 1000)) + end_time = format_timestamp_srt(int(segment['end'] * 1000)) + + f.write(f"{i}\n") + f.write(f"{start_time} --> {end_time}\n") + f.write(f"{segment['text'].strip()}\n\n") + + return output_path + + +def export_to_ass(segments, output_path, video_width=1920, video_height=1080, style=None): + """ + Export transcript segments to ASS format with styling. + + Args: + segments (list): List of transcript segments with start, end, and text + output_path (Path): Path to save the ASS file + video_width (int): Width of the video in pixels + video_height (int): Height of the video in pixels + style (dict, optional): Custom style parameters + + Returns: + Path: Path to the saved ASS file + """ + # Default style + default_style = { + "fontname": "Arial", + "fontsize": "48", + "primary_color": "&H00FFFFFF", # White + "secondary_color": "&H000000FF", # Blue + "outline_color": "&H00000000", # Black + "back_color": "&H80000000", # Semi-transparent black + "bold": "-1", # True + "italic": "0", # False + "alignment": "2", # Bottom center + } + + # Apply custom style if provided + if style: + default_style.update(style) + + # ASS header template + ass_header = f"""[Script Info] +Title: Transcription +ScriptType: v4.00+ +WrapStyle: 0 +PlayResX: {video_width} +PlayResY: {video_height} +ScaledBorderAndShadow: yes + +[V4+ Styles] +Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding +Style: Default,{default_style['fontname']},{default_style['fontsize']},{default_style['primary_color']},{default_style['secondary_color']},{default_style['outline_color']},{default_style['back_color']},{default_style['bold']},{default_style['italic']},0,0,100,100,0,0,1,2,2,{default_style['alignment']},10,10,10,1 + +[Events] +Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text +""" + + with open(output_path, 'w', encoding='utf-8') as f: + f.write(ass_header) + + for segment in segments: + start_time = format_timestamp_ass(int(segment['start'] * 1000)) + end_time = format_timestamp_ass(int(segment['end'] * 1000)) + text = segment['text'].strip().replace('\n', '\\N') + + f.write(f"Dialogue: 0,{start_time},{end_time},Default,,0,0,0,,{text}\n") + + return output_path + + +def export_to_vtt(segments, output_path): + """ + Export transcript segments to WebVTT format. + + Args: + segments (list): List of transcript segments with start, end, and text + output_path (Path): Path to save the WebVTT file + + Returns: + Path: Path to the saved WebVTT file + """ + with open(output_path, 'w', encoding='utf-8') as f: + # WebVTT header + f.write("WEBVTT\n\n") + + for i, segment in enumerate(segments, 1): + start_time = format_timestamp_vtt(int(segment['start'] * 1000)) + end_time = format_timestamp_vtt(int(segment['end'] * 1000)) + + # Optional cue identifier + f.write(f"{i}\n") + f.write(f"{start_time} --> {end_time}\n") + f.write(f"{segment['text'].strip()}\n\n") + + return output_path + + +def transcript_to_segments(transcript, segment_duration=5.0): + """ + Convert a plain transcript to timed segments for subtitle export. + Used when the original segments are not available. + + Args: + transcript (str): Full transcript text + segment_duration (float): Duration of each segment in seconds + + Returns: + list: List of segments with start, end, and text + """ + # Split transcript into sentences + sentences = re.split(r'(?<=[.!?])\s+', transcript) + segments = [] + + current_time = 0.0 + for sentence in sentences: + if not sentence.strip(): + continue + + # Estimate duration based on word count (approx. 2.5 words per second) + word_count = len(sentence.split()) + duration = max(2.0, word_count / 2.5) + + segments.append({ + 'start': current_time, + 'end': current_time + duration, + 'text': sentence + }) + + current_time += duration + + return segments + + +def compress_file(input_path, compression_type='gzip'): + """ + Compress a file using the specified compression method. + + Args: + input_path (Path): Path to the file to compress + compression_type (str): Type of compression ('gzip' or 'zip') + + Returns: + Path: Path to the compressed file + """ + input_path = Path(input_path) + + if compression_type == 'gzip': + output_path = input_path.with_suffix(input_path.suffix + '.gz') + with open(input_path, 'rb') as f_in: + with gzip.open(output_path, 'wb') as f_out: + f_out.write(f_in.read()) + return output_path + + elif compression_type == 'zip': + output_path = input_path.with_suffix('.zip') + with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + zipf.write(input_path, arcname=input_path.name) + return output_path + + else: + logger.warning(f"Unsupported compression type: {compression_type}") + return input_path + + +def export_transcript(transcript, output_path, format_type='srt', segments=None, + compress=False, compression_type='gzip', style=None): + """ + Export transcript to the specified subtitle format. + + Args: + transcript (str): Full transcript text + output_path (Path): Base path for the output file (without extension) + format_type (str): 'srt', 'ass', or 'vtt' + segments (list, optional): List of transcript segments with timing information + compress (bool): Whether to compress the output file + compression_type (str): Type of compression ('gzip' or 'zip') + style (dict, optional): Custom style parameters for ASS format + + Returns: + Path: Path to the saved subtitle file + """ + output_path = Path(output_path) + + # If segments are not provided, create them from the transcript + if segments is None: + segments = transcript_to_segments(transcript) + + if format_type.lower() == 'srt': + output_file = output_path.with_suffix('.srt') + result_path = export_to_srt(segments, output_file) + elif format_type.lower() == 'ass': + output_file = output_path.with_suffix('.ass') + result_path = export_to_ass(segments, output_file, style=style) + elif format_type.lower() == 'vtt': + output_file = output_path.with_suffix('.vtt') + result_path = export_to_vtt(segments, output_file) + else: + raise ValueError(f"Unsupported format type: {format_type}. Use 'srt', 'ass', or 'vtt'.") + + # Compress the file if requested + if compress: + result_path = compress_file(result_path, compression_type) + + return result_path \ No newline at end of file diff --git a/utils/gpu_utils.py b/utils/gpu_utils.py new file mode 100644 index 0000000..cf7b1fc --- /dev/null +++ b/utils/gpu_utils.py @@ -0,0 +1,202 @@ +""" +GPU utilities for the OBS Recording Transcriber. +Provides functions to detect and configure GPU acceleration. +""" + +import logging +import os +import platform +import subprocess +import torch + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def get_gpu_info(): + """ + Get information about available GPUs. + + Returns: + dict: Information about available GPUs + """ + gpu_info = { + "cuda_available": torch.cuda.is_available(), + "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, + "cuda_devices": [], + "mps_available": hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + } + + # Get CUDA device information + if gpu_info["cuda_available"]: + for i in range(gpu_info["cuda_device_count"]): + device_props = torch.cuda.get_device_properties(i) + gpu_info["cuda_devices"].append({ + "index": i, + "name": device_props.name, + "total_memory": device_props.total_memory, + "compute_capability": f"{device_props.major}.{device_props.minor}" + }) + + return gpu_info + + +def get_optimal_device(): + """ + Get the optimal device for computation. + + Returns: + torch.device: The optimal device (cuda, mps, or cpu) + """ + if torch.cuda.is_available(): + # If multiple GPUs are available, select the one with the most memory + if torch.cuda.device_count() > 1: + max_memory = 0 + best_device = 0 + for i in range(torch.cuda.device_count()): + device_props = torch.cuda.get_device_properties(i) + if device_props.total_memory > max_memory: + max_memory = device_props.total_memory + best_device = i + return torch.device(f"cuda:{best_device}") + return torch.device("cuda:0") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + + +def set_memory_limits(memory_fraction=0.8): + """ + Set memory limits for GPU usage. + + Args: + memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0) + + Returns: + bool: True if successful, False otherwise + """ + if not torch.cuda.is_available(): + return False + + try: + # Import only if CUDA is available + import torch.cuda + + # Set memory fraction for each device + for i in range(torch.cuda.device_count()): + torch.cuda.set_per_process_memory_fraction(memory_fraction, i) + + return True + except Exception as e: + logger.error(f"Error setting memory limits: {e}") + return False + + +def optimize_for_inference(): + """ + Apply optimizations for inference. + + Returns: + bool: True if successful, False otherwise + """ + try: + # Set deterministic algorithms for reproducibility + torch.backends.cudnn.deterministic = True + + # Enable cuDNN benchmark mode for optimized performance + torch.backends.cudnn.benchmark = True + + # Disable gradient calculation for inference + torch.set_grad_enabled(False) + + return True + except Exception as e: + logger.error(f"Error optimizing for inference: {e}") + return False + + +def get_recommended_batch_size(model_size="base"): + """ + Get recommended batch size based on available GPU memory. + + Args: + model_size (str): Size of the model (tiny, base, small, medium, large) + + Returns: + int: Recommended batch size + """ + # Default batch sizes for CPU + default_batch_sizes = { + "tiny": 16, + "base": 8, + "small": 4, + "medium": 2, + "large": 1 + } + + # If CUDA is not available, return default CPU batch size + if not torch.cuda.is_available(): + return default_batch_sizes.get(model_size, 1) + + # Approximate memory requirements in GB for different model sizes + memory_requirements = { + "tiny": 1, + "base": 2, + "small": 4, + "medium": 8, + "large": 16 + } + + # Get available GPU memory + device = get_optimal_device() + if device.type == "cuda": + device_idx = device.index + device_props = torch.cuda.get_device_properties(device_idx) + available_memory_gb = device_props.total_memory / (1024 ** 3) + + # Calculate batch size based on available memory + model_memory = memory_requirements.get(model_size, 2) + max_batch_size = int(available_memory_gb / model_memory) + + # Ensure batch size is at least 1 + return max(1, max_batch_size) + + # For MPS or other devices, return default + return default_batch_sizes.get(model_size, 1) + + +def configure_gpu(model_size="base", memory_fraction=0.8): + """ + Configure GPU settings for optimal performance. + + Args: + model_size (str): Size of the model (tiny, base, small, medium, large) + memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0) + + Returns: + dict: Configuration information + """ + gpu_info = get_gpu_info() + device = get_optimal_device() + + # Set memory limits if using CUDA + if device.type == "cuda": + set_memory_limits(memory_fraction) + + # Apply inference optimizations + optimize_for_inference() + + # Get recommended batch size + batch_size = get_recommended_batch_size(model_size) + + config = { + "device": device, + "batch_size": batch_size, + "gpu_info": gpu_info, + "memory_fraction": memory_fraction if device.type == "cuda" else None + } + + logger.info(f"GPU configuration: Using {device} with batch size {batch_size}") + return config \ No newline at end of file diff --git a/utils/keyword_extraction.py b/utils/keyword_extraction.py new file mode 100644 index 0000000..d9cdbb0 --- /dev/null +++ b/utils/keyword_extraction.py @@ -0,0 +1,325 @@ +""" +Keyword extraction utilities for the OBS Recording Transcriber. +Provides functions to extract keywords and link them to timestamps. +""" + +import logging +import re +import torch +import numpy as np +from pathlib import Path +from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification +from sklearn.feature_extraction.text import TfidfVectorizer +from collections import Counter + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Try to import GPU utilities, but don't fail if not available +try: + from utils.gpu_utils import get_optimal_device + GPU_UTILS_AVAILABLE = True +except ImportError: + GPU_UTILS_AVAILABLE = False + +# Default models +NER_MODEL = "dslim/bert-base-NER" + + +def extract_keywords_tfidf(text, max_keywords=10, ngram_range=(1, 2)): + """ + Extract keywords using TF-IDF. + + Args: + text (str): Text to extract keywords from + max_keywords (int): Maximum number of keywords to extract + ngram_range (tuple): Range of n-grams to consider + + Returns: + list: List of (keyword, score) tuples + """ + try: + # Preprocess text + text = text.lower() + + # Remove common stopwords + stopwords = {'a', 'an', 'the', 'and', 'or', 'but', 'if', 'because', 'as', 'what', + 'when', 'where', 'how', 'who', 'which', 'this', 'that', 'these', 'those', + 'then', 'just', 'so', 'than', 'such', 'both', 'through', 'about', 'for', + 'is', 'of', 'while', 'during', 'to', 'from', 'in', 'out', 'on', 'off', 'by'} + + # Create sentences for better TF-IDF analysis + sentences = re.split(r'[.!?]', text) + sentences = [s.strip() for s in sentences if s.strip()] + + if not sentences: + return [] + + # Apply TF-IDF + vectorizer = TfidfVectorizer( + max_features=100, + stop_words=stopwords, + ngram_range=ngram_range + ) + + try: + tfidf_matrix = vectorizer.fit_transform(sentences) + feature_names = vectorizer.get_feature_names_out() + + # Calculate average TF-IDF score across all sentences + avg_tfidf = np.mean(tfidf_matrix.toarray(), axis=0) + + # Get top keywords + keywords = [(feature_names[i], avg_tfidf[i]) for i in avg_tfidf.argsort()[::-1]] + + # Filter out single-character keywords and limit to max_keywords + keywords = [(k, s) for k, s in keywords if len(k) > 1][:max_keywords] + + return keywords + except ValueError as e: + logger.warning(f"TF-IDF extraction failed: {e}") + return [] + + except Exception as e: + logger.error(f"Error extracting keywords with TF-IDF: {e}") + return [] + + +def extract_named_entities(text, model=NER_MODEL, use_gpu=True): + """ + Extract named entities from text. + + Args: + text (str): Text to extract entities from + model (str): Model to use for NER + use_gpu (bool): Whether to use GPU acceleration if available + + Returns: + list: List of (entity, type) tuples + """ + # Configure device + device = torch.device("cpu") + if use_gpu and GPU_UTILS_AVAILABLE: + device = get_optimal_device() + device_arg = 0 if device.type == "cuda" else -1 + else: + device_arg = -1 + + try: + # Initialize the pipeline + ner_pipeline = pipeline("ner", model=model, device=device_arg, aggregation_strategy="simple") + + # Split text into manageable chunks if too long + max_length = 512 + if len(text) > max_length: + chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)] + else: + chunks = [text] + + # Process each chunk + all_entities = [] + for chunk in chunks: + entities = ner_pipeline(chunk) + all_entities.extend(entities) + + # Extract entity text and type + entity_info = [(entity["word"], entity["entity_group"]) for entity in all_entities] + + return entity_info + except Exception as e: + logger.error(f"Error extracting named entities: {e}") + return [] + + +def find_keyword_timestamps(segments, keywords): + """ + Find timestamps for keywords in transcript segments. + + Args: + segments (list): List of transcript segments with timing info + keywords (list): List of keywords to find + + Returns: + dict: Dictionary mapping keywords to lists of timestamps + """ + keyword_timestamps = {} + + # Convert keywords to lowercase for case-insensitive matching + if isinstance(keywords[0], tuple): + # If keywords is a list of (keyword, score) tuples + keywords_lower = [k.lower() for k, _ in keywords] + else: + # If keywords is just a list of keywords + keywords_lower = [k.lower() for k in keywords] + + # Process each segment + for segment in segments: + segment_text = segment["text"].lower() + start_time = segment["start"] + end_time = segment["end"] + + # Check each keyword + for i, keyword in enumerate(keywords_lower): + if keyword in segment_text: + # Get the original case of the keyword + original_keyword = keywords[i][0] if isinstance(keywords[0], tuple) else keywords[i] + + # Initialize the list if this is the first occurrence + if original_keyword not in keyword_timestamps: + keyword_timestamps[original_keyword] = [] + + # Add the timestamp + keyword_timestamps[original_keyword].append({ + "start": start_time, + "end": end_time, + "context": segment["text"] + }) + + return keyword_timestamps + + +def extract_keywords_from_transcript(transcript, segments, max_keywords=15, use_gpu=True): + """ + Extract keywords from transcript and link them to timestamps. + + Args: + transcript (str): Full transcript text + segments (list): List of transcript segments with timing info + max_keywords (int): Maximum number of keywords to extract + use_gpu (bool): Whether to use GPU acceleration if available + + Returns: + tuple: (keyword_timestamps, entities_with_timestamps) + """ + try: + # Extract keywords using TF-IDF + tfidf_keywords = extract_keywords_tfidf(transcript, max_keywords=max_keywords) + + # Extract named entities + entities = extract_named_entities(transcript, use_gpu=use_gpu) + + # Count entity occurrences and get the most frequent ones + entity_counter = Counter([entity for entity, _ in entities]) + top_entities = [(entity, count) for entity, count in entity_counter.most_common(max_keywords)] + + # Find timestamps for keywords and entities + keyword_timestamps = find_keyword_timestamps(segments, tfidf_keywords) + entity_timestamps = find_keyword_timestamps(segments, top_entities) + + return keyword_timestamps, entity_timestamps + + except Exception as e: + logger.error(f"Error extracting keywords from transcript: {e}") + return {}, {} + + +def generate_keyword_index(keyword_timestamps, entity_timestamps=None): + """ + Generate a keyword index with timestamps. + + Args: + keyword_timestamps (dict): Dictionary mapping keywords to timestamp lists + entity_timestamps (dict, optional): Dictionary mapping entities to timestamp lists + + Returns: + str: Formatted keyword index + """ + lines = ["# Keyword Index\n"] + + # Add keywords section + if keyword_timestamps: + lines.append("## Keywords\n") + for keyword, timestamps in sorted(keyword_timestamps.items()): + if timestamps: + times = [f"{int(ts['start'] // 60):02d}:{int(ts['start'] % 60):02d}" for ts in timestamps] + lines.append(f"- **{keyword}**: {', '.join(times)}\n") + + # Add entities section + if entity_timestamps: + lines.append("\n## Named Entities\n") + for entity, timestamps in sorted(entity_timestamps.items()): + if timestamps: + times = [f"{int(ts['start'] // 60):02d}:{int(ts['start'] % 60):02d}" for ts in timestamps] + lines.append(f"- **{entity}**: {', '.join(times)}\n") + + return "".join(lines) + + +def generate_interactive_transcript(segments, keyword_timestamps=None, entity_timestamps=None): + """ + Generate an interactive transcript with keyword highlighting. + + Args: + segments (list): List of transcript segments with timing info + keyword_timestamps (dict, optional): Dictionary mapping keywords to timestamp lists + entity_timestamps (dict, optional): Dictionary mapping entities to timestamp lists + + Returns: + str: HTML formatted interactive transcript + """ + # Combine keywords and entities + all_keywords = {} + if keyword_timestamps: + all_keywords.update(keyword_timestamps) + if entity_timestamps: + all_keywords.update(entity_timestamps) + + # Generate HTML + html = ["
"] + + for segment in segments: + start_time = segment["start"] + end_time = segment["end"] + text = segment["text"] + + # Format timestamp + timestamp = f"{int(start_time // 60):02d}:{int(start_time % 60):02d}" + + # Add speaker if available + speaker = segment.get("speaker", "") + speaker_html = f"[{speaker}] " if speaker else "" + + # Highlight keywords in text + highlighted_text = text + for keyword in all_keywords: + # Use regex to match whole words only + pattern = r'\b' + re.escape(keyword) + r'\b' + replacement = f"{keyword}" + highlighted_text = re.sub(pattern, replacement, highlighted_text, flags=re.IGNORECASE) + + # Add segment to HTML + html.append(f"

") + html.append(f"{timestamp} {speaker_html}{highlighted_text}") + html.append("

") + + html.append("
") + + return "\n".join(html) + + +def create_keyword_cloud_data(keyword_timestamps, entity_timestamps=None): + """ + Create data for a keyword cloud visualization. + + Args: + keyword_timestamps (dict): Dictionary mapping keywords to timestamp lists + entity_timestamps (dict, optional): Dictionary mapping entities to timestamp lists + + Returns: + list: List of (keyword, weight) tuples for visualization + """ + cloud_data = [] + + # Process keywords + for keyword, timestamps in keyword_timestamps.items(): + weight = len(timestamps) # Weight by occurrence count + cloud_data.append((keyword, weight)) + + # Process entities if provided + if entity_timestamps: + for entity, timestamps in entity_timestamps.items(): + weight = len(timestamps) * 1.5 # Give entities slightly higher weight + cloud_data.append((entity, weight)) + + return cloud_data \ No newline at end of file diff --git a/utils/ollama_integration.py b/utils/ollama_integration.py new file mode 100644 index 0000000..b3ddbc8 --- /dev/null +++ b/utils/ollama_integration.py @@ -0,0 +1,155 @@ +""" +Ollama integration for local AI model inference. +Provides functions to use Ollama's API for text summarization. +""" + +import requests +import json +import logging +from pathlib import Path +import os + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Default Ollama API endpoint +OLLAMA_API_URL = "http://localhost:11434/api" + + +def check_ollama_available(): + """ + Check if Ollama service is available. + + Returns: + bool: True if Ollama is available, False otherwise + """ + try: + response = requests.get(f"{OLLAMA_API_URL}/tags", timeout=2) + return response.status_code == 200 + except requests.exceptions.RequestException: + return False + + +def list_available_models(): + """ + List available models in Ollama. + + Returns: + list: List of available model names + """ + try: + response = requests.get(f"{OLLAMA_API_URL}/tags") + if response.status_code == 200: + models = response.json().get('models', []) + return [model['name'] for model in models] + return [] + except requests.exceptions.RequestException as e: + logger.error(f"Error listing Ollama models: {e}") + return [] + + +def summarize_with_ollama(text, model="llama3", max_length=150): + """ + Summarize text using Ollama's local API. + + Args: + text (str): Text to summarize + model (str): Ollama model to use + max_length (int): Maximum length of the summary + + Returns: + str: Summarized text or None if failed + """ + if not check_ollama_available(): + logger.warning("Ollama service is not available") + return None + + # Check if the model is available + available_models = list_available_models() + if model not in available_models: + logger.warning(f"Model {model} not available in Ollama. Available models: {available_models}") + return None + + # Prepare the prompt for summarization + prompt = f"Summarize the following text in about {max_length} words:\n\n{text}" + + try: + # Make the API request + response = requests.post( + f"{OLLAMA_API_URL}/generate", + json={ + "model": model, + "prompt": prompt, + "stream": False, + "options": { + "temperature": 0.3, + "top_p": 0.9, + "max_tokens": max_length * 2 # Approximate token count + } + } + ) + + if response.status_code == 200: + result = response.json() + return result.get('response', '').strip() + else: + logger.error(f"Ollama API error: {response.status_code} - {response.text}") + return None + except requests.exceptions.RequestException as e: + logger.error(f"Error communicating with Ollama: {e}") + return None + + +def chunk_and_summarize(text, model="llama3", chunk_size=4000, max_length=150): + """ + Chunk long text and summarize each chunk, then combine the summaries. + + Args: + text (str): Text to summarize + model (str): Ollama model to use + chunk_size (int): Maximum size of each chunk in characters + max_length (int): Maximum length of the final summary + + Returns: + str: Combined summary or None if failed + """ + if len(text) <= chunk_size: + return summarize_with_ollama(text, model, max_length) + + # Split text into chunks + words = text.split() + chunks = [] + current_chunk = [] + current_length = 0 + + for word in words: + if current_length + len(word) + 1 <= chunk_size: + current_chunk.append(word) + current_length += len(word) + 1 + else: + chunks.append(' '.join(current_chunk)) + current_chunk = [word] + current_length = len(word) + 1 + + if current_chunk: + chunks.append(' '.join(current_chunk)) + + # Summarize each chunk + chunk_summaries = [] + for i, chunk in enumerate(chunks): + logger.info(f"Summarizing chunk {i+1}/{len(chunks)}") + summary = summarize_with_ollama(chunk, model, max_length // len(chunks)) + if summary: + chunk_summaries.append(summary) + + if not chunk_summaries: + return None + + # If there's only one chunk summary, return it + if len(chunk_summaries) == 1: + return chunk_summaries[0] + + # Otherwise, combine the summaries and summarize again + combined_summary = " ".join(chunk_summaries) + return summarize_with_ollama(combined_summary, model, max_length) \ No newline at end of file diff --git a/utils/transcription.py b/utils/transcription.py index 4e737d2..9c2437b 100644 --- a/utils/transcription.py +++ b/utils/transcription.py @@ -1,22 +1,114 @@ import whisper from pathlib import Path from transformers import pipeline, AutoTokenizer +from utils.audio_processing import extract_audio +from utils.summarization import summarize_text +import logging +import torch + +# Try to import GPU utilities, but don't fail if not available +try: + from utils.gpu_utils import configure_gpu, get_optimal_device + GPU_UTILS_AVAILABLE = True +except ImportError: + GPU_UTILS_AVAILABLE = False + +# Try to import caching utilities, but don't fail if not available +try: + from utils.cache import load_from_cache, save_to_cache + CACHE_AVAILABLE = True +except ImportError: + CACHE_AVAILABLE = False + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) WHISPER_MODEL = "base" SUMMARIZATION_MODEL = "t5-base" -def transcribe_audio(audio_path: Path): - """Transcribe audio using Whisper.""" - model = whisper.load_model(WHISPER_MODEL) - result = model.transcribe(str(audio_path)) +def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cache_max_age=None, + use_gpu=True, memory_fraction=0.8): + """ + Transcribe audio using Whisper and return both segments and full transcript. + + Args: + audio_path (Path): Path to the audio or video file + model (str): Whisper model size to use (tiny, base, small, medium, large) + use_cache (bool): Whether to use caching + cache_max_age (float, optional): Maximum age of cache in seconds + use_gpu (bool): Whether to use GPU acceleration if available + memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0) + + Returns: + tuple: (segments, transcript) where segments is a list of dicts with timing info + """ + audio_path = Path(audio_path) + + # Check cache first if enabled + if use_cache and CACHE_AVAILABLE: + cached_data = load_from_cache(audio_path, model, "transcribe", cache_max_age) + if cached_data: + logger.info(f"Using cached transcription for {audio_path}") + return cached_data.get("segments", []), cached_data.get("transcript", "") + + # Extract audio if the input is a video file + if audio_path.suffix.lower() in ['.mp4', '.avi', '.mov', '.mkv']: + audio_path = extract_audio(audio_path) + + # Configure GPU if available and requested + device = torch.device("cpu") + if use_gpu and GPU_UTILS_AVAILABLE: + gpu_config = configure_gpu(model, memory_fraction) + device = gpu_config["device"] + logger.info(f"Using device: {device} for transcription") + + # Load the specified Whisper model + logger.info(f"Loading Whisper model: {model}") + whisper_model = whisper.load_model(model, device=device if device.type != "mps" else "cpu") + + # Transcribe the audio + logger.info(f"Transcribing audio: {audio_path}") + result = whisper_model.transcribe(str(audio_path)) + + # Extract the full transcript and segments transcript = result["text"] - summary = summarize_text(transcript) - return transcript, summary + segments = result["segments"] + + # Cache the results if caching is enabled + if use_cache and CACHE_AVAILABLE: + cache_data = { + "transcript": transcript, + "segments": segments + } + save_to_cache(audio_path, cache_data, model, "transcribe") + + return segments, transcript -def summarize_text(text): - """Summarize text using a pre-trained T5 transformer model with chunking.""" - summarization_pipeline = pipeline("summarization", model=SUMMARIZATION_MODEL) - tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL) + +def summarize_text(text, model=SUMMARIZATION_MODEL, use_gpu=True, memory_fraction=0.8): + """ + Summarize text using a pre-trained transformer model with chunking. + + Args: + text (str): Text to summarize + model (str): Model to use for summarization + use_gpu (bool): Whether to use GPU acceleration if available + memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0) + + Returns: + str: Summarized text + """ + # Configure device + device = torch.device("cpu") + if use_gpu and GPU_UTILS_AVAILABLE: + device = get_optimal_device() + logger.info(f"Using device: {device} for summarization") + + # Initialize the pipeline with the specified device + device_arg = -1 if device.type == "cpu" else 0 # -1 for CPU, 0 for GPU + summarization_pipeline = pipeline("summarization", model=model, device=device_arg) + tokenizer = AutoTokenizer.from_pretrained(model) max_tokens = 512 @@ -24,20 +116,57 @@ def summarize_text(text): num_tokens = len(tokens['input_ids'][0]) if num_tokens > max_tokens: - chunks = chunk_text(text, max_tokens) + chunks = chunk_text(text, max_tokens, tokenizer) summaries = [] - for chunk in chunks: - summary_output = summarization_pipeline("summarize: " + chunk, max_length=150, min_length=30, do_sample=False) + + for i, chunk in enumerate(chunks): + logger.info(f"Summarizing chunk {i+1}/{len(chunks)}") + summary_output = summarization_pipeline( + "summarize: " + chunk, + max_length=150, + min_length=30, + do_sample=False + ) summaries.append(summary_output[0]['summary_text']) + overall_summary = " ".join(summaries) + + # If the combined summary is still long, summarize it again + if len(summaries) > 1: + logger.info("Generating final summary from chunk summaries") + combined_text = " ".join(summaries) + overall_summary = summarization_pipeline( + "summarize: " + combined_text, + max_length=150, + min_length=30, + do_sample=False + )[0]['summary_text'] else: - overall_summary = summarization_pipeline("summarize: " + text, max_length=150, min_length=30, do_sample=False)[0]['summary_text'] + overall_summary = summarization_pipeline( + "summarize: " + text, + max_length=150, + min_length=30, + do_sample=False + )[0]['summary_text'] return overall_summary -def chunk_text(text, max_tokens): - """Splits the text into a list of chunks based on token limits.""" - tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL) + +def chunk_text(text, max_tokens, tokenizer=None): + """ + Splits the text into a list of chunks based on token limits. + + Args: + text (str): Text to chunk + max_tokens (int): Maximum tokens per chunk + tokenizer (AutoTokenizer, optional): Tokenizer to use + + Returns: + list: List of text chunks + """ + if tokenizer is None: + tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL) + words = text.split() chunks = [] diff --git a/utils/translation.py b/utils/translation.py new file mode 100644 index 0000000..84ae541 --- /dev/null +++ b/utils/translation.py @@ -0,0 +1,283 @@ +""" +Translation utilities for the OBS Recording Transcriber. +Provides functions for language detection and translation. +""" + +import logging +import torch +from pathlib import Path +from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, M2M100ForConditionalGeneration +import whisper +import iso639 + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Try to import GPU utilities, but don't fail if not available +try: + from utils.gpu_utils import get_optimal_device + GPU_UTILS_AVAILABLE = True +except ImportError: + GPU_UTILS_AVAILABLE = False + +# Default models +TRANSLATION_MODEL = "facebook/m2m100_418M" +LANGUAGE_DETECTION_MODEL = "papluca/xlm-roberta-base-language-detection" + +# ISO language code mapping +def get_language_name(code): + """ + Get the language name from ISO code. + + Args: + code (str): ISO language code + + Returns: + str: Language name or original code if not found + """ + try: + return iso639.languages.get(part1=code).name + except (KeyError, AttributeError): + try: + return iso639.languages.get(part2b=code).name + except (KeyError, AttributeError): + return code + + +def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True): + """ + Detect the language of a text. + + Args: + text (str): Text to detect language for + model (str): Model to use for language detection + use_gpu (bool): Whether to use GPU acceleration if available + + Returns: + tuple: (language_code, confidence) + """ + # Configure device + device = torch.device("cpu") + if use_gpu and GPU_UTILS_AVAILABLE: + device = get_optimal_device() + device_arg = 0 if device.type == "cuda" else -1 + else: + device_arg = -1 + + try: + # Initialize the pipeline + classifier = pipeline("text-classification", model=model, device=device_arg) + + # Truncate text if too long + max_length = 512 + if len(text) > max_length: + text = text[:max_length] + + # Detect language + result = classifier(text)[0] + language_code = result["label"] + confidence = result["score"] + + return language_code, confidence + except Exception as e: + logger.error(f"Error detecting language: {e}") + return None, 0.0 + + +def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_MODEL, use_gpu=True): + """ + Translate text from source language to target language. + + Args: + text (str): Text to translate + source_lang (str, optional): Source language code (auto-detect if None) + target_lang (str): Target language code + model (str): Model to use for translation + use_gpu (bool): Whether to use GPU acceleration if available + + Returns: + str: Translated text + """ + # Auto-detect source language if not provided + if source_lang is None: + detected_lang, confidence = detect_language(text, use_gpu=use_gpu) + if detected_lang and confidence > 0.5: + source_lang = detected_lang + logger.info(f"Detected language: {get_language_name(source_lang)} ({source_lang}) with confidence {confidence:.2f}") + else: + logger.warning("Could not reliably detect language, defaulting to English") + source_lang = "en" + + # Skip translation if source and target are the same + if source_lang == target_lang: + logger.info(f"Source and target languages are the same ({source_lang}), skipping translation") + return text + + # Configure device + device = torch.device("cpu") + if use_gpu and GPU_UTILS_AVAILABLE: + device = get_optimal_device() + + try: + # Load model and tokenizer + tokenizer = AutoTokenizer.from_pretrained(model) + model = M2M100ForConditionalGeneration.from_pretrained(model) + + # Move model to device + model = model.to(device) + + # Prepare for translation + tokenizer.src_lang = source_lang + + # Split text into manageable chunks if too long + max_length = 512 + if len(text) > max_length: + chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)] + else: + chunks = [text] + + # Translate each chunk + translated_chunks = [] + for chunk in chunks: + encoded = tokenizer(chunk, return_tensors="pt").to(device) + generated_tokens = model.generate( + **encoded, + forced_bos_token_id=tokenizer.get_lang_id(target_lang), + max_length=max_length + ) + translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] + translated_chunks.append(translated_chunk) + + # Combine translated chunks + translated_text = " ".join(translated_chunks) + + return translated_text + except Exception as e: + logger.error(f"Error translating text: {e}") + return text + + +def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=True): + """ + Translate transcript segments. + + Args: + segments (list): List of transcript segments + source_lang (str, optional): Source language code (auto-detect if None) + target_lang (str): Target language code + use_gpu (bool): Whether to use GPU acceleration if available + + Returns: + list: Translated segments + """ + if not segments: + return [] + + # Auto-detect source language from combined text if not provided + if source_lang is None: + combined_text = " ".join([segment["text"] for segment in segments]) + detected_lang, _ = detect_language(combined_text, use_gpu=use_gpu) + source_lang = detected_lang if detected_lang else "en" + + # Skip translation if source and target are the same + if source_lang == target_lang: + return segments + + try: + # Initialize translation pipeline + translated_segments = [] + + # Translate each segment + for segment in segments: + translated_text = translate_text( + segment["text"], + source_lang=source_lang, + target_lang=target_lang, + use_gpu=use_gpu + ) + + # Create a new segment with translated text + translated_segment = segment.copy() + translated_segment["text"] = translated_text + translated_segment["original_text"] = segment["text"] + translated_segment["source_lang"] = source_lang + translated_segment["target_lang"] = target_lang + + translated_segments.append(translated_segment) + + return translated_segments + except Exception as e: + logger.error(f"Error translating segments: {e}") + return segments + + +def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en", + use_gpu=True, detect_source=True): + """ + Transcribe audio and translate to target language. + + Args: + audio_path (Path): Path to the audio file + whisper_model (str): Whisper model size to use + target_lang (str): Target language code + use_gpu (bool): Whether to use GPU acceleration if available + detect_source (bool): Whether to auto-detect source language + + Returns: + tuple: (original_segments, translated_segments, original_transcript, translated_transcript) + """ + audio_path = Path(audio_path) + + # Configure device + device = torch.device("cpu") + if use_gpu and GPU_UTILS_AVAILABLE: + device = get_optimal_device() + + try: + # Step 1: Transcribe audio with Whisper + logger.info(f"Transcribing audio with Whisper model: {whisper_model}") + model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu") + + # Use Whisper's built-in language detection if requested + if detect_source: + # First, detect language with Whisper + audio = whisper.load_audio(str(audio_path)) + audio = whisper.pad_or_trim(audio) + mel = whisper.log_mel_spectrogram(audio).to(device if device.type != "mps" else "cpu") + _, probs = model.detect_language(mel) + source_lang = max(probs, key=probs.get) + logger.info(f"Whisper detected language: {get_language_name(source_lang)} ({source_lang})") + + # Transcribe with detected language + result = model.transcribe(str(audio_path), language=source_lang) + else: + # Transcribe without language specification + result = model.transcribe(str(audio_path)) + source_lang = result.get("language", "en") + + original_segments = result["segments"] + original_transcript = result["text"] + + # Step 2: Translate if needed + if source_lang != target_lang: + logger.info(f"Translating from {source_lang} to {target_lang}") + translated_segments = translate_segments( + original_segments, + source_lang=source_lang, + target_lang=target_lang, + use_gpu=use_gpu + ) + + # Create full translated transcript + translated_transcript = " ".join([segment["text"] for segment in translated_segments]) + else: + logger.info(f"Source and target languages are the same ({source_lang}), skipping translation") + translated_segments = original_segments + translated_transcript = original_transcript + + return original_segments, translated_segments, original_transcript, translated_transcript + + except Exception as e: + logger.error(f"Error in transcribe_and_translate: {e}") + return None, None, None, None \ No newline at end of file