diff --git a/server.py b/server.py index 6ff68400..290e0456 100644 --- a/server.py +++ b/server.py @@ -53,11 +53,11 @@ def get_title_and_summary(llm_input_text, last_timestamp): prompt = f""" ### Human: - Create a JSON object as response. The JSON object must have 2 fields: - i) title and ii) summary. For the title field,generate a short title - for the given text. For the summary field, summarize the given text + Create a JSON object as response. The JSON object must have 2 fields: + i) title and ii) summary. For the title field,generate a short title + for the given text. For the summary field, summarize the given text in three sentences. - + {llm_input_text} ### Assistant: @@ -144,12 +144,12 @@ def get_transcription(frames): result_text = "" try: - segments, _ = model.transcribe(audiofilename, - language="en", - beam_size=5, - vad_filter=True, - vad_parameters=dict(min_silence_duration_ms=500) - ) + segments, _ = \ + model.transcribe(audiofilename, + language="en", + beam_size=5, + vad_filter=True, + vad_parameters=dict(min_silence_duration_ms=500)) os.remove(audiofilename) segments = list(segments) result_text = "" diff --git a/server_multithreaded.py b/trials/server_multithreaded.py similarity index 98% rename from server_multithreaded.py rename to trials/server_multithreaded.py index 2862fa36..1d27dfdb 100644 --- a/server_multithreaded.py +++ b/trials/server_multithreaded.py @@ -16,8 +16,8 @@ from av import AudioFifo from sortedcontainers import SortedDict from whisper_jax import FlaxWhisperPipline -from utils.log_utils import logger -from utils.run_utils import config, Mutex +from ..utils.log_utils import logger +from ..utils.run_utils import config, Mutex WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"] pcs = set() diff --git a/trials/whisper-jax/whisjax.py b/trials/whisper-jax/whisjax.py index 98f718f3..eb87629d 100644 --- a/trials/whisper-jax/whisjax.py +++ b/trials/whisper-jax/whisjax.py @@ -21,8 +21,8 @@ from whisper_jax import FlaxWhisperPipline from ...utils.file_utils import download_files, upload_files from ...utils.log_utils import logger from ...utils.run_utils import config -from ...utils.text_utilities import post_process_transcription, summarize -from ...utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud +from ...utils.text_utils import post_process_transcription, summarize +from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud nltk.download('punkt', quiet=True) nltk.download('stopwords', quiet=True) diff --git a/trials/whisper-jax/whisjax_realtime.py b/trials/whisper-jax/whisjax_realtime.py index d1ec1a82..efb39461 100644 --- a/trials/whisper-jax/whisjax_realtime.py +++ b/trials/whisper-jax/whisjax_realtime.py @@ -13,8 +13,8 @@ from whisper_jax import FlaxWhisperPipline from ...utils.file_utils import upload_files from ...utils.log_utils import logger from ...utils.run_utils import config -from ...utils.text_utilities import post_process_transcription, summarize -from ...utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud +from ...utils.text_utils import post_process_transcription, summarize +from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] diff --git a/utils/text_utilities.py b/utils/text_utils.py similarity index 85% rename from utils/text_utilities.py rename to utils/text_utils.py index 6210e78e..25126b34 100644 --- a/utils/text_utilities.py +++ b/utils/text_utils.py @@ -154,7 +154,7 @@ def chunk_text(text, def summarize(transcript_text, timestamp, real_time=False, - summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): + chunk_summarize=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") summary_model = config["DEFAULT"]["SUMMARY_MODEL"] if not summary_model: @@ -166,27 +166,35 @@ def summarize(transcript_text, timestamp, model = BartForConditionalGeneration.from_pretrained(summary_model) model = model.to(device) - output_filename = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + output_file = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" if real_time: - output_filename = "real_time_" + output_filename + output_file = "real_time_" + output_file - if summarize_using_chunks != "YES": + if chunk_summarize != "YES": + max_length = int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]) inputs = tokenizer. \ batch_encode_plus([transcript_text], truncation=True, padding='longest', - max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]), + max_length=max_length, return_tensors='pt') inputs = inputs.to(device) with torch.no_grad(): + num_beans = int(config["DEFAULT"]["BEAM_SIZE"]) + max_length = int(config["DEFAULT"]["MAX_LENGTH"]) summaries = model.generate(inputs['input_ids'], - num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) + num_beams=num_beans, + length_penalty=2.0, + max_length=max_length, + early_stopping=True) - decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) - for summary in summaries] + decoded_summaries = \ + [tokenizer.decode(summary, + skip_special_tokens=True, + clean_up_tokenization_spaces=False) + for summary in summaries] summary = " ".join(decoded_summaries) - with open("./artefacts/" + output_filename, 'w') as f: + with open("./artefacts/" + output_file, 'w') as f: f.write(summary.strip() + "\n") else: logger.info("Breaking transcript into smaller chunks") @@ -195,8 +203,8 @@ def summarize(transcript_text, timestamp, logger.info(f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words") - logger.info(f"Writing summary text to: {output_filename}") - with open(output_filename, 'w') as f: + logger.info(f"Writing summary text to: {output_file}") + with open(output_file, 'w') as f: summaries = summarize_chunks(chunks, tokenizer, model) for summary in summaries: f.write(summary.strip() + " ") diff --git a/utils/viz_utilities.py b/utils/viz_utils.py similarity index 97% rename from utils/viz_utilities.py rename to utils/viz_utils.py index 6da24bb0..d7debd0c 100644 --- a/utils/viz_utilities.py +++ b/utils/viz_utils.py @@ -45,14 +45,14 @@ def create_wordcloud(timestamp, real_time=False): plt.axis("off") plt.tight_layout(pad=0) - wordcloud_name = "wordcloud" + wordcloud = "wordcloud" if real_time: - wordcloud_name = "real_time_" + wordcloud_name + "_" + \ + wordcloud = "real_time_" + wordcloud + "_" + \ timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" else: - wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" + wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" - plt.savefig("./artefacts/" + wordcloud_name) + plt.savefig("./artefacts/" + wordcloud) def create_talk_diff_scatter_viz(timestamp, real_time=False):