mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
add summary features
This commit is contained in:
12
config.ini
12
config.ini
@@ -4,7 +4,15 @@ KMP_DUPLICATE_LIB_OK=TRUE
|
|||||||
# Export OpenAI API Key
|
# Export OpenAI API Key
|
||||||
OPENAI_APIKEY=
|
OPENAI_APIKEY=
|
||||||
# Export Whisper Model Size
|
# Export Whisper Model Size
|
||||||
WHISPER_MODEL_SIZE=tiny
|
WHISPER_MODEL_SIZE=medium
|
||||||
|
# AWS config
|
||||||
AWS_ACCESS_KEY=***REMOVED***
|
AWS_ACCESS_KEY=***REMOVED***
|
||||||
AWS_SECRET_KEY=***REMOVED***
|
AWS_SECRET_KEY=***REMOVED***
|
||||||
BUCKET_NAME='reflector-bucket'
|
BUCKET_NAME='reflector-bucket'
|
||||||
|
# Summarizer config
|
||||||
|
SUMMARY_MODEL=facebook/bart-large-cnn
|
||||||
|
INPUT_ENCODING_MAX_LENGTH=1024
|
||||||
|
MAX_LENGTH=2048
|
||||||
|
BEAM_SIZE=6
|
||||||
|
MAX_CHUNK_LENGTH=1024
|
||||||
|
SUMMARIZE_USING_CHUNKS=YES
|
||||||
1
transcript_timestamps(2).txt
Normal file
1
transcript_timestamps(2).txt
Normal file
File diff suppressed because one or more lines are too long
205
whisjax.py
205
whisjax.py
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
|
import torch
|
||||||
import collections
|
import collections
|
||||||
import configparser
|
import configparser
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
@@ -27,10 +28,15 @@ from transformers import BartTokenizer, BartForConditionalGeneration
|
|||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from whisper_jax import FlaxWhisperPipline
|
from whisper_jax import FlaxWhisperPipline
|
||||||
from wordcloud import WordCloud, STOPWORDS
|
from wordcloud import WordCloud, STOPWORDS
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from nltk.tokenize import word_tokenize
|
||||||
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
|
||||||
from file_util import upload_files, download_files
|
from file_util import upload_files, download_files
|
||||||
|
|
||||||
nltk.download('punkt')
|
nltk.download('punkt')
|
||||||
|
nltk.download('stopwords')
|
||||||
|
|
||||||
# Configurations can be found in config.ini. Set them properly before executing
|
# Configurations can be found in config.ini. Set them properly before executing
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -52,16 +58,13 @@ def init_argparse() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str,
|
parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str,
|
||||||
default="english", choices=['english', 'spanish', 'french', 'german', 'romanian'])
|
default="english", choices=['english', 'spanish', 'french', 'german', 'romanian'])
|
||||||
parser.add_argument("-t", "--transcript", help="Save a copy of the intermediary transcript file", type=str)
|
parser.add_argument("-t", "--transcript", help="Save a copy of the intermediary transcript file", type=str)
|
||||||
parser.add_argument(
|
|
||||||
"-m", "--model_name", help="Name or path of the BART model",
|
|
||||||
type=str, default="facebook/bart-base")
|
|
||||||
parser.add_argument("location")
|
parser.add_argument("location")
|
||||||
parser.add_argument("output")
|
parser.add_argument("output")
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(txt, max_chunk_length=500):
|
def chunk_text(txt, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
||||||
"""
|
"""
|
||||||
Split text into smaller chunks.
|
Split text into smaller chunks.
|
||||||
:param txt: Text to be chunked
|
:param txt: Text to be chunked
|
||||||
@@ -89,13 +92,17 @@ def summarize_chunks(chunks, tokenizer, model):
|
|||||||
:param model:
|
:param model:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
summaries = []
|
summaries = []
|
||||||
for c in chunks:
|
for c in chunks:
|
||||||
input_ids = tokenizer.encode(c, return_tensors='pt')
|
input_ids = tokenizer.encode(c, return_tensors='pt')
|
||||||
summary_ids = model.generate(
|
input_ids = input_ids.to(device)
|
||||||
input_ids, num_beams=4, length_penalty=2.0, max_length=1024, no_repeat_ngram_size=3)
|
with torch.no_grad():
|
||||||
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
summary_ids = model.generate(input_ids,
|
||||||
summaries.append(summary)
|
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||||
|
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||||
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
||||||
|
summaries.append(summary)
|
||||||
return summaries
|
return summaries
|
||||||
|
|
||||||
|
|
||||||
@@ -223,33 +230,137 @@ def create_talk_diff_scatter_viz():
|
|||||||
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
|
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
|
||||||
|
|
||||||
# pick the 2 most matched topic to be used for plotting
|
# pick the 2 most matched topic to be used for plotting
|
||||||
# topic_times = collections.defaultdict(int)
|
topic_times = collections.defaultdict(int)
|
||||||
# for key in ts_to_topic_mapping_top_1.keys():
|
for key in ts_to_topic_mapping_top_1.keys():
|
||||||
# duration = key[1] - key[0]
|
if key[0] is None or key[1] is None:
|
||||||
# topic_times[ts_to_topic_mapping_top_1[key]] += duration
|
continue
|
||||||
#
|
duration = key[1] - key[0]
|
||||||
# topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
|
topic_times[ts_to_topic_mapping_top_1[key]] += duration
|
||||||
#
|
|
||||||
# cat_1 = topic_times[0][0]
|
|
||||||
# cat_1_name = topic_times[0][0]
|
|
||||||
# cat_2_name = topic_times[1][0]
|
|
||||||
#
|
|
||||||
# # Scatter plot of topics
|
|
||||||
# df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
|
|
||||||
# corpus = st.CorpusFromParsedDocuments(
|
|
||||||
# df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
|
|
||||||
# ).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
|
|
||||||
# html = st.produce_scattertext_explorer(
|
|
||||||
# corpus,
|
|
||||||
# category=cat_1,
|
|
||||||
# category_name=cat_1_name,
|
|
||||||
# not_category_name=cat_2_name,
|
|
||||||
# minimum_term_frequency=0, pmi_threshold_coefficient=0,
|
|
||||||
# width_in_pixels=1000,
|
|
||||||
# transform=st.Scalers.dense_rank
|
|
||||||
# )
|
|
||||||
# open('./demo_compact.html', 'w').write(html)
|
|
||||||
|
|
||||||
|
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
cat_1 = topic_times[0][0]
|
||||||
|
cat_1_name = topic_times[0][0]
|
||||||
|
cat_2_name = topic_times[1][0]
|
||||||
|
|
||||||
|
# Scatter plot of topics
|
||||||
|
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
|
||||||
|
corpus = st.CorpusFromParsedDocuments(
|
||||||
|
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
|
||||||
|
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
|
||||||
|
html = st.produce_scattertext_explorer(
|
||||||
|
corpus,
|
||||||
|
category=cat_1,
|
||||||
|
category_name=cat_1_name,
|
||||||
|
not_category_name=cat_2_name,
|
||||||
|
minimum_term_frequency=0, pmi_threshold_coefficient=0,
|
||||||
|
width_in_pixels=1000,
|
||||||
|
transform=st.Scalers.dense_rank
|
||||||
|
)
|
||||||
|
open('./demo_compact.html', 'w').write(html)
|
||||||
|
|
||||||
|
def preprocess_sentence(sentence):
|
||||||
|
stop_words = set(stopwords.words('english'))
|
||||||
|
tokens = word_tokenize(sentence.lower())
|
||||||
|
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
||||||
|
return ' '.join(tokens)
|
||||||
|
|
||||||
|
def compute_similarity(sent1, sent2):
|
||||||
|
tfidf_vectorizer = TfidfVectorizer()
|
||||||
|
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
|
||||||
|
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
||||||
|
|
||||||
|
def remove_almost_alike_sentences(sentences, threshold=0.7):
|
||||||
|
num_sentences = len(sentences)
|
||||||
|
removed_indices = set()
|
||||||
|
|
||||||
|
for i in range(num_sentences):
|
||||||
|
if i not in removed_indices:
|
||||||
|
for j in range(i + 1, num_sentences):
|
||||||
|
if j not in removed_indices:
|
||||||
|
sentence1 = preprocess_sentence(sentences[i])
|
||||||
|
sentence2 = preprocess_sentence(sentences[j])
|
||||||
|
similarity = compute_similarity(sentence1, sentence2)
|
||||||
|
|
||||||
|
if similarity >= threshold:
|
||||||
|
removed_indices.add(max(i, j))
|
||||||
|
|
||||||
|
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
|
||||||
|
return filtered_sentences
|
||||||
|
|
||||||
|
def remove_outright_duplicate_sentences_from_chunk(chunk):
|
||||||
|
chunk_text = chunk["text"]
|
||||||
|
sentences = nltk.sent_tokenize(chunk_text)
|
||||||
|
nonduplicate_sentences = list(dict.fromkeys(sentences))
|
||||||
|
return nonduplicate_sentences
|
||||||
|
|
||||||
|
def remove_whisper_repititive_hallucination(nonduplicate_sentences):
|
||||||
|
chunk_sentences = []
|
||||||
|
|
||||||
|
for sent in nonduplicate_sentences:
|
||||||
|
temp_result = ""
|
||||||
|
seen = {}
|
||||||
|
words = nltk.word_tokenize(sent)
|
||||||
|
n_gram_filter = 3
|
||||||
|
for i in range(len(words)):
|
||||||
|
if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[
|
||||||
|
i + 1:i + n_gram_filter + 2]:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2]
|
||||||
|
temp_result += words[i]
|
||||||
|
temp_result += " "
|
||||||
|
chunk_sentences.append(temp_result)
|
||||||
|
return chunk_sentences
|
||||||
|
|
||||||
|
def remove_duplicates_from_transcript_chunk(whisper_result):
|
||||||
|
for chunk in whisper_result["chunks"]:
|
||||||
|
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
|
||||||
|
chunk_sentences = remove_whisper_repititive_hallucination(nonduplicate_sentences)
|
||||||
|
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
|
||||||
|
chunk["text"] = " ".join(similarity_matched_sentences)
|
||||||
|
return whisper_result
|
||||||
|
|
||||||
|
def summarize(transcript_text, output_file,
|
||||||
|
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
|
||||||
|
if not summary_model:
|
||||||
|
summary_model = "facebook/bart-large-cnn"
|
||||||
|
|
||||||
|
# Summarize the generated transcript using the BART model
|
||||||
|
logger.info(f"Loading BART model: {summary_model}")
|
||||||
|
tokenizer = BartTokenizer.from_pretrained(summary_model)
|
||||||
|
model = BartForConditionalGeneration.from_pretrained(summary_model)
|
||||||
|
model = model.to(device)
|
||||||
|
|
||||||
|
if summarize_using_chunks != "YES":
|
||||||
|
inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest',
|
||||||
|
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
|
||||||
|
return_tensors='pt')
|
||||||
|
inputs = inputs.to(device)
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
summaries = model.generate(inputs['input_ids'],
|
||||||
|
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||||
|
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||||
|
|
||||||
|
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for
|
||||||
|
summary in summaries]
|
||||||
|
summary = " ".join(decoded_summaries)
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
f.write(summary.strip() + "\n\n")
|
||||||
|
else:
|
||||||
|
logger.info("Breaking transcript into smaller chunks")
|
||||||
|
chunks = chunk_text(transcript_text)
|
||||||
|
|
||||||
|
logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
|
||||||
|
|
||||||
|
logger.info(f"Writing summary text to: {output_file}")
|
||||||
|
with open(output_file, 'w') as f:
|
||||||
|
summaries = summarize_chunks(chunks, tokenizer, model)
|
||||||
|
for summary in summaries:
|
||||||
|
f.write(summary.strip() + " ")
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = init_argparse()
|
parser = init_argparse()
|
||||||
@@ -314,13 +425,19 @@ def main():
|
|||||||
whisper_result = pipeline(audio_filename, return_timestamps=True)
|
whisper_result = pipeline(audio_filename, return_timestamps=True)
|
||||||
logger.info("Finished transcribing file")
|
logger.info("Finished transcribing file")
|
||||||
|
|
||||||
|
whisper_result = remove_duplicates_from_transcript_chunk(whisper_result)
|
||||||
|
|
||||||
|
transcript_text = ""
|
||||||
|
for chunk in whisper_result["chunks"]:
|
||||||
|
transcript_text += chunk["text"]
|
||||||
|
|
||||||
# If we got the transcript parameter on the command line,
|
# If we got the transcript parameter on the command line,
|
||||||
# save the transcript to the specified file.
|
# save the transcript to the specified file.
|
||||||
if args.transcript:
|
if args.transcript:
|
||||||
logger.info(f"Saving transcript to: {args.transcript}")
|
logger.info(f"Saving transcript to: {args.transcript}")
|
||||||
transcript_file = open(args.transcript, "w")
|
transcript_file = open(args.transcript, "w")
|
||||||
transcript_file_timestamps = open(args.transcript[0:len(args.transcript) - 4] + "_timestamps.txt", "w")
|
transcript_file_timestamps = open(args.transcript[0:len(args.transcript) - 4] + "_timestamps.txt", "w")
|
||||||
transcript_file.write(whisper_result["text"])
|
transcript_file.write(transcript_text)
|
||||||
transcript_file_timestamps.write(str(whisper_result))
|
transcript_file_timestamps.write(str(whisper_result))
|
||||||
transcript_file.close()
|
transcript_file.close()
|
||||||
transcript_file_timestamps.close()
|
transcript_file_timestamps.close()
|
||||||
@@ -337,23 +454,7 @@ def main():
|
|||||||
"wordcloud.png", "mappings.pkl"]
|
"wordcloud.png", "mappings.pkl"]
|
||||||
upload_files(files_to_upload)
|
upload_files(files_to_upload)
|
||||||
|
|
||||||
# Summarize the generated transcript using the BART model
|
summarize(transcript_text, args.output)
|
||||||
logger.info(f"Loading BART model: {args.model_name}")
|
|
||||||
tokenizer = BartTokenizer.from_pretrained(args.model_name)
|
|
||||||
model = BartForConditionalGeneration.from_pretrained(args.model_name)
|
|
||||||
|
|
||||||
logger.info("Breaking transcript into smaller chunks")
|
|
||||||
chunks = chunk_text(whisper_result['text'])
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
|
|
||||||
|
|
||||||
logger.info(f"Writing summary text in {args.language} to: {args.output}")
|
|
||||||
with open(args.output, 'w') as f:
|
|
||||||
f.write('Summary of: ' + args.location + "\n\n")
|
|
||||||
summaries = summarize_chunks(chunks, tokenizer, model)
|
|
||||||
for summary in summaries:
|
|
||||||
f.write(summary.strip() + "\n\n")
|
|
||||||
|
|
||||||
logger.info("Summarization completed")
|
logger.info("Summarization completed")
|
||||||
|
|
||||||
|
|||||||
@@ -11,13 +11,12 @@ config = configparser.ConfigParser()
|
|||||||
config.read('config.ini')
|
config.read('config.ini')
|
||||||
|
|
||||||
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
||||||
OPENAI_APIKEY = config['DEFAULT']["OPENAI_APIKEY"]
|
|
||||||
|
|
||||||
FRAMES_PER_BUFFER = 8000
|
FRAMES_PER_BUFFER = 8000
|
||||||
FORMAT = pyaudio.paInt16
|
FORMAT = pyaudio.paInt16
|
||||||
CHANNELS = 1
|
CHANNELS = 1
|
||||||
RATE = 44100
|
RATE = 44100
|
||||||
RECORD_SECONDS = 10
|
RECORD_SECONDS = 5
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -49,7 +48,7 @@ def main():
|
|||||||
|
|
||||||
listener = keyboard.Listener(on_press=on_press)
|
listener = keyboard.Listener(on_press=on_press)
|
||||||
listener.start()
|
listener.start()
|
||||||
|
print("Listening...")
|
||||||
|
|
||||||
while proceed:
|
while proceed:
|
||||||
try:
|
try:
|
||||||
@@ -57,7 +56,6 @@ def main():
|
|||||||
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
|
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
|
||||||
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
|
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
|
||||||
frames.append(data)
|
frames.append(data)
|
||||||
print("Collected Input", len(frames))
|
|
||||||
|
|
||||||
wf = wave.open(TEMP_AUDIO_FILE, 'wb')
|
wf = wave.open(TEMP_AUDIO_FILE, 'wb')
|
||||||
wf.setnchannels(CHANNELS)
|
wf.setnchannels(CHANNELS)
|
||||||
@@ -70,14 +68,12 @@ def main():
|
|||||||
print(whisper_result['text'])
|
print(whisper_result['text'])
|
||||||
|
|
||||||
transcription += whisper_result['text']
|
transcription += whisper_result['text']
|
||||||
if len(transcription) > 10:
|
|
||||||
transcription += "\n"
|
|
||||||
transcript_file.write(transcription)
|
|
||||||
transcription = ""
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
break
|
finally:
|
||||||
|
with open("real_time_transcription.txt", "w") as f:
|
||||||
|
transcript_file.write(transcription)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user