diff --git a/text_utilities.py b/text_utilities.py index bf7cc5ff..ccb2bbb1 100644 --- a/text_utilities.py +++ b/text_utilities.py @@ -74,6 +74,47 @@ def post_process_transcription(whisper_result): chunk["text"] = " ".join(similarity_matched_sentences) return whisper_result + +def summarize_chunks(chunks, tokenizer, model): + """ + Summarize each chunk using a summarizer model + :param chunks: + :param tokenizer: + :param model: + :return: + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + summaries = [] + for c in chunks: + input_ids = tokenizer.encode(c, return_tensors='pt') + input_ids = input_ids.to(device) + with torch.no_grad(): + summary_ids = model.generate(input_ids, + num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, + max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) + summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) + summaries.append(summary) + return summaries + +def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): + """ + Split text into smaller chunks. + :param txt: Text to be chunked + :param max_chunk_length: length of chunk + :return: chunked texts + """ + sentences = nltk.sent_tokenize(text) + chunks = [] + current_chunk = "" + for sentence in sentences: + if len(current_chunk) + len(sentence) < max_chunk_length: + current_chunk += f" {sentence.strip()}" + else: + chunks.append(current_chunk.strip()) + current_chunk = f"{sentence.strip()}" + chunks.append(current_chunk.strip()) + return chunks + def summarize(transcript_text, timestamp, real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -118,43 +159,3 @@ def summarize(transcript_text, timestamp, summaries = summarize_chunks(chunks, tokenizer, model) for summary in summaries: f.write(summary.strip() + " ") - -def summarize_chunks(chunks, tokenizer, model): - """ - Summarize each chunk using a summarizer model - :param chunks: - :param tokenizer: - :param model: - :return: - """ - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - summaries = [] - for c in chunks: - input_ids = tokenizer.encode(c, return_tensors='pt') - input_ids = input_ids.to(device) - with torch.no_grad(): - summary_ids = model.generate(input_ids, - num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) - summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) - summaries.append(summary) - return summaries - -def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): - """ - Split text into smaller chunks. - :param txt: Text to be chunked - :param max_chunk_length: length of chunk - :return: chunked texts - """ - sentences = nltk.sent_tokenize(text) - chunks = [] - current_chunk = "" - for sentence in sentences: - if len(current_chunk) + len(sentence) < max_chunk_length: - current_chunk += f" {sentence.strip()}" - else: - chunks.append(current_chunk.strip()) - current_chunk = f"{sentence.strip()}" - chunks.append(current_chunk.strip()) - return chunks \ No newline at end of file diff --git a/viz_utilities.py b/viz_utilities.py index 51456c38..2d270394 100644 --- a/viz_utilities.py +++ b/viz_utilities.py @@ -70,7 +70,12 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False): agenda_topics.append(line.split(":")[0]) # Load the transcription with timestamp - with open("transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt") as f: + filename = "" + if real_time: + filename = "real_time_transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + else: + filename = "transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + with open(filename) as f: transcription_timestamp_text = f.read() res = ast.literal_eval(transcription_timestamp_text) diff --git a/whisjax_realtime.py b/whisjax_realtime.py index 6e8c4c5d..7a93cee5 100644 --- a/whisjax_realtime.py +++ b/whisjax_realtime.py @@ -6,7 +6,7 @@ from whisper_jax import FlaxWhisperPipline from pynput import keyboard import jax.numpy as jnp import wave -import datetime +from datetime import datetime from file_utilities import upload_files from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz from text_utilities import summarize, post_process_transcription