mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 20:59:05 +00:00
flake8 / pylint updates
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
"""
|
||||
Utility file for file handling related functions, including file downloads and
|
||||
uploads to cloud storage
|
||||
"""
|
||||
|
||||
import sys
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
|
||||
from .log_utils import logger
|
||||
from .run_utils import config
|
||||
from .log_utils import LOGGER
|
||||
from .run_utils import CONFIG
|
||||
|
||||
BUCKET_NAME = config["AWS"]["BUCKET_NAME"]
|
||||
BUCKET_NAME = CONFIG["AWS"]["BUCKET_NAME"]
|
||||
|
||||
s3 = boto3.client('s3',
|
||||
aws_access_key_id=config["AWS"]["AWS_ACCESS_KEY"],
|
||||
aws_secret_access_key=config["AWS"]["AWS_SECRET_KEY"])
|
||||
aws_access_key_id=CONFIG["AWS"]["AWS_ACCESS_KEY"],
|
||||
aws_secret_access_key=CONFIG["AWS"]["AWS_SECRET_KEY"])
|
||||
|
||||
|
||||
def upload_files(files_to_upload):
|
||||
@@ -19,12 +24,12 @@ def upload_files(files_to_upload):
|
||||
:param files_to_upload: List of files to upload
|
||||
:return: None
|
||||
"""
|
||||
for KEY in files_to_upload:
|
||||
logger.info("Uploading file " + KEY)
|
||||
for key in files_to_upload:
|
||||
LOGGER.info("Uploading file " + key)
|
||||
try:
|
||||
s3.upload_file(KEY, BUCKET_NAME, KEY)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
print(e.response)
|
||||
s3.upload_file(key, BUCKET_NAME, key)
|
||||
except botocore.exceptions.ClientError as exception:
|
||||
print(exception.response)
|
||||
|
||||
|
||||
def download_files(files_to_download):
|
||||
@@ -33,12 +38,12 @@ def download_files(files_to_download):
|
||||
:param files_to_download: List of files to download
|
||||
:return: None
|
||||
"""
|
||||
for KEY in files_to_download:
|
||||
logger.info("Downloading file " + KEY)
|
||||
for key in files_to_download:
|
||||
LOGGER.info("Downloading file " + key)
|
||||
try:
|
||||
s3.download_file(BUCKET_NAME, KEY, KEY)
|
||||
except botocore.exceptions.ClientError as e:
|
||||
if e.response['Error']['Code'] == "404":
|
||||
s3.download_file(BUCKET_NAME, key, key)
|
||||
except botocore.exceptions.ClientError as exception:
|
||||
if exception.response['Error']['Code'] == "404":
|
||||
print("The object does not exist.")
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
"""
|
||||
Utility function to format the artefacts created during Reflector run
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
with open("../artefacts/meeting_titles_and_summaries.txt", "r") as f:
|
||||
with open("../artefacts/meeting_titles_and_summaries.txt", "r",
|
||||
encoding='utf-8') as f:
|
||||
outputs = f.read()
|
||||
|
||||
outputs = json.loads(outputs)
|
||||
|
||||
transcript_file = open("../artefacts/meeting_transcript.txt", "a")
|
||||
title_desc_file = open("../artefacts/meeting_title_description.txt", "a")
|
||||
summary_file = open("../artefacts/meeting_summary.txt", "a")
|
||||
transcript_file = open("../artefacts/meeting_transcript.txt",
|
||||
"a",
|
||||
encoding='utf-8')
|
||||
title_desc_file = open("../artefacts/meeting_title_description.txt",
|
||||
"a",
|
||||
encoding='utf-8')
|
||||
summary_file = open("../artefacts/meeting_summary.txt",
|
||||
"a",
|
||||
encoding='utf-8')
|
||||
|
||||
for item in outputs["topics"]:
|
||||
transcript_file.write(item["transcript"])
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
"""
|
||||
Utility file for logging
|
||||
"""
|
||||
|
||||
import loguru
|
||||
|
||||
|
||||
class SingletonLogger:
|
||||
"""
|
||||
Use Singleton design pattern to create a logger object and share it
|
||||
across the entire project
|
||||
"""
|
||||
__instance = None
|
||||
|
||||
@staticmethod
|
||||
@@ -15,4 +23,4 @@ class SingletonLogger:
|
||||
return SingletonLogger.__instance
|
||||
|
||||
|
||||
logger = SingletonLogger.get_logger()
|
||||
LOGGER = SingletonLogger.get_logger()
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
"""
|
||||
Utility file for server side asynchronous task running and config objects
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import configparser
|
||||
import contextlib
|
||||
@@ -7,6 +11,9 @@ from typing import ContextManager, Generic, TypeVar
|
||||
|
||||
|
||||
class ReflectorConfig:
|
||||
"""
|
||||
Create a single config object to share across the project
|
||||
"""
|
||||
__config = None
|
||||
|
||||
@staticmethod
|
||||
@@ -17,7 +24,7 @@ class ReflectorConfig:
|
||||
return ReflectorConfig.__config
|
||||
|
||||
|
||||
config = ReflectorConfig.get_config()
|
||||
CONFIG = ReflectorConfig.get_config()
|
||||
|
||||
|
||||
def run_in_executor(func, *args, executor=None, **kwargs):
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
"""
|
||||
Utility file for all text processing related functionalities
|
||||
"""
|
||||
|
||||
import nltk
|
||||
import torch
|
||||
from nltk.corpus import stopwords
|
||||
@@ -6,8 +10,8 @@ from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
|
||||
from log_utils import logger
|
||||
from run_utils import config
|
||||
from log_utils import LOGGER
|
||||
from run_utils import CONFIG
|
||||
|
||||
nltk.download('punkt', quiet=True)
|
||||
|
||||
@@ -32,6 +36,12 @@ def compute_similarity(sent1, sent2):
|
||||
|
||||
|
||||
def remove_almost_alike_sentences(sentences, threshold=0.7):
|
||||
"""
|
||||
Filter sentences that are similar beyond a set threshold
|
||||
:param sentences:
|
||||
:param threshold:
|
||||
:return:
|
||||
"""
|
||||
num_sentences = len(sentences)
|
||||
removed_indices = set()
|
||||
|
||||
@@ -62,6 +72,11 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
|
||||
|
||||
|
||||
def remove_outright_duplicate_sentences_from_chunk(chunk):
|
||||
"""
|
||||
Remove repetitive sentences
|
||||
:param chunk:
|
||||
:return:
|
||||
"""
|
||||
chunk_text = chunk["text"]
|
||||
sentences = nltk.sent_tokenize(chunk_text)
|
||||
nonduplicate_sentences = list(dict.fromkeys(sentences))
|
||||
@@ -69,6 +84,12 @@ def remove_outright_duplicate_sentences_from_chunk(chunk):
|
||||
|
||||
|
||||
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
||||
"""
|
||||
Remove sentences that are repeated as a result of Whisper
|
||||
hallucinations
|
||||
:param nonduplicate_sentences:
|
||||
:return:
|
||||
"""
|
||||
chunk_sentences = []
|
||||
|
||||
for sent in nonduplicate_sentences:
|
||||
@@ -91,6 +112,11 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
||||
|
||||
|
||||
def post_process_transcription(whisper_result):
|
||||
"""
|
||||
Parent function to perform post-processing on the transcription result
|
||||
:param whisper_result:
|
||||
:return:
|
||||
"""
|
||||
transcript_text = ""
|
||||
for chunk in whisper_result["chunks"]:
|
||||
nonduplicate_sentences = \
|
||||
@@ -121,9 +147,9 @@ def summarize_chunks(chunks, tokenizer, model):
|
||||
with torch.no_grad():
|
||||
summary_ids = \
|
||||
model.generate(input_ids,
|
||||
num_beams=int(config["SUMMARIZER"]["BEAM_SIZE"]),
|
||||
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
|
||||
length_penalty=2.0,
|
||||
max_length=int(config["SUMMARIZER"]["MAX_LENGTH"]),
|
||||
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
|
||||
early_stopping=True)
|
||||
summary = tokenizer.decode(summary_ids[0],
|
||||
skip_special_tokens=True)
|
||||
@@ -132,7 +158,7 @@ def summarize_chunks(chunks, tokenizer, model):
|
||||
|
||||
|
||||
def chunk_text(text,
|
||||
max_chunk_length=int(config["SUMMARIZER"]["MAX_CHUNK_LENGTH"])):
|
||||
max_chunk_length=int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])):
|
||||
"""
|
||||
Split text into smaller chunks.
|
||||
:param text: Text to be chunked
|
||||
@@ -154,14 +180,22 @@ def chunk_text(text,
|
||||
|
||||
def summarize(transcript_text, timestamp,
|
||||
real_time=False,
|
||||
chunk_summarize=config["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
|
||||
chunk_summarize=CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
|
||||
"""
|
||||
Summarize the given text either as a whole or as chunks as needed
|
||||
:param transcript_text:
|
||||
:param timestamp:
|
||||
:param real_time:
|
||||
:param chunk_summarize:
|
||||
:return:
|
||||
"""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
summary_model = config["SUMMARIZER"]["SUMMARY_MODEL"]
|
||||
summary_model = CONFIG["SUMMARIZER"]["SUMMARY_MODEL"]
|
||||
if not summary_model:
|
||||
summary_model = "facebook/bart-large-cnn"
|
||||
|
||||
# Summarize the generated transcript using the BART model
|
||||
logger.info(f"Loading BART model: {summary_model}")
|
||||
LOGGER.info(f"Loading BART model: {summary_model}")
|
||||
tokenizer = BartTokenizer.from_pretrained(summary_model)
|
||||
model = BartForConditionalGeneration.from_pretrained(summary_model)
|
||||
model = model.to(device)
|
||||
@@ -171,7 +205,7 @@ def summarize(transcript_text, timestamp,
|
||||
output_file = "real_time_" + output_file
|
||||
|
||||
if chunk_summarize != "YES":
|
||||
max_length = int(config["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
|
||||
max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
|
||||
inputs = tokenizer. \
|
||||
batch_encode_plus([transcript_text], truncation=True,
|
||||
padding='longest',
|
||||
@@ -180,8 +214,8 @@ def summarize(transcript_text, timestamp,
|
||||
inputs = inputs.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
num_beans = int(config["SUMMARIZER"]["BEAM_SIZE"])
|
||||
max_length = int(config["SUMMARIZER"]["MAX_LENGTH"])
|
||||
num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
|
||||
max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
|
||||
summaries = model.generate(inputs['input_ids'],
|
||||
num_beams=num_beans,
|
||||
length_penalty=2.0,
|
||||
@@ -194,16 +228,16 @@ def summarize(transcript_text, timestamp,
|
||||
clean_up_tokenization_spaces=False)
|
||||
for summary in summaries]
|
||||
summary = " ".join(decoded_summaries)
|
||||
with open("./artefacts/" + output_file, 'w') as f:
|
||||
f.write(summary.strip() + "\n")
|
||||
with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file:
|
||||
file.write(summary.strip() + "\n")
|
||||
else:
|
||||
logger.info("Breaking transcript into smaller chunks")
|
||||
LOGGER.info("Breaking transcript into smaller chunks")
|
||||
chunks = chunk_text(transcript_text)
|
||||
|
||||
logger.info(f"Transcript broken into {len(chunks)} "
|
||||
LOGGER.info(f"Transcript broken into {len(chunks)} "
|
||||
f"chunks of at most 500 words")
|
||||
|
||||
logger.info(f"Writing summary text to: {output_file}")
|
||||
LOGGER.info(f"Writing summary text to: {output_file}")
|
||||
with open(output_file, 'w') as f:
|
||||
summaries = summarize_chunks(chunks, tokenizer, model)
|
||||
for summary in summaries:
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
"""
|
||||
Utility file for all visualization related functions
|
||||
"""
|
||||
|
||||
import ast
|
||||
import collections
|
||||
import os
|
||||
@@ -81,8 +85,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
||||
else:
|
||||
filename = "./artefacts/transcript_with_timestamp_" + \
|
||||
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||
with open(filename) as f:
|
||||
transcription_timestamp_text = f.read()
|
||||
with open(filename) as file:
|
||||
transcription_timestamp_text = file.read()
|
||||
|
||||
res = ast.literal_eval(transcription_timestamp_text)
|
||||
chunks = res["chunks"]
|
||||
|
||||
Reference in New Issue
Block a user