flake8 checks

This commit is contained in:
Gokul Mohanarangan
2023-07-25 10:35:47 +05:30
parent cec8bbcf6c
commit 8be41647fe
6 changed files with 40 additions and 32 deletions

View File

@@ -53,11 +53,11 @@ def get_title_and_summary(llm_input_text, last_timestamp):
prompt = f""" prompt = f"""
### Human: ### Human:
Create a JSON object as response. The JSON object must have 2 fields: Create a JSON object as response. The JSON object must have 2 fields:
i) title and ii) summary. For the title field,generate a short title i) title and ii) summary. For the title field,generate a short title
for the given text. For the summary field, summarize the given text for the given text. For the summary field, summarize the given text
in three sentences. in three sentences.
{llm_input_text} {llm_input_text}
### Assistant: ### Assistant:
@@ -144,12 +144,12 @@ def get_transcription(frames):
result_text = "" result_text = ""
try: try:
segments, _ = model.transcribe(audiofilename, segments, _ = \
language="en", model.transcribe(audiofilename,
beam_size=5, language="en",
vad_filter=True, beam_size=5,
vad_parameters=dict(min_silence_duration_ms=500) vad_filter=True,
) vad_parameters=dict(min_silence_duration_ms=500))
os.remove(audiofilename) os.remove(audiofilename)
segments = list(segments) segments = list(segments)
result_text = "" result_text = ""

View File

@@ -16,8 +16,8 @@ from av import AudioFifo
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from utils.log_utils import logger from ..utils.log_utils import logger
from utils.run_utils import config, Mutex from ..utils.run_utils import config, Mutex
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"] WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"]
pcs = set() pcs = set()

View File

@@ -21,8 +21,8 @@ from whisper_jax import FlaxWhisperPipline
from ...utils.file_utils import download_files, upload_files from ...utils.file_utils import download_files, upload_files
from ...utils.log_utils import logger from ...utils.log_utils import logger
from ...utils.run_utils import config from ...utils.run_utils import config
from ...utils.text_utilities import post_process_transcription, summarize from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True) nltk.download('stopwords', quiet=True)

View File

@@ -13,8 +13,8 @@ from whisper_jax import FlaxWhisperPipline
from ...utils.file_utils import upload_files from ...utils.file_utils import upload_files
from ...utils.log_utils import logger from ...utils.log_utils import logger
from ...utils.run_utils import config from ...utils.run_utils import config
from ...utils.text_utilities import post_process_transcription, summarize from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]

View File

@@ -154,7 +154,7 @@ def chunk_text(text,
def summarize(transcript_text, timestamp, def summarize(transcript_text, timestamp,
real_time=False, real_time=False,
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): chunk_summarize=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["DEFAULT"]["SUMMARY_MODEL"] summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
if not summary_model: if not summary_model:
@@ -166,27 +166,35 @@ def summarize(transcript_text, timestamp,
model = BartForConditionalGeneration.from_pretrained(summary_model) model = BartForConditionalGeneration.from_pretrained(summary_model)
model = model.to(device) model = model.to(device)
output_filename = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" output_file = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
if real_time: if real_time:
output_filename = "real_time_" + output_filename output_file = "real_time_" + output_file
if summarize_using_chunks != "YES": if chunk_summarize != "YES":
max_length = int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer. \ inputs = tokenizer. \
batch_encode_plus([transcript_text], truncation=True, batch_encode_plus([transcript_text], truncation=True,
padding='longest', padding='longest',
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]), max_length=max_length,
return_tensors='pt') return_tensors='pt')
inputs = inputs.to(device) inputs = inputs.to(device)
with torch.no_grad(): with torch.no_grad():
num_beans = int(config["DEFAULT"]["BEAM_SIZE"])
max_length = int(config["DEFAULT"]["MAX_LENGTH"])
summaries = model.generate(inputs['input_ids'], summaries = model.generate(inputs['input_ids'],
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, num_beams=num_beans,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) length_penalty=2.0,
max_length=max_length,
early_stopping=True)
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) decoded_summaries = \
for summary in summaries] [tokenizer.decode(summary,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for summary in summaries]
summary = " ".join(decoded_summaries) summary = " ".join(decoded_summaries)
with open("./artefacts/" + output_filename, 'w') as f: with open("./artefacts/" + output_file, 'w') as f:
f.write(summary.strip() + "\n") f.write(summary.strip() + "\n")
else: else:
logger.info("Breaking transcript into smaller chunks") logger.info("Breaking transcript into smaller chunks")
@@ -195,8 +203,8 @@ def summarize(transcript_text, timestamp,
logger.info(f"Transcript broken into {len(chunks)} " logger.info(f"Transcript broken into {len(chunks)} "
f"chunks of at most 500 words") f"chunks of at most 500 words")
logger.info(f"Writing summary text to: {output_filename}") logger.info(f"Writing summary text to: {output_file}")
with open(output_filename, 'w') as f: with open(output_file, 'w') as f:
summaries = summarize_chunks(chunks, tokenizer, model) summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries: for summary in summaries:
f.write(summary.strip() + " ") f.write(summary.strip() + " ")

View File

@@ -45,14 +45,14 @@ def create_wordcloud(timestamp, real_time=False):
plt.axis("off") plt.axis("off")
plt.tight_layout(pad=0) plt.tight_layout(pad=0)
wordcloud_name = "wordcloud" wordcloud = "wordcloud"
if real_time: if real_time:
wordcloud_name = "real_time_" + wordcloud_name + "_" + \ wordcloud = "real_time_" + wordcloud + "_" + \
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
else: else:
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
plt.savefig("./artefacts/" + wordcloud_name) plt.savefig("./artefacts/" + wordcloud)
def create_talk_diff_scatter_viz(timestamp, real_time=False): def create_talk_diff_scatter_viz(timestamp, real_time=False):