Merge pull request #24 from Monadical-SAS/whisper-jax-gokul

Minor refactor
This commit is contained in:
projects-g
2023-07-10 22:49:23 +05:30
committed by GitHub
8 changed files with 82 additions and 85 deletions

View File

@@ -1,23 +1,26 @@
import asyncio import asyncio
import configparser
import datetime import datetime
import io import io
import json import json
import logging import logging
import os import os
import sys
import threading import threading
import uuid import uuid
import wave import wave
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from sortedcontainers import SortedDict
import configparser
import jax.numpy as jnp import jax.numpy as jnp
from aiohttp import web from aiohttp import webq
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import (MediaRelay) from aiortc.contrib.media import (MediaRelay)
from av import AudioFifo from av import AudioFifo
from sortedcontainers import SortedDict
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from utils.server_utils import Mutex
ROOT = os.path.dirname(__file__) ROOT = os.path.dirname(__file__)
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -46,14 +49,14 @@ total_bytes_handled = 0
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
frame_lock = threading.Lock() frame_lock = Mutex(audio_buffer)
total_bytes_handled_lock = threading.Lock()
def channel_log(channel, t, message): def channel_log(channel, t, message):
print("channel(%s) %s %s" % (channel.label, t, message)) print("channel(%s) %s %s" % (channel.label, t, message))
def thread_queue_channel_send(): def thread_queue_channel_send():
print("M-thread created")
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
@@ -62,25 +65,15 @@ def thread_queue_channel_send():
if message: if message:
del sorted_message_queue[least_time] del sorted_message_queue[least_time]
data_channel.send(message) data_channel.send(message)
print("M-thread sent message to client")
with total_bytes_handled_lock:
print("Bytes handled :", total_bytes_handled, " Time : ", datetime.datetime.now() - start_time)
except Exception as e: except Exception as e:
print("Exception", str(e)) print("Exception", str(e))
pass pass
loop.run_forever() loop.run_forever()
# async def channel_send(channel, message):
# channel_log(channel, ">", message)
# if channel and message:
# channel.send(message)
def get_transcription(local_thread_id): def get_transcription():
# Block 1
print("T-thread -> ", str(local_thread_id) , "created")
global frame_lock
while True: while True:
with frame_lock: with frame_lock.lock() as audio_buffer:
frames = audio_buffer.read_many(CHUNK_SIZE * 960, partial=False) frames = audio_buffer.read_many(CHUNK_SIZE * 960, partial=False)
if not frames: if not frames:
transcribe = False transcribe = False
@@ -89,7 +82,6 @@ def get_transcription(local_thread_id):
if transcribe: if transcribe:
try: try:
print("T-thread ", str(local_thread_id), "is transcribing")
sorted_message_queue[frames[0].time] = None sorted_message_queue[frames[0].time] = None
out_file = io.BytesIO() out_file = io.BytesIO()
wf = wave.open(out_file, "wb") wf = wave.open(out_file, "wb")
@@ -102,10 +94,6 @@ def get_transcription(local_thread_id):
wf.close() wf.close()
whisper_result = pipeline(out_file.getvalue()) whisper_result = pipeline(out_file.getvalue())
global total_bytes_handled
with total_bytes_handled_lock:
total_bytes_handled += sys.getsizeof(wf)
item = {'text': whisper_result["text"], item = {'text': whisper_result["text"],
'start_time': str(frames[0].time), 'start_time': str(frames[0].time),
'time': str(datetime.datetime.now()) 'time': str(datetime.datetime.now())
@@ -115,6 +103,7 @@ def get_transcription(local_thread_id):
except Exception as e: except Exception as e:
print("Exception -> ", str(e)) print("Exception -> ", str(e))
class AudioStreamTrack(MediaStreamTrack): class AudioStreamTrack(MediaStreamTrack):
""" """
A video stream track that transforms frames from an another track. A video stream track that transforms frames from an another track.
@@ -127,7 +116,6 @@ class AudioStreamTrack(MediaStreamTrack):
self.track = track self.track = track
async def recv(self): async def recv(self):
# print("Awaiting track in server")
frame = await self.track.recv() frame = await self.track.recv()
audio_buffer.write(frame) audio_buffer.write(frame)
return frame return frame
@@ -136,7 +124,7 @@ class AudioStreamTrack(MediaStreamTrack):
def start_messaging_thread(): def start_messaging_thread():
message_thread = threading.Thread(target=thread_queue_channel_send) message_thread = threading.Thread(target=thread_queue_channel_send)
message_thread.start() message_thread.start()
# message_thread.join()
def start_transcription_thread(max_threads): def start_transcription_thread(max_threads):
t_threads = [] t_threads = []
@@ -145,12 +133,9 @@ def start_transcription_thread(max_threads):
t_threads.append(t_thread) t_threads.append(t_thread)
t_thread.start() t_thread.start()
# for t_thread in t_threads:
# t_thread.join()
async def offer(request): async def offer(request):
params = await request.json() params = await request.json()
print("Request received")
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection() pc = RTCPeerConnection()
@@ -185,11 +170,8 @@ async def offer(request):
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track):
print("Track %s received", track.kind)
log_info("Track %s received", track.kind) log_info("Track %s received", track.kind)
# Trials to listen to the correct track
pc.addTrack(AudioStreamTrack(relay.subscribe(track))) pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# pc.addTrack(AudioStreamTrack(track))
# handle offer # handle offer
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)
@@ -197,7 +179,6 @@ async def offer(request):
# send answer # send answer
answer = await pc.createAnswer() answer = await pc.createAnswer()
await pc.setLocalDescription(answer) await pc.setLocalDescription(answer)
print("Response sent")
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
text=json.dumps( text=json.dumps(
@@ -207,7 +188,6 @@ async def offer(request):
async def on_shutdown(app): async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs] coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros) await asyncio.gather(*coros)
pcs.clear() pcs.clear()
@@ -221,5 +201,3 @@ if __name__ == "__main__":
web.run_app( web.run_app(
app, access_log=None, host="127.0.0.1", port=1250 app, access_log=None, host="127.0.0.1", port=1250
) )

View File

@@ -1,16 +1,17 @@
import ast
import asyncio import asyncio
import configparser
import logging import logging
import time import time
import uuid import uuid
import threading
import configparser
import httpx import httpx
import pyaudio import pyaudio
import requests import requests
import ast
import stamina import stamina
from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc import (RTCPeerConnection, RTCSessionDescription)
from aiortc.contrib.media import (MediaPlayer, MediaRelay) from aiortc.contrib.media import (MediaPlayer, MediaRelay)
from utils.server_utils import Mutex from utils.server_utils import Mutex
logger = logging.getLogger("pc") logger = logging.getLogger("pc")
@@ -19,13 +20,14 @@ file_lock = Mutex(open("test_sm_6.txt", "a"))
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read('config.ini') config.read('config.ini')
class StreamClient: class StreamClient:
def __init__( def __init__(
self, self,
signaling, signaling,
url="http://127.0.0.1:1250", url="http://127.0.0.1:1250",
play_from=None, play_from=None,
ping_pong=False ping_pong=False
): ):
self.signaling = signaling self.signaling = signaling
self.server_url = url self.server_url = url
@@ -103,7 +105,6 @@ class StreamClient:
channel = pc.createDataChannel("data-channel") channel = pc.createDataChannel("data-channel")
self.channel_log(channel, "-", "created by local party") self.channel_log(channel, "-", "created by local party")
async def send_pings(): async def send_pings():
while True: while True:
self.channel_send(channel, "ping %d" % self.current_stamp()) self.channel_send(channel, "ping %d" % self.current_stamp())

View File

@@ -1,6 +1,7 @@
import configparser
import boto3 import boto3
import botocore import botocore
import configparser
from loguru import logger from loguru import logger
config = configparser.ConfigParser() config = configparser.ConfigParser()
@@ -12,6 +13,7 @@ s3 = boto3.client('s3',
aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"], aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"],
aws_secret_access_key=config["DEFAULT"]["AWS_SECRET_KEY"]) aws_secret_access_key=config["DEFAULT"]["AWS_SECRET_KEY"])
def upload_files(files_to_upload): def upload_files(files_to_upload):
""" """
Upload a list of files to the configured S3 bucket Upload a list of files to the configured S3 bucket
@@ -45,6 +47,7 @@ def download_files(files_to_download):
if __name__ == "__main__": if __name__ == "__main__":
import sys import sys
if sys.argv[1] == "download": if sys.argv[1] == "download":
download_files([sys.argv[2]]) download_files([sys.argv[2]])
elif sys.argv[1] == "upload": elif sys.argv[1] == "upload":

View File

@@ -1,9 +1,10 @@
import asyncio import asyncio
from functools import partial
import contextlib import contextlib
from functools import partial
from threading import Lock from threading import Lock
from typing import ContextManager, Generic, TypeVar from typing import ContextManager, Generic, TypeVar
def run_in_executor(func, *args, executor=None, **kwargs): def run_in_executor(func, *args, executor=None, **kwargs):
callback = partial(func, *args, **kwargs) callback = partial(func, *args, **kwargs)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
@@ -11,6 +12,8 @@ def run_in_executor(func, *args, executor=None, **kwargs):
T = TypeVar("T") T = TypeVar("T")
class Mutex(Generic[T]): class Mutex(Generic[T]):
def __init__(self, value: T): def __init__(self, value: T):
self.__value = value self.__value = value
@@ -22,4 +25,4 @@ class Mutex(Generic[T]):
try: try:
yield self.__value yield self.__value
finally: finally:
self.__lock.release() self.__lock.release()

View File

@@ -1,23 +1,27 @@
import torch
import configparser import configparser
import nltk import nltk
from transformers import BartTokenizer, BartForConditionalGeneration import torch
from loguru import logger from loguru import logger
from nltk.corpus import stopwords from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.tokenize import word_tokenize from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
from transformers import BartTokenizer, BartForConditionalGeneration
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read('config.ini') config.read('config.ini')
def preprocess_sentence(sentence): def preprocess_sentence(sentence):
stop_words = set(stopwords.words('english')) stop_words = set(stopwords.words('english'))
tokens = word_tokenize(sentence.lower()) tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens if token.isalnum() and token not in stop_words] tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
return ' '.join(tokens) return ' '.join(tokens)
def compute_similarity(sent1, sent2): def compute_similarity(sent1, sent2):
""" """
Compute the similarity Compute the similarity
@@ -28,6 +32,7 @@ def compute_similarity(sent1, sent2):
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0] return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
return 0.0 return 0.0
def remove_almost_alike_sentences(sentences, threshold=0.7): def remove_almost_alike_sentences(sentences, threshold=0.7):
num_sentences = len(sentences) num_sentences = len(sentences)
removed_indices = set() removed_indices = set()
@@ -55,12 +60,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices] filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
return filtered_sentences return filtered_sentences
def remove_outright_duplicate_sentences_from_chunk(chunk): def remove_outright_duplicate_sentences_from_chunk(chunk):
chunk_text = chunk["text"] chunk_text = chunk["text"]
sentences = nltk.sent_tokenize(chunk_text) sentences = nltk.sent_tokenize(chunk_text)
nonduplicate_sentences = list(dict.fromkeys(sentences)) nonduplicate_sentences = list(dict.fromkeys(sentences))
return nonduplicate_sentences return nonduplicate_sentences
def remove_whisper_repetitive_hallucination(nonduplicate_sentences): def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
chunk_sentences = [] chunk_sentences = []
@@ -80,6 +87,7 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
chunk_sentences.append(temp_result) chunk_sentences.append(temp_result)
return chunk_sentences return chunk_sentences
def post_process_transcription(whisper_result): def post_process_transcription(whisper_result):
transcript_text = "" transcript_text = ""
for chunk in whisper_result["chunks"]: for chunk in whisper_result["chunks"]:
@@ -107,12 +115,13 @@ def summarize_chunks(chunks, tokenizer, model):
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
with torch.no_grad(): with torch.no_grad():
summary_ids = model.generate(input_ids, summary_ids = model.generate(input_ids,
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(summary) summaries.append(summary)
return summaries return summaries
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
""" """
Split text into smaller chunks. Split text into smaller chunks.
@@ -132,6 +141,7 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])
chunks.append(current_chunk.strip()) chunks.append(current_chunk.strip())
return chunks return chunks
def summarize(transcript_text, timestamp, def summarize(transcript_text, timestamp,
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

View File

@@ -1,15 +1,16 @@
import matplotlib.pyplot as plt
from wordcloud import WordCloud, STOPWORDS
from nltk.corpus import stopwords
import collections
import spacy
import os
from pathlib import Path
import pickle
import ast import ast
import collections
import configparser
import os
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import scattertext as st import scattertext as st
import configparser import spacy
from nltk.corpus import stopwords
from wordcloud import WordCloud, STOPWORDS
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read('config.ini') config.read('config.ini')
@@ -29,7 +30,7 @@ def create_wordcloud(timestamp, real_time=False):
if real_time: if real_time:
filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
else: else:
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
with open("./artefacts/" + filename, "r") as f: with open("./artefacts/" + filename, "r") as f:
transcription_text = f.read() transcription_text = f.read()
@@ -202,4 +203,4 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
if real_time: if real_time:
open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
else: else:
open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)

View File

@@ -6,25 +6,24 @@
import argparse import argparse
import configparser import configparser
import jax.numpy as jnp import os
import re
import subprocess
import tempfile
from datetime import datetime
from urllib.parse import urlparse
import jax.numpy as jnp
import moviepy.editor import moviepy.editor
import moviepy.editor import moviepy.editor
import nltk import nltk
import os
import subprocess
import re
import tempfile
from loguru import logger
import yt_dlp as youtube_dl import yt_dlp as youtube_dl
from loguru import logger
from urllib.parse import urlparse
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from datetime import datetime
from utils.file_utilities import upload_files, download_files from utils.file_utilities import upload_files, download_files
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
from utils.text_utilities import summarize, post_process_transcription from utils.text_utilities import summarize, post_process_transcription
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True) nltk.download('stopwords', quiet=True)
@@ -36,6 +35,7 @@ config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
NOW = datetime.now() NOW = datetime.now()
def init_argparse() -> argparse.ArgumentParser: def init_argparse() -> argparse.ArgumentParser:
""" """
Parse the CLI arguments Parse the CLI arguments
@@ -52,7 +52,6 @@ def init_argparse() -> argparse.ArgumentParser:
return parser return parser
def main(): def main():
parser = init_argparse() parser = init_argparse()
args = parser.parse_args() args = parser.parse_args()
@@ -140,10 +139,10 @@ def main():
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file: with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
transcript_file.write(transcript_text) transcript_file.write(transcript_text)
with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps: with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt",
"w") as transcript_file_timestamps:
transcript_file_timestamps.write(str(whisper_result)) transcript_file_timestamps.write(str(whisper_result))
logger.info("Creating word cloud") logger.info("Creating word cloud")
create_wordcloud(NOW) create_wordcloud(NOW)

View File

@@ -1,18 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import configparser import configparser
import pyaudio import time
from whisper_jax import FlaxWhisperPipline
from pynput import keyboard
import jax.numpy as jnp
import wave import wave
from datetime import datetime from datetime import datetime
from utils.file_utilities import upload_files
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz import jax.numpy as jnp
from utils.text_utilities import summarize, post_process_transcription import pyaudio
from loguru import logger from loguru import logger
import time from pynput import keyboard
from termcolor import colored from termcolor import colored
from whisper_jax import FlaxWhisperPipline
from utils.file_utilities import upload_files
from utils.text_utilities import summarize, post_process_transcription
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
config = configparser.ConfigParser() config = configparser.ConfigParser()
config.read('config.ini') config.read('config.ini')