From 2dba4ddeb88ef6aa2fd40ce52e54074ae0fdff3e Mon Sep 17 00:00:00 2001 From: gokul Date: Wed, 21 Jun 2023 15:47:32 +0530 Subject: [PATCH] Refactor codebase and fix errors from demo run --- README.md | 22 +- config.ini | 4 +- file_util.py => file_utilities.py | 0 text_utilities.py | 160 ++++++++++++++ viz_utilities.py | 190 ++++++++++++++++ whisjax.py | 354 ++---------------------------- whisjax_realtime.py | 137 ++++++++++++ whisjax_realtime_trial.py | 84 ------- 8 files changed, 527 insertions(+), 424 deletions(-) rename file_util.py => file_utilities.py (100%) create mode 100644 text_utilities.py create mode 100644 viz_utilities.py create mode 100644 whisjax_realtime.py delete mode 100644 whisjax_realtime_trial.py diff --git a/README.md b/README.md index 453d2515..c1b28b5c 100644 --- a/README.md +++ b/README.md @@ -123,24 +123,32 @@ microphone input which you will be using for speaking. We use [Blackhole](https: 2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and local microphone input. - Be sure to mirror the settings given (including the name) ![here](./images/aggregate_input.png) + Be sure to mirror the settings given ![here](./images/aggregate_input.png) 3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device) Refer ![here](./images/multi-output.png) +4) Set the aggregator input device name created in step 2 in config.ini as ```BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME``` -Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and +5) Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and Input tabs. -From the reflector root folder, - -run ```python3 whisjax_realtime_trial.py``` +6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen to +and the output should be fed back to your specified output devices if everything is configured properly. Check this +before trying out the trial. **Permissions:** You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in -```System Preferences -> Privacy & Security -> Microphone``` and in -```System Preferences -> Privacy & Security -> Accessibility```. +```System Preferences -> Privacy & Security -> Microphone```, +```System Preferences -> Privacy & Security -> Accessibility```, +```System Preferences -> Privacy & Security -> Input Monitoring```. + +From the reflector root folder, + +run ```python3 whisjax_realtime.py``` + +The transcription text should be written to ```real_time_transcription_.txt```. NEXT STEPS: diff --git a/config.ini b/config.ini index b5d17a84..138a1778 100644 --- a/config.ini +++ b/config.ini @@ -16,4 +16,6 @@ INPUT_ENCODING_MAX_LENGTH=1024 MAX_LENGTH=2048 BEAM_SIZE=6 MAX_CHUNK_LENGTH=1024 -SUMMARIZE_USING_CHUNKS=YES \ No newline at end of file +SUMMARIZE_USING_CHUNKS=YES +# Audio device +BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME=ref-agg-input \ No newline at end of file diff --git a/file_util.py b/file_utilities.py similarity index 100% rename from file_util.py rename to file_utilities.py diff --git a/text_utilities.py b/text_utilities.py new file mode 100644 index 00000000..bf7cc5ff --- /dev/null +++ b/text_utilities.py @@ -0,0 +1,160 @@ +import torch +import configparser +import nltk +from transformers import BartTokenizer, BartForConditionalGeneration +from loguru import logger +from nltk.corpus import stopwords +from sklearn.feature_extraction.text import TfidfVectorizer +from nltk.tokenize import word_tokenize +from sklearn.metrics.pairwise import cosine_similarity + + +config = configparser.ConfigParser() +config.read('config.ini') + +def preprocess_sentence(sentence): + stop_words = set(stopwords.words('english')) + tokens = word_tokenize(sentence.lower()) + tokens = [token for token in tokens if token.isalnum() and token not in stop_words] + return ' '.join(tokens) + +def compute_similarity(sent1, sent2): + tfidf_vectorizer = TfidfVectorizer() + tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2]) + return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0] + +def remove_almost_alike_sentences(sentences, threshold=0.7): + num_sentences = len(sentences) + removed_indices = set() + + for i in range(num_sentences): + if i not in removed_indices: + for j in range(i + 1, num_sentences): + if j not in removed_indices: + sentence1 = preprocess_sentence(sentences[i]) + sentence2 = preprocess_sentence(sentences[j]) + similarity = compute_similarity(sentence1, sentence2) + + if similarity >= threshold: + removed_indices.add(max(i, j)) + + filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices] + return filtered_sentences + +def remove_outright_duplicate_sentences_from_chunk(chunk): + chunk_text = chunk["text"] + sentences = nltk.sent_tokenize(chunk_text) + nonduplicate_sentences = list(dict.fromkeys(sentences)) + return nonduplicate_sentences + +def remove_whisper_repetitive_hallucination(nonduplicate_sentences): + chunk_sentences = [] + + for sent in nonduplicate_sentences: + temp_result = "" + seen = {} + words = nltk.word_tokenize(sent) + n_gram_filter = 3 + for i in range(len(words)): + if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[ + i + 1:i + n_gram_filter + 2]: + pass + else: + seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2] + temp_result += words[i] + temp_result += " " + chunk_sentences.append(temp_result) + return chunk_sentences + +def post_process_transcription(whisper_result): + for chunk in whisper_result["chunks"]: + nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk) + chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences) + similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences) + chunk["text"] = " ".join(similarity_matched_sentences) + return whisper_result + +def summarize(transcript_text, timestamp, + real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + summary_model = config["DEFAULT"]["SUMMARY_MODEL"] + if not summary_model: + summary_model = "facebook/bart-large-cnn" + + # Summarize the generated transcript using the BART model + logger.info(f"Loading BART model: {summary_model}") + tokenizer = BartTokenizer.from_pretrained(summary_model) + model = BartForConditionalGeneration.from_pretrained(summary_model) + model = model.to(device) + + output_filename = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + if real_time: + output_filename = "real_time_" + output_filename + + if summarize_using_chunks != "YES": + inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest', + max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]), + return_tensors='pt') + inputs = inputs.to(device) + + with torch.no_grad(): + summaries = model.generate(inputs['input_ids'], + num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, + max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) + + decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for + summary in summaries] + summary = " ".join(decoded_summaries) + with open(output_filename, 'w') as f: + f.write(summary.strip() + "\n") + else: + logger.info("Breaking transcript into smaller chunks") + chunks = chunk_text(transcript_text) + + logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable + + logger.info(f"Writing summary text to: {output_filename}") + with open(output_filename, 'w') as f: + summaries = summarize_chunks(chunks, tokenizer, model) + for summary in summaries: + f.write(summary.strip() + " ") + +def summarize_chunks(chunks, tokenizer, model): + """ + Summarize each chunk using a summarizer model + :param chunks: + :param tokenizer: + :param model: + :return: + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + summaries = [] + for c in chunks: + input_ids = tokenizer.encode(c, return_tensors='pt') + input_ids = input_ids.to(device) + with torch.no_grad(): + summary_ids = model.generate(input_ids, + num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, + max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) + summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) + summaries.append(summary) + return summaries + +def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): + """ + Split text into smaller chunks. + :param txt: Text to be chunked + :param max_chunk_length: length of chunk + :return: chunked texts + """ + sentences = nltk.sent_tokenize(text) + chunks = [] + current_chunk = "" + for sentence in sentences: + if len(current_chunk) + len(sentence) < max_chunk_length: + current_chunk += f" {sentence.strip()}" + else: + chunks.append(current_chunk.strip()) + current_chunk = f"{sentence.strip()}" + chunks.append(current_chunk.strip()) + return chunks \ No newline at end of file diff --git a/viz_utilities.py b/viz_utilities.py new file mode 100644 index 00000000..51456c38 --- /dev/null +++ b/viz_utilities.py @@ -0,0 +1,190 @@ +import matplotlib.pyplot as plt +from wordcloud import WordCloud, STOPWORDS +import collections +import spacy +import pickle +import ast +import pandas as pd +import scattertext as st +import configparser + +config = configparser.ConfigParser() +config.read('config.ini') + + +def create_wordcloud(timestamp, real_time=False): + """ + Create a basic word cloud visualization of transcribed text + :return: None. The wordcloud image is saved locally + """ + filename = "transcript" + if real_time: + filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + else: + filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + + with open(filename, "r") as f: + transcription_text = f.read() + + stopwords = set(STOPWORDS) + + # python_mask = np.array(PIL.Image.open("download1.png")) + + wordcloud = WordCloud(height=800, width=800, + background_color='white', + stopwords=stopwords, + min_font_size=8).generate(transcription_text) + + # Plot wordcloud and save image + plt.figure(facecolor=None) + plt.imshow(wordcloud, interpolation="bilinear") + plt.axis("off") + plt.tight_layout(pad=0) + + wordcloud_name = "wordcloud" + if real_time: + wordcloud_name = "real_time_" + wordcloud_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" + else: + wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" + + plt.savefig(wordcloud_name) + + +def create_talk_diff_scatter_viz(timestamp, real_time=False): + """ + Perform agenda vs transription diff to see covered topics. + Create a scatter plot of words in topics. + :return: None. Saved locally. + """ + spaCy_model = "en_core_web_md" + nlp = spacy.load(spaCy_model) + nlp.add_pipe('sentencizer') + + agenda_topics = [] + agenda = [] + # Load the agenda + with open("agenda-headers.txt", "r") as f: + for line in f.readlines(): + if line.strip(): + agenda.append(line.strip()) + agenda_topics.append(line.split(":")[0]) + + # Load the transcription with timestamp + with open("transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt") as f: + transcription_timestamp_text = f.read() + + res = ast.literal_eval(transcription_timestamp_text) + chunks = res["chunks"] + + # create df for processing + df = pd.DataFrame.from_dict(res["chunks"]) + + covered_items = {} + # ts: timestamp + # Map each timestamped chunk with top1 and top2 matched agenda + ts_to_topic_mapping_top_1 = {} + ts_to_topic_mapping_top_2 = {} + + # Also create a mapping of the different timestamps in which each topic was covered + topic_to_ts_mapping_top_1 = collections.defaultdict(list) + topic_to_ts_mapping_top_2 = collections.defaultdict(list) + + similarity_threshold = 0.7 + + for c in chunks: + doc_transcription = nlp(c["text"]) + topic_similarities = [] + for item in range(len(agenda)): + item_doc = nlp(agenda[item]) + # if not doc_transcription or not all(token.has_vector for token in doc_transcription): + if not doc_transcription: + continue + similarity = doc_transcription.similarity(item_doc) + topic_similarities.append((item, similarity)) + topic_similarities.sort(key=lambda x: x[1], reverse=True) + for i in range(2): + if topic_similarities[i][1] >= similarity_threshold: + covered_items[agenda[topic_similarities[i][0]]] = True + # top1 match + if i == 0: + ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[topic_similarities[i][0]] + topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) + # top2 match + else: + ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[topic_similarities[i][0]] + topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) + + def create_new_columns(record): + """ + Accumulate the mapping information into the df + :param record: + :return: + """ + record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]] + record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]] + return record + + df = df.apply(create_new_columns, axis=1) + + # Count the number of items covered and calculatre the percentage + num_covered_items = sum(covered_items.values()) + percentage_covered = num_covered_items / len(agenda) * 100 + + # Print the results + print("💬 Agenda items covered in the transcription:") + for item in agenda: + if item in covered_items and covered_items[item]: + print("✅ ", item) + else: + print("❌ ", item) + print("📊 Coverage: {:.2f}%".format(percentage_covered)) + + # Save df, mappings for further experimentation + df_name = "df" + if real_time: + df_name = "real_time_" + df_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + else: + df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + df.to_pickle(df_name) + + my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2, + topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2] + + mappings_name = "mappings" + if real_time: + mappings_name = "real_time_" + mappings_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + else: + mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + pickle.dump(my_mappings, open(mappings_name, "wb")) + + # to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") ) + + # pick the 2 most matched topic to be used for plotting + topic_times = collections.defaultdict(int) + for key in ts_to_topic_mapping_top_1.keys(): + if key[0] is None or key[1] is None: + continue + duration = key[1] - key[0] + topic_times[ts_to_topic_mapping_top_1[key]] += duration + + topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True) + + cat_1 = topic_times[0][0] + cat_1_name = topic_times[0][0] + cat_2_name = topic_times[1][0] + + # Scatter plot of topics + df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences)) + corpus = st.CorpusFromParsedDocuments( + df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' + ).build().get_unigram_corpus().compact(st.AssociationCompactor(2000)) + html = st.produce_scattertext_explorer( + corpus, + category=cat_1, + category_name=cat_1_name, + not_category_name=cat_2_name, + minimum_term_frequency=0, pmi_threshold_coefficient=0, + width_in_pixels=1000, + transform=st.Scalers.dense_rank + ) + open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) \ No newline at end of file diff --git a/whisjax.py b/whisjax.py index be6707c5..4264ffce 100644 --- a/whisjax.py +++ b/whisjax.py @@ -5,35 +5,26 @@ # summarize podcast.mp3 summary.txt import argparse -import ast -import torch -import collections import configparser import jax.numpy as jnp -import matplotlib.pyplot as plt + import moviepy.editor import moviepy.editor import nltk import os import subprocess -import pandas as pd -import pickle import re -import scattertext as st -import spacy import tempfile from loguru import logger from pytube import YouTube -from transformers import BartTokenizer, BartForConditionalGeneration + from urllib.parse import urlparse from whisper_jax import FlaxWhisperPipline -from wordcloud import WordCloud, STOPWORDS -from nltk.corpus import stopwords -from sklearn.feature_extraction.text import TfidfVectorizer -from nltk.tokenize import word_tokenize -from sklearn.metrics.pairwise import cosine_similarity -from file_util import upload_files, download_files +from datetime import datetime +from file_utilities import upload_files, download_files +from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz +from text_utilities import summarize, post_process_transcription nltk.download('punkt') nltk.download('stopwords') @@ -43,7 +34,7 @@ config = configparser.ConfigParser() config.read('config.ini') WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] - +NOW = datetime.now() def init_argparse() -> argparse.ArgumentParser: """ @@ -57,310 +48,10 @@ def init_argparse() -> argparse.ArgumentParser: parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str, default="english", choices=['english', 'spanish', 'french', 'german', 'romanian']) - parser.add_argument("-t", "--transcript", help="Save a copy of the intermediary transcript file", type=str) parser.add_argument("location") - parser.add_argument("output") - return parser -def chunk_text(txt, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): - """ - Split text into smaller chunks. - :param txt: Text to be chunked - :param max_chunk_length: length of chunk - :return: chunked texts - """ - sentences = nltk.sent_tokenize(txt) - chunks = [] - current_chunk = "" - for sentence in sentences: - if len(current_chunk) + len(sentence) < max_chunk_length: - current_chunk += f" {sentence.strip()}" - else: - chunks.append(current_chunk.strip()) - current_chunk = f"{sentence.strip()}" - chunks.append(current_chunk.strip()) - return chunks - - -def summarize_chunks(chunks, tokenizer, model): - """ - Summarize each chunk using a summarizer model - :param chunks: - :param tokenizer: - :param model: - :return: - """ - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - summaries = [] - for c in chunks: - input_ids = tokenizer.encode(c, return_tensors='pt') - input_ids = input_ids.to(device) - with torch.no_grad(): - summary_ids = model.generate(input_ids, - num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) - summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) - summaries.append(summary) - return summaries - - -def create_wordcloud(): - """ - Create a basic word cloud visualization of transcribed text - :return: None. The wordcloud image is saved locally - """ - with open("transcript.txt", "r") as f: - transcription_text = f.read() - - stopwords = set(STOPWORDS) - - # python_mask = np.array(PIL.Image.open("download1.png")) - - wordcloud = WordCloud(height=800, width=800, - background_color='white', - stopwords=stopwords, - min_font_size=8).generate(transcription_text) - - # Plot wordcloud and save image - plt.figure(facecolor=None) - plt.imshow(wordcloud, interpolation="bilinear") - plt.axis("off") - plt.tight_layout(pad=0) - plt.savefig("wordcloud.png") - - -def create_talk_diff_scatter_viz(): - """ - Perform agenda vs transription diff to see covered topics. - Create a scatter plot of words in topics. - :return: None. Saved locally. - """ - spaCy_model = "en_core_web_md" - nlp = spacy.load(spaCy_model) - nlp.add_pipe('sentencizer') - - agenda_topics = [] - agenda = [] - # Load the agenda - with open("agenda-headers.txt", "r") as f: - for line in f.readlines(): - if line.strip(): - agenda.append(line.strip()) - agenda_topics.append(line.split(":")[0]) - - # Load the transcription with timestamp - with open("transcript_timestamps.txt", "r") as f: - transcription_timestamp_text = f.read() - - res = ast.literal_eval(transcription_timestamp_text) - chunks = res["chunks"] - - # create df for processing - df = pd.DataFrame.from_dict(res["chunks"]) - - covered_items = {} - # ts: timestamp - # Map each timestamped chunk with top1 and top2 matched agenda - ts_to_topic_mapping_top_1 = {} - ts_to_topic_mapping_top_2 = {} - - # Also create a mapping of the different timestamps in which each topic was covered - topic_to_ts_mapping_top_1 = collections.defaultdict(list) - topic_to_ts_mapping_top_2 = collections.defaultdict(list) - - similarity_threshold = 0.7 - - for c in chunks: - doc_transcription = nlp(c["text"]) - topic_similarities = [] - for item in range(len(agenda)): - item_doc = nlp(agenda[item]) - # if not doc_transcription or not all(token.has_vector for token in doc_transcription): - if not doc_transcription: - continue - similarity = doc_transcription.similarity(item_doc) - topic_similarities.append((item, similarity)) - topic_similarities.sort(key=lambda x: x[1], reverse=True) - for i in range(2): - if topic_similarities[i][1] >= similarity_threshold: - covered_items[agenda[topic_similarities[i][0]]] = True - # top1 match - if i == 0: - ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[topic_similarities[i][0]] - topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) - # top2 match - else: - ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[topic_similarities[i][0]] - topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) - - def create_new_columns(record): - """ - Accumulate the mapping information into the df - :param record: - :return: - """ - record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]] - record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]] - return record - - df = df.apply(create_new_columns, axis=1) - - # Count the number of items covered and calculatre the percentage - num_covered_items = sum(covered_items.values()) - percentage_covered = num_covered_items / len(agenda) * 100 - - # Print the results - print("💬 Agenda items covered in the transcription:") - for item in agenda: - if item in covered_items and covered_items[item]: - print("✅ ", item) - else: - print("❌ ", item) - print("📊 Coverage: {:.2f}%".format(percentage_covered)) - - # Save df, mappings for further experimentation - df.to_pickle("df.pkl") - - my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2, - topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2] - pickle.dump(my_mappings, open("mappings.pkl", "wb")) - - # to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") ) - - # pick the 2 most matched topic to be used for plotting - topic_times = collections.defaultdict(int) - for key in ts_to_topic_mapping_top_1.keys(): - if key[0] is None or key[1] is None: - continue - duration = key[1] - key[0] - topic_times[ts_to_topic_mapping_top_1[key]] += duration - - topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True) - - cat_1 = topic_times[0][0] - cat_1_name = topic_times[0][0] - cat_2_name = topic_times[1][0] - - # Scatter plot of topics - df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences)) - corpus = st.CorpusFromParsedDocuments( - df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' - ).build().get_unigram_corpus().compact(st.AssociationCompactor(2000)) - html = st.produce_scattertext_explorer( - corpus, - category=cat_1, - category_name=cat_1_name, - not_category_name=cat_2_name, - minimum_term_frequency=0, pmi_threshold_coefficient=0, - width_in_pixels=1000, - transform=st.Scalers.dense_rank - ) - open('./demo_compact.html', 'w').write(html) - -def preprocess_sentence(sentence): - stop_words = set(stopwords.words('english')) - tokens = word_tokenize(sentence.lower()) - tokens = [token for token in tokens if token.isalnum() and token not in stop_words] - return ' '.join(tokens) - -def compute_similarity(sent1, sent2): - tfidf_vectorizer = TfidfVectorizer() - tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2]) - return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0] - -def remove_almost_alike_sentences(sentences, threshold=0.7): - num_sentences = len(sentences) - removed_indices = set() - - for i in range(num_sentences): - if i not in removed_indices: - for j in range(i + 1, num_sentences): - if j not in removed_indices: - sentence1 = preprocess_sentence(sentences[i]) - sentence2 = preprocess_sentence(sentences[j]) - similarity = compute_similarity(sentence1, sentence2) - - if similarity >= threshold: - removed_indices.add(max(i, j)) - - filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices] - return filtered_sentences - -def remove_outright_duplicate_sentences_from_chunk(chunk): - chunk_text = chunk["text"] - sentences = nltk.sent_tokenize(chunk_text) - nonduplicate_sentences = list(dict.fromkeys(sentences)) - return nonduplicate_sentences - -def remove_whisper_repititive_hallucination(nonduplicate_sentences): - chunk_sentences = [] - - for sent in nonduplicate_sentences: - temp_result = "" - seen = {} - words = nltk.word_tokenize(sent) - n_gram_filter = 3 - for i in range(len(words)): - if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[ - i + 1:i + n_gram_filter + 2]: - pass - else: - seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2] - temp_result += words[i] - temp_result += " " - chunk_sentences.append(temp_result) - return chunk_sentences - -def remove_duplicates_from_transcript_chunk(whisper_result): - for chunk in whisper_result["chunks"]: - nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk) - chunk_sentences = remove_whisper_repititive_hallucination(nonduplicate_sentences) - similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences) - chunk["text"] = " ".join(similarity_matched_sentences) - return whisper_result - -def summarize(transcript_text, output_file, - summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - summary_model = config["DEFAULT"]["SUMMARY_MODEL"] - if not summary_model: - summary_model = "facebook/bart-large-cnn" - - # Summarize the generated transcript using the BART model - logger.info(f"Loading BART model: {summary_model}") - tokenizer = BartTokenizer.from_pretrained(summary_model) - model = BartForConditionalGeneration.from_pretrained(summary_model) - model = model.to(device) - - if summarize_using_chunks != "YES": - inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest', - max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]), - return_tensors='pt') - inputs = inputs.to(device) - - with torch.no_grad(): - summaries = model.generate(inputs['input_ids'], - num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) - - decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for - summary in summaries] - summary = " ".join(decoded_summaries) - with open(output_file, 'w') as f: - f.write(summary.strip() + "\n\n") - else: - logger.info("Breaking transcript into smaller chunks") - chunks = chunk_text(transcript_text) - - logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable - - logger.info(f"Writing summary text to: {output_file}") - with open(output_file, 'w') as f: - summaries = summarize_chunks(chunks, tokenizer, model) - for summary in summaries: - f.write(summary.strip() + " ") def main(): parser = init_argparse() @@ -425,41 +116,40 @@ def main(): whisper_result = pipeline(audio_filename, return_timestamps=True) logger.info("Finished transcribing file") - whisper_result = remove_duplicates_from_transcript_chunk(whisper_result) + whisper_result = post_process_transcription(whisper_result) transcript_text = "" for chunk in whisper_result["chunks"]: transcript_text += chunk["text"] - # If we got the transcript parameter on the command line, - # save the transcript to the specified file. - if args.transcript: - logger.info(f"Saving transcript to: {args.transcript}") - transcript_file = open(args.transcript, "w") - transcript_file_timestamps = open(args.transcript[0:len(args.transcript) - 4] + "_timestamps.txt", "w") + with open("transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file: transcript_file.write(transcript_text) + + with open("transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps: transcript_file_timestamps.write(str(whisper_result)) - transcript_file.close() - transcript_file_timestamps.close() + logger.info("Creating word cloud") - create_wordcloud() + create_wordcloud(NOW) logger.info("Performing talk-diff and talk-diff visualization") - create_talk_diff_scatter_viz() + create_talk_diff_scatter_viz(NOW) # S3 : Push artefacts to S3 bucket - files_to_upload = ["transcript.txt", "transcript_timestamps.txt", - "df.pkl", - "wordcloud.png", "mappings.pkl"] + suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S") + files_to_upload = ["transcript_" + suffix + ".txt", + "transcript_with_timestamp_" + suffix + ".txt", + "df_" + suffix + ".pkl", + "wordcloud_" + suffix + ".png", + "mappings_" + suffix + ".pkl"] upload_files(files_to_upload) - summarize(transcript_text, args.output) + summarize(transcript_text, NOW, False, False) logger.info("Summarization completed") # Summarization takes a lot of time, so do this separately at the end - files_to_upload = ["summary.txt"] + files_to_upload = ["summary_" + suffix + ".txt"] upload_files(files_to_upload) diff --git a/whisjax_realtime.py b/whisjax_realtime.py new file mode 100644 index 00000000..6e8c4c5d --- /dev/null +++ b/whisjax_realtime.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 + +import configparser +import pyaudio +from whisper_jax import FlaxWhisperPipline +from pynput import keyboard +import jax.numpy as jnp +import wave +import datetime +from file_utilities import upload_files +from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz +from text_utilities import summarize, post_process_transcription +from loguru import logger + + +config = configparser.ConfigParser() +config.read('config.ini') + +WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] + +FRAMES_PER_BUFFER = 8000 +FORMAT = pyaudio.paInt16 +CHANNELS = 2 +RATE = 44100 +RECORD_SECONDS = 15 +NOW = datetime.now() + + +def main(): + p = pyaudio.PyAudio() + AUDIO_DEVICE_ID = -1 + for i in range(p.get_device_count()): + if p.get_device_info_by_index(i)["name"] == config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: + AUDIO_DEVICE_ID = i + audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) + stream = p.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + frames_per_buffer=FRAMES_PER_BUFFER, + input_device_index=audio_devices['index'] + ) + + pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], + dtype=jnp.float16, + batch_size=16) + + transcription = "" + + TEMP_AUDIO_FILE = "temp_audio.wav" + global proceed + proceed = True + + def on_press(key): + if key == keyboard.Key.esc: + global proceed + proceed = False + + transcript_with_timestamp = {"text": "", "chunks": []} + last_transcribed_time = 0.0 + + listener = keyboard.Listener(on_press=on_press) + listener.start() + print("Attempting real-time transcription.. Listening...") + + try: + while proceed: + frames = [] + for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)): + data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False) + frames.append(data) + + wf = wave.open(TEMP_AUDIO_FILE, 'wb') + wf.setnchannels(CHANNELS) + wf.setsampwidth(p.get_sample_size(FORMAT)) + wf.setframerate(RATE) + wf.writeframes(b''.join(frames)) + wf.close() + + whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True) + print(whisper_result['text']) + + timestamp = whisper_result["chunks"][0]["timestamp"] + start = timestamp[0] + end = timestamp[1] + if end is None: + end = start + 15.0 + duration = end - start + item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration), + 'text': whisper_result['text']} + last_transcribed_time = last_transcribed_time + duration + transcript_with_timestamp["chunks"].append(item) + + transcription += whisper_result['text'] + + except Exception as e: + print(e) + finally: + with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: + f.write(transcription) + with open("real_time_transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: + transcript_with_timestamp["text"] = transcription + f.write(str(transcript_with_timestamp)) + + transcript_with_timestamp = post_process_transcription(transcript_with_timestamp) + + transcript_text = "" + for chunk in transcript_with_timestamp["chunks"]: + transcript_text += chunk["text"] + + logger.info("Creating word cloud") + create_wordcloud(NOW, True) + + logger.info("Performing talk-diff and talk-diff visualization") + create_talk_diff_scatter_viz(NOW, True) + + # S3 : Push artefacts to S3 bucket + suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S") + files_to_upload = ["real_time_transcript_" + suffix + ".txt", + "real_time_transcript_with_timestamp" + suffix + ".txt", + "real_time_df_" + suffix + ".pkl", + "real_time_wordcloud_" + suffix + ".png", + "real_time_mappings_" + suffix + ".pkl"] + upload_files(files_to_upload) + + summarize(transcript_text, NOW, True, True) + + logger.info("Summarization completed") + + # Summarization takes a lot of time, so do this separately at the end + files_to_upload = ["real_time_summary_" + suffix + ".txt"] + upload_files(files_to_upload) + + +if __name__ == "__main__": + main() diff --git a/whisjax_realtime_trial.py b/whisjax_realtime_trial.py deleted file mode 100644 index 3be1edad..00000000 --- a/whisjax_realtime_trial.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 - -import configparser -import pyaudio -from whisper_jax import FlaxWhisperPipline -from pynput import keyboard -import jax.numpy as jnp -import wave - -config = configparser.ConfigParser() -config.read('config.ini') - -WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] - -FRAMES_PER_BUFFER = 8000 -FORMAT = pyaudio.paInt16 -CHANNELS = 2 -RATE = 44100 -RECORD_SECONDS = 15 - - -def main(): - p = pyaudio.PyAudio() - AUDIO_DEVICE_ID = -1 - for i in range(p.get_device_count()): - if p.get_device_info_by_index(i)["name"] == "ref-agg-input": - AUDIO_DEVICE_ID = i - audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) - stream = p.open( - format=FORMAT, - channels=CHANNELS, - rate=RATE, - input=True, - frames_per_buffer=FRAMES_PER_BUFFER, - input_device_index=audio_devices['index'] - ) - - pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], - dtype=jnp.float16, - batch_size=16) - - transcript_file = open("transcript.txt", "w+") - transcription = "" - - TEMP_AUDIO_FILE = "temp_audio.wav" - global proceed - proceed = True - - def on_press(key): - if key == keyboard.Key.esc: - global proceed - proceed = False - - listener = keyboard.Listener(on_press=on_press) - listener.start() - print("Attempting real-time transcription.. Listening...") - while proceed: - try: - frames = [] - for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)): - data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False) - frames.append(data) - - wf = wave.open(TEMP_AUDIO_FILE, 'wb') - wf.setnchannels(CHANNELS) - wf.setsampwidth(p.get_sample_size(FORMAT)) - wf.setframerate(RATE) - wf.writeframes(b''.join(frames)) - wf.close() - - whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True) - print(whisper_result['text']) - - transcription += whisper_result['text'] - - except Exception as e: - print(e) - finally: - with open("real_time_transcription.txt", "w") as f: - transcript_file.write(transcription) - - -if __name__ == "__main__": - main()