Merge pull request #20 from Monadical-SAS/whisper-jax-gokul

Micro improvements and bug fixes
This commit is contained in:
projects-g
2023-06-23 20:03:25 +05:30
committed by GitHub
9 changed files with 586 additions and 170 deletions

View File

@@ -27,10 +27,8 @@ To setup,
```sh setup_dependencies.sh cuda12```
4) ``` pip install -r requirements.txt```
5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
4) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ"```

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -47,4 +47,5 @@ scattertext
pandas
jupyter
seaborn
matplotlib
matplotlib
termcolor

View File

@@ -1,4 +1,4 @@
# Upgrade pip
Upgrade pip
pip install --upgrade pip
# Default to CPU Installation of JAX
@@ -24,3 +24,10 @@ pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
# Update to latest version
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
pip install -r requirements.txt
# download spacy models
export KMP_DUPLICATE_LIB_OK=True
python -m spacy download en_core_web_sm
python -m spacy download en_core_web_md

View File

@@ -20,8 +20,11 @@ def preprocess_sentence(sentence):
def compute_similarity(sent1, sent2):
tfidf_vectorizer = TfidfVectorizer()
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
print("semt1", sent1, sent2)
if sent1 is not None and sent2 is not None:
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
return 0.0
def remove_almost_alike_sentences(sentences, threshold=0.7):
num_sentences = len(sentences)
@@ -31,12 +34,21 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
if i not in removed_indices:
for j in range(i + 1, num_sentences):
if j not in removed_indices:
sentence1 = preprocess_sentence(sentences[i])
sentence2 = preprocess_sentence(sentences[j])
similarity = compute_similarity(sentence1, sentence2)
l_i = len(sentences[i])
l_j = len(sentences[j])
if l_i == 0 or l_j == 0:
if l_i == 0:
removed_indices.add(i)
if l_j == 0:
removed_indices.add(j)
else:
sentence1 = preprocess_sentence(sentences[i])
sentence2 = preprocess_sentence(sentences[j])
if len(sentence1) != 0 and len(sentence2) != 0:
similarity = compute_similarity(sentence1, sentence2)
if similarity >= threshold:
removed_indices.add(max(i, j))
if similarity >= threshold:
removed_indices.add(max(i, j))
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
return filtered_sentences
@@ -67,11 +79,14 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
return chunk_sentences
def post_process_transcription(whisper_result):
transcript_text = ""
for chunk in whisper_result["chunks"]:
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences)
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
chunk["text"] = " ".join(similarity_matched_sentences)
transcript_text += chunk["text"]
whisper_result["text"] = transcript_text
return whisper_result

View File

@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt
from wordcloud import WordCloud, STOPWORDS
from nltk.corpus import stopwords
import collections
import spacy
import pickle
@@ -11,6 +12,10 @@ import configparser
config = configparser.ConfigParser()
config.read('config.ini')
en = spacy.load('en_core_web_md')
spacy_stopwords = en.Defaults.stop_words
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
def create_wordcloud(timestamp, real_time=False):
"""
@@ -26,13 +31,11 @@ def create_wordcloud(timestamp, real_time=False):
with open(filename, "r") as f:
transcription_text = f.read()
stopwords = set(STOPWORDS)
# python_mask = np.array(PIL.Image.open("download1.png"))
wordcloud = WordCloud(height=800, width=800,
background_color='white',
stopwords=stopwords,
stopwords=STOPWORDS,
min_font_size=8).generate(transcription_text)
# Plot wordcloud and save image
@@ -192,4 +195,7 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
width_in_pixels=1000,
transform=st.Scalers.dense_rank
)
open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
if real_time:
open('./real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
else:
open('./scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)

View File

@@ -26,8 +26,8 @@ from file_utilities import upload_files, download_files
from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
from text_utilities import summarize, post_process_transcription
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
# Configurations can be found in config.ini. Set them properly before executing
config = configparser.ConfigParser()
@@ -141,7 +141,8 @@ def main():
"transcript_with_timestamp_" + suffix + ".txt",
"df_" + suffix + ".pkl",
"wordcloud_" + suffix + ".png",
"mappings_" + suffix + ".pkl"]
"mappings_" + suffix + ".pkl",
"scatter_" + suffix + ".html"]
upload_files(files_to_upload)
summarize(transcript_text, NOW, False, False)

View File

@@ -12,7 +12,10 @@ from viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
from text_utilities import summarize, post_process_transcription
from loguru import logger
import nltk
nltk.download('stopwords')
import time
from termcolor import colored
nltk.download('stopwords', quiet=True)
config = configparser.ConfigParser()
config.read('config.ini')
@@ -68,9 +71,11 @@ def main():
try:
while proceed:
frames = []
start_time = time.time()
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
frames.append(data)
end_time = time.time()
wf = wave.open(TEMP_AUDIO_FILE, 'wb')
wf.setnchannels(CHANNELS)
@@ -80,8 +85,6 @@ def main():
wf.close()
whisper_result = pipeline(TEMP_AUDIO_FILE, return_timestamps=True)
print(whisper_result['text'])
timestamp = whisper_result["chunks"][0]["timestamp"]
start = timestamp[0]
end = timestamp[1]
@@ -89,12 +92,18 @@ def main():
end = start + 15.0
duration = end - start
item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration),
'text': whisper_result['text']}
'text': whisper_result['text'],
'stats': (str(end_time - start_time), str(duration))
}
last_transcribed_time = last_transcribed_time + duration
transcript_with_timestamp["chunks"].append(item)
transcription += whisper_result['text']
print(colored("<START>", "yellow"))
print(colored(whisper_result['text'], 'green'))
print(colored("<END> Recorded duration: " + str(end_time - start_time) + " | Transcribed duration: " +
str(duration), "yellow"))
except Exception as e:
print(e)
finally:
@@ -106,10 +115,6 @@ def main():
transcript_with_timestamp = post_process_transcription(transcript_with_timestamp)
transcript_text = ""
for chunk in transcript_with_timestamp["chunks"]:
transcript_text += chunk["text"]
logger.info("Creating word cloud")
create_wordcloud(NOW, True)
@@ -122,10 +127,11 @@ def main():
"real_time_transcript_with_timestamp" + suffix + ".txt",
"real_time_df_" + suffix + ".pkl",
"real_time_wordcloud_" + suffix + ".png",
"real_time_mappings_" + suffix + ".pkl"]
"real_time_mappings_" + suffix + ".pkl",
"real_time_scatter_" + suffix + ".html"]
upload_files(files_to_upload)
summarize(transcript_text, NOW, True, True)
summarize(transcript_with_timestamp["text"], NOW, True, True)
logger.info("Summarization completed")