mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
245 lines
8.5 KiB
Python
245 lines
8.5 KiB
Python
"""
|
|
Utility file for all text processing related functionalities
|
|
"""
|
|
|
|
import nltk
|
|
import torch
|
|
from nltk.corpus import stopwords
|
|
from nltk.tokenize import word_tokenize
|
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
from transformers import BartForConditionalGeneration, BartTokenizer
|
|
|
|
from log_utils import LOGGER
|
|
from run_utils import CONFIG
|
|
|
|
nltk.download('punkt', quiet=True)
|
|
|
|
|
|
def preprocess_sentence(sentence):
|
|
stop_words = set(stopwords.words('english'))
|
|
tokens = word_tokenize(sentence.lower())
|
|
tokens = [token for token in tokens
|
|
if token.isalnum() and token not in stop_words]
|
|
return ' '.join(tokens)
|
|
|
|
|
|
def compute_similarity(sent1, sent2):
|
|
"""
|
|
Compute the similarity
|
|
"""
|
|
tfidf_vectorizer = TfidfVectorizer()
|
|
if sent1 is not None and sent2 is not None:
|
|
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
|
|
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
|
return 0.0
|
|
|
|
|
|
def remove_almost_alike_sentences(sentences, threshold=0.7):
|
|
"""
|
|
Filter sentences that are similar beyond a set threshold
|
|
:param sentences:
|
|
:param threshold:
|
|
:return:
|
|
"""
|
|
num_sentences = len(sentences)
|
|
removed_indices = set()
|
|
|
|
for i in range(num_sentences):
|
|
if i not in removed_indices:
|
|
for j in range(i + 1, num_sentences):
|
|
if j not in removed_indices:
|
|
l_i = len(sentences[i])
|
|
l_j = len(sentences[j])
|
|
if l_i == 0 or l_j == 0:
|
|
if l_i == 0:
|
|
removed_indices.add(i)
|
|
if l_j == 0:
|
|
removed_indices.add(j)
|
|
else:
|
|
sentence1 = preprocess_sentence(sentences[i])
|
|
sentence2 = preprocess_sentence(sentences[j])
|
|
if len(sentence1) != 0 and len(sentence2) != 0:
|
|
similarity = compute_similarity(sentence1,
|
|
sentence2)
|
|
|
|
if similarity >= threshold:
|
|
removed_indices.add(max(i, j))
|
|
|
|
filtered_sentences = [sentences[i] for i in range(num_sentences)
|
|
if i not in removed_indices]
|
|
return filtered_sentences
|
|
|
|
|
|
def remove_outright_duplicate_sentences_from_chunk(chunk):
|
|
"""
|
|
Remove repetitive sentences
|
|
:param chunk:
|
|
:return:
|
|
"""
|
|
chunk_text = chunk["text"]
|
|
sentences = nltk.sent_tokenize(chunk_text)
|
|
nonduplicate_sentences = list(dict.fromkeys(sentences))
|
|
return nonduplicate_sentences
|
|
|
|
|
|
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
|
"""
|
|
Remove sentences that are repeated as a result of Whisper
|
|
hallucinations
|
|
:param nonduplicate_sentences:
|
|
:return:
|
|
"""
|
|
chunk_sentences = []
|
|
|
|
for sent in nonduplicate_sentences:
|
|
temp_result = ""
|
|
seen = {}
|
|
words = nltk.word_tokenize(sent)
|
|
n_gram_filter = 3
|
|
for i in range(len(words)):
|
|
if str(words[i:i + n_gram_filter]) in seen and \
|
|
seen[str(words[i:i + n_gram_filter])] == \
|
|
words[i + 1:i + n_gram_filter + 2]:
|
|
pass
|
|
else:
|
|
seen[str(words[i:i + n_gram_filter])] = \
|
|
words[i + 1:i + n_gram_filter + 2]
|
|
temp_result += words[i]
|
|
temp_result += " "
|
|
chunk_sentences.append(temp_result)
|
|
return chunk_sentences
|
|
|
|
|
|
def post_process_transcription(whisper_result):
|
|
"""
|
|
Parent function to perform post-processing on the transcription result
|
|
:param whisper_result:
|
|
:return:
|
|
"""
|
|
transcript_text = ""
|
|
for chunk in whisper_result["chunks"]:
|
|
nonduplicate_sentences = \
|
|
remove_outright_duplicate_sentences_from_chunk(chunk)
|
|
chunk_sentences = \
|
|
remove_whisper_repetitive_hallucination(nonduplicate_sentences)
|
|
similarity_matched_sentences = \
|
|
remove_almost_alike_sentences(chunk_sentences)
|
|
chunk["text"] = " ".join(similarity_matched_sentences)
|
|
transcript_text += chunk["text"]
|
|
whisper_result["text"] = transcript_text
|
|
return whisper_result
|
|
|
|
|
|
def summarize_chunks(chunks, tokenizer, model):
|
|
"""
|
|
Summarize each chunk using a summarizer model
|
|
:param chunks:
|
|
:param tokenizer:
|
|
:param model:
|
|
:return:
|
|
"""
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
summaries = []
|
|
for c in chunks:
|
|
input_ids = tokenizer.encode(c, return_tensors='pt')
|
|
input_ids = input_ids.to(device)
|
|
with torch.no_grad():
|
|
summary_ids = \
|
|
model.generate(input_ids,
|
|
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
|
|
length_penalty=2.0,
|
|
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
|
|
early_stopping=True)
|
|
summary = tokenizer.decode(summary_ids[0],
|
|
skip_special_tokens=True)
|
|
summaries.append(summary)
|
|
return summaries
|
|
|
|
|
|
def chunk_text(text,
|
|
max_chunk_length=int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])):
|
|
"""
|
|
Split text into smaller chunks.
|
|
:param text: Text to be chunked
|
|
:param max_chunk_length: length of chunk
|
|
:return: chunked texts
|
|
"""
|
|
sentences = nltk.sent_tokenize(text)
|
|
chunks = []
|
|
current_chunk = ""
|
|
for sentence in sentences:
|
|
if len(current_chunk) + len(sentence) < max_chunk_length:
|
|
current_chunk += f" {sentence.strip()}"
|
|
else:
|
|
chunks.append(current_chunk.strip())
|
|
current_chunk = f"{sentence.strip()}"
|
|
chunks.append(current_chunk.strip())
|
|
return chunks
|
|
|
|
|
|
def summarize(transcript_text, timestamp,
|
|
real_time=False,
|
|
chunk_summarize=CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
|
|
"""
|
|
Summarize the given text either as a whole or as chunks as needed
|
|
:param transcript_text:
|
|
:param timestamp:
|
|
:param real_time:
|
|
:param chunk_summarize:
|
|
:return:
|
|
"""
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
summary_model = CONFIG["SUMMARIZER"]["SUMMARY_MODEL"]
|
|
if not summary_model:
|
|
summary_model = "facebook/bart-large-cnn"
|
|
|
|
# Summarize the generated transcript using the BART model
|
|
LOGGER.info(f"Loading BART model: {summary_model}")
|
|
tokenizer = BartTokenizer.from_pretrained(summary_model)
|
|
model = BartForConditionalGeneration.from_pretrained(summary_model)
|
|
model = model.to(device)
|
|
|
|
output_file = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
|
if real_time:
|
|
output_file = "real_time_" + output_file
|
|
|
|
if chunk_summarize != "YES":
|
|
max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
|
|
inputs = tokenizer. \
|
|
batch_encode_plus([transcript_text], truncation=True,
|
|
padding='longest',
|
|
max_length=max_length,
|
|
return_tensors='pt')
|
|
inputs = inputs.to(device)
|
|
|
|
with torch.no_grad():
|
|
num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
|
|
max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
|
|
summaries = model.generate(inputs['input_ids'],
|
|
num_beams=num_beans,
|
|
length_penalty=2.0,
|
|
max_length=max_length,
|
|
early_stopping=True)
|
|
|
|
decoded_summaries = \
|
|
[tokenizer.decode(summary,
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=False)
|
|
for summary in summaries]
|
|
summary = " ".join(decoded_summaries)
|
|
with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file:
|
|
file.write(summary.strip() + "\n")
|
|
else:
|
|
LOGGER.info("Breaking transcript into smaller chunks")
|
|
chunks = chunk_text(transcript_text)
|
|
|
|
LOGGER.info(f"Transcript broken into {len(chunks)} "
|
|
f"chunks of at most 500 words")
|
|
|
|
LOGGER.info(f"Writing summary text to: {output_file}")
|
|
with open(output_file, 'w') as f:
|
|
summaries = summarize_chunks(chunks, tokenizer, model)
|
|
for summary in summaries:
|
|
f.write(summary.strip() + " ")
|