mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
flake8 warnings fix
This commit is contained in:
@@ -15,7 +15,8 @@ nltk.download('punkt', quiet=True)
|
||||
def preprocess_sentence(sentence):
|
||||
stop_words = set(stopwords.words('english'))
|
||||
tokens = word_tokenize(sentence.lower())
|
||||
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
||||
tokens = [token for token in tokens
|
||||
if token.isalnum() and token not in stop_words]
|
||||
return ' '.join(tokens)
|
||||
|
||||
|
||||
@@ -49,12 +50,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
|
||||
sentence1 = preprocess_sentence(sentences[i])
|
||||
sentence2 = preprocess_sentence(sentences[j])
|
||||
if len(sentence1) != 0 and len(sentence2) != 0:
|
||||
similarity = compute_similarity(sentence1, sentence2)
|
||||
similarity = compute_similarity(sentence1,
|
||||
sentence2)
|
||||
|
||||
if similarity >= threshold:
|
||||
removed_indices.add(max(i, j))
|
||||
|
||||
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
|
||||
filtered_sentences = [sentences[i] for i in range(num_sentences)
|
||||
if i not in removed_indices]
|
||||
return filtered_sentences
|
||||
|
||||
|
||||
@@ -74,11 +77,13 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
||||
words = nltk.word_tokenize(sent)
|
||||
n_gram_filter = 3
|
||||
for i in range(len(words)):
|
||||
if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[
|
||||
i + 1:i + n_gram_filter + 2]:
|
||||
if str(words[i:i + n_gram_filter]) in seen and \
|
||||
seen[str(words[i:i + n_gram_filter])] == \
|
||||
words[i + 1:i + n_gram_filter + 2]:
|
||||
pass
|
||||
else:
|
||||
seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2]
|
||||
seen[str(words[i:i + n_gram_filter])] = \
|
||||
words[i + 1:i + n_gram_filter + 2]
|
||||
temp_result += words[i]
|
||||
temp_result += " "
|
||||
chunk_sentences.append(temp_result)
|
||||
@@ -88,9 +93,12 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
||||
def post_process_transcription(whisper_result):
|
||||
transcript_text = ""
|
||||
for chunk in whisper_result["chunks"]:
|
||||
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
|
||||
chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences)
|
||||
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
|
||||
nonduplicate_sentences = \
|
||||
remove_outright_duplicate_sentences_from_chunk(chunk)
|
||||
chunk_sentences = \
|
||||
remove_whisper_repetitive_hallucination(nonduplicate_sentences)
|
||||
similarity_matched_sentences = \
|
||||
remove_almost_alike_sentences(chunk_sentences)
|
||||
chunk["text"] = " ".join(similarity_matched_sentences)
|
||||
transcript_text += chunk["text"]
|
||||
whisper_result["text"] = transcript_text
|
||||
@@ -111,18 +119,23 @@ def summarize_chunks(chunks, tokenizer, model):
|
||||
input_ids = tokenizer.encode(c, return_tensors='pt')
|
||||
input_ids = input_ids.to(device)
|
||||
with torch.no_grad():
|
||||
summary_ids = model.generate(input_ids,
|
||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
||||
summary_ids = \
|
||||
model.generate(input_ids,
|
||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]),
|
||||
length_penalty=2.0,
|
||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]),
|
||||
early_stopping=True)
|
||||
summary = tokenizer.decode(summary_ids[0],
|
||||
skip_special_tokens=True)
|
||||
summaries.append(summary)
|
||||
return summaries
|
||||
|
||||
|
||||
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
||||
def chunk_text(text,
|
||||
max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
||||
"""
|
||||
Split text into smaller chunks.
|
||||
:param txt: Text to be chunked
|
||||
:param text: Text to be chunked
|
||||
:param max_chunk_length: length of chunk
|
||||
:return: chunked texts
|
||||
"""
|
||||
@@ -140,7 +153,8 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])
|
||||
|
||||
|
||||
def summarize(transcript_text, timestamp,
|
||||
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
||||
real_time=False,
|
||||
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
|
||||
if not summary_model:
|
||||
@@ -157,9 +171,11 @@ def summarize(transcript_text, timestamp,
|
||||
output_filename = "real_time_" + output_filename
|
||||
|
||||
if summarize_using_chunks != "YES":
|
||||
inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest',
|
||||
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
|
||||
return_tensors='pt')
|
||||
inputs = tokenizer.\
|
||||
batch_encode_plus([transcript_text], truncation=True,
|
||||
padding='longest',
|
||||
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
|
||||
return_tensors='pt')
|
||||
inputs = inputs.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -167,8 +183,8 @@ def summarize(transcript_text, timestamp,
|
||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||
|
||||
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for
|
||||
summary in summaries]
|
||||
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
for summary in summaries]
|
||||
summary = " ".join(decoded_summaries)
|
||||
with open(output_filename, 'w') as f:
|
||||
f.write(summary.strip() + "\n")
|
||||
@@ -176,7 +192,8 @@ def summarize(transcript_text, timestamp,
|
||||
logger.info("Breaking transcript into smaller chunks")
|
||||
chunks = chunk_text(transcript_text)
|
||||
|
||||
logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
|
||||
logger.info(f"Transcript broken into {len(chunks)} "
|
||||
f"chunks of at most 500 words")
|
||||
|
||||
logger.info(f"Writing summary text to: {output_filename}")
|
||||
with open(output_filename, 'w') as f:
|
||||
|
||||
Reference in New Issue
Block a user