Files
TalkEdit/utils/summarization.py

112 lines
3.6 KiB
Python
Raw Normal View History

from transformers import pipeline, AutoTokenizer
import torch
import logging
import streamlit as st
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
2025-01-28 17:00:03 -05:00
SUMMARY_MODEL = "Falconsai/text_summarization"
@st.cache_resource
def _load_summarizer(device_int):
"""Load and cache the summarization pipeline."""
logger.info(f"Loading summarization model on device {device_int}")
return pipeline("summarization", model=SUMMARY_MODEL, device=device_int)
@st.cache_resource
def _load_summary_tokenizer():
"""Load and cache the summarization tokenizer."""
return AutoTokenizer.from_pretrained(SUMMARY_MODEL)
def chunk_text(text, max_tokens, tokenizer):
"""
Splits text into chunks by tokenizing once, then splitting by token windows.
Much faster than the per-word tokenization approach.
"""
all_ids = tokenizer(text, return_tensors='pt', truncation=False)['input_ids'][0]
content_ids = all_ids[1:-1] # strip BOS/EOS
usable_max = max_tokens - 2 # leave room for special tokens
chunks = []
for i in range(0, len(content_ids), usable_max):
chunk_ids = content_ids[i : i + usable_max]
decoded = tokenizer.decode(chunk_ids, skip_special_tokens=True).strip()
if decoded:
chunks.append(decoded)
if not chunks:
chunks.append(text)
return chunks
def summarize_text(text, use_gpu=True, memory_fraction=0.8):
"""
Summarize text using a Hugging Face pipeline with chunking support.
Args:
text (str): Text to summarize
use_gpu (bool): Whether to use GPU if available
memory_fraction (float): Fraction of GPU memory to use
Returns:
str: Summarized text
"""
device = -1
if use_gpu and torch.cuda.is_available():
device = 0
torch.cuda.set_per_process_memory_fraction(memory_fraction)
logger.info(f"Using device {device} for summarization")
try:
summarizer = _load_summarizer(device)
tokenizer = _load_summary_tokenizer()
max_tokens = 512
tokens = tokenizer(text, return_tensors='pt')
num_tokens = len(tokens['input_ids'][0])
if num_tokens > max_tokens:
chunks = chunk_text(text, max_tokens, tokenizer)
summaries = []
for i, chunk in enumerate(chunks):
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
summary_output = summarizer(
"summarize: " + chunk,
max_length=150,
min_length=30,
do_sample=False
)
summaries.append(summary_output[0]['summary_text'])
if len(summaries) > 1:
logger.info("Generating final summary from chunk summaries")
combined_text = " ".join(summaries)
return summarizer(
"summarize: " + combined_text,
max_length=150,
min_length=30,
do_sample=False
)[0]['summary_text']
return summaries[0]
else:
return summarizer(
"summarize: " + text,
max_length=150,
min_length=30,
do_sample=False
)[0]['summary_text']
except Exception as e:
logger.error(f"Error during summarization: {e}")
if device != -1:
logger.info("Falling back to CPU")
return summarize_text(text, use_gpu=False, memory_fraction=memory_fraction)
raise