minor refactor

This commit is contained in:
Gokul Mohanarangan
2023-07-10 22:48:22 +05:30
parent 73c4270764
commit 3128813ca3
8 changed files with 82 additions and 85 deletions

View File

@@ -1,23 +1,27 @@
import torch
import configparser
import nltk
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
from loguru import logger
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import BartTokenizer, BartForConditionalGeneration
nltk.download('punkt', quiet=True)
config = configparser.ConfigParser()
config.read('config.ini')
def preprocess_sentence(sentence):
stop_words = set(stopwords.words('english'))
tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
return ' '.join(tokens)
def compute_similarity(sent1, sent2):
"""
Compute the similarity
@@ -28,6 +32,7 @@ def compute_similarity(sent1, sent2):
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
return 0.0
def remove_almost_alike_sentences(sentences, threshold=0.7):
num_sentences = len(sentences)
removed_indices = set()
@@ -55,12 +60,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
return filtered_sentences
def remove_outright_duplicate_sentences_from_chunk(chunk):
chunk_text = chunk["text"]
sentences = nltk.sent_tokenize(chunk_text)
nonduplicate_sentences = list(dict.fromkeys(sentences))
return nonduplicate_sentences
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
chunk_sentences = []
@@ -80,6 +87,7 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
chunk_sentences.append(temp_result)
return chunk_sentences
def post_process_transcription(whisper_result):
transcript_text = ""
for chunk in whisper_result["chunks"]:
@@ -107,12 +115,13 @@ def summarize_chunks(chunks, tokenizer, model):
input_ids = input_ids.to(device)
with torch.no_grad():
summary_ids = model.generate(input_ids,
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(summary)
return summaries
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
"""
Split text into smaller chunks.
@@ -132,6 +141,7 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])
chunks.append(current_chunk.strip())
return chunks
def summarize(transcript_text, timestamp,
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")