flake8 checks

This commit is contained in:
Gokul Mohanarangan
2023-07-25 10:35:47 +05:30
parent cec8bbcf6c
commit 8be41647fe
6 changed files with 40 additions and 32 deletions

View File

@@ -154,7 +154,7 @@ def chunk_text(text,
def summarize(transcript_text, timestamp,
real_time=False,
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
chunk_summarize=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
if not summary_model:
@@ -166,27 +166,35 @@ def summarize(transcript_text, timestamp,
model = BartForConditionalGeneration.from_pretrained(summary_model)
model = model.to(device)
output_filename = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
output_file = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
if real_time:
output_filename = "real_time_" + output_filename
output_file = "real_time_" + output_file
if summarize_using_chunks != "YES":
if chunk_summarize != "YES":
max_length = int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer. \
batch_encode_plus([transcript_text], truncation=True,
padding='longest',
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
max_length=max_length,
return_tensors='pt')
inputs = inputs.to(device)
with torch.no_grad():
num_beans = int(config["DEFAULT"]["BEAM_SIZE"])
max_length = int(config["DEFAULT"]["MAX_LENGTH"])
summaries = model.generate(inputs['input_ids'],
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
num_beams=num_beans,
length_penalty=2.0,
max_length=max_length,
early_stopping=True)
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for summary in summaries]
decoded_summaries = \
[tokenizer.decode(summary,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for summary in summaries]
summary = " ".join(decoded_summaries)
with open("./artefacts/" + output_filename, 'w') as f:
with open("./artefacts/" + output_file, 'w') as f:
f.write(summary.strip() + "\n")
else:
logger.info("Breaking transcript into smaller chunks")
@@ -195,8 +203,8 @@ def summarize(transcript_text, timestamp,
logger.info(f"Transcript broken into {len(chunks)} "
f"chunks of at most 500 words")
logger.info(f"Writing summary text to: {output_filename}")
with open(output_filename, 'w') as f:
logger.info(f"Writing summary text to: {output_file}")
with open(output_file, 'w') as f:
summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries:
f.write(summary.strip() + " ")

View File

@@ -45,14 +45,14 @@ def create_wordcloud(timestamp, real_time=False):
plt.axis("off")
plt.tight_layout(pad=0)
wordcloud_name = "wordcloud"
wordcloud = "wordcloud"
if real_time:
wordcloud_name = "real_time_" + wordcloud_name + "_" + \
wordcloud = "real_time_" + wordcloud + "_" + \
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
else:
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
plt.savefig("./artefacts/" + wordcloud_name)
plt.savefig("./artefacts/" + wordcloud)
def create_talk_diff_scatter_viz(timestamp, real_time=False):