server: reformat whole project using black

This commit is contained in:
Mathieu Virbel
2023-07-27 14:08:41 +02:00
parent 314321c603
commit 094ed696c4
12 changed files with 406 additions and 237 deletions

View File

@@ -15,7 +15,7 @@ from transformers import BartForConditionalGeneration, BartTokenizer
from log_utils import LOGGER
from run_utils import CONFIG
nltk.download('punkt', quiet=True)
nltk.download("punkt", quiet=True)
def preprocess_sentence(sentence: str) -> str:
@@ -24,11 +24,10 @@ def preprocess_sentence(sentence: str) -> str:
:param sentence:
:return:
"""
stop_words = set(stopwords.words('english'))
stop_words = set(stopwords.words("english"))
tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens
if token.isalnum() and token not in stop_words]
return ' '.join(tokens)
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
return " ".join(tokens)
def compute_similarity(sent1: str, sent2: str) -> float:
@@ -67,14 +66,14 @@ def remove_almost_alike_sentences(sentences: List[str], threshold=0.7) -> List[s
sentence1 = preprocess_sentence(sentences[i])
sentence2 = preprocess_sentence(sentences[j])
if len(sentence1) != 0 and len(sentence2) != 0:
similarity = compute_similarity(sentence1,
sentence2)
similarity = compute_similarity(sentence1, sentence2)
if similarity >= threshold:
removed_indices.add(max(i, j))
filtered_sentences = [sentences[i] for i in range(num_sentences)
if i not in removed_indices]
filtered_sentences = [
sentences[i] for i in range(num_sentences) if i not in removed_indices
]
return filtered_sentences
@@ -90,7 +89,9 @@ def remove_outright_duplicate_sentences_from_chunk(chunk: str) -> List[str]:
return nonduplicate_sentences
def remove_whisper_repetitive_hallucination(nonduplicate_sentences: List[str]) -> List[str]:
def remove_whisper_repetitive_hallucination(
nonduplicate_sentences: List[str],
) -> List[str]:
"""
Remove sentences that are repeated as a result of Whisper
hallucinations
@@ -105,13 +106,16 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences: List[str]) -
words = nltk.word_tokenize(sent)
n_gram_filter = 3
for i in range(len(words)):
if str(words[i:i + n_gram_filter]) in seen and \
seen[str(words[i:i + n_gram_filter])] == \
words[i + 1:i + n_gram_filter + 2]:
if (
str(words[i : i + n_gram_filter]) in seen
and seen[str(words[i : i + n_gram_filter])]
== words[i + 1 : i + n_gram_filter + 2]
):
pass
else:
seen[str(words[i:i + n_gram_filter])] = \
words[i + 1:i + n_gram_filter + 2]
seen[str(words[i : i + n_gram_filter])] = words[
i + 1 : i + n_gram_filter + 2
]
temp_result += words[i]
temp_result += " "
chunk_sentences.append(temp_result)
@@ -126,12 +130,11 @@ def post_process_transcription(whisper_result: dict) -> dict:
"""
transcript_text = ""
for chunk in whisper_result["chunks"]:
nonduplicate_sentences = \
remove_outright_duplicate_sentences_from_chunk(chunk)
chunk_sentences = \
remove_whisper_repetitive_hallucination(nonduplicate_sentences)
similarity_matched_sentences = \
remove_almost_alike_sentences(chunk_sentences)
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
chunk_sentences = remove_whisper_repetitive_hallucination(
nonduplicate_sentences
)
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
chunk["text"] = " ".join(similarity_matched_sentences)
transcript_text += chunk["text"]
whisper_result["text"] = transcript_text
@@ -149,23 +152,24 @@ def summarize_chunks(chunks: List[str], tokenizer, model) -> List[str]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summaries = []
for c in chunks:
input_ids = tokenizer.encode(c, return_tensors='pt')
input_ids = tokenizer.encode(c, return_tensors="pt")
input_ids = input_ids.to(device)
with torch.no_grad():
summary_ids = \
model.generate(input_ids,
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
length_penalty=2.0,
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
early_stopping=True)
summary = tokenizer.decode(summary_ids[0],
skip_special_tokens=True)
summary_ids = model.generate(
input_ids,
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
length_penalty=2.0,
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
early_stopping=True,
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(summary)
return summaries
def chunk_text(text: str,
max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])) -> List[str]:
def chunk_text(
text: str, max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])
) -> List[str]:
"""
Split text into smaller chunks.
:param text: Text to be chunked
@@ -185,9 +189,12 @@ def chunk_text(text: str,
return chunks
def summarize(transcript_text: str, timestamp: datetime.datetime.timestamp,
real_time: bool = False,
chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
def summarize(
transcript_text: str,
timestamp: datetime.datetime.timestamp,
real_time: bool = False,
chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"],
):
"""
Summarize the given text either as a whole or as chunks as needed
:param transcript_text:
@@ -213,39 +220,45 @@ def summarize(transcript_text: str, timestamp: datetime.datetime.timestamp,
if chunk_summarize != "YES":
max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer. \
batch_encode_plus([transcript_text], truncation=True,
padding='longest',
max_length=max_length,
return_tensors='pt')
inputs = tokenizer.batch_encode_plus(
[transcript_text],
truncation=True,
padding="longest",
max_length=max_length,
return_tensors="pt",
)
inputs = inputs.to(device)
with torch.no_grad():
num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
summaries = model.generate(inputs['input_ids'],
num_beams=num_beans,
length_penalty=2.0,
max_length=max_length,
early_stopping=True)
summaries = model.generate(
inputs["input_ids"],
num_beams=num_beans,
length_penalty=2.0,
max_length=max_length,
early_stopping=True,
)
decoded_summaries = \
[tokenizer.decode(summary,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for summary in summaries]
decoded_summaries = [
tokenizer.decode(
summary, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
for summary in summaries
]
summary = " ".join(decoded_summaries)
with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file:
with open("./artefacts/" + output_file, "w", encoding="utf-8") as file:
file.write(summary.strip() + "\n")
else:
LOGGER.info("Breaking transcript into smaller chunks")
chunks = chunk_text(transcript_text)
LOGGER.info(f"Transcript broken into {len(chunks)} "
f"chunks of at most 500 words")
LOGGER.info(
f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words"
)
LOGGER.info(f"Writing summary text to: {output_file}")
with open(output_file, 'w') as f:
with open(output_file, "w") as f:
summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries:
f.write(summary.strip() + " ")