diff --git a/README.md b/README.md index 453d2515..12e4ef15 100644 --- a/README.md +++ b/README.md @@ -32,11 +32,11 @@ To setup, 5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it. -``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ" --transcript transcript.txt summary.txt ``` +``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ"``` You can even run it on local file or a file in your configured S3 bucket. -``` python3 whisjax.py "startup.mp4" --transcript transcript.txt summary.txt ``` +``` python3 whisjax.py "startup.mp4"``` The script will take care of a few cases like youtube file, local file, video file, audio-only file, file in S3, etc. If local file is not present, it can automatically take the file from S3. @@ -85,7 +85,7 @@ mentioned above or simply use the GUI of AWS Management Console. 1) ```agenda_topic : ``` 3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact called - ```df.pkl``` , you can load the df and choose different topics to plot. You can filter using certain words to search for the + ```df_.pkl``` , you can load the df and choose different topics to plot. You can filter using certain words to search for the transcriptions and you can see the top influencers and characteristic in each topic we have chosen to plot in the interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with, named ```Viz_experiments.ipynb```. @@ -123,24 +123,32 @@ microphone input which you will be using for speaking. We use [Blackhole](https: 2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and local microphone input. - Be sure to mirror the settings given (including the name) ![here](./images/aggregate_input.png) + Be sure to mirror the settings given ![here](./images/aggregate_input.png) 3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device) Refer ![here](./images/multi-output.png) +4) Set the aggregator input device name created in step 2 in config.ini as ```BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME``` -Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and +5) Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and Input tabs. -From the reflector root folder, - -run ```python3 whisjax_realtime_trial.py``` +6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen to +and the output should be fed back to your specified output devices if everything is configured properly. Check this +before trying out the trial. **Permissions:** You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in -```System Preferences -> Privacy & Security -> Microphone``` and in -```System Preferences -> Privacy & Security -> Accessibility```. +```System Preferences -> Privacy & Security -> Microphone```, +```System Preferences -> Privacy & Security -> Accessibility```, +```System Preferences -> Privacy & Security -> Input Monitoring```. + +From the reflector root folder, + +run ```python3 whisjax_realtime.py``` + +The transcription text should be written to ```real_time_transcription_.txt```. NEXT STEPS: diff --git a/config.ini b/config.ini index b5d17a84..138a1778 100644 --- a/config.ini +++ b/config.ini @@ -16,4 +16,6 @@ INPUT_ENCODING_MAX_LENGTH=1024 MAX_LENGTH=2048 BEAM_SIZE=6 MAX_CHUNK_LENGTH=1024 -SUMMARIZE_USING_CHUNKS=YES \ No newline at end of file +SUMMARIZE_USING_CHUNKS=YES +# Audio device +BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME=ref-agg-input \ No newline at end of file diff --git a/file_util.py b/file_utilities.py similarity index 100% rename from file_util.py rename to file_utilities.py diff --git a/text_utilities.py b/text_utilities.py new file mode 100644 index 00000000..bf7cc5ff --- /dev/null +++ b/text_utilities.py @@ -0,0 +1,160 @@ +import torch +import configparser +import nltk +from transformers import BartTokenizer, BartForConditionalGeneration +from loguru import logger +from nltk.corpus import stopwords +from sklearn.feature_extraction.text import TfidfVectorizer +from nltk.tokenize import word_tokenize +from sklearn.metrics.pairwise import cosine_similarity + + +config = configparser.ConfigParser() +config.read('config.ini') + +def preprocess_sentence(sentence): + stop_words = set(stopwords.words('english')) + tokens = word_tokenize(sentence.lower()) + tokens = [token for token in tokens if token.isalnum() and token not in stop_words] + return ' '.join(tokens) + +def compute_similarity(sent1, sent2): + tfidf_vectorizer = TfidfVectorizer() + tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2]) + return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0] + +def remove_almost_alike_sentences(sentences, threshold=0.7): + num_sentences = len(sentences) + removed_indices = set() + + for i in range(num_sentences): + if i not in removed_indices: + for j in range(i + 1, num_sentences): + if j not in removed_indices: + sentence1 = preprocess_sentence(sentences[i]) + sentence2 = preprocess_sentence(sentences[j]) + similarity = compute_similarity(sentence1, sentence2) + + if similarity >= threshold: + removed_indices.add(max(i, j)) + + filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices] + return filtered_sentences + +def remove_outright_duplicate_sentences_from_chunk(chunk): + chunk_text = chunk["text"] + sentences = nltk.sent_tokenize(chunk_text) + nonduplicate_sentences = list(dict.fromkeys(sentences)) + return nonduplicate_sentences + +def remove_whisper_repetitive_hallucination(nonduplicate_sentences): + chunk_sentences = [] + + for sent in nonduplicate_sentences: + temp_result = "" + seen = {} + words = nltk.word_tokenize(sent) + n_gram_filter = 3 + for i in range(len(words)): + if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[ + i + 1:i + n_gram_filter + 2]: + pass + else: + seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2] + temp_result += words[i] + temp_result += " " + chunk_sentences.append(temp_result) + return chunk_sentences + +def post_process_transcription(whisper_result): + for chunk in whisper_result["chunks"]: + nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk) + chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences) + similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences) + chunk["text"] = " ".join(similarity_matched_sentences) + return whisper_result + +def summarize(transcript_text, timestamp, + real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + summary_model = config["DEFAULT"]["SUMMARY_MODEL"] + if not summary_model: + summary_model = "facebook/bart-large-cnn" + + # Summarize the generated transcript using the BART model + logger.info(f"Loading BART model: {summary_model}") + tokenizer = BartTokenizer.from_pretrained(summary_model) + model = BartForConditionalGeneration.from_pretrained(summary_model) + model = model.to(device) + + output_filename = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + if real_time: + output_filename = "real_time_" + output_filename + + if summarize_using_chunks != "YES": + inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest', + max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]), + return_tensors='pt') + inputs = inputs.to(device) + + with torch.no_grad(): + summaries = model.generate(inputs['input_ids'], + num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, + max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) + + decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for + summary in summaries] + summary = " ".join(decoded_summaries) + with open(output_filename, 'w') as f: + f.write(summary.strip() + "\n") + else: + logger.info("Breaking transcript into smaller chunks") + chunks = chunk_text(transcript_text) + + logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable + + logger.info(f"Writing summary text to: {output_filename}") + with open(output_filename, 'w') as f: + summaries = summarize_chunks(chunks, tokenizer, model) + for summary in summaries: + f.write(summary.strip() + " ") + +def summarize_chunks(chunks, tokenizer, model): + """ + Summarize each chunk using a summarizer model + :param chunks: + :param tokenizer: + :param model: + :return: + """ + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + summaries = [] + for c in chunks: + input_ids = tokenizer.encode(c, return_tensors='pt') + input_ids = input_ids.to(device) + with torch.no_grad(): + summary_ids = model.generate(input_ids, + num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, + max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) + summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) + summaries.append(summary) + return summaries + +def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): + """ + Split text into smaller chunks. + :param txt: Text to be chunked + :param max_chunk_length: length of chunk + :return: chunked texts + """ + sentences = nltk.sent_tokenize(text) + chunks = [] + current_chunk = "" + for sentence in sentences: + if len(current_chunk) + len(sentence) < max_chunk_length: + current_chunk += f" {sentence.strip()}" + else: + chunks.append(current_chunk.strip()) + current_chunk = f"{sentence.strip()}" + chunks.append(current_chunk.strip()) + return chunks \ No newline at end of file diff --git a/viz_utilities.py b/viz_utilities.py new file mode 100644 index 00000000..51456c38 --- /dev/null +++ b/viz_utilities.py @@ -0,0 +1,190 @@ +import matplotlib.pyplot as plt +from wordcloud import WordCloud, STOPWORDS +import collections +import spacy +import pickle +import ast +import pandas as pd +import scattertext as st +import configparser + +config = configparser.ConfigParser() +config.read('config.ini') + + +def create_wordcloud(timestamp, real_time=False): + """ + Create a basic word cloud visualization of transcribed text + :return: None. The wordcloud image is saved locally + """ + filename = "transcript" + if real_time: + filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + else: + filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + + with open(filename, "r") as f: + transcription_text = f.read() + + stopwords = set(STOPWORDS) + + # python_mask = np.array(PIL.Image.open("download1.png")) + + wordcloud = WordCloud(height=800, width=800, + background_color='white', + stopwords=stopwords, + min_font_size=8).generate(transcription_text) + + # Plot wordcloud and save image + plt.figure(facecolor=None) + plt.imshow(wordcloud, interpolation="bilinear") + plt.axis("off") + plt.tight_layout(pad=0) + + wordcloud_name = "wordcloud" + if real_time: + wordcloud_name = "real_time_" + wordcloud_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" + else: + wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" + + plt.savefig(wordcloud_name) + + +def create_talk_diff_scatter_viz(timestamp, real_time=False): + """ + Perform agenda vs transription diff to see covered topics. + Create a scatter plot of words in topics. + :return: None. Saved locally. + """ + spaCy_model = "en_core_web_md" + nlp = spacy.load(spaCy_model) + nlp.add_pipe('sentencizer') + + agenda_topics = [] + agenda = [] + # Load the agenda + with open("agenda-headers.txt", "r") as f: + for line in f.readlines(): + if line.strip(): + agenda.append(line.strip()) + agenda_topics.append(line.split(":")[0]) + + # Load the transcription with timestamp + with open("transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt") as f: + transcription_timestamp_text = f.read() + + res = ast.literal_eval(transcription_timestamp_text) + chunks = res["chunks"] + + # create df for processing + df = pd.DataFrame.from_dict(res["chunks"]) + + covered_items = {} + # ts: timestamp + # Map each timestamped chunk with top1 and top2 matched agenda + ts_to_topic_mapping_top_1 = {} + ts_to_topic_mapping_top_2 = {} + + # Also create a mapping of the different timestamps in which each topic was covered + topic_to_ts_mapping_top_1 = collections.defaultdict(list) + topic_to_ts_mapping_top_2 = collections.defaultdict(list) + + similarity_threshold = 0.7 + + for c in chunks: + doc_transcription = nlp(c["text"]) + topic_similarities = [] + for item in range(len(agenda)): + item_doc = nlp(agenda[item]) + # if not doc_transcription or not all(token.has_vector for token in doc_transcription): + if not doc_transcription: + continue + similarity = doc_transcription.similarity(item_doc) + topic_similarities.append((item, similarity)) + topic_similarities.sort(key=lambda x: x[1], reverse=True) + for i in range(2): + if topic_similarities[i][1] >= similarity_threshold: + covered_items[agenda[topic_similarities[i][0]]] = True + # top1 match + if i == 0: + ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[topic_similarities[i][0]] + topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) + # top2 match + else: + ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[topic_similarities[i][0]] + topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) + + def create_new_columns(record): + """ + Accumulate the mapping information into the df + :param record: + :return: + """ + record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]] + record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]] + return record + + df = df.apply(create_new_columns, axis=1) + + # Count the number of items covered and calculatre the percentage + num_covered_items = sum(covered_items.values()) + percentage_covered = num_covered_items / len(agenda) * 100 + + # Print the results + print("💬 Agenda items covered in the transcription:") + for item in agenda: + if item in covered_items and covered_items[item]: + print("✅ ", item) + else: + print("❌ ", item) + print("📊 Coverage: {:.2f}%".format(percentage_covered)) + + # Save df, mappings for further experimentation + df_name = "df" + if real_time: + df_name = "real_time_" + df_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + else: + df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + df.to_pickle(df_name) + + my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2, + topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2] + + mappings_name = "mappings" + if real_time: + mappings_name = "real_time_" + mappings_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + else: + mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + pickle.dump(my_mappings, open(mappings_name, "wb")) + + # to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") ) + + # pick the 2 most matched topic to be used for plotting + topic_times = collections.defaultdict(int) + for key in ts_to_topic_mapping_top_1.keys(): + if key[0] is None or key[1] is None: + continue + duration = key[1] - key[0] + topic_times[ts_to_topic_mapping_top_1[key]] += duration + + topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True) + + cat_1 = topic_times[0][0] + cat_1_name = topic_times[0][0] + cat_2_name = topic_times[1][0] + + # Scatter plot of topics + df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences)) + corpus = st.CorpusFromParsedDocuments( + df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' + ).build().get_unigram_corpus().compact(st.AssociationCompactor(2000)) + html = st.produce_scattertext_explorer( + corpus, + category=cat_1, + category_name=cat_1_name, + not_category_name=cat_2_name, + minimum_term_frequency=0, pmi_threshold_coefficient=0, + width_in_pixels=1000, + transform=st.Scalers.dense_rank + ) + open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) \ No newline at end of file diff --git a/whisjax.py b/whisjax.py index be6707c5..4264ffce 100644 --- a/whisjax.py +++ b/whisjax.py @@ -5,35 +5,26 @@ # summarize podcast.mp3 summary.txt import argparse -import ast -import torch -import collections import configparser import jax.numpy as jnp -import matplotlib.pyplot as plt + import moviepy.editor import moviepy.editor import nltk import os import subprocess -import pandas as pd -import pickle import re -import scattertext as st -import spacy import tempfile from loguru import logger from pytube import YouTube -from transformers import BartTokenizer, BartForConditionalGeneration + from urllib.parse import urlparse from whisper_jax import FlaxWhisperPipline -from wordcloud import WordCloud, STOPWORDS -from nltk.corpus import stopwords -from sklearn.feature_extraction.text import TfidfVectorizer -from nltk.tokenize import word_tokenize -from sklearn.metrics.pairwise import cosine_similarity -from file_util import upload_files, download_files +from datetime import datetime +from file_utilities import upload_files, download_files +from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz +from text_utilities import summarize, post_process_transcription nltk.download('punkt') nltk.download('stopwords') @@ -43,7 +34,7 @@ config = configparser.ConfigParser() config.read('config.ini') WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] - +NOW = datetime.now() def init_argparse() -> argparse.ArgumentParser: """ @@ -57,310 +48,10 @@ def init_argparse() -> argparse.ArgumentParser: parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str, default="english", choices=['english', 'spanish', 'french', 'german', 'romanian']) - parser.add_argument("-t", "--transcript", help="Save a copy of the intermediary transcript file", type=str) parser.add_argument("location") - parser.add_argument("output") - return parser -def chunk_text(txt, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): - """ - Split text into smaller chunks. - :param txt: Text to be chunked - :param max_chunk_length: length of chunk - :return: chunked texts - """ - sentences = nltk.sent_tokenize(txt) - chunks = [] - current_chunk = "" - for sentence in sentences: - if len(current_chunk) + len(sentence) < max_chunk_length: - current_chunk += f" {sentence.strip()}" - else: - chunks.append(current_chunk.strip()) - current_chunk = f"{sentence.strip()}" - chunks.append(current_chunk.strip()) - return chunks - - -def summarize_chunks(chunks, tokenizer, model): - """ - Summarize each chunk using a summarizer model - :param chunks: - :param tokenizer: - :param model: - :return: - """ - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - summaries = [] - for c in chunks: - input_ids = tokenizer.encode(c, return_tensors='pt') - input_ids = input_ids.to(device) - with torch.no_grad(): - summary_ids = model.generate(input_ids, - num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) - summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) - summaries.append(summary) - return summaries - - -def create_wordcloud(): - """ - Create a basic word cloud visualization of transcribed text - :return: None. The wordcloud image is saved locally - """ - with open("transcript.txt", "r") as f: - transcription_text = f.read() - - stopwords = set(STOPWORDS) - - # python_mask = np.array(PIL.Image.open("download1.png")) - - wordcloud = WordCloud(height=800, width=800, - background_color='white', - stopwords=stopwords, - min_font_size=8).generate(transcription_text) - - # Plot wordcloud and save image - plt.figure(facecolor=None) - plt.imshow(wordcloud, interpolation="bilinear") - plt.axis("off") - plt.tight_layout(pad=0) - plt.savefig("wordcloud.png") - - -def create_talk_diff_scatter_viz(): - """ - Perform agenda vs transription diff to see covered topics. - Create a scatter plot of words in topics. - :return: None. Saved locally. - """ - spaCy_model = "en_core_web_md" - nlp = spacy.load(spaCy_model) - nlp.add_pipe('sentencizer') - - agenda_topics = [] - agenda = [] - # Load the agenda - with open("agenda-headers.txt", "r") as f: - for line in f.readlines(): - if line.strip(): - agenda.append(line.strip()) - agenda_topics.append(line.split(":")[0]) - - # Load the transcription with timestamp - with open("transcript_timestamps.txt", "r") as f: - transcription_timestamp_text = f.read() - - res = ast.literal_eval(transcription_timestamp_text) - chunks = res["chunks"] - - # create df for processing - df = pd.DataFrame.from_dict(res["chunks"]) - - covered_items = {} - # ts: timestamp - # Map each timestamped chunk with top1 and top2 matched agenda - ts_to_topic_mapping_top_1 = {} - ts_to_topic_mapping_top_2 = {} - - # Also create a mapping of the different timestamps in which each topic was covered - topic_to_ts_mapping_top_1 = collections.defaultdict(list) - topic_to_ts_mapping_top_2 = collections.defaultdict(list) - - similarity_threshold = 0.7 - - for c in chunks: - doc_transcription = nlp(c["text"]) - topic_similarities = [] - for item in range(len(agenda)): - item_doc = nlp(agenda[item]) - # if not doc_transcription or not all(token.has_vector for token in doc_transcription): - if not doc_transcription: - continue - similarity = doc_transcription.similarity(item_doc) - topic_similarities.append((item, similarity)) - topic_similarities.sort(key=lambda x: x[1], reverse=True) - for i in range(2): - if topic_similarities[i][1] >= similarity_threshold: - covered_items[agenda[topic_similarities[i][0]]] = True - # top1 match - if i == 0: - ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[topic_similarities[i][0]] - topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) - # top2 match - else: - ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[topic_similarities[i][0]] - topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) - - def create_new_columns(record): - """ - Accumulate the mapping information into the df - :param record: - :return: - """ - record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]] - record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]] - return record - - df = df.apply(create_new_columns, axis=1) - - # Count the number of items covered and calculatre the percentage - num_covered_items = sum(covered_items.values()) - percentage_covered = num_covered_items / len(agenda) * 100 - - # Print the results - print("💬 Agenda items covered in the transcription:") - for item in agenda: - if item in covered_items and covered_items[item]: - print("✅ ", item) - else: - print("❌ ", item) - print("📊 Coverage: {:.2f}%".format(percentage_covered)) - - # Save df, mappings for further experimentation - df.to_pickle("df.pkl") - - my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2, - topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2] - pickle.dump(my_mappings, open("mappings.pkl", "wb")) - - # to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") ) - - # pick the 2 most matched topic to be used for plotting - topic_times = collections.defaultdict(int) - for key in ts_to_topic_mapping_top_1.keys(): - if key[0] is None or key[1] is None: - continue - duration = key[1] - key[0] - topic_times[ts_to_topic_mapping_top_1[key]] += duration - - topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True) - - cat_1 = topic_times[0][0] - cat_1_name = topic_times[0][0] - cat_2_name = topic_times[1][0] - - # Scatter plot of topics - df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences)) - corpus = st.CorpusFromParsedDocuments( - df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' - ).build().get_unigram_corpus().compact(st.AssociationCompactor(2000)) - html = st.produce_scattertext_explorer( - corpus, - category=cat_1, - category_name=cat_1_name, - not_category_name=cat_2_name, - minimum_term_frequency=0, pmi_threshold_coefficient=0, - width_in_pixels=1000, - transform=st.Scalers.dense_rank - ) - open('./demo_compact.html', 'w').write(html) - -def preprocess_sentence(sentence): - stop_words = set(stopwords.words('english')) - tokens = word_tokenize(sentence.lower()) - tokens = [token for token in tokens if token.isalnum() and token not in stop_words] - return ' '.join(tokens) - -def compute_similarity(sent1, sent2): - tfidf_vectorizer = TfidfVectorizer() - tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2]) - return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0] - -def remove_almost_alike_sentences(sentences, threshold=0.7): - num_sentences = len(sentences) - removed_indices = set() - - for i in range(num_sentences): - if i not in removed_indices: - for j in range(i + 1, num_sentences): - if j not in removed_indices: - sentence1 = preprocess_sentence(sentences[i]) - sentence2 = preprocess_sentence(sentences[j]) - similarity = compute_similarity(sentence1, sentence2) - - if similarity >= threshold: - removed_indices.add(max(i, j)) - - filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices] - return filtered_sentences - -def remove_outright_duplicate_sentences_from_chunk(chunk): - chunk_text = chunk["text"] - sentences = nltk.sent_tokenize(chunk_text) - nonduplicate_sentences = list(dict.fromkeys(sentences)) - return nonduplicate_sentences - -def remove_whisper_repititive_hallucination(nonduplicate_sentences): - chunk_sentences = [] - - for sent in nonduplicate_sentences: - temp_result = "" - seen = {} - words = nltk.word_tokenize(sent) - n_gram_filter = 3 - for i in range(len(words)): - if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[ - i + 1:i + n_gram_filter + 2]: - pass - else: - seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2] - temp_result += words[i] - temp_result += " " - chunk_sentences.append(temp_result) - return chunk_sentences - -def remove_duplicates_from_transcript_chunk(whisper_result): - for chunk in whisper_result["chunks"]: - nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk) - chunk_sentences = remove_whisper_repititive_hallucination(nonduplicate_sentences) - similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences) - chunk["text"] = " ".join(similarity_matched_sentences) - return whisper_result - -def summarize(transcript_text, output_file, - summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - summary_model = config["DEFAULT"]["SUMMARY_MODEL"] - if not summary_model: - summary_model = "facebook/bart-large-cnn" - - # Summarize the generated transcript using the BART model - logger.info(f"Loading BART model: {summary_model}") - tokenizer = BartTokenizer.from_pretrained(summary_model) - model = BartForConditionalGeneration.from_pretrained(summary_model) - model = model.to(device) - - if summarize_using_chunks != "YES": - inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest', - max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]), - return_tensors='pt') - inputs = inputs.to(device) - - with torch.no_grad(): - summaries = model.generate(inputs['input_ids'], - num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, - max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) - - decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for - summary in summaries] - summary = " ".join(decoded_summaries) - with open(output_file, 'w') as f: - f.write(summary.strip() + "\n\n") - else: - logger.info("Breaking transcript into smaller chunks") - chunks = chunk_text(transcript_text) - - logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable - - logger.info(f"Writing summary text to: {output_file}") - with open(output_file, 'w') as f: - summaries = summarize_chunks(chunks, tokenizer, model) - for summary in summaries: - f.write(summary.strip() + " ") def main(): parser = init_argparse() @@ -425,41 +116,40 @@ def main(): whisper_result = pipeline(audio_filename, return_timestamps=True) logger.info("Finished transcribing file") - whisper_result = remove_duplicates_from_transcript_chunk(whisper_result) + whisper_result = post_process_transcription(whisper_result) transcript_text = "" for chunk in whisper_result["chunks"]: transcript_text += chunk["text"] - # If we got the transcript parameter on the command line, - # save the transcript to the specified file. - if args.transcript: - logger.info(f"Saving transcript to: {args.transcript}") - transcript_file = open(args.transcript, "w") - transcript_file_timestamps = open(args.transcript[0:len(args.transcript) - 4] + "_timestamps.txt", "w") + with open("transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file: transcript_file.write(transcript_text) + + with open("transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps: transcript_file_timestamps.write(str(whisper_result)) - transcript_file.close() - transcript_file_timestamps.close() + logger.info("Creating word cloud") - create_wordcloud() + create_wordcloud(NOW) logger.info("Performing talk-diff and talk-diff visualization") - create_talk_diff_scatter_viz() + create_talk_diff_scatter_viz(NOW) # S3 : Push artefacts to S3 bucket - files_to_upload = ["transcript.txt", "transcript_timestamps.txt", - "df.pkl", - "wordcloud.png", "mappings.pkl"] + suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S") + files_to_upload = ["transcript_" + suffix + ".txt", + "transcript_with_timestamp_" + suffix + ".txt", + "df_" + suffix + ".pkl", + "wordcloud_" + suffix + ".png", + "mappings_" + suffix + ".pkl"] upload_files(files_to_upload) - summarize(transcript_text, args.output) + summarize(transcript_text, NOW, False, False) logger.info("Summarization completed") # Summarization takes a lot of time, so do this separately at the end - files_to_upload = ["summary.txt"] + files_to_upload = ["summary_" + suffix + ".txt"] upload_files(files_to_upload) diff --git a/whisjax_realtime.py b/whisjax_realtime.py new file mode 100644 index 00000000..6e8c4c5d --- /dev/null +++ b/whisjax_realtime.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 + +import configparser +import pyaudio +from whisper_jax import FlaxWhisperPipline +from pynput import keyboard +import jax.numpy as jnp +import wave +import datetime +from file_utilities import upload_files +from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz +from text_utilities import summarize, post_process_transcription +from loguru import logger + + +config = configparser.ConfigParser() +config.read('config.ini') + +WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] + +FRAMES_PER_BUFFER = 8000 +FORMAT = pyaudio.paInt16 +CHANNELS = 2 +RATE = 44100 +RECORD_SECONDS = 15 +NOW = datetime.now() + + +def main(): + p = pyaudio.PyAudio() + AUDIO_DEVICE_ID = -1 + for i in range(p.get_device_count()): + if p.get_device_info_by_index(i)["name"] == config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: + AUDIO_DEVICE_ID = i + audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) + stream = p.open( + format=FORMAT, + channels=CHANNELS, + rate=RATE, + input=True, + frames_per_buffer=FRAMES_PER_BUFFER, + input_device_index=audio_devices['index'] + ) + + pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], + dtype=jnp.float16, + batch_size=16) + + transcription = "" + + TEMP_AUDIO_FILE = "temp_audio.wav" + global proceed + proceed = True + + def on_press(key): + if key == keyboard.Key.esc: + global proceed + proceed = False + + transcript_with_timestamp = {"text": "", "chunks": []} + last_transcribed_time = 0.0 + + listener = keyboard.Listener(on_press=on_press) + listener.start() + print("Attempting real-time transcription.. Listening...") + + try: + while proceed: + frames = [] + for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)): + data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False) + frames.append(data) + + wf = wave.open(TEMP_AUDIO_FILE, 'wb') + wf.setnchannels(CHANNELS) + wf.setsampwidth(p.get_sample_size(FORMAT)) + wf.setframerate(RATE) + wf.writeframes(b''.join(frames)) + wf.close() + + whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True) + print(whisper_result['text']) + + timestamp = whisper_result["chunks"][0]["timestamp"] + start = timestamp[0] + end = timestamp[1] + if end is None: + end = start + 15.0 + duration = end - start + item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration), + 'text': whisper_result['text']} + last_transcribed_time = last_transcribed_time + duration + transcript_with_timestamp["chunks"].append(item) + + transcription += whisper_result['text'] + + except Exception as e: + print(e) + finally: + with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: + f.write(transcription) + with open("real_time_transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: + transcript_with_timestamp["text"] = transcription + f.write(str(transcript_with_timestamp)) + + transcript_with_timestamp = post_process_transcription(transcript_with_timestamp) + + transcript_text = "" + for chunk in transcript_with_timestamp["chunks"]: + transcript_text += chunk["text"] + + logger.info("Creating word cloud") + create_wordcloud(NOW, True) + + logger.info("Performing talk-diff and talk-diff visualization") + create_talk_diff_scatter_viz(NOW, True) + + # S3 : Push artefacts to S3 bucket + suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S") + files_to_upload = ["real_time_transcript_" + suffix + ".txt", + "real_time_transcript_with_timestamp" + suffix + ".txt", + "real_time_df_" + suffix + ".pkl", + "real_time_wordcloud_" + suffix + ".png", + "real_time_mappings_" + suffix + ".pkl"] + upload_files(files_to_upload) + + summarize(transcript_text, NOW, True, True) + + logger.info("Summarization completed") + + # Summarization takes a lot of time, so do this separately at the end + files_to_upload = ["real_time_summary_" + suffix + ".txt"] + upload_files(files_to_upload) + + +if __name__ == "__main__": + main() diff --git a/whisjax_realtime_trial.py b/whisjax_realtime_trial.py deleted file mode 100644 index 3be1edad..00000000 --- a/whisjax_realtime_trial.py +++ /dev/null @@ -1,84 +0,0 @@ -#!/usr/bin/env python3 - -import configparser -import pyaudio -from whisper_jax import FlaxWhisperPipline -from pynput import keyboard -import jax.numpy as jnp -import wave - -config = configparser.ConfigParser() -config.read('config.ini') - -WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] - -FRAMES_PER_BUFFER = 8000 -FORMAT = pyaudio.paInt16 -CHANNELS = 2 -RATE = 44100 -RECORD_SECONDS = 15 - - -def main(): - p = pyaudio.PyAudio() - AUDIO_DEVICE_ID = -1 - for i in range(p.get_device_count()): - if p.get_device_info_by_index(i)["name"] == "ref-agg-input": - AUDIO_DEVICE_ID = i - audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) - stream = p.open( - format=FORMAT, - channels=CHANNELS, - rate=RATE, - input=True, - frames_per_buffer=FRAMES_PER_BUFFER, - input_device_index=audio_devices['index'] - ) - - pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], - dtype=jnp.float16, - batch_size=16) - - transcript_file = open("transcript.txt", "w+") - transcription = "" - - TEMP_AUDIO_FILE = "temp_audio.wav" - global proceed - proceed = True - - def on_press(key): - if key == keyboard.Key.esc: - global proceed - proceed = False - - listener = keyboard.Listener(on_press=on_press) - listener.start() - print("Attempting real-time transcription.. Listening...") - while proceed: - try: - frames = [] - for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)): - data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False) - frames.append(data) - - wf = wave.open(TEMP_AUDIO_FILE, 'wb') - wf.setnchannels(CHANNELS) - wf.setsampwidth(p.get_sample_size(FORMAT)) - wf.setframerate(RATE) - wf.writeframes(b''.join(frames)) - wf.close() - - whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True) - print(whisper_result['text']) - - transcription += whisper_result['text'] - - except Exception as e: - print(e) - finally: - with open("real_time_transcription.txt", "w") as f: - transcript_file.write(transcription) - - -if __name__ == "__main__": - main()