Merge pull request #20 from Monadical-SAS/whisper-jax-gokul

Micro improvements and bug fixes
This commit is contained in:
projects-g
2023-06-23 20:03:25 +05:30
committed by GitHub
9 changed files with 586 additions and 170 deletions

View File

@@ -27,10 +27,8 @@ To setup,
```sh setup_dependencies.sh cuda12``` ```sh setup_dependencies.sh cuda12```
4) ``` pip install -r requirements.txt```
4) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ"``` ``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ"```

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -47,4 +47,5 @@ scattertext
pandas pandas
jupyter jupyter
seaborn seaborn
matplotlib matplotlib
termcolor

View File

@@ -1,4 +1,4 @@
# Upgrade pip Upgrade pip
pip install --upgrade pip pip install --upgrade pip
# Default to CPU Installation of JAX # Default to CPU Installation of JAX
@@ -24,3 +24,10 @@ pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
# Update to latest version # Update to latest version
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
pip install -r requirements.txt
# download spacy models
export KMP_DUPLICATE_LIB_OK=True
python -m spacy download en_core_web_sm
python -m spacy download en_core_web_md

View File

@@ -20,8 +20,11 @@ def preprocess_sentence(sentence):
def compute_similarity(sent1, sent2): def compute_similarity(sent1, sent2):
tfidf_vectorizer = TfidfVectorizer() tfidf_vectorizer = TfidfVectorizer()
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2]) print("semt1", sent1, sent2)
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0] if sent1 is not None and sent2 is not None:
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
return 0.0
def remove_almost_alike_sentences(sentences, threshold=0.7): def remove_almost_alike_sentences(sentences, threshold=0.7):
num_sentences = len(sentences) num_sentences = len(sentences)
@@ -31,12 +34,21 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
if i not in removed_indices: if i not in removed_indices:
for j in range(i + 1, num_sentences): for j in range(i + 1, num_sentences):
if j not in removed_indices: if j not in removed_indices:
sentence1 = preprocess_sentence(sentences[i]) l_i = len(sentences[i])
sentence2 = preprocess_sentence(sentences[j]) l_j = len(sentences[j])
similarity = compute_similarity(sentence1, sentence2) if l_i == 0 or l_j == 0:
if l_i == 0:
removed_indices.add(i)
if l_j == 0:
removed_indices.add(j)
else:
sentence1 = preprocess_sentence(sentences[i])
sentence2 = preprocess_sentence(sentences[j])
if len(sentence1) != 0 and len(sentence2) != 0:
similarity = compute_similarity(sentence1, sentence2)
if similarity >= threshold: if similarity >= threshold:
removed_indices.add(max(i, j)) removed_indices.add(max(i, j))
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices] filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
return filtered_sentences return filtered_sentences
@@ -67,11 +79,14 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
return chunk_sentences return chunk_sentences
def post_process_transcription(whisper_result): def post_process_transcription(whisper_result):
transcript_text = ""
for chunk in whisper_result["chunks"]: for chunk in whisper_result["chunks"]:
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk) nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences) chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences)
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences) similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
chunk["text"] = " ".join(similarity_matched_sentences) chunk["text"] = " ".join(similarity_matched_sentences)
transcript_text += chunk["text"]
whisper_result["text"] = transcript_text
return whisper_result return whisper_result

View File

@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from wordcloud import WordCloud, STOPWORDS from wordcloud import WordCloud, STOPWORDS
from nltk.corpus import stopwords
import collections import collections
import spacy import spacy
import pickle import pickle
@@ -11,6 +12,10 @@ import configparser
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read('config.ini') config.read('config.ini')
en = spacy.load('en_core_web_md')
spacy_stopwords = en.Defaults.stop_words
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
def create_wordcloud(timestamp, real_time=False): def create_wordcloud(timestamp, real_time=False):
""" """
@@ -26,13 +31,11 @@ def create_wordcloud(timestamp, real_time=False):
with open(filename, "r") as f: with open(filename, "r") as f:
transcription_text = f.read() transcription_text = f.read()
stopwords = set(STOPWORDS)
# python_mask = np.array(PIL.Image.open("download1.png")) # python_mask = np.array(PIL.Image.open("download1.png"))
wordcloud = WordCloud(height=800, width=800, wordcloud = WordCloud(height=800, width=800,
background_color='white', background_color='white',
stopwords=stopwords, stopwords=STOPWORDS,
min_font_size=8).generate(transcription_text) min_font_size=8).generate(transcription_text)
# Plot wordcloud and save image # Plot wordcloud and save image
@@ -192,4 +195,7 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
width_in_pixels=1000, width_in_pixels=1000,
transform=st.Scalers.dense_rank transform=st.Scalers.dense_rank
) )
open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) if real_time:
open('./real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
else:
open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)

View File

@@ -26,8 +26,8 @@ from file_utilities import upload_files, download_files
from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
from text_utilities import summarize, post_process_transcription from text_utilities import summarize, post_process_transcription
nltk.download('punkt') nltk.download('punkt', quiet=True)
nltk.download('stopwords') nltk.download('stopwords', quiet=True)
# Configurations can be found in config.ini. Set them properly before executing # Configurations can be found in config.ini. Set them properly before executing
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -141,7 +141,8 @@ def main():
"transcript_with_timestamp_" + suffix + ".txt", "transcript_with_timestamp_" + suffix + ".txt",
"df_" + suffix + ".pkl", "df_" + suffix + ".pkl",
"wordcloud_" + suffix + ".png", "wordcloud_" + suffix + ".png",
"mappings_" + suffix + ".pkl"] "mappings_" + suffix + ".pkl",
"scatter_" + suffix + ".html"]
upload_files(files_to_upload) upload_files(files_to_upload)
summarize(transcript_text, NOW, False, False) summarize(transcript_text, NOW, False, False)

View File

@@ -12,7 +12,10 @@ from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
from text_utilities import summarize, post_process_transcription from text_utilities import summarize, post_process_transcription
from loguru import logger from loguru import logger
import nltk import nltk
nltk.download('stopwords') import time
from termcolor import colored
nltk.download('stopwords', quiet=True)
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read('config.ini') config.read('config.ini')
@@ -68,9 +71,11 @@ def main():
try: try:
while proceed: while proceed:
frames = [] frames = []
start_time = time.time()
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)): for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False) data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
frames.append(data) frames.append(data)
end_time = time.time()
wf = wave.open(TEMP_AUDIO_FILE, 'wb') wf = wave.open(TEMP_AUDIO_FILE, 'wb')
wf.setnchannels(CHANNELS) wf.setnchannels(CHANNELS)
@@ -80,8 +85,6 @@ def main():
wf.close() wf.close()
whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True) whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True)
print(whisper_result['text'])
timestamp = whisper_result["chunks"][0]["timestamp"] timestamp = whisper_result["chunks"][0]["timestamp"]
start = timestamp[0] start = timestamp[0]
end = timestamp[1] end = timestamp[1]
@@ -89,12 +92,18 @@ def main():
end = start + 15.0 end = start + 15.0
duration = end - start duration = end - start
item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration), item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration),
'text': whisper_result['text']} 'text': whisper_result['text'],
'stats': (str(end_time - start_time), str(duration))
}
last_transcribed_time = last_transcribed_time + duration last_transcribed_time = last_transcribed_time + duration
transcript_with_timestamp["chunks"].append(item) transcript_with_timestamp["chunks"].append(item)
transcription += whisper_result['text'] transcription += whisper_result['text']
print(colored("<START>", "yellow"))
print(colored(whisper_result['text'], 'green'))
print(colored("<END> Recorded duration: " + str(end_time - start_time) + " | Transcribed duration: " +
str(duration), "yellow"))
except Exception as e: except Exception as e:
print(e) print(e)
finally: finally:
@@ -106,10 +115,6 @@ def main():
transcript_with_timestamp = post_process_transcription(transcript_with_timestamp) transcript_with_timestamp = post_process_transcription(transcript_with_timestamp)
transcript_text = ""
for chunk in transcript_with_timestamp["chunks"]:
transcript_text += chunk["text"]
logger.info("Creating word cloud") logger.info("Creating word cloud")
create_wordcloud(NOW, True) create_wordcloud(NOW, True)
@@ -122,10 +127,11 @@ def main():
"real_time_transcript_with_timestamp" + suffix + ".txt", "real_time_transcript_with_timestamp" + suffix + ".txt",
"real_time_df_" + suffix + ".pkl", "real_time_df_" + suffix + ".pkl",
"real_time_wordcloud_" + suffix + ".png", "real_time_wordcloud_" + suffix + ".png",
"real_time_mappings_" + suffix + ".pkl"] "real_time_mappings_" + suffix + ".pkl",
"real_time_scatter_" + suffix + ".html"]
upload_files(files_to_upload) upload_files(files_to_upload)
summarize(transcript_text, NOW, True, True) summarize(transcript_with_timestamp["text"], NOW, True, True)
logger.info("Summarization completed") logger.info("Summarization completed")