flake8 checks

This commit is contained in:
Gokul Mohanarangan
2023-07-25 10:35:47 +05:30
parent cec8bbcf6c
commit 8be41647fe
6 changed files with 40 additions and 32 deletions

View File

@@ -53,11 +53,11 @@ def get_title_and_summary(llm_input_text, last_timestamp):
prompt = f"""
### Human:
Create a JSON object as response. The JSON object must have 2 fields:
i) title and ii) summary. For the title field,generate a short title
for the given text. For the summary field, summarize the given text
Create a JSON object as response. The JSON object must have 2 fields:
i) title and ii) summary. For the title field,generate a short title
for the given text. For the summary field, summarize the given text
in three sentences.
{llm_input_text}
### Assistant:
@@ -144,12 +144,12 @@ def get_transcription(frames):
result_text = ""
try:
segments, _ = model.transcribe(audiofilename,
language="en",
beam_size=5,
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500)
)
segments, _ = \
model.transcribe(audiofilename,
language="en",
beam_size=5,
vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500))
os.remove(audiofilename)
segments = list(segments)
result_text = ""

View File

@@ -16,8 +16,8 @@ from av import AudioFifo
from sortedcontainers import SortedDict
from whisper_jax import FlaxWhisperPipline
from utils.log_utils import logger
from utils.run_utils import config, Mutex
from ..utils.log_utils import logger
from ..utils.run_utils import config, Mutex
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"]
pcs = set()

View File

@@ -21,8 +21,8 @@ from whisper_jax import FlaxWhisperPipline
from ...utils.file_utils import download_files, upload_files
from ...utils.log_utils import logger
from ...utils.run_utils import config
from ...utils.text_utilities import post_process_transcription, summarize
from ...utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud
from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)

View File

@@ -13,8 +13,8 @@ from whisper_jax import FlaxWhisperPipline
from ...utils.file_utils import upload_files
from ...utils.log_utils import logger
from ...utils.run_utils import config
from ...utils.text_utilities import post_process_transcription, summarize
from ...utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud
from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]

View File

@@ -154,7 +154,7 @@ def chunk_text(text,
def summarize(transcript_text, timestamp,
real_time=False,
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
chunk_summarize=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
if not summary_model:
@@ -166,27 +166,35 @@ def summarize(transcript_text, timestamp,
model = BartForConditionalGeneration.from_pretrained(summary_model)
model = model.to(device)
output_filename = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
output_file = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
if real_time:
output_filename = "real_time_" + output_filename
output_file = "real_time_" + output_file
if summarize_using_chunks != "YES":
if chunk_summarize != "YES":
max_length = int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer. \
batch_encode_plus([transcript_text], truncation=True,
padding='longest',
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
max_length=max_length,
return_tensors='pt')
inputs = inputs.to(device)
with torch.no_grad():
num_beans = int(config["DEFAULT"]["BEAM_SIZE"])
max_length = int(config["DEFAULT"]["MAX_LENGTH"])
summaries = model.generate(inputs['input_ids'],
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
num_beams=num_beans,
length_penalty=2.0,
max_length=max_length,
early_stopping=True)
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for summary in summaries]
decoded_summaries = \
[tokenizer.decode(summary,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for summary in summaries]
summary = " ".join(decoded_summaries)
with open("./artefacts/" + output_filename, 'w') as f:
with open("./artefacts/" + output_file, 'w') as f:
f.write(summary.strip() + "\n")
else:
logger.info("Breaking transcript into smaller chunks")
@@ -195,8 +203,8 @@ def summarize(transcript_text, timestamp,
logger.info(f"Transcript broken into {len(chunks)} "
f"chunks of at most 500 words")
logger.info(f"Writing summary text to: {output_filename}")
with open(output_filename, 'w') as f:
logger.info(f"Writing summary text to: {output_file}")
with open(output_file, 'w') as f:
summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries:
f.write(summary.strip() + " ")

View File

@@ -45,14 +45,14 @@ def create_wordcloud(timestamp, real_time=False):
plt.axis("off")
plt.tight_layout(pad=0)
wordcloud_name = "wordcloud"
wordcloud = "wordcloud"
if real_time:
wordcloud_name = "real_time_" + wordcloud_name + "_" + \
wordcloud = "real_time_" + wordcloud + "_" + \
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
else:
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
plt.savefig("./artefacts/" + wordcloud_name)
plt.savefig("./artefacts/" + wordcloud)
def create_talk_diff_scatter_viz(timestamp, real_time=False):