flake8 checks

This commit is contained in:
Gokul Mohanarangan
2023-07-25 10:35:47 +05:30
parent cec8bbcf6c
commit 8be41647fe
6 changed files with 40 additions and 32 deletions

View File

@@ -144,12 +144,12 @@ def get_transcription(frames):
result_text = "" result_text = ""
try: try:
segments, _ = model.transcribe(audiofilename, segments, _ = \
model.transcribe(audiofilename,
language="en", language="en",
beam_size=5, beam_size=5,
vad_filter=True, vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500) vad_parameters=dict(min_silence_duration_ms=500))
)
os.remove(audiofilename) os.remove(audiofilename)
segments = list(segments) segments = list(segments)
result_text = "" result_text = ""

View File

@@ -16,8 +16,8 @@ from av import AudioFifo
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from utils.log_utils import logger from ..utils.log_utils import logger
from utils.run_utils import config, Mutex from ..utils.run_utils import config, Mutex
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"] WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"]
pcs = set() pcs = set()

View File

@@ -21,8 +21,8 @@ from whisper_jax import FlaxWhisperPipline
from ...utils.file_utils import download_files, upload_files from ...utils.file_utils import download_files, upload_files
from ...utils.log_utils import logger from ...utils.log_utils import logger
from ...utils.run_utils import config from ...utils.run_utils import config
from ...utils.text_utilities import post_process_transcription, summarize from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True) nltk.download('stopwords', quiet=True)

View File

@@ -13,8 +13,8 @@ from whisper_jax import FlaxWhisperPipline
from ...utils.file_utils import upload_files from ...utils.file_utils import upload_files
from ...utils.log_utils import logger from ...utils.log_utils import logger
from ...utils.run_utils import config from ...utils.run_utils import config
from ...utils.text_utilities import post_process_transcription, summarize from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]

View File

@@ -154,7 +154,7 @@ def chunk_text(text,
def summarize(transcript_text, timestamp, def summarize(transcript_text, timestamp,
real_time=False, real_time=False,
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): chunk_summarize=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["DEFAULT"]["SUMMARY_MODEL"] summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
if not summary_model: if not summary_model:
@@ -166,27 +166,35 @@ def summarize(transcript_text, timestamp,
model = BartForConditionalGeneration.from_pretrained(summary_model) model = BartForConditionalGeneration.from_pretrained(summary_model)
model = model.to(device) model = model.to(device)
output_filename = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" output_file = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
if real_time: if real_time:
output_filename = "real_time_" + output_filename output_file = "real_time_" + output_file
if summarize_using_chunks != "YES": if chunk_summarize != "YES":
max_length = int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer. \ inputs = tokenizer. \
batch_encode_plus([transcript_text], truncation=True, batch_encode_plus([transcript_text], truncation=True,
padding='longest', padding='longest',
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]), max_length=max_length,
return_tensors='pt') return_tensors='pt')
inputs = inputs.to(device) inputs = inputs.to(device)
with torch.no_grad(): with torch.no_grad():
num_beans = int(config["DEFAULT"]["BEAM_SIZE"])
max_length = int(config["DEFAULT"]["MAX_LENGTH"])
summaries = model.generate(inputs['input_ids'], summaries = model.generate(inputs['input_ids'],
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, num_beams=num_beans,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) length_penalty=2.0,
max_length=max_length,
early_stopping=True)
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) decoded_summaries = \
[tokenizer.decode(summary,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)
for summary in summaries] for summary in summaries]
summary = " ".join(decoded_summaries) summary = " ".join(decoded_summaries)
with open("./artefacts/" + output_filename, 'w') as f: with open("./artefacts/" + output_file, 'w') as f:
f.write(summary.strip() + "\n") f.write(summary.strip() + "\n")
else: else:
logger.info("Breaking transcript into smaller chunks") logger.info("Breaking transcript into smaller chunks")
@@ -195,8 +203,8 @@ def summarize(transcript_text, timestamp,
logger.info(f"Transcript broken into {len(chunks)} " logger.info(f"Transcript broken into {len(chunks)} "
f"chunks of at most 500 words") f"chunks of at most 500 words")
logger.info(f"Writing summary text to: {output_filename}") logger.info(f"Writing summary text to: {output_file}")
with open(output_filename, 'w') as f: with open(output_file, 'w') as f:
summaries = summarize_chunks(chunks, tokenizer, model) summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries: for summary in summaries:
f.write(summary.strip() + " ") f.write(summary.strip() + " ")

View File

@@ -45,14 +45,14 @@ def create_wordcloud(timestamp, real_time=False):
plt.axis("off") plt.axis("off")
plt.tight_layout(pad=0) plt.tight_layout(pad=0)
wordcloud_name = "wordcloud" wordcloud = "wordcloud"
if real_time: if real_time:
wordcloud_name = "real_time_" + wordcloud_name + "_" + \ wordcloud = "real_time_" + wordcloud + "_" + \
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
else: else:
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
plt.savefig("./artefacts/" + wordcloud_name) plt.savefig("./artefacts/" + wordcloud)
def create_talk_diff_scatter_viz(timestamp, real_time=False): def create_talk_diff_scatter_viz(timestamp, real_time=False):