flake8 / pylint updates

This commit is contained in:
Gokul Mohanarangan
2023-07-26 11:28:14 +05:30
parent c970fc89dd
commit e512b4dca5
15 changed files with 279 additions and 146 deletions

View File

@@ -1,3 +1,7 @@
"""
Utility file for all text processing related functionalities
"""
import nltk
import torch
from nltk.corpus import stopwords
@@ -6,8 +10,8 @@ from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import BartForConditionalGeneration, BartTokenizer
from log_utils import logger
from run_utils import config
from log_utils import LOGGER
from run_utils import CONFIG
nltk.download('punkt', quiet=True)
@@ -32,6 +36,12 @@ def compute_similarity(sent1, sent2):
def remove_almost_alike_sentences(sentences, threshold=0.7):
"""
Filter sentences that are similar beyond a set threshold
:param sentences:
:param threshold:
:return:
"""
num_sentences = len(sentences)
removed_indices = set()
@@ -62,6 +72,11 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
def remove_outright_duplicate_sentences_from_chunk(chunk):
"""
Remove repetitive sentences
:param chunk:
:return:
"""
chunk_text = chunk["text"]
sentences = nltk.sent_tokenize(chunk_text)
nonduplicate_sentences = list(dict.fromkeys(sentences))
@@ -69,6 +84,12 @@ def remove_outright_duplicate_sentences_from_chunk(chunk):
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
"""
Remove sentences that are repeated as a result of Whisper
hallucinations
:param nonduplicate_sentences:
:return:
"""
chunk_sentences = []
for sent in nonduplicate_sentences:
@@ -91,6 +112,11 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
def post_process_transcription(whisper_result):
"""
Parent function to perform post-processing on the transcription result
:param whisper_result:
:return:
"""
transcript_text = ""
for chunk in whisper_result["chunks"]:
nonduplicate_sentences = \
@@ -121,9 +147,9 @@ def summarize_chunks(chunks, tokenizer, model):
with torch.no_grad():
summary_ids = \
model.generate(input_ids,
num_beams=int(config["SUMMARIZER"]["BEAM_SIZE"]),
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
length_penalty=2.0,
max_length=int(config["SUMMARIZER"]["MAX_LENGTH"]),
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
early_stopping=True)
summary = tokenizer.decode(summary_ids[0],
skip_special_tokens=True)
@@ -132,7 +158,7 @@ def summarize_chunks(chunks, tokenizer, model):
def chunk_text(text,
max_chunk_length=int(config["SUMMARIZER"]["MAX_CHUNK_LENGTH"])):
max_chunk_length=int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])):
"""
Split text into smaller chunks.
:param text: Text to be chunked
@@ -154,14 +180,22 @@ def chunk_text(text,
def summarize(transcript_text, timestamp,
real_time=False,
chunk_summarize=config["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
chunk_summarize=CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
"""
Summarize the given text either as a whole or as chunks as needed
:param transcript_text:
:param timestamp:
:param real_time:
:param chunk_summarize:
:return:
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["SUMMARIZER"]["SUMMARY_MODEL"]
summary_model = CONFIG["SUMMARIZER"]["SUMMARY_MODEL"]
if not summary_model:
summary_model = "facebook/bart-large-cnn"
# Summarize the generated transcript using the BART model
logger.info(f"Loading BART model: {summary_model}")
LOGGER.info(f"Loading BART model: {summary_model}")
tokenizer = BartTokenizer.from_pretrained(summary_model)
model = BartForConditionalGeneration.from_pretrained(summary_model)
model = model.to(device)
@@ -171,7 +205,7 @@ def summarize(transcript_text, timestamp,
output_file = "real_time_" + output_file
if chunk_summarize != "YES":
max_length = int(config["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer. \
batch_encode_plus([transcript_text], truncation=True,
padding='longest',
@@ -180,8 +214,8 @@ def summarize(transcript_text, timestamp,
inputs = inputs.to(device)
with torch.no_grad():
num_beans = int(config["SUMMARIZER"]["BEAM_SIZE"])
max_length = int(config["SUMMARIZER"]["MAX_LENGTH"])
num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
summaries = model.generate(inputs['input_ids'],
num_beams=num_beans,
length_penalty=2.0,
@@ -194,16 +228,16 @@ def summarize(transcript_text, timestamp,
clean_up_tokenization_spaces=False)
for summary in summaries]
summary = " ".join(decoded_summaries)
with open("./artefacts/" + output_file, 'w') as f:
f.write(summary.strip() + "\n")
with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file:
file.write(summary.strip() + "\n")
else:
logger.info("Breaking transcript into smaller chunks")
LOGGER.info("Breaking transcript into smaller chunks")
chunks = chunk_text(transcript_text)
logger.info(f"Transcript broken into {len(chunks)} "
LOGGER.info(f"Transcript broken into {len(chunks)} "
f"chunks of at most 500 words")
logger.info(f"Writing summary text to: {output_file}")
LOGGER.info(f"Writing summary text to: {output_file}")
with open(output_file, 'w') as f:
summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries: