mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Merge pull request #20 from Monadical-SAS/whisper-jax-gokul
Micro improvements and bug fixes
This commit is contained in:
@@ -27,10 +27,8 @@ To setup,
|
|||||||
|
|
||||||
```sh setup_dependencies.sh cuda12```
|
```sh setup_dependencies.sh cuda12```
|
||||||
|
|
||||||
4) ``` pip install -r requirements.txt```
|
|
||||||
|
|
||||||
|
4) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
|
||||||
5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
|
|
||||||
|
|
||||||
``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ"```
|
``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ"```
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@@ -47,4 +47,5 @@ scattertext
|
|||||||
pandas
|
pandas
|
||||||
jupyter
|
jupyter
|
||||||
seaborn
|
seaborn
|
||||||
matplotlib
|
matplotlib
|
||||||
|
termcolor
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
# Upgrade pip
|
Upgrade pip
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
|
|
||||||
# Default to CPU Installation of JAX
|
# Default to CPU Installation of JAX
|
||||||
@@ -24,3 +24,10 @@ pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
|
|||||||
# Update to latest version
|
# Update to latest version
|
||||||
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
|
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
|
||||||
|
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# download spacy models
|
||||||
|
export KMP_DUPLICATE_LIB_OK=True
|
||||||
|
python -m spacy download en_core_web_sm
|
||||||
|
python -m spacy download en_core_web_md
|
||||||
|
|
||||||
|
|||||||
@@ -20,8 +20,11 @@ def preprocess_sentence(sentence):
|
|||||||
|
|
||||||
def compute_similarity(sent1, sent2):
|
def compute_similarity(sent1, sent2):
|
||||||
tfidf_vectorizer = TfidfVectorizer()
|
tfidf_vectorizer = TfidfVectorizer()
|
||||||
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
|
print("semt1", sent1, sent2)
|
||||||
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
if sent1 is not None and sent2 is not None:
|
||||||
|
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
|
||||||
|
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
||||||
|
return 0.0
|
||||||
|
|
||||||
def remove_almost_alike_sentences(sentences, threshold=0.7):
|
def remove_almost_alike_sentences(sentences, threshold=0.7):
|
||||||
num_sentences = len(sentences)
|
num_sentences = len(sentences)
|
||||||
@@ -31,12 +34,21 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
|
|||||||
if i not in removed_indices:
|
if i not in removed_indices:
|
||||||
for j in range(i + 1, num_sentences):
|
for j in range(i + 1, num_sentences):
|
||||||
if j not in removed_indices:
|
if j not in removed_indices:
|
||||||
sentence1 = preprocess_sentence(sentences[i])
|
l_i = len(sentences[i])
|
||||||
sentence2 = preprocess_sentence(sentences[j])
|
l_j = len(sentences[j])
|
||||||
similarity = compute_similarity(sentence1, sentence2)
|
if l_i == 0 or l_j == 0:
|
||||||
|
if l_i == 0:
|
||||||
|
removed_indices.add(i)
|
||||||
|
if l_j == 0:
|
||||||
|
removed_indices.add(j)
|
||||||
|
else:
|
||||||
|
sentence1 = preprocess_sentence(sentences[i])
|
||||||
|
sentence2 = preprocess_sentence(sentences[j])
|
||||||
|
if len(sentence1) != 0 and len(sentence2) != 0:
|
||||||
|
similarity = compute_similarity(sentence1, sentence2)
|
||||||
|
|
||||||
if similarity >= threshold:
|
if similarity >= threshold:
|
||||||
removed_indices.add(max(i, j))
|
removed_indices.add(max(i, j))
|
||||||
|
|
||||||
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
|
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
|
||||||
return filtered_sentences
|
return filtered_sentences
|
||||||
@@ -67,11 +79,14 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
|||||||
return chunk_sentences
|
return chunk_sentences
|
||||||
|
|
||||||
def post_process_transcription(whisper_result):
|
def post_process_transcription(whisper_result):
|
||||||
|
transcript_text = ""
|
||||||
for chunk in whisper_result["chunks"]:
|
for chunk in whisper_result["chunks"]:
|
||||||
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
|
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
|
||||||
chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences)
|
chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences)
|
||||||
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
|
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
|
||||||
chunk["text"] = " ".join(similarity_matched_sentences)
|
chunk["text"] = " ".join(similarity_matched_sentences)
|
||||||
|
transcript_text += chunk["text"]
|
||||||
|
whisper_result["text"] = transcript_text
|
||||||
return whisper_result
|
return whisper_result
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from wordcloud import WordCloud, STOPWORDS
|
from wordcloud import WordCloud, STOPWORDS
|
||||||
|
from nltk.corpus import stopwords
|
||||||
import collections
|
import collections
|
||||||
import spacy
|
import spacy
|
||||||
import pickle
|
import pickle
|
||||||
@@ -11,6 +12,10 @@ import configparser
|
|||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read('config.ini')
|
config.read('config.ini')
|
||||||
|
|
||||||
|
en = spacy.load('en_core_web_md')
|
||||||
|
spacy_stopwords = en.Defaults.stop_words
|
||||||
|
|
||||||
|
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
|
||||||
|
|
||||||
def create_wordcloud(timestamp, real_time=False):
|
def create_wordcloud(timestamp, real_time=False):
|
||||||
"""
|
"""
|
||||||
@@ -26,13 +31,11 @@ def create_wordcloud(timestamp, real_time=False):
|
|||||||
with open(filename, "r") as f:
|
with open(filename, "r") as f:
|
||||||
transcription_text = f.read()
|
transcription_text = f.read()
|
||||||
|
|
||||||
stopwords = set(STOPWORDS)
|
|
||||||
|
|
||||||
# python_mask = np.array(PIL.Image.open("download1.png"))
|
# python_mask = np.array(PIL.Image.open("download1.png"))
|
||||||
|
|
||||||
wordcloud = WordCloud(height=800, width=800,
|
wordcloud = WordCloud(height=800, width=800,
|
||||||
background_color='white',
|
background_color='white',
|
||||||
stopwords=stopwords,
|
stopwords=STOPWORDS,
|
||||||
min_font_size=8).generate(transcription_text)
|
min_font_size=8).generate(transcription_text)
|
||||||
|
|
||||||
# Plot wordcloud and save image
|
# Plot wordcloud and save image
|
||||||
@@ -192,4 +195,7 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
|||||||
width_in_pixels=1000,
|
width_in_pixels=1000,
|
||||||
transform=st.Scalers.dense_rank
|
transform=st.Scalers.dense_rank
|
||||||
)
|
)
|
||||||
open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
if real_time:
|
||||||
|
open('./real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||||
|
else:
|
||||||
|
open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||||
@@ -26,8 +26,8 @@ from file_utilities import upload_files, download_files
|
|||||||
from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
||||||
from text_utilities import summarize, post_process_transcription
|
from text_utilities import summarize, post_process_transcription
|
||||||
|
|
||||||
nltk.download('punkt')
|
nltk.download('punkt', quiet=True)
|
||||||
nltk.download('stopwords')
|
nltk.download('stopwords', quiet=True)
|
||||||
|
|
||||||
# Configurations can be found in config.ini. Set them properly before executing
|
# Configurations can be found in config.ini. Set them properly before executing
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -141,7 +141,8 @@ def main():
|
|||||||
"transcript_with_timestamp_" + suffix + ".txt",
|
"transcript_with_timestamp_" + suffix + ".txt",
|
||||||
"df_" + suffix + ".pkl",
|
"df_" + suffix + ".pkl",
|
||||||
"wordcloud_" + suffix + ".png",
|
"wordcloud_" + suffix + ".png",
|
||||||
"mappings_" + suffix + ".pkl"]
|
"mappings_" + suffix + ".pkl",
|
||||||
|
"scatter_" + suffix + ".html"]
|
||||||
upload_files(files_to_upload)
|
upload_files(files_to_upload)
|
||||||
|
|
||||||
summarize(transcript_text, NOW, False, False)
|
summarize(transcript_text, NOW, False, False)
|
||||||
|
|||||||
@@ -12,7 +12,10 @@ from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
|||||||
from text_utilities import summarize, post_process_transcription
|
from text_utilities import summarize, post_process_transcription
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import nltk
|
import nltk
|
||||||
nltk.download('stopwords')
|
import time
|
||||||
|
from termcolor import colored
|
||||||
|
|
||||||
|
nltk.download('stopwords', quiet=True)
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read('config.ini')
|
config.read('config.ini')
|
||||||
@@ -68,9 +71,11 @@ def main():
|
|||||||
try:
|
try:
|
||||||
while proceed:
|
while proceed:
|
||||||
frames = []
|
frames = []
|
||||||
|
start_time = time.time()
|
||||||
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
|
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
|
||||||
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
|
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
|
||||||
frames.append(data)
|
frames.append(data)
|
||||||
|
end_time = time.time()
|
||||||
|
|
||||||
wf = wave.open(TEMP_AUDIO_FILE, 'wb')
|
wf = wave.open(TEMP_AUDIO_FILE, 'wb')
|
||||||
wf.setnchannels(CHANNELS)
|
wf.setnchannels(CHANNELS)
|
||||||
@@ -80,8 +85,6 @@ def main():
|
|||||||
wf.close()
|
wf.close()
|
||||||
|
|
||||||
whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True)
|
whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True)
|
||||||
print(whisper_result['text'])
|
|
||||||
|
|
||||||
timestamp = whisper_result["chunks"][0]["timestamp"]
|
timestamp = whisper_result["chunks"][0]["timestamp"]
|
||||||
start = timestamp[0]
|
start = timestamp[0]
|
||||||
end = timestamp[1]
|
end = timestamp[1]
|
||||||
@@ -89,12 +92,18 @@ def main():
|
|||||||
end = start + 15.0
|
end = start + 15.0
|
||||||
duration = end - start
|
duration = end - start
|
||||||
item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration),
|
item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration),
|
||||||
'text': whisper_result['text']}
|
'text': whisper_result['text'],
|
||||||
|
'stats': (str(end_time - start_time), str(duration))
|
||||||
|
}
|
||||||
last_transcribed_time = last_transcribed_time + duration
|
last_transcribed_time = last_transcribed_time + duration
|
||||||
transcript_with_timestamp["chunks"].append(item)
|
transcript_with_timestamp["chunks"].append(item)
|
||||||
|
|
||||||
transcription += whisper_result['text']
|
transcription += whisper_result['text']
|
||||||
|
|
||||||
|
print(colored("<START>", "yellow"))
|
||||||
|
print(colored(whisper_result['text'], 'green'))
|
||||||
|
print(colored("<END> Recorded duration: " + str(end_time - start_time) + " | Transcribed duration: " +
|
||||||
|
str(duration), "yellow"))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
finally:
|
finally:
|
||||||
@@ -106,10 +115,6 @@ def main():
|
|||||||
|
|
||||||
transcript_with_timestamp = post_process_transcription(transcript_with_timestamp)
|
transcript_with_timestamp = post_process_transcription(transcript_with_timestamp)
|
||||||
|
|
||||||
transcript_text = ""
|
|
||||||
for chunk in transcript_with_timestamp["chunks"]:
|
|
||||||
transcript_text += chunk["text"]
|
|
||||||
|
|
||||||
logger.info("Creating word cloud")
|
logger.info("Creating word cloud")
|
||||||
create_wordcloud(NOW, True)
|
create_wordcloud(NOW, True)
|
||||||
|
|
||||||
@@ -122,10 +127,11 @@ def main():
|
|||||||
"real_time_transcript_with_timestamp" + suffix + ".txt",
|
"real_time_transcript_with_timestamp" + suffix + ".txt",
|
||||||
"real_time_df_" + suffix + ".pkl",
|
"real_time_df_" + suffix + ".pkl",
|
||||||
"real_time_wordcloud_" + suffix + ".png",
|
"real_time_wordcloud_" + suffix + ".png",
|
||||||
"real_time_mappings_" + suffix + ".pkl"]
|
"real_time_mappings_" + suffix + ".pkl",
|
||||||
|
"real_time_scatter_" + suffix + ".html"]
|
||||||
upload_files(files_to_upload)
|
upload_files(files_to_upload)
|
||||||
|
|
||||||
summarize(transcript_text, NOW, True, True)
|
summarize(transcript_with_timestamp["text"], NOW, True, True)
|
||||||
|
|
||||||
logger.info("Summarization completed")
|
logger.info("Summarization completed")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user