Fix summarization issues and improve GPU handling. Update .gitignore for venv
This commit is contained in:
@ -25,7 +25,6 @@ logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WHISPER_MODEL = "base"
|
||||
SUMMARIZATION_MODEL = "t5-base"
|
||||
|
||||
def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cache_max_age=None,
|
||||
use_gpu=True, memory_fraction=0.8):
|
||||
@ -83,107 +82,4 @@ def transcribe_audio(audio_path: Path, model=WHISPER_MODEL, use_cache=True, cach
|
||||
}
|
||||
save_to_cache(audio_path, cache_data, model, "transcribe")
|
||||
|
||||
return segments, transcript
|
||||
|
||||
|
||||
def summarize_text(text, model=SUMMARIZATION_MODEL, use_gpu=True, memory_fraction=0.8):
|
||||
"""
|
||||
Summarize text using a pre-trained transformer model with chunking.
|
||||
|
||||
Args:
|
||||
text (str): Text to summarize
|
||||
model (str): Model to use for summarization
|
||||
use_gpu (bool): Whether to use GPU acceleration if available
|
||||
memory_fraction (float): Fraction of GPU memory to use (0.0 to 1.0)
|
||||
|
||||
Returns:
|
||||
str: Summarized text
|
||||
"""
|
||||
# Configure device
|
||||
device = torch.device("cpu")
|
||||
if use_gpu and GPU_UTILS_AVAILABLE:
|
||||
device = get_optimal_device()
|
||||
logger.info(f"Using device: {device} for summarization")
|
||||
|
||||
# Initialize the pipeline with the specified device
|
||||
device_arg = -1 if device.type == "cpu" else 0 # -1 for CPU, 0 for GPU
|
||||
summarization_pipeline = pipeline("summarization", model=model, device=device_arg)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model)
|
||||
|
||||
max_tokens = 512
|
||||
|
||||
tokens = tokenizer(text, return_tensors='pt')
|
||||
num_tokens = len(tokens['input_ids'][0])
|
||||
|
||||
if num_tokens > max_tokens:
|
||||
chunks = chunk_text(text, max_tokens, tokenizer)
|
||||
summaries = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
logger.info(f"Summarizing chunk {i+1}/{len(chunks)}")
|
||||
summary_output = summarization_pipeline(
|
||||
"summarize: " + chunk,
|
||||
max_length=150,
|
||||
min_length=30,
|
||||
do_sample=False
|
||||
)
|
||||
summaries.append(summary_output[0]['summary_text'])
|
||||
|
||||
overall_summary = " ".join(summaries)
|
||||
|
||||
# If the combined summary is still long, summarize it again
|
||||
if len(summaries) > 1:
|
||||
logger.info("Generating final summary from chunk summaries")
|
||||
combined_text = " ".join(summaries)
|
||||
overall_summary = summarization_pipeline(
|
||||
"summarize: " + combined_text,
|
||||
max_length=150,
|
||||
min_length=30,
|
||||
do_sample=False
|
||||
)[0]['summary_text']
|
||||
else:
|
||||
overall_summary = summarization_pipeline(
|
||||
"summarize: " + text,
|
||||
max_length=150,
|
||||
min_length=30,
|
||||
do_sample=False
|
||||
)[0]['summary_text']
|
||||
|
||||
return overall_summary
|
||||
|
||||
|
||||
def chunk_text(text, max_tokens, tokenizer=None):
|
||||
"""
|
||||
Splits the text into a list of chunks based on token limits.
|
||||
|
||||
Args:
|
||||
text (str): Text to chunk
|
||||
max_tokens (int): Maximum tokens per chunk
|
||||
tokenizer (AutoTokenizer, optional): Tokenizer to use
|
||||
|
||||
Returns:
|
||||
list: List of text chunks
|
||||
"""
|
||||
if tokenizer is None:
|
||||
tokenizer = AutoTokenizer.from_pretrained(SUMMARIZATION_MODEL)
|
||||
|
||||
words = text.split()
|
||||
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
current_length = 0
|
||||
|
||||
for word in words:
|
||||
hypothetical_length = current_length + len(tokenizer(word, return_tensors='pt')['input_ids'][0]) - 2
|
||||
if hypothetical_length <= max_tokens:
|
||||
current_chunk.append(word)
|
||||
current_length = hypothetical_length
|
||||
else:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
current_chunk = [word]
|
||||
current_length = len(tokenizer(word, return_tensors='pt')['input_ids'][0]) - 2
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(' '.join(current_chunk))
|
||||
|
||||
return chunks
|
||||
return segments, transcript
|
||||
Reference in New Issue
Block a user