Merge pull request #19 from Monadical-SAS/whisper-jax-gokul

Added new features and refactored codebase to split logic into standalone components
This commit is contained in:
projects-g
2023-06-21 19:14:16 +05:30
committed by GitHub
8 changed files with 530 additions and 427 deletions

View File

@@ -32,11 +32,11 @@ To setup,
5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ" --transcript transcript.txt summary.txt ```
``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ"```
You can even run it on local file or a file in your configured S3 bucket.
``` python3 whisjax.py "startup.mp4" --transcript transcript.txt summary.txt ```
``` python3 whisjax.py "startup.mp4"```
The script will take care of a few cases like youtube file, local file, video file, audio-only file,
file in S3, etc. If local file is not present, it can automatically take the file from S3.
@@ -85,7 +85,7 @@ mentioned above or simply use the GUI of AWS Management Console.
1) ```agenda_topic : <short description>```
3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the
topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact called
```df.pkl``` , you can load the df and choose different topics to plot. You can filter using certain words to search for the
```df_<timestamp>.pkl``` , you can load the df and choose different topics to plot. You can filter using certain words to search for the
transcriptions and you can see the top influencers and characteristic in each topic we have chosen to plot in the
interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with, named
```Viz_experiments.ipynb```.
@@ -123,24 +123,32 @@ microphone input which you will be using for speaking. We use [Blackhole](https:
2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and
local microphone input.
Be sure to mirror the settings given (including the name) ![here](./images/aggregate_input.png)
Be sure to mirror the settings given ![here](./images/aggregate_input.png)
3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device)
Refer ![here](./images/multi-output.png)
4) Set the aggregator input device name created in step 2 in config.ini as ```BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME```
Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and
5) Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and
Input tabs.
From the reflector root folder,
run ```python3 whisjax_realtime_trial.py```
6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen to
and the output should be fed back to your specified output devices if everything is configured properly. Check this
before trying out the trial.
**Permissions:**
You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in
```System Preferences -> Privacy & Security -> Microphone``` and in
```System Preferences -> Privacy & Security -> Accessibility```.
```System Preferences -> Privacy & Security -> Microphone```,
```System Preferences -> Privacy & Security -> Accessibility```,
```System Preferences -> Privacy & Security -> Input Monitoring```.
From the reflector root folder,
run ```python3 whisjax_realtime.py```
The transcription text should be written to ```real_time_transcription_<timestamp>.txt```.
NEXT STEPS:

View File

@@ -16,4 +16,6 @@ INPUT_ENCODING_MAX_LENGTH=1024
MAX_LENGTH=2048
BEAM_SIZE=6
MAX_CHUNK_LENGTH=1024
SUMMARIZE_USING_CHUNKS=YES
SUMMARIZE_USING_CHUNKS=YES
# Audio device
BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME=ref-agg-input

160
text_utilities.py Normal file
View File

@@ -0,0 +1,160 @@
import torch
import configparser
import nltk
from transformers import BartTokenizer, BartForConditionalGeneration
from loguru import logger
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.tokenize import word_tokenize
from sklearn.metrics.pairwise import cosine_similarity
config = configparser.ConfigParser()
config.read('config.ini')
def preprocess_sentence(sentence):
stop_words = set(stopwords.words('english'))
tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
return ' '.join(tokens)
def compute_similarity(sent1, sent2):
tfidf_vectorizer = TfidfVectorizer()
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
def remove_almost_alike_sentences(sentences, threshold=0.7):
num_sentences = len(sentences)
removed_indices = set()
for i in range(num_sentences):
if i not in removed_indices:
for j in range(i + 1, num_sentences):
if j not in removed_indices:
sentence1 = preprocess_sentence(sentences[i])
sentence2 = preprocess_sentence(sentences[j])
similarity = compute_similarity(sentence1, sentence2)
if similarity >= threshold:
removed_indices.add(max(i, j))
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
return filtered_sentences
def remove_outright_duplicate_sentences_from_chunk(chunk):
chunk_text = chunk["text"]
sentences = nltk.sent_tokenize(chunk_text)
nonduplicate_sentences = list(dict.fromkeys(sentences))
return nonduplicate_sentences
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
chunk_sentences = []
for sent in nonduplicate_sentences:
temp_result = ""
seen = {}
words = nltk.word_tokenize(sent)
n_gram_filter = 3
for i in range(len(words)):
if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[
i + 1:i + n_gram_filter + 2]:
pass
else:
seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2]
temp_result += words[i]
temp_result += " "
chunk_sentences.append(temp_result)
return chunk_sentences
def post_process_transcription(whisper_result):
for chunk in whisper_result["chunks"]:
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences)
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
chunk["text"] = " ".join(similarity_matched_sentences)
return whisper_result
def summarize(transcript_text, timestamp,
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
if not summary_model:
summary_model = "facebook/bart-large-cnn"
# Summarize the generated transcript using the BART model
logger.info(f"Loading BART model: {summary_model}")
tokenizer = BartTokenizer.from_pretrained(summary_model)
model = BartForConditionalGeneration.from_pretrained(summary_model)
model = model.to(device)
output_filename = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
if real_time:
output_filename = "real_time_" + output_filename
if summarize_using_chunks != "YES":
inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest',
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
return_tensors='pt')
inputs = inputs.to(device)
with torch.no_grad():
summaries = model.generate(inputs['input_ids'],
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for
summary in summaries]
summary = " ".join(decoded_summaries)
with open(output_filename, 'w') as f:
f.write(summary.strip() + "\n")
else:
logger.info("Breaking transcript into smaller chunks")
chunks = chunk_text(transcript_text)
logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
logger.info(f"Writing summary text to: {output_filename}")
with open(output_filename, 'w') as f:
summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries:
f.write(summary.strip() + " ")
def summarize_chunks(chunks, tokenizer, model):
"""
Summarize each chunk using a summarizer model
:param chunks:
:param tokenizer:
:param model:
:return:
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summaries = []
for c in chunks:
input_ids = tokenizer.encode(c, return_tensors='pt')
input_ids = input_ids.to(device)
with torch.no_grad():
summary_ids = model.generate(input_ids,
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(summary)
return summaries
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
"""
Split text into smaller chunks.
:param txt: Text to be chunked
:param max_chunk_length: length of chunk
:return: chunked texts
"""
sentences = nltk.sent_tokenize(text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < max_chunk_length:
current_chunk += f" {sentence.strip()}"
else:
chunks.append(current_chunk.strip())
current_chunk = f"{sentence.strip()}"
chunks.append(current_chunk.strip())
return chunks

190
viz_utilities.py Normal file
View File

@@ -0,0 +1,190 @@
import matplotlib.pyplot as plt
from wordcloud import WordCloud, STOPWORDS
import collections
import spacy
import pickle
import ast
import pandas as pd
import scattertext as st
import configparser
config = configparser.ConfigParser()
config.read('config.ini')
def create_wordcloud(timestamp, real_time=False):
"""
Create a basic word cloud visualization of transcribed text
:return: None. The wordcloud image is saved locally
"""
filename = "transcript"
if real_time:
filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
else:
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
with open(filename, "r") as f:
transcription_text = f.read()
stopwords = set(STOPWORDS)
# python_mask = np.array(PIL.Image.open("download1.png"))
wordcloud = WordCloud(height=800, width=800,
background_color='white',
stopwords=stopwords,
min_font_size=8).generate(transcription_text)
# Plot wordcloud and save image
plt.figure(facecolor=None)
plt.imshow(wordcloud, interpolation="bilinear")
plt.axis("off")
plt.tight_layout(pad=0)
wordcloud_name = "wordcloud"
if real_time:
wordcloud_name = "real_time_" + wordcloud_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
else:
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
plt.savefig(wordcloud_name)
def create_talk_diff_scatter_viz(timestamp, real_time=False):
"""
Perform agenda vs transription diff to see covered topics.
Create a scatter plot of words in topics.
:return: None. Saved locally.
"""
spaCy_model = "en_core_web_md"
nlp = spacy.load(spaCy_model)
nlp.add_pipe('sentencizer')
agenda_topics = []
agenda = []
# Load the agenda
with open("agenda-headers.txt", "r") as f:
for line in f.readlines():
if line.strip():
agenda.append(line.strip())
agenda_topics.append(line.split(":")[0])
# Load the transcription with timestamp
with open("transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt") as f:
transcription_timestamp_text = f.read()
res = ast.literal_eval(transcription_timestamp_text)
chunks = res["chunks"]
# create df for processing
df = pd.DataFrame.from_dict(res["chunks"])
covered_items = {}
# ts: timestamp
# Map each timestamped chunk with top1 and top2 matched agenda
ts_to_topic_mapping_top_1 = {}
ts_to_topic_mapping_top_2 = {}
# Also create a mapping of the different timestamps in which each topic was covered
topic_to_ts_mapping_top_1 = collections.defaultdict(list)
topic_to_ts_mapping_top_2 = collections.defaultdict(list)
similarity_threshold = 0.7
for c in chunks:
doc_transcription = nlp(c["text"])
topic_similarities = []
for item in range(len(agenda)):
item_doc = nlp(agenda[item])
# if not doc_transcription or not all(token.has_vector for token in doc_transcription):
if not doc_transcription:
continue
similarity = doc_transcription.similarity(item_doc)
topic_similarities.append((item, similarity))
topic_similarities.sort(key=lambda x: x[1], reverse=True)
for i in range(2):
if topic_similarities[i][1] >= similarity_threshold:
covered_items[agenda[topic_similarities[i][0]]] = True
# top1 match
if i == 0:
ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[topic_similarities[i][0]]
topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
# top2 match
else:
ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[topic_similarities[i][0]]
topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
def create_new_columns(record):
"""
Accumulate the mapping information into the df
:param record:
:return:
"""
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]]
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]]
return record
df = df.apply(create_new_columns, axis=1)
# Count the number of items covered and calculatre the percentage
num_covered_items = sum(covered_items.values())
percentage_covered = num_covered_items / len(agenda) * 100
# Print the results
print("💬 Agenda items covered in the transcription:")
for item in agenda:
if item in covered_items and covered_items[item]:
print("", item)
else:
print("", item)
print("📊 Coverage: {:.2f}%".format(percentage_covered))
# Save df, mappings for further experimentation
df_name = "df"
if real_time:
df_name = "real_time_" + df_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
else:
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
df.to_pickle(df_name)
my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2,
topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2]
mappings_name = "mappings"
if real_time:
mappings_name = "real_time_" + mappings_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
else:
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
pickle.dump(my_mappings, open(mappings_name, "wb"))
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
# pick the 2 most matched topic to be used for plotting
topic_times = collections.defaultdict(int)
for key in ts_to_topic_mapping_top_1.keys():
if key[0] is None or key[1] is None:
continue
duration = key[1] - key[0]
topic_times[ts_to_topic_mapping_top_1[key]] += duration
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
cat_1 = topic_times[0][0]
cat_1_name = topic_times[0][0]
cat_2_name = topic_times[1][0]
# Scatter plot of topics
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
corpus = st.CorpusFromParsedDocuments(
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
html = st.produce_scattertext_explorer(
corpus,
category=cat_1,
category_name=cat_1_name,
not_category_name=cat_2_name,
minimum_term_frequency=0, pmi_threshold_coefficient=0,
width_in_pixels=1000,
transform=st.Scalers.dense_rank
)
open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)

View File

@@ -5,35 +5,26 @@
# summarize podcast.mp3 summary.txt
import argparse
import ast
import torch
import collections
import configparser
import jax.numpy as jnp
import matplotlib.pyplot as plt
import moviepy.editor
import moviepy.editor
import nltk
import os
import subprocess
import pandas as pd
import pickle
import re
import scattertext as st
import spacy
import tempfile
from loguru import logger
from pytube import YouTube
from transformers import BartTokenizer, BartForConditionalGeneration
from urllib.parse import urlparse
from whisper_jax import FlaxWhisperPipline
from wordcloud import WordCloud, STOPWORDS
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.tokenize import word_tokenize
from sklearn.metrics.pairwise import cosine_similarity
from file_util import upload_files, download_files
from datetime import datetime
from file_utilities import upload_files, download_files
from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
from text_utilities import summarize, post_process_transcription
nltk.download('punkt')
nltk.download('stopwords')
@@ -43,7 +34,7 @@ config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
NOW = datetime.now()
def init_argparse() -> argparse.ArgumentParser:
"""
@@ -57,310 +48,10 @@ def init_argparse() -> argparse.ArgumentParser:
parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str,
default="english", choices=['english', 'spanish', 'french', 'german', 'romanian'])
parser.add_argument("-t", "--transcript", help="Save a copy of the intermediary transcript file", type=str)
parser.add_argument("location")
parser.add_argument("output")
return parser
def chunk_text(txt, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
"""
Split text into smaller chunks.
:param txt: Text to be chunked
:param max_chunk_length: length of chunk
:return: chunked texts
"""
sentences = nltk.sent_tokenize(txt)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < max_chunk_length:
current_chunk += f" {sentence.strip()}"
else:
chunks.append(current_chunk.strip())
current_chunk = f"{sentence.strip()}"
chunks.append(current_chunk.strip())
return chunks
def summarize_chunks(chunks, tokenizer, model):
"""
Summarize each chunk using a summarizer model
:param chunks:
:param tokenizer:
:param model:
:return:
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summaries = []
for c in chunks:
input_ids = tokenizer.encode(c, return_tensors='pt')
input_ids = input_ids.to(device)
with torch.no_grad():
summary_ids = model.generate(input_ids,
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(summary)
return summaries
def create_wordcloud():
"""
Create a basic word cloud visualization of transcribed text
:return: None. The wordcloud image is saved locally
"""
with open("transcript.txt", "r") as f:
transcription_text = f.read()
stopwords = set(STOPWORDS)
# python_mask = np.array(PIL.Image.open("download1.png"))
wordcloud = WordCloud(height=800, width=800,
background_color='white',
stopwords=stopwords,
min_font_size=8).generate(transcription_text)
# Plot wordcloud and save image
plt.figure(facecolor=None)
plt.imshow(wordcloud, interpolation="bilinear")
plt.axis("off")
plt.tight_layout(pad=0)
plt.savefig("wordcloud.png")
def create_talk_diff_scatter_viz():
"""
Perform agenda vs transription diff to see covered topics.
Create a scatter plot of words in topics.
:return: None. Saved locally.
"""
spaCy_model = "en_core_web_md"
nlp = spacy.load(spaCy_model)
nlp.add_pipe('sentencizer')
agenda_topics = []
agenda = []
# Load the agenda
with open("agenda-headers.txt", "r") as f:
for line in f.readlines():
if line.strip():
agenda.append(line.strip())
agenda_topics.append(line.split(":")[0])
# Load the transcription with timestamp
with open("transcript_timestamps.txt", "r") as f:
transcription_timestamp_text = f.read()
res = ast.literal_eval(transcription_timestamp_text)
chunks = res["chunks"]
# create df for processing
df = pd.DataFrame.from_dict(res["chunks"])
covered_items = {}
# ts: timestamp
# Map each timestamped chunk with top1 and top2 matched agenda
ts_to_topic_mapping_top_1 = {}
ts_to_topic_mapping_top_2 = {}
# Also create a mapping of the different timestamps in which each topic was covered
topic_to_ts_mapping_top_1 = collections.defaultdict(list)
topic_to_ts_mapping_top_2 = collections.defaultdict(list)
similarity_threshold = 0.7
for c in chunks:
doc_transcription = nlp(c["text"])
topic_similarities = []
for item in range(len(agenda)):
item_doc = nlp(agenda[item])
# if not doc_transcription or not all(token.has_vector for token in doc_transcription):
if not doc_transcription:
continue
similarity = doc_transcription.similarity(item_doc)
topic_similarities.append((item, similarity))
topic_similarities.sort(key=lambda x: x[1], reverse=True)
for i in range(2):
if topic_similarities[i][1] >= similarity_threshold:
covered_items[agenda[topic_similarities[i][0]]] = True
# top1 match
if i == 0:
ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[topic_similarities[i][0]]
topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
# top2 match
else:
ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[topic_similarities[i][0]]
topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
def create_new_columns(record):
"""
Accumulate the mapping information into the df
:param record:
:return:
"""
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]]
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]]
return record
df = df.apply(create_new_columns, axis=1)
# Count the number of items covered and calculatre the percentage
num_covered_items = sum(covered_items.values())
percentage_covered = num_covered_items / len(agenda) * 100
# Print the results
print("💬 Agenda items covered in the transcription:")
for item in agenda:
if item in covered_items and covered_items[item]:
print("", item)
else:
print("", item)
print("📊 Coverage: {:.2f}%".format(percentage_covered))
# Save df, mappings for further experimentation
df.to_pickle("df.pkl")
my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2,
topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2]
pickle.dump(my_mappings, open("mappings.pkl", "wb"))
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
# pick the 2 most matched topic to be used for plotting
topic_times = collections.defaultdict(int)
for key in ts_to_topic_mapping_top_1.keys():
if key[0] is None or key[1] is None:
continue
duration = key[1] - key[0]
topic_times[ts_to_topic_mapping_top_1[key]] += duration
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
cat_1 = topic_times[0][0]
cat_1_name = topic_times[0][0]
cat_2_name = topic_times[1][0]
# Scatter plot of topics
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
corpus = st.CorpusFromParsedDocuments(
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
html = st.produce_scattertext_explorer(
corpus,
category=cat_1,
category_name=cat_1_name,
not_category_name=cat_2_name,
minimum_term_frequency=0, pmi_threshold_coefficient=0,
width_in_pixels=1000,
transform=st.Scalers.dense_rank
)
open('./demo_compact.html', 'w').write(html)
def preprocess_sentence(sentence):
stop_words = set(stopwords.words('english'))
tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
return ' '.join(tokens)
def compute_similarity(sent1, sent2):
tfidf_vectorizer = TfidfVectorizer()
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
def remove_almost_alike_sentences(sentences, threshold=0.7):
num_sentences = len(sentences)
removed_indices = set()
for i in range(num_sentences):
if i not in removed_indices:
for j in range(i + 1, num_sentences):
if j not in removed_indices:
sentence1 = preprocess_sentence(sentences[i])
sentence2 = preprocess_sentence(sentences[j])
similarity = compute_similarity(sentence1, sentence2)
if similarity >= threshold:
removed_indices.add(max(i, j))
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
return filtered_sentences
def remove_outright_duplicate_sentences_from_chunk(chunk):
chunk_text = chunk["text"]
sentences = nltk.sent_tokenize(chunk_text)
nonduplicate_sentences = list(dict.fromkeys(sentences))
return nonduplicate_sentences
def remove_whisper_repititive_hallucination(nonduplicate_sentences):
chunk_sentences = []
for sent in nonduplicate_sentences:
temp_result = ""
seen = {}
words = nltk.word_tokenize(sent)
n_gram_filter = 3
for i in range(len(words)):
if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[
i + 1:i + n_gram_filter + 2]:
pass
else:
seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2]
temp_result += words[i]
temp_result += " "
chunk_sentences.append(temp_result)
return chunk_sentences
def remove_duplicates_from_transcript_chunk(whisper_result):
for chunk in whisper_result["chunks"]:
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
chunk_sentences = remove_whisper_repititive_hallucination(nonduplicate_sentences)
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
chunk["text"] = " ".join(similarity_matched_sentences)
return whisper_result
def summarize(transcript_text, output_file,
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
if not summary_model:
summary_model = "facebook/bart-large-cnn"
# Summarize the generated transcript using the BART model
logger.info(f"Loading BART model: {summary_model}")
tokenizer = BartTokenizer.from_pretrained(summary_model)
model = BartForConditionalGeneration.from_pretrained(summary_model)
model = model.to(device)
if summarize_using_chunks != "YES":
inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest',
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
return_tensors='pt')
inputs = inputs.to(device)
with torch.no_grad():
summaries = model.generate(inputs['input_ids'],
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for
summary in summaries]
summary = " ".join(decoded_summaries)
with open(output_file, 'w') as f:
f.write(summary.strip() + "\n\n")
else:
logger.info("Breaking transcript into smaller chunks")
chunks = chunk_text(transcript_text)
logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
logger.info(f"Writing summary text to: {output_file}")
with open(output_file, 'w') as f:
summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries:
f.write(summary.strip() + " ")
def main():
parser = init_argparse()
@@ -425,41 +116,40 @@ def main():
whisper_result = pipeline(audio_filename, return_timestamps=True)
logger.info("Finished transcribing file")
whisper_result = remove_duplicates_from_transcript_chunk(whisper_result)
whisper_result = post_process_transcription(whisper_result)
transcript_text = ""
for chunk in whisper_result["chunks"]:
transcript_text += chunk["text"]
# If we got the transcript parameter on the command line,
# save the transcript to the specified file.
if args.transcript:
logger.info(f"Saving transcript to: {args.transcript}")
transcript_file = open(args.transcript, "w")
transcript_file_timestamps = open(args.transcript[0:len(args.transcript) - 4] + "_timestamps.txt", "w")
with open("transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
transcript_file.write(transcript_text)
with open("transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps:
transcript_file_timestamps.write(str(whisper_result))
transcript_file.close()
transcript_file_timestamps.close()
logger.info("Creating word cloud")
create_wordcloud()
create_wordcloud(NOW)
logger.info("Performing talk-diff and talk-diff visualization")
create_talk_diff_scatter_viz()
create_talk_diff_scatter_viz(NOW)
# S3 : Push artefacts to S3 bucket
files_to_upload = ["transcript.txt", "transcript_timestamps.txt",
"df.pkl",
"wordcloud.png", "mappings.pkl"]
suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S")
files_to_upload = ["transcript_" + suffix + ".txt",
"transcript_with_timestamp_" + suffix + ".txt",
"df_" + suffix + ".pkl",
"wordcloud_" + suffix + ".png",
"mappings_" + suffix + ".pkl"]
upload_files(files_to_upload)
summarize(transcript_text, args.output)
summarize(transcript_text, NOW, False, False)
logger.info("Summarization completed")
# Summarization takes a lot of time, so do this separately at the end
files_to_upload = ["summary.txt"]
files_to_upload = ["summary_" + suffix + ".txt"]
upload_files(files_to_upload)

137
whisjax_realtime.py Normal file
View File

@@ -0,0 +1,137 @@
#!/usr/bin/env python3
import configparser
import pyaudio
from whisper_jax import FlaxWhisperPipline
from pynput import keyboard
import jax.numpy as jnp
import wave
import datetime
from file_utilities import upload_files
from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
from text_utilities import summarize, post_process_transcription
from loguru import logger
config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
FRAMES_PER_BUFFER = 8000
FORMAT = pyaudio.paInt16
CHANNELS = 2
RATE = 44100
RECORD_SECONDS = 15
NOW = datetime.now()
def main():
p = pyaudio.PyAudio()
AUDIO_DEVICE_ID = -1
for i in range(p.get_device_count()):
if p.get_device_info_by_index(i)["name"] == config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]:
AUDIO_DEVICE_ID = i
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=FRAMES_PER_BUFFER,
input_device_index=audio_devices['index']
)
pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"],
dtype=jnp.float16,
batch_size=16)
transcription = ""
TEMP_AUDIO_FILE = "temp_audio.wav"
global proceed
proceed = True
def on_press(key):
if key == keyboard.Key.esc:
global proceed
proceed = False
transcript_with_timestamp = {"text": "", "chunks": []}
last_transcribed_time = 0.0
listener = keyboard.Listener(on_press=on_press)
listener.start()
print("Attempting real-time transcription.. Listening...")
try:
while proceed:
frames = []
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
frames.append(data)
wf = wave.open(TEMP_AUDIO_FILE, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(p.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True)
print(whisper_result['text'])
timestamp = whisper_result["chunks"][0]["timestamp"]
start = timestamp[0]
end = timestamp[1]
if end is None:
end = start + 15.0
duration = end - start
item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration),
'text': whisper_result['text']}
last_transcribed_time = last_transcribed_time + duration
transcript_with_timestamp["chunks"].append(item)
transcription += whisper_result['text']
except Exception as e:
print(e)
finally:
with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
f.write(transcription)
with open("real_time_transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
transcript_with_timestamp["text"] = transcription
f.write(str(transcript_with_timestamp))
transcript_with_timestamp = post_process_transcription(transcript_with_timestamp)
transcript_text = ""
for chunk in transcript_with_timestamp["chunks"]:
transcript_text += chunk["text"]
logger.info("Creating word cloud")
create_wordcloud(NOW, True)
logger.info("Performing talk-diff and talk-diff visualization")
create_talk_diff_scatter_viz(NOW, True)
# S3 : Push artefacts to S3 bucket
suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S")
files_to_upload = ["real_time_transcript_" + suffix + ".txt",
"real_time_transcript_with_timestamp" + suffix + ".txt",
"real_time_df_" + suffix + ".pkl",
"real_time_wordcloud_" + suffix + ".png",
"real_time_mappings_" + suffix + ".pkl"]
upload_files(files_to_upload)
summarize(transcript_text, NOW, True, True)
logger.info("Summarization completed")
# Summarization takes a lot of time, so do this separately at the end
files_to_upload = ["real_time_summary_" + suffix + ".txt"]
upload_files(files_to_upload)
if __name__ == "__main__":
main()

View File

@@ -1,84 +0,0 @@
#!/usr/bin/env python3
import configparser
import pyaudio
from whisper_jax import FlaxWhisperPipline
from pynput import keyboard
import jax.numpy as jnp
import wave
config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
FRAMES_PER_BUFFER = 8000
FORMAT = pyaudio.paInt16
CHANNELS = 2
RATE = 44100
RECORD_SECONDS = 15
def main():
p = pyaudio.PyAudio()
AUDIO_DEVICE_ID = -1
for i in range(p.get_device_count()):
if p.get_device_info_by_index(i)["name"] == "ref-agg-input":
AUDIO_DEVICE_ID = i
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=FRAMES_PER_BUFFER,
input_device_index=audio_devices['index']
)
pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"],
dtype=jnp.float16,
batch_size=16)
transcript_file = open("transcript.txt", "w+")
transcription = ""
TEMP_AUDIO_FILE = "temp_audio.wav"
global proceed
proceed = True
def on_press(key):
if key == keyboard.Key.esc:
global proceed
proceed = False
listener = keyboard.Listener(on_press=on_press)
listener.start()
print("Attempting real-time transcription.. Listening...")
while proceed:
try:
frames = []
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
frames.append(data)
wf = wave.open(TEMP_AUDIO_FILE, 'wb')
wf.setnchannels(CHANNELS)
wf.setsampwidth(p.get_sample_size(FORMAT))
wf.setframerate(RATE)
wf.writeframes(b''.join(frames))
wf.close()
whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True)
print(whisper_result['text'])
transcription += whisper_result['text']
except Exception as e:
print(e)
finally:
with open("real_time_transcription.txt", "w") as f:
transcript_file.write(transcription)
if __name__ == "__main__":
main()