283 lines
9.9 KiB
Python
283 lines
9.9 KiB
Python
|
|
"""
|
||
|
|
Translation utilities for the OBS Recording Transcriber.
|
||
|
|
Provides functions for language detection and translation.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import logging
|
||
|
|
import torch
|
||
|
|
from pathlib import Path
|
||
|
|
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer, M2M100ForConditionalGeneration
|
||
|
|
import whisper
|
||
|
|
import iso639
|
||
|
|
|
||
|
|
# Configure logging
|
||
|
|
logging.basicConfig(level=logging.INFO)
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
# Try to import GPU utilities, but don't fail if not available
|
||
|
|
try:
|
||
|
|
from utils.gpu_utils import get_optimal_device
|
||
|
|
GPU_UTILS_AVAILABLE = True
|
||
|
|
except ImportError:
|
||
|
|
GPU_UTILS_AVAILABLE = False
|
||
|
|
|
||
|
|
# Default models
|
||
|
|
TRANSLATION_MODEL = "facebook/m2m100_418M"
|
||
|
|
LANGUAGE_DETECTION_MODEL = "papluca/xlm-roberta-base-language-detection"
|
||
|
|
|
||
|
|
# ISO language code mapping
|
||
|
|
def get_language_name(code):
|
||
|
|
"""
|
||
|
|
Get the language name from ISO code.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
code (str): ISO language code
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
str: Language name or original code if not found
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
return iso639.languages.get(part1=code).name
|
||
|
|
except (KeyError, AttributeError):
|
||
|
|
try:
|
||
|
|
return iso639.languages.get(part2b=code).name
|
||
|
|
except (KeyError, AttributeError):
|
||
|
|
return code
|
||
|
|
|
||
|
|
|
||
|
|
def detect_language(text, model=LANGUAGE_DETECTION_MODEL, use_gpu=True):
|
||
|
|
"""
|
||
|
|
Detect the language of a text.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text (str): Text to detect language for
|
||
|
|
model (str): Model to use for language detection
|
||
|
|
use_gpu (bool): Whether to use GPU acceleration if available
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
tuple: (language_code, confidence)
|
||
|
|
"""
|
||
|
|
# Configure device
|
||
|
|
device = torch.device("cpu")
|
||
|
|
if use_gpu and GPU_UTILS_AVAILABLE:
|
||
|
|
device = get_optimal_device()
|
||
|
|
device_arg = 0 if device.type == "cuda" else -1
|
||
|
|
else:
|
||
|
|
device_arg = -1
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Initialize the pipeline
|
||
|
|
classifier = pipeline("text-classification", model=model, device=device_arg)
|
||
|
|
|
||
|
|
# Truncate text if too long
|
||
|
|
max_length = 512
|
||
|
|
if len(text) > max_length:
|
||
|
|
text = text[:max_length]
|
||
|
|
|
||
|
|
# Detect language
|
||
|
|
result = classifier(text)[0]
|
||
|
|
language_code = result["label"]
|
||
|
|
confidence = result["score"]
|
||
|
|
|
||
|
|
return language_code, confidence
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error detecting language: {e}")
|
||
|
|
return None, 0.0
|
||
|
|
|
||
|
|
|
||
|
|
def translate_text(text, source_lang=None, target_lang="en", model=TRANSLATION_MODEL, use_gpu=True):
|
||
|
|
"""
|
||
|
|
Translate text from source language to target language.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text (str): Text to translate
|
||
|
|
source_lang (str, optional): Source language code (auto-detect if None)
|
||
|
|
target_lang (str): Target language code
|
||
|
|
model (str): Model to use for translation
|
||
|
|
use_gpu (bool): Whether to use GPU acceleration if available
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
str: Translated text
|
||
|
|
"""
|
||
|
|
# Auto-detect source language if not provided
|
||
|
|
if source_lang is None:
|
||
|
|
detected_lang, confidence = detect_language(text, use_gpu=use_gpu)
|
||
|
|
if detected_lang and confidence > 0.5:
|
||
|
|
source_lang = detected_lang
|
||
|
|
logger.info(f"Detected language: {get_language_name(source_lang)} ({source_lang}) with confidence {confidence:.2f}")
|
||
|
|
else:
|
||
|
|
logger.warning("Could not reliably detect language, defaulting to English")
|
||
|
|
source_lang = "en"
|
||
|
|
|
||
|
|
# Skip translation if source and target are the same
|
||
|
|
if source_lang == target_lang:
|
||
|
|
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
|
||
|
|
return text
|
||
|
|
|
||
|
|
# Configure device
|
||
|
|
device = torch.device("cpu")
|
||
|
|
if use_gpu and GPU_UTILS_AVAILABLE:
|
||
|
|
device = get_optimal_device()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Load model and tokenizer
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(model)
|
||
|
|
model = M2M100ForConditionalGeneration.from_pretrained(model)
|
||
|
|
|
||
|
|
# Move model to device
|
||
|
|
model = model.to(device)
|
||
|
|
|
||
|
|
# Prepare for translation
|
||
|
|
tokenizer.src_lang = source_lang
|
||
|
|
|
||
|
|
# Split text into manageable chunks if too long
|
||
|
|
max_length = 512
|
||
|
|
if len(text) > max_length:
|
||
|
|
chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]
|
||
|
|
else:
|
||
|
|
chunks = [text]
|
||
|
|
|
||
|
|
# Translate each chunk
|
||
|
|
translated_chunks = []
|
||
|
|
for chunk in chunks:
|
||
|
|
encoded = tokenizer(chunk, return_tensors="pt").to(device)
|
||
|
|
generated_tokens = model.generate(
|
||
|
|
**encoded,
|
||
|
|
forced_bos_token_id=tokenizer.get_lang_id(target_lang),
|
||
|
|
max_length=max_length
|
||
|
|
)
|
||
|
|
translated_chunk = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
|
||
|
|
translated_chunks.append(translated_chunk)
|
||
|
|
|
||
|
|
# Combine translated chunks
|
||
|
|
translated_text = " ".join(translated_chunks)
|
||
|
|
|
||
|
|
return translated_text
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error translating text: {e}")
|
||
|
|
return text
|
||
|
|
|
||
|
|
|
||
|
|
def translate_segments(segments, source_lang=None, target_lang="en", use_gpu=True):
|
||
|
|
"""
|
||
|
|
Translate transcript segments.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
segments (list): List of transcript segments
|
||
|
|
source_lang (str, optional): Source language code (auto-detect if None)
|
||
|
|
target_lang (str): Target language code
|
||
|
|
use_gpu (bool): Whether to use GPU acceleration if available
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
list: Translated segments
|
||
|
|
"""
|
||
|
|
if not segments:
|
||
|
|
return []
|
||
|
|
|
||
|
|
# Auto-detect source language from combined text if not provided
|
||
|
|
if source_lang is None:
|
||
|
|
combined_text = " ".join([segment["text"] for segment in segments])
|
||
|
|
detected_lang, _ = detect_language(combined_text, use_gpu=use_gpu)
|
||
|
|
source_lang = detected_lang if detected_lang else "en"
|
||
|
|
|
||
|
|
# Skip translation if source and target are the same
|
||
|
|
if source_lang == target_lang:
|
||
|
|
return segments
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Initialize translation pipeline
|
||
|
|
translated_segments = []
|
||
|
|
|
||
|
|
# Translate each segment
|
||
|
|
for segment in segments:
|
||
|
|
translated_text = translate_text(
|
||
|
|
segment["text"],
|
||
|
|
source_lang=source_lang,
|
||
|
|
target_lang=target_lang,
|
||
|
|
use_gpu=use_gpu
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create a new segment with translated text
|
||
|
|
translated_segment = segment.copy()
|
||
|
|
translated_segment["text"] = translated_text
|
||
|
|
translated_segment["original_text"] = segment["text"]
|
||
|
|
translated_segment["source_lang"] = source_lang
|
||
|
|
translated_segment["target_lang"] = target_lang
|
||
|
|
|
||
|
|
translated_segments.append(translated_segment)
|
||
|
|
|
||
|
|
return translated_segments
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error translating segments: {e}")
|
||
|
|
return segments
|
||
|
|
|
||
|
|
|
||
|
|
def transcribe_and_translate(audio_path, whisper_model="base", target_lang="en",
|
||
|
|
use_gpu=True, detect_source=True):
|
||
|
|
"""
|
||
|
|
Transcribe audio and translate to target language.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
audio_path (Path): Path to the audio file
|
||
|
|
whisper_model (str): Whisper model size to use
|
||
|
|
target_lang (str): Target language code
|
||
|
|
use_gpu (bool): Whether to use GPU acceleration if available
|
||
|
|
detect_source (bool): Whether to auto-detect source language
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
tuple: (original_segments, translated_segments, original_transcript, translated_transcript)
|
||
|
|
"""
|
||
|
|
audio_path = Path(audio_path)
|
||
|
|
|
||
|
|
# Configure device
|
||
|
|
device = torch.device("cpu")
|
||
|
|
if use_gpu and GPU_UTILS_AVAILABLE:
|
||
|
|
device = get_optimal_device()
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Step 1: Transcribe audio with Whisper
|
||
|
|
logger.info(f"Transcribing audio with Whisper model: {whisper_model}")
|
||
|
|
model = whisper.load_model(whisper_model, device=device if device.type != "mps" else "cpu")
|
||
|
|
|
||
|
|
# Use Whisper's built-in language detection if requested
|
||
|
|
if detect_source:
|
||
|
|
# First, detect language with Whisper
|
||
|
|
audio = whisper.load_audio(str(audio_path))
|
||
|
|
audio = whisper.pad_or_trim(audio)
|
||
|
|
mel = whisper.log_mel_spectrogram(audio).to(device if device.type != "mps" else "cpu")
|
||
|
|
_, probs = model.detect_language(mel)
|
||
|
|
source_lang = max(probs, key=probs.get)
|
||
|
|
logger.info(f"Whisper detected language: {get_language_name(source_lang)} ({source_lang})")
|
||
|
|
|
||
|
|
# Transcribe with detected language
|
||
|
|
result = model.transcribe(str(audio_path), language=source_lang)
|
||
|
|
else:
|
||
|
|
# Transcribe without language specification
|
||
|
|
result = model.transcribe(str(audio_path))
|
||
|
|
source_lang = result.get("language", "en")
|
||
|
|
|
||
|
|
original_segments = result["segments"]
|
||
|
|
original_transcript = result["text"]
|
||
|
|
|
||
|
|
# Step 2: Translate if needed
|
||
|
|
if source_lang != target_lang:
|
||
|
|
logger.info(f"Translating from {source_lang} to {target_lang}")
|
||
|
|
translated_segments = translate_segments(
|
||
|
|
original_segments,
|
||
|
|
source_lang=source_lang,
|
||
|
|
target_lang=target_lang,
|
||
|
|
use_gpu=use_gpu
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create full translated transcript
|
||
|
|
translated_transcript = " ".join([segment["text"] for segment in translated_segments])
|
||
|
|
else:
|
||
|
|
logger.info(f"Source and target languages are the same ({source_lang}), skipping translation")
|
||
|
|
translated_segments = original_segments
|
||
|
|
translated_transcript = original_transcript
|
||
|
|
|
||
|
|
return original_segments, translated_segments, original_transcript, translated_transcript
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error in transcribe_and_translate: {e}")
|
||
|
|
return None, None, None, None
|