mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
minor refactor
This commit is contained in:
@@ -1,23 +1,27 @@
|
||||
import torch
|
||||
import configparser
|
||||
|
||||
import nltk
|
||||
from transformers import BartTokenizer, BartForConditionalGeneration
|
||||
import torch
|
||||
from loguru import logger
|
||||
from nltk.corpus import stopwords
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from nltk.tokenize import word_tokenize
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from transformers import BartTokenizer, BartForConditionalGeneration
|
||||
|
||||
nltk.download('punkt', quiet=True)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read('config.ini')
|
||||
|
||||
|
||||
def preprocess_sentence(sentence):
|
||||
stop_words = set(stopwords.words('english'))
|
||||
tokens = word_tokenize(sentence.lower())
|
||||
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
||||
return ' '.join(tokens)
|
||||
|
||||
|
||||
def compute_similarity(sent1, sent2):
|
||||
"""
|
||||
Compute the similarity
|
||||
@@ -28,6 +32,7 @@ def compute_similarity(sent1, sent2):
|
||||
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
||||
return 0.0
|
||||
|
||||
|
||||
def remove_almost_alike_sentences(sentences, threshold=0.7):
|
||||
num_sentences = len(sentences)
|
||||
removed_indices = set()
|
||||
@@ -55,12 +60,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
|
||||
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
|
||||
return filtered_sentences
|
||||
|
||||
|
||||
def remove_outright_duplicate_sentences_from_chunk(chunk):
|
||||
chunk_text = chunk["text"]
|
||||
sentences = nltk.sent_tokenize(chunk_text)
|
||||
nonduplicate_sentences = list(dict.fromkeys(sentences))
|
||||
return nonduplicate_sentences
|
||||
|
||||
|
||||
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
||||
chunk_sentences = []
|
||||
|
||||
@@ -80,6 +87,7 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
||||
chunk_sentences.append(temp_result)
|
||||
return chunk_sentences
|
||||
|
||||
|
||||
def post_process_transcription(whisper_result):
|
||||
transcript_text = ""
|
||||
for chunk in whisper_result["chunks"]:
|
||||
@@ -107,12 +115,13 @@ def summarize_chunks(chunks, tokenizer, model):
|
||||
input_ids = input_ids.to(device)
|
||||
with torch.no_grad():
|
||||
summary_ids = model.generate(input_ids,
|
||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
||||
summaries.append(summary)
|
||||
return summaries
|
||||
|
||||
|
||||
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
||||
"""
|
||||
Split text into smaller chunks.
|
||||
@@ -132,6 +141,7 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])
|
||||
chunks.append(current_chunk.strip())
|
||||
return chunks
|
||||
|
||||
|
||||
def summarize(transcript_text, timestamp,
|
||||
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
Reference in New Issue
Block a user