mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
Merge pull request #19 from Monadical-SAS/whisper-jax-gokul
Added new features and refactored codebase to split logic into standalone components
This commit is contained in:
28
README.md
28
README.md
@@ -32,11 +32,11 @@ To setup,
|
|||||||
|
|
||||||
5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
|
5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
|
||||||
|
|
||||||
``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ" --transcript transcript.txt summary.txt ```
|
``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ"```
|
||||||
|
|
||||||
You can even run it on local file or a file in your configured S3 bucket.
|
You can even run it on local file or a file in your configured S3 bucket.
|
||||||
|
|
||||||
``` python3 whisjax.py "startup.mp4" --transcript transcript.txt summary.txt ```
|
``` python3 whisjax.py "startup.mp4"```
|
||||||
|
|
||||||
The script will take care of a few cases like youtube file, local file, video file, audio-only file,
|
The script will take care of a few cases like youtube file, local file, video file, audio-only file,
|
||||||
file in S3, etc. If local file is not present, it can automatically take the file from S3.
|
file in S3, etc. If local file is not present, it can automatically take the file from S3.
|
||||||
@@ -85,7 +85,7 @@ mentioned above or simply use the GUI of AWS Management Console.
|
|||||||
1) ```agenda_topic : <short description>```
|
1) ```agenda_topic : <short description>```
|
||||||
3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the
|
3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the
|
||||||
topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact called
|
topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact called
|
||||||
```df.pkl``` , you can load the df and choose different topics to plot. You can filter using certain words to search for the
|
```df_<timestamp>.pkl``` , you can load the df and choose different topics to plot. You can filter using certain words to search for the
|
||||||
transcriptions and you can see the top influencers and characteristic in each topic we have chosen to plot in the
|
transcriptions and you can see the top influencers and characteristic in each topic we have chosen to plot in the
|
||||||
interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with, named
|
interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with, named
|
||||||
```Viz_experiments.ipynb```.
|
```Viz_experiments.ipynb```.
|
||||||
@@ -123,24 +123,32 @@ microphone input which you will be using for speaking. We use [Blackhole](https:
|
|||||||
2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and
|
2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and
|
||||||
local microphone input.
|
local microphone input.
|
||||||
|
|
||||||
Be sure to mirror the settings given (including the name) 
|
Be sure to mirror the settings given 
|
||||||
3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device)
|
3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device)
|
||||||
|
|
||||||
Refer 
|
Refer 
|
||||||
|
|
||||||
|
4) Set the aggregator input device name created in step 2 in config.ini as ```BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME```
|
||||||
|
|
||||||
Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and
|
5) Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and
|
||||||
Input tabs.
|
Input tabs.
|
||||||
|
|
||||||
From the reflector root folder,
|
6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen to
|
||||||
|
and the output should be fed back to your specified output devices if everything is configured properly. Check this
|
||||||
run ```python3 whisjax_realtime_trial.py```
|
before trying out the trial.
|
||||||
|
|
||||||
**Permissions:**
|
**Permissions:**
|
||||||
|
|
||||||
You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in
|
You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in
|
||||||
```System Preferences -> Privacy & Security -> Microphone``` and in
|
```System Preferences -> Privacy & Security -> Microphone```,
|
||||||
```System Preferences -> Privacy & Security -> Accessibility```.
|
```System Preferences -> Privacy & Security -> Accessibility```,
|
||||||
|
```System Preferences -> Privacy & Security -> Input Monitoring```.
|
||||||
|
|
||||||
|
From the reflector root folder,
|
||||||
|
|
||||||
|
run ```python3 whisjax_realtime.py```
|
||||||
|
|
||||||
|
The transcription text should be written to ```real_time_transcription_<timestamp>.txt```.
|
||||||
|
|
||||||
|
|
||||||
NEXT STEPS:
|
NEXT STEPS:
|
||||||
|
|||||||
@@ -17,3 +17,5 @@ MAX_LENGTH=2048
|
|||||||
BEAM_SIZE=6
|
BEAM_SIZE=6
|
||||||
MAX_CHUNK_LENGTH=1024
|
MAX_CHUNK_LENGTH=1024
|
||||||
SUMMARIZE_USING_CHUNKS=YES
|
SUMMARIZE_USING_CHUNKS=YES
|
||||||
|
# Audio device
|
||||||
|
BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME=ref-agg-input
|
||||||
160
text_utilities.py
Normal file
160
text_utilities.py
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
import torch
|
||||||
|
import configparser
|
||||||
|
import nltk
|
||||||
|
from transformers import BartTokenizer, BartForConditionalGeneration
|
||||||
|
from loguru import logger
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read('config.ini')
|
||||||
|
|
||||||
|
def preprocess_sentence(sentence):
|
||||||
|
stop_words = set(stopwords.words('english'))
|
||||||
|
tokens = word_tokenize(sentence.lower())
|
||||||
|
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
||||||
|
return ' '.join(tokens)
|
||||||
|
|
||||||
|
def compute_similarity(sent1, sent2):
|
||||||
|
tfidf_vectorizer = TfidfVectorizer()
|
||||||
|
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
|
||||||
|
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
||||||
|
|
||||||
|
def remove_almost_alike_sentences(sentences, threshold=0.7):
|
||||||
|
num_sentences = len(sentences)
|
||||||
|
removed_indices = set()
|
||||||
|
|
||||||
|
for i in range(num_sentences):
|
||||||
|
if i not in removed_indices:
|
||||||
|
for j in range(i + 1, num_sentences):
|
||||||
|
if j not in removed_indices:
|
||||||
|
sentence1 = preprocess_sentence(sentences[i])
|
||||||
|
sentence2 = preprocess_sentence(sentences[j])
|
||||||
|
similarity = compute_similarity(sentence1, sentence2)
|
||||||
|
|
||||||
|
if similarity >= threshold:
|
||||||
|
removed_indices.add(max(i, j))
|
||||||
|
|
||||||
|
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
|
||||||
|
return filtered_sentences
|
||||||
|
|
||||||
|
def remove_outright_duplicate_sentences_from_chunk(chunk):
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
sentences = nltk.sent_tokenize(chunk_text)
|
||||||
|
nonduplicate_sentences = list(dict.fromkeys(sentences))
|
||||||
|
return nonduplicate_sentences
|
||||||
|
|
||||||
|
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
||||||
|
chunk_sentences = []
|
||||||
|
|
||||||
|
for sent in nonduplicate_sentences:
|
||||||
|
temp_result = ""
|
||||||
|
seen = {}
|
||||||
|
words = nltk.word_tokenize(sent)
|
||||||
|
n_gram_filter = 3
|
||||||
|
for i in range(len(words)):
|
||||||
|
if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[
|
||||||
|
i + 1:i + n_gram_filter + 2]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2]
|
||||||
|
temp_result += words[i]
|
||||||
|
temp_result += " "
|
||||||
|
chunk_sentences.append(temp_result)
|
||||||
|
return chunk_sentences
|
||||||
|
|
||||||
|
def post_process_transcription(whisper_result):
|
||||||
|
for chunk in whisper_result["chunks"]:
|
||||||
|
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
|
||||||
|
chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences)
|
||||||
|
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
|
||||||
|
chunk["text"] = " ".join(similarity_matched_sentences)
|
||||||
|
return whisper_result
|
||||||
|
|
||||||
|
def summarize(transcript_text, timestamp,
|
||||||
|
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
|
||||||
|
if not summary_model:
|
||||||
|
summary_model = "facebook/bart-large-cnn"
|
||||||
|
|
||||||
|
# Summarize the generated transcript using the BART model
|
||||||
|
logger.info(f"Loading BART model: {summary_model}")
|
||||||
|
tokenizer = BartTokenizer.from_pretrained(summary_model)
|
||||||
|
model = BartForConditionalGeneration.from_pretrained(summary_model)
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
output_filename = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||||
|
if real_time:
|
||||||
|
output_filename = "real_time_" + output_filename
|
||||||
|
|
||||||
|
if summarize_using_chunks != "YES":
|
||||||
|
inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest',
|
||||||
|
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
|
||||||
|
return_tensors='pt')
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
summaries = model.generate(inputs['input_ids'],
|
||||||
|
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||||
|
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||||
|
|
||||||
|
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for
|
||||||
|
summary in summaries]
|
||||||
|
summary = " ".join(decoded_summaries)
|
||||||
|
with open(output_filename, 'w') as f:
|
||||||
|
f.write(summary.strip() + "\n")
|
||||||
|
else:
|
||||||
|
logger.info("Breaking transcript into smaller chunks")
|
||||||
|
chunks = chunk_text(transcript_text)
|
||||||
|
|
||||||
|
logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
|
||||||
|
|
||||||
|
logger.info(f"Writing summary text to: {output_filename}")
|
||||||
|
with open(output_filename, 'w') as f:
|
||||||
|
summaries = summarize_chunks(chunks, tokenizer, model)
|
||||||
|
for summary in summaries:
|
||||||
|
f.write(summary.strip() + " ")
|
||||||
|
|
||||||
|
def summarize_chunks(chunks, tokenizer, model):
|
||||||
|
"""
|
||||||
|
Summarize each chunk using a summarizer model
|
||||||
|
:param chunks:
|
||||||
|
:param tokenizer:
|
||||||
|
:param model:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
summaries = []
|
||||||
|
for c in chunks:
|
||||||
|
input_ids = tokenizer.encode(c, return_tensors='pt')
|
||||||
|
input_ids = input_ids.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
summary_ids = model.generate(input_ids,
|
||||||
|
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||||
|
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||||
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
||||||
|
summaries.append(summary)
|
||||||
|
return summaries
|
||||||
|
|
||||||
|
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
||||||
|
"""
|
||||||
|
Split text into smaller chunks.
|
||||||
|
:param txt: Text to be chunked
|
||||||
|
:param max_chunk_length: length of chunk
|
||||||
|
:return: chunked texts
|
||||||
|
"""
|
||||||
|
sentences = nltk.sent_tokenize(text)
|
||||||
|
chunks = []
|
||||||
|
current_chunk = ""
|
||||||
|
for sentence in sentences:
|
||||||
|
if len(current_chunk) + len(sentence) < max_chunk_length:
|
||||||
|
current_chunk += f" {sentence.strip()}"
|
||||||
|
else:
|
||||||
|
chunks.append(current_chunk.strip())
|
||||||
|
current_chunk = f"{sentence.strip()}"
|
||||||
|
chunks.append(current_chunk.strip())
|
||||||
|
return chunks
|
||||||
190
viz_utilities.py
Normal file
190
viz_utilities.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from wordcloud import WordCloud, STOPWORDS
|
||||||
|
import collections
|
||||||
|
import spacy
|
||||||
|
import pickle
|
||||||
|
import ast
|
||||||
|
import pandas as pd
|
||||||
|
import scattertext as st
|
||||||
|
import configparser
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read('config.ini')
|
||||||
|
|
||||||
|
|
||||||
|
def create_wordcloud(timestamp, real_time=False):
|
||||||
|
"""
|
||||||
|
Create a basic word cloud visualization of transcribed text
|
||||||
|
:return: None. The wordcloud image is saved locally
|
||||||
|
"""
|
||||||
|
filename = "transcript"
|
||||||
|
if real_time:
|
||||||
|
filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||||
|
else:
|
||||||
|
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||||
|
|
||||||
|
with open(filename, "r") as f:
|
||||||
|
transcription_text = f.read()
|
||||||
|
|
||||||
|
stopwords = set(STOPWORDS)
|
||||||
|
|
||||||
|
# python_mask = np.array(PIL.Image.open("download1.png"))
|
||||||
|
|
||||||
|
wordcloud = WordCloud(height=800, width=800,
|
||||||
|
background_color='white',
|
||||||
|
stopwords=stopwords,
|
||||||
|
min_font_size=8).generate(transcription_text)
|
||||||
|
|
||||||
|
# Plot wordcloud and save image
|
||||||
|
plt.figure(facecolor=None)
|
||||||
|
plt.imshow(wordcloud, interpolation="bilinear")
|
||||||
|
plt.axis("off")
|
||||||
|
plt.tight_layout(pad=0)
|
||||||
|
|
||||||
|
wordcloud_name = "wordcloud"
|
||||||
|
if real_time:
|
||||||
|
wordcloud_name = "real_time_" + wordcloud_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
||||||
|
else:
|
||||||
|
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
||||||
|
|
||||||
|
plt.savefig(wordcloud_name)
|
||||||
|
|
||||||
|
|
||||||
|
def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
||||||
|
"""
|
||||||
|
Perform agenda vs transription diff to see covered topics.
|
||||||
|
Create a scatter plot of words in topics.
|
||||||
|
:return: None. Saved locally.
|
||||||
|
"""
|
||||||
|
spaCy_model = "en_core_web_md"
|
||||||
|
nlp = spacy.load(spaCy_model)
|
||||||
|
nlp.add_pipe('sentencizer')
|
||||||
|
|
||||||
|
agenda_topics = []
|
||||||
|
agenda = []
|
||||||
|
# Load the agenda
|
||||||
|
with open("agenda-headers.txt", "r") as f:
|
||||||
|
for line in f.readlines():
|
||||||
|
if line.strip():
|
||||||
|
agenda.append(line.strip())
|
||||||
|
agenda_topics.append(line.split(":")[0])
|
||||||
|
|
||||||
|
# Load the transcription with timestamp
|
||||||
|
with open("transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt") as f:
|
||||||
|
transcription_timestamp_text = f.read()
|
||||||
|
|
||||||
|
res = ast.literal_eval(transcription_timestamp_text)
|
||||||
|
chunks = res["chunks"]
|
||||||
|
|
||||||
|
# create df for processing
|
||||||
|
df = pd.DataFrame.from_dict(res["chunks"])
|
||||||
|
|
||||||
|
covered_items = {}
|
||||||
|
# ts: timestamp
|
||||||
|
# Map each timestamped chunk with top1 and top2 matched agenda
|
||||||
|
ts_to_topic_mapping_top_1 = {}
|
||||||
|
ts_to_topic_mapping_top_2 = {}
|
||||||
|
|
||||||
|
# Also create a mapping of the different timestamps in which each topic was covered
|
||||||
|
topic_to_ts_mapping_top_1 = collections.defaultdict(list)
|
||||||
|
topic_to_ts_mapping_top_2 = collections.defaultdict(list)
|
||||||
|
|
||||||
|
similarity_threshold = 0.7
|
||||||
|
|
||||||
|
for c in chunks:
|
||||||
|
doc_transcription = nlp(c["text"])
|
||||||
|
topic_similarities = []
|
||||||
|
for item in range(len(agenda)):
|
||||||
|
item_doc = nlp(agenda[item])
|
||||||
|
# if not doc_transcription or not all(token.has_vector for token in doc_transcription):
|
||||||
|
if not doc_transcription:
|
||||||
|
continue
|
||||||
|
similarity = doc_transcription.similarity(item_doc)
|
||||||
|
topic_similarities.append((item, similarity))
|
||||||
|
topic_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
for i in range(2):
|
||||||
|
if topic_similarities[i][1] >= similarity_threshold:
|
||||||
|
covered_items[agenda[topic_similarities[i][0]]] = True
|
||||||
|
# top1 match
|
||||||
|
if i == 0:
|
||||||
|
ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[topic_similarities[i][0]]
|
||||||
|
topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
|
||||||
|
# top2 match
|
||||||
|
else:
|
||||||
|
ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[topic_similarities[i][0]]
|
||||||
|
topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
|
||||||
|
|
||||||
|
def create_new_columns(record):
|
||||||
|
"""
|
||||||
|
Accumulate the mapping information into the df
|
||||||
|
:param record:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]]
|
||||||
|
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]]
|
||||||
|
return record
|
||||||
|
|
||||||
|
df = df.apply(create_new_columns, axis=1)
|
||||||
|
|
||||||
|
# Count the number of items covered and calculatre the percentage
|
||||||
|
num_covered_items = sum(covered_items.values())
|
||||||
|
percentage_covered = num_covered_items / len(agenda) * 100
|
||||||
|
|
||||||
|
# Print the results
|
||||||
|
print("💬 Agenda items covered in the transcription:")
|
||||||
|
for item in agenda:
|
||||||
|
if item in covered_items and covered_items[item]:
|
||||||
|
print("✅ ", item)
|
||||||
|
else:
|
||||||
|
print("❌ ", item)
|
||||||
|
print("📊 Coverage: {:.2f}%".format(percentage_covered))
|
||||||
|
|
||||||
|
# Save df, mappings for further experimentation
|
||||||
|
df_name = "df"
|
||||||
|
if real_time:
|
||||||
|
df_name = "real_time_" + df_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||||
|
else:
|
||||||
|
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||||
|
df.to_pickle(df_name)
|
||||||
|
|
||||||
|
my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2,
|
||||||
|
topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2]
|
||||||
|
|
||||||
|
mappings_name = "mappings"
|
||||||
|
if real_time:
|
||||||
|
mappings_name = "real_time_" + mappings_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||||
|
else:
|
||||||
|
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||||
|
pickle.dump(my_mappings, open(mappings_name, "wb"))
|
||||||
|
|
||||||
|
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
|
||||||
|
|
||||||
|
# pick the 2 most matched topic to be used for plotting
|
||||||
|
topic_times = collections.defaultdict(int)
|
||||||
|
for key in ts_to_topic_mapping_top_1.keys():
|
||||||
|
if key[0] is None or key[1] is None:
|
||||||
|
continue
|
||||||
|
duration = key[1] - key[0]
|
||||||
|
topic_times[ts_to_topic_mapping_top_1[key]] += duration
|
||||||
|
|
||||||
|
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
cat_1 = topic_times[0][0]
|
||||||
|
cat_1_name = topic_times[0][0]
|
||||||
|
cat_2_name = topic_times[1][0]
|
||||||
|
|
||||||
|
# Scatter plot of topics
|
||||||
|
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
|
||||||
|
corpus = st.CorpusFromParsedDocuments(
|
||||||
|
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
|
||||||
|
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
|
||||||
|
html = st.produce_scattertext_explorer(
|
||||||
|
corpus,
|
||||||
|
category=cat_1,
|
||||||
|
category_name=cat_1_name,
|
||||||
|
not_category_name=cat_2_name,
|
||||||
|
minimum_term_frequency=0, pmi_threshold_coefficient=0,
|
||||||
|
width_in_pixels=1000,
|
||||||
|
transform=st.Scalers.dense_rank
|
||||||
|
)
|
||||||
|
open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||||
354
whisjax.py
354
whisjax.py
@@ -5,35 +5,26 @@
|
|||||||
# summarize podcast.mp3 summary.txt
|
# summarize podcast.mp3 summary.txt
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
|
||||||
import torch
|
|
||||||
import collections
|
|
||||||
import configparser
|
import configparser
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import moviepy.editor
|
import moviepy.editor
|
||||||
import moviepy.editor
|
import moviepy.editor
|
||||||
import nltk
|
import nltk
|
||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import pandas as pd
|
|
||||||
import pickle
|
|
||||||
import re
|
import re
|
||||||
import scattertext as st
|
|
||||||
import spacy
|
|
||||||
import tempfile
|
import tempfile
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pytube import YouTube
|
from pytube import YouTube
|
||||||
from transformers import BartTokenizer, BartForConditionalGeneration
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from whisper_jax import FlaxWhisperPipline
|
from whisper_jax import FlaxWhisperPipline
|
||||||
from wordcloud import WordCloud, STOPWORDS
|
|
||||||
from nltk.corpus import stopwords
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
from nltk.tokenize import word_tokenize
|
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
|
||||||
|
|
||||||
from file_util import upload_files, download_files
|
from datetime import datetime
|
||||||
|
from file_utilities import upload_files, download_files
|
||||||
|
from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
||||||
|
from text_utilities import summarize, post_process_transcription
|
||||||
|
|
||||||
nltk.download('punkt')
|
nltk.download('punkt')
|
||||||
nltk.download('stopwords')
|
nltk.download('stopwords')
|
||||||
@@ -43,7 +34,7 @@ config = configparser.ConfigParser()
|
|||||||
config.read('config.ini')
|
config.read('config.ini')
|
||||||
|
|
||||||
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
||||||
|
NOW = datetime.now()
|
||||||
|
|
||||||
def init_argparse() -> argparse.ArgumentParser:
|
def init_argparse() -> argparse.ArgumentParser:
|
||||||
"""
|
"""
|
||||||
@@ -57,310 +48,10 @@ def init_argparse() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str,
|
parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str,
|
||||||
default="english", choices=['english', 'spanish', 'french', 'german', 'romanian'])
|
default="english", choices=['english', 'spanish', 'french', 'german', 'romanian'])
|
||||||
parser.add_argument("-t", "--transcript", help="Save a copy of the intermediary transcript file", type=str)
|
|
||||||
parser.add_argument("location")
|
parser.add_argument("location")
|
||||||
parser.add_argument("output")
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(txt, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
|
||||||
"""
|
|
||||||
Split text into smaller chunks.
|
|
||||||
:param txt: Text to be chunked
|
|
||||||
:param max_chunk_length: length of chunk
|
|
||||||
:return: chunked texts
|
|
||||||
"""
|
|
||||||
sentences = nltk.sent_tokenize(txt)
|
|
||||||
chunks = []
|
|
||||||
current_chunk = ""
|
|
||||||
for sentence in sentences:
|
|
||||||
if len(current_chunk) + len(sentence) < max_chunk_length:
|
|
||||||
current_chunk += f" {sentence.strip()}"
|
|
||||||
else:
|
|
||||||
chunks.append(current_chunk.strip())
|
|
||||||
current_chunk = f"{sentence.strip()}"
|
|
||||||
chunks.append(current_chunk.strip())
|
|
||||||
return chunks
|
|
||||||
|
|
||||||
|
|
||||||
def summarize_chunks(chunks, tokenizer, model):
|
|
||||||
"""
|
|
||||||
Summarize each chunk using a summarizer model
|
|
||||||
:param chunks:
|
|
||||||
:param tokenizer:
|
|
||||||
:param model:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
summaries = []
|
|
||||||
for c in chunks:
|
|
||||||
input_ids = tokenizer.encode(c, return_tensors='pt')
|
|
||||||
input_ids = input_ids.to(device)
|
|
||||||
with torch.no_grad():
|
|
||||||
summary_ids = model.generate(input_ids,
|
|
||||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
|
||||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
|
||||||
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
|
||||||
summaries.append(summary)
|
|
||||||
return summaries
|
|
||||||
|
|
||||||
|
|
||||||
def create_wordcloud():
|
|
||||||
"""
|
|
||||||
Create a basic word cloud visualization of transcribed text
|
|
||||||
:return: None. The wordcloud image is saved locally
|
|
||||||
"""
|
|
||||||
with open("transcript.txt", "r") as f:
|
|
||||||
transcription_text = f.read()
|
|
||||||
|
|
||||||
stopwords = set(STOPWORDS)
|
|
||||||
|
|
||||||
# python_mask = np.array(PIL.Image.open("download1.png"))
|
|
||||||
|
|
||||||
wordcloud = WordCloud(height=800, width=800,
|
|
||||||
background_color='white',
|
|
||||||
stopwords=stopwords,
|
|
||||||
min_font_size=8).generate(transcription_text)
|
|
||||||
|
|
||||||
# Plot wordcloud and save image
|
|
||||||
plt.figure(facecolor=None)
|
|
||||||
plt.imshow(wordcloud, interpolation="bilinear")
|
|
||||||
plt.axis("off")
|
|
||||||
plt.tight_layout(pad=0)
|
|
||||||
plt.savefig("wordcloud.png")
|
|
||||||
|
|
||||||
|
|
||||||
def create_talk_diff_scatter_viz():
|
|
||||||
"""
|
|
||||||
Perform agenda vs transription diff to see covered topics.
|
|
||||||
Create a scatter plot of words in topics.
|
|
||||||
:return: None. Saved locally.
|
|
||||||
"""
|
|
||||||
spaCy_model = "en_core_web_md"
|
|
||||||
nlp = spacy.load(spaCy_model)
|
|
||||||
nlp.add_pipe('sentencizer')
|
|
||||||
|
|
||||||
agenda_topics = []
|
|
||||||
agenda = []
|
|
||||||
# Load the agenda
|
|
||||||
with open("agenda-headers.txt", "r") as f:
|
|
||||||
for line in f.readlines():
|
|
||||||
if line.strip():
|
|
||||||
agenda.append(line.strip())
|
|
||||||
agenda_topics.append(line.split(":")[0])
|
|
||||||
|
|
||||||
# Load the transcription with timestamp
|
|
||||||
with open("transcript_timestamps.txt", "r") as f:
|
|
||||||
transcription_timestamp_text = f.read()
|
|
||||||
|
|
||||||
res = ast.literal_eval(transcription_timestamp_text)
|
|
||||||
chunks = res["chunks"]
|
|
||||||
|
|
||||||
# create df for processing
|
|
||||||
df = pd.DataFrame.from_dict(res["chunks"])
|
|
||||||
|
|
||||||
covered_items = {}
|
|
||||||
# ts: timestamp
|
|
||||||
# Map each timestamped chunk with top1 and top2 matched agenda
|
|
||||||
ts_to_topic_mapping_top_1 = {}
|
|
||||||
ts_to_topic_mapping_top_2 = {}
|
|
||||||
|
|
||||||
# Also create a mapping of the different timestamps in which each topic was covered
|
|
||||||
topic_to_ts_mapping_top_1 = collections.defaultdict(list)
|
|
||||||
topic_to_ts_mapping_top_2 = collections.defaultdict(list)
|
|
||||||
|
|
||||||
similarity_threshold = 0.7
|
|
||||||
|
|
||||||
for c in chunks:
|
|
||||||
doc_transcription = nlp(c["text"])
|
|
||||||
topic_similarities = []
|
|
||||||
for item in range(len(agenda)):
|
|
||||||
item_doc = nlp(agenda[item])
|
|
||||||
# if not doc_transcription or not all(token.has_vector for token in doc_transcription):
|
|
||||||
if not doc_transcription:
|
|
||||||
continue
|
|
||||||
similarity = doc_transcription.similarity(item_doc)
|
|
||||||
topic_similarities.append((item, similarity))
|
|
||||||
topic_similarities.sort(key=lambda x: x[1], reverse=True)
|
|
||||||
for i in range(2):
|
|
||||||
if topic_similarities[i][1] >= similarity_threshold:
|
|
||||||
covered_items[agenda[topic_similarities[i][0]]] = True
|
|
||||||
# top1 match
|
|
||||||
if i == 0:
|
|
||||||
ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[topic_similarities[i][0]]
|
|
||||||
topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
|
|
||||||
# top2 match
|
|
||||||
else:
|
|
||||||
ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[topic_similarities[i][0]]
|
|
||||||
topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
|
|
||||||
|
|
||||||
def create_new_columns(record):
|
|
||||||
"""
|
|
||||||
Accumulate the mapping information into the df
|
|
||||||
:param record:
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]]
|
|
||||||
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]]
|
|
||||||
return record
|
|
||||||
|
|
||||||
df = df.apply(create_new_columns, axis=1)
|
|
||||||
|
|
||||||
# Count the number of items covered and calculatre the percentage
|
|
||||||
num_covered_items = sum(covered_items.values())
|
|
||||||
percentage_covered = num_covered_items / len(agenda) * 100
|
|
||||||
|
|
||||||
# Print the results
|
|
||||||
print("💬 Agenda items covered in the transcription:")
|
|
||||||
for item in agenda:
|
|
||||||
if item in covered_items and covered_items[item]:
|
|
||||||
print("✅ ", item)
|
|
||||||
else:
|
|
||||||
print("❌ ", item)
|
|
||||||
print("📊 Coverage: {:.2f}%".format(percentage_covered))
|
|
||||||
|
|
||||||
# Save df, mappings for further experimentation
|
|
||||||
df.to_pickle("df.pkl")
|
|
||||||
|
|
||||||
my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2,
|
|
||||||
topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2]
|
|
||||||
pickle.dump(my_mappings, open("mappings.pkl", "wb"))
|
|
||||||
|
|
||||||
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
|
|
||||||
|
|
||||||
# pick the 2 most matched topic to be used for plotting
|
|
||||||
topic_times = collections.defaultdict(int)
|
|
||||||
for key in ts_to_topic_mapping_top_1.keys():
|
|
||||||
if key[0] is None or key[1] is None:
|
|
||||||
continue
|
|
||||||
duration = key[1] - key[0]
|
|
||||||
topic_times[ts_to_topic_mapping_top_1[key]] += duration
|
|
||||||
|
|
||||||
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
|
|
||||||
|
|
||||||
cat_1 = topic_times[0][0]
|
|
||||||
cat_1_name = topic_times[0][0]
|
|
||||||
cat_2_name = topic_times[1][0]
|
|
||||||
|
|
||||||
# Scatter plot of topics
|
|
||||||
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
|
|
||||||
corpus = st.CorpusFromParsedDocuments(
|
|
||||||
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
|
|
||||||
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
|
|
||||||
html = st.produce_scattertext_explorer(
|
|
||||||
corpus,
|
|
||||||
category=cat_1,
|
|
||||||
category_name=cat_1_name,
|
|
||||||
not_category_name=cat_2_name,
|
|
||||||
minimum_term_frequency=0, pmi_threshold_coefficient=0,
|
|
||||||
width_in_pixels=1000,
|
|
||||||
transform=st.Scalers.dense_rank
|
|
||||||
)
|
|
||||||
open('./demo_compact.html', 'w').write(html)
|
|
||||||
|
|
||||||
def preprocess_sentence(sentence):
|
|
||||||
stop_words = set(stopwords.words('english'))
|
|
||||||
tokens = word_tokenize(sentence.lower())
|
|
||||||
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
|
||||||
return ' '.join(tokens)
|
|
||||||
|
|
||||||
def compute_similarity(sent1, sent2):
|
|
||||||
tfidf_vectorizer = TfidfVectorizer()
|
|
||||||
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
|
|
||||||
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
|
||||||
|
|
||||||
def remove_almost_alike_sentences(sentences, threshold=0.7):
|
|
||||||
num_sentences = len(sentences)
|
|
||||||
removed_indices = set()
|
|
||||||
|
|
||||||
for i in range(num_sentences):
|
|
||||||
if i not in removed_indices:
|
|
||||||
for j in range(i + 1, num_sentences):
|
|
||||||
if j not in removed_indices:
|
|
||||||
sentence1 = preprocess_sentence(sentences[i])
|
|
||||||
sentence2 = preprocess_sentence(sentences[j])
|
|
||||||
similarity = compute_similarity(sentence1, sentence2)
|
|
||||||
|
|
||||||
if similarity >= threshold:
|
|
||||||
removed_indices.add(max(i, j))
|
|
||||||
|
|
||||||
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
|
|
||||||
return filtered_sentences
|
|
||||||
|
|
||||||
def remove_outright_duplicate_sentences_from_chunk(chunk):
|
|
||||||
chunk_text = chunk["text"]
|
|
||||||
sentences = nltk.sent_tokenize(chunk_text)
|
|
||||||
nonduplicate_sentences = list(dict.fromkeys(sentences))
|
|
||||||
return nonduplicate_sentences
|
|
||||||
|
|
||||||
def remove_whisper_repititive_hallucination(nonduplicate_sentences):
|
|
||||||
chunk_sentences = []
|
|
||||||
|
|
||||||
for sent in nonduplicate_sentences:
|
|
||||||
temp_result = ""
|
|
||||||
seen = {}
|
|
||||||
words = nltk.word_tokenize(sent)
|
|
||||||
n_gram_filter = 3
|
|
||||||
for i in range(len(words)):
|
|
||||||
if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[
|
|
||||||
i + 1:i + n_gram_filter + 2]:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2]
|
|
||||||
temp_result += words[i]
|
|
||||||
temp_result += " "
|
|
||||||
chunk_sentences.append(temp_result)
|
|
||||||
return chunk_sentences
|
|
||||||
|
|
||||||
def remove_duplicates_from_transcript_chunk(whisper_result):
|
|
||||||
for chunk in whisper_result["chunks"]:
|
|
||||||
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
|
|
||||||
chunk_sentences = remove_whisper_repititive_hallucination(nonduplicate_sentences)
|
|
||||||
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
|
|
||||||
chunk["text"] = " ".join(similarity_matched_sentences)
|
|
||||||
return whisper_result
|
|
||||||
|
|
||||||
def summarize(transcript_text, output_file,
|
|
||||||
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
|
|
||||||
if not summary_model:
|
|
||||||
summary_model = "facebook/bart-large-cnn"
|
|
||||||
|
|
||||||
# Summarize the generated transcript using the BART model
|
|
||||||
logger.info(f"Loading BART model: {summary_model}")
|
|
||||||
tokenizer = BartTokenizer.from_pretrained(summary_model)
|
|
||||||
model = BartForConditionalGeneration.from_pretrained(summary_model)
|
|
||||||
model = model.to(device)
|
|
||||||
|
|
||||||
if summarize_using_chunks != "YES":
|
|
||||||
inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest',
|
|
||||||
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
|
|
||||||
return_tensors='pt')
|
|
||||||
inputs = inputs.to(device)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
summaries = model.generate(inputs['input_ids'],
|
|
||||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
|
||||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
|
||||||
|
|
||||||
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for
|
|
||||||
summary in summaries]
|
|
||||||
summary = " ".join(decoded_summaries)
|
|
||||||
with open(output_file, 'w') as f:
|
|
||||||
f.write(summary.strip() + "\n\n")
|
|
||||||
else:
|
|
||||||
logger.info("Breaking transcript into smaller chunks")
|
|
||||||
chunks = chunk_text(transcript_text)
|
|
||||||
|
|
||||||
logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
|
|
||||||
|
|
||||||
logger.info(f"Writing summary text to: {output_file}")
|
|
||||||
with open(output_file, 'w') as f:
|
|
||||||
summaries = summarize_chunks(chunks, tokenizer, model)
|
|
||||||
for summary in summaries:
|
|
||||||
f.write(summary.strip() + " ")
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = init_argparse()
|
parser = init_argparse()
|
||||||
@@ -425,41 +116,40 @@ def main():
|
|||||||
whisper_result = pipeline(audio_filename, return_timestamps=True)
|
whisper_result = pipeline(audio_filename, return_timestamps=True)
|
||||||
logger.info("Finished transcribing file")
|
logger.info("Finished transcribing file")
|
||||||
|
|
||||||
whisper_result = remove_duplicates_from_transcript_chunk(whisper_result)
|
whisper_result = post_process_transcription(whisper_result)
|
||||||
|
|
||||||
transcript_text = ""
|
transcript_text = ""
|
||||||
for chunk in whisper_result["chunks"]:
|
for chunk in whisper_result["chunks"]:
|
||||||
transcript_text += chunk["text"]
|
transcript_text += chunk["text"]
|
||||||
|
|
||||||
# If we got the transcript parameter on the command line,
|
with open("transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
|
||||||
# save the transcript to the specified file.
|
|
||||||
if args.transcript:
|
|
||||||
logger.info(f"Saving transcript to: {args.transcript}")
|
|
||||||
transcript_file = open(args.transcript, "w")
|
|
||||||
transcript_file_timestamps = open(args.transcript[0:len(args.transcript) - 4] + "_timestamps.txt", "w")
|
|
||||||
transcript_file.write(transcript_text)
|
transcript_file.write(transcript_text)
|
||||||
|
|
||||||
|
with open("transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps:
|
||||||
transcript_file_timestamps.write(str(whisper_result))
|
transcript_file_timestamps.write(str(whisper_result))
|
||||||
transcript_file.close()
|
|
||||||
transcript_file_timestamps.close()
|
|
||||||
|
|
||||||
logger.info("Creating word cloud")
|
logger.info("Creating word cloud")
|
||||||
create_wordcloud()
|
create_wordcloud(NOW)
|
||||||
|
|
||||||
logger.info("Performing talk-diff and talk-diff visualization")
|
logger.info("Performing talk-diff and talk-diff visualization")
|
||||||
create_talk_diff_scatter_viz()
|
create_talk_diff_scatter_viz(NOW)
|
||||||
|
|
||||||
# S3 : Push artefacts to S3 bucket
|
# S3 : Push artefacts to S3 bucket
|
||||||
files_to_upload = ["transcript.txt", "transcript_timestamps.txt",
|
suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
"df.pkl",
|
files_to_upload = ["transcript_" + suffix + ".txt",
|
||||||
"wordcloud.png", "mappings.pkl"]
|
"transcript_with_timestamp_" + suffix + ".txt",
|
||||||
|
"df_" + suffix + ".pkl",
|
||||||
|
"wordcloud_" + suffix + ".png",
|
||||||
|
"mappings_" + suffix + ".pkl"]
|
||||||
upload_files(files_to_upload)
|
upload_files(files_to_upload)
|
||||||
|
|
||||||
summarize(transcript_text, args.output)
|
summarize(transcript_text, NOW, False, False)
|
||||||
|
|
||||||
logger.info("Summarization completed")
|
logger.info("Summarization completed")
|
||||||
|
|
||||||
# Summarization takes a lot of time, so do this separately at the end
|
# Summarization takes a lot of time, so do this separately at the end
|
||||||
files_to_upload = ["summary.txt"]
|
files_to_upload = ["summary_" + suffix + ".txt"]
|
||||||
upload_files(files_to_upload)
|
upload_files(files_to_upload)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
137
whisjax_realtime.py
Normal file
137
whisjax_realtime.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import configparser
|
||||||
|
import pyaudio
|
||||||
|
from whisper_jax import FlaxWhisperPipline
|
||||||
|
from pynput import keyboard
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import wave
|
||||||
|
import datetime
|
||||||
|
from file_utilities import upload_files
|
||||||
|
from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
||||||
|
from text_utilities import summarize, post_process_transcription
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
|
config = configparser.ConfigParser()
|
||||||
|
config.read('config.ini')
|
||||||
|
|
||||||
|
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
||||||
|
|
||||||
|
FRAMES_PER_BUFFER = 8000
|
||||||
|
FORMAT = pyaudio.paInt16
|
||||||
|
CHANNELS = 2
|
||||||
|
RATE = 44100
|
||||||
|
RECORD_SECONDS = 15
|
||||||
|
NOW = datetime.now()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
p = pyaudio.PyAudio()
|
||||||
|
AUDIO_DEVICE_ID = -1
|
||||||
|
for i in range(p.get_device_count()):
|
||||||
|
if p.get_device_info_by_index(i)["name"] == config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]:
|
||||||
|
AUDIO_DEVICE_ID = i
|
||||||
|
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
|
||||||
|
stream = p.open(
|
||||||
|
format=FORMAT,
|
||||||
|
channels=CHANNELS,
|
||||||
|
rate=RATE,
|
||||||
|
input=True,
|
||||||
|
frames_per_buffer=FRAMES_PER_BUFFER,
|
||||||
|
input_device_index=audio_devices['index']
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"],
|
||||||
|
dtype=jnp.float16,
|
||||||
|
batch_size=16)
|
||||||
|
|
||||||
|
transcription = ""
|
||||||
|
|
||||||
|
TEMP_AUDIO_FILE = "temp_audio.wav"
|
||||||
|
global proceed
|
||||||
|
proceed = True
|
||||||
|
|
||||||
|
def on_press(key):
|
||||||
|
if key == keyboard.Key.esc:
|
||||||
|
global proceed
|
||||||
|
proceed = False
|
||||||
|
|
||||||
|
transcript_with_timestamp = {"text": "", "chunks": []}
|
||||||
|
last_transcribed_time = 0.0
|
||||||
|
|
||||||
|
listener = keyboard.Listener(on_press=on_press)
|
||||||
|
listener.start()
|
||||||
|
print("Attempting real-time transcription.. Listening...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while proceed:
|
||||||
|
frames = []
|
||||||
|
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
|
||||||
|
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
|
||||||
|
frames.append(data)
|
||||||
|
|
||||||
|
wf = wave.open(TEMP_AUDIO_FILE, 'wb')
|
||||||
|
wf.setnchannels(CHANNELS)
|
||||||
|
wf.setsampwidth(p.get_sample_size(FORMAT))
|
||||||
|
wf.setframerate(RATE)
|
||||||
|
wf.writeframes(b''.join(frames))
|
||||||
|
wf.close()
|
||||||
|
|
||||||
|
whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True)
|
||||||
|
print(whisper_result['text'])
|
||||||
|
|
||||||
|
timestamp = whisper_result["chunks"][0]["timestamp"]
|
||||||
|
start = timestamp[0]
|
||||||
|
end = timestamp[1]
|
||||||
|
if end is None:
|
||||||
|
end = start + 15.0
|
||||||
|
duration = end - start
|
||||||
|
item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration),
|
||||||
|
'text': whisper_result['text']}
|
||||||
|
last_transcribed_time = last_transcribed_time + duration
|
||||||
|
transcript_with_timestamp["chunks"].append(item)
|
||||||
|
|
||||||
|
transcription += whisper_result['text']
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
finally:
|
||||||
|
with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
|
||||||
|
f.write(transcription)
|
||||||
|
with open("real_time_transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
|
||||||
|
transcript_with_timestamp["text"] = transcription
|
||||||
|
f.write(str(transcript_with_timestamp))
|
||||||
|
|
||||||
|
transcript_with_timestamp = post_process_transcription(transcript_with_timestamp)
|
||||||
|
|
||||||
|
transcript_text = ""
|
||||||
|
for chunk in transcript_with_timestamp["chunks"]:
|
||||||
|
transcript_text += chunk["text"]
|
||||||
|
|
||||||
|
logger.info("Creating word cloud")
|
||||||
|
create_wordcloud(NOW, True)
|
||||||
|
|
||||||
|
logger.info("Performing talk-diff and talk-diff visualization")
|
||||||
|
create_talk_diff_scatter_viz(NOW, True)
|
||||||
|
|
||||||
|
# S3 : Push artefacts to S3 bucket
|
||||||
|
suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
|
files_to_upload = ["real_time_transcript_" + suffix + ".txt",
|
||||||
|
"real_time_transcript_with_timestamp" + suffix + ".txt",
|
||||||
|
"real_time_df_" + suffix + ".pkl",
|
||||||
|
"real_time_wordcloud_" + suffix + ".png",
|
||||||
|
"real_time_mappings_" + suffix + ".pkl"]
|
||||||
|
upload_files(files_to_upload)
|
||||||
|
|
||||||
|
summarize(transcript_text, NOW, True, True)
|
||||||
|
|
||||||
|
logger.info("Summarization completed")
|
||||||
|
|
||||||
|
# Summarization takes a lot of time, so do this separately at the end
|
||||||
|
files_to_upload = ["real_time_summary_" + suffix + ".txt"]
|
||||||
|
upload_files(files_to_upload)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@@ -1,84 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import configparser
|
|
||||||
import pyaudio
|
|
||||||
from whisper_jax import FlaxWhisperPipline
|
|
||||||
from pynput import keyboard
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import wave
|
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
|
||||||
config.read('config.ini')
|
|
||||||
|
|
||||||
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
|
||||||
|
|
||||||
FRAMES_PER_BUFFER = 8000
|
|
||||||
FORMAT = pyaudio.paInt16
|
|
||||||
CHANNELS = 2
|
|
||||||
RATE = 44100
|
|
||||||
RECORD_SECONDS = 15
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
p = pyaudio.PyAudio()
|
|
||||||
AUDIO_DEVICE_ID = -1
|
|
||||||
for i in range(p.get_device_count()):
|
|
||||||
if p.get_device_info_by_index(i)["name"] == "ref-agg-input":
|
|
||||||
AUDIO_DEVICE_ID = i
|
|
||||||
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
|
|
||||||
stream = p.open(
|
|
||||||
format=FORMAT,
|
|
||||||
channels=CHANNELS,
|
|
||||||
rate=RATE,
|
|
||||||
input=True,
|
|
||||||
frames_per_buffer=FRAMES_PER_BUFFER,
|
|
||||||
input_device_index=audio_devices['index']
|
|
||||||
)
|
|
||||||
|
|
||||||
pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"],
|
|
||||||
dtype=jnp.float16,
|
|
||||||
batch_size=16)
|
|
||||||
|
|
||||||
transcript_file = open("transcript.txt", "w+")
|
|
||||||
transcription = ""
|
|
||||||
|
|
||||||
TEMP_AUDIO_FILE = "temp_audio.wav"
|
|
||||||
global proceed
|
|
||||||
proceed = True
|
|
||||||
|
|
||||||
def on_press(key):
|
|
||||||
if key == keyboard.Key.esc:
|
|
||||||
global proceed
|
|
||||||
proceed = False
|
|
||||||
|
|
||||||
listener = keyboard.Listener(on_press=on_press)
|
|
||||||
listener.start()
|
|
||||||
print("Attempting real-time transcription.. Listening...")
|
|
||||||
while proceed:
|
|
||||||
try:
|
|
||||||
frames = []
|
|
||||||
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
|
|
||||||
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
|
|
||||||
frames.append(data)
|
|
||||||
|
|
||||||
wf = wave.open(TEMP_AUDIO_FILE, 'wb')
|
|
||||||
wf.setnchannels(CHANNELS)
|
|
||||||
wf.setsampwidth(p.get_sample_size(FORMAT))
|
|
||||||
wf.setframerate(RATE)
|
|
||||||
wf.writeframes(b''.join(frames))
|
|
||||||
wf.close()
|
|
||||||
|
|
||||||
whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True)
|
|
||||||
print(whisper_result['text'])
|
|
||||||
|
|
||||||
transcription += whisper_result['text']
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
finally:
|
|
||||||
with open("real_time_transcription.txt", "w") as f:
|
|
||||||
transcript_file.write(transcription)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
Reference in New Issue
Block a user