Merge pull request #24 from Monadical-SAS/whisper-jax-gokul

Minor refactor
This commit is contained in:
projects-g
2023-07-10 22:49:23 +05:30
committed by GitHub
8 changed files with 82 additions and 85 deletions

View File

@@ -1,23 +1,26 @@
import asyncio
import configparser
import datetime
import io
import json
import logging
import os
import sys
import threading
import uuid
import wave
from concurrent.futures import ThreadPoolExecutor
from sortedcontainers import SortedDict
import configparser
import jax.numpy as jnp
from aiohttp import web
from aiohttp import webq
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import (MediaRelay)
from av import AudioFifo
from sortedcontainers import SortedDict
from whisper_jax import FlaxWhisperPipline
from utils.server_utils import Mutex
ROOT = os.path.dirname(__file__)
config = configparser.ConfigParser()
@@ -46,14 +49,14 @@ total_bytes_handled = 0
executor = ThreadPoolExecutor()
frame_lock = threading.Lock()
total_bytes_handled_lock = threading.Lock()
frame_lock = Mutex(audio_buffer)
def channel_log(channel, t, message):
print("channel(%s) %s %s" % (channel.label, t, message))
def thread_queue_channel_send():
print("M-thread created")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
@@ -62,25 +65,15 @@ def thread_queue_channel_send():
if message:
del sorted_message_queue[least_time]
data_channel.send(message)
print("M-thread sent message to client")
with total_bytes_handled_lock:
print("Bytes handled :", total_bytes_handled, " Time : ", datetime.datetime.now() - start_time)
except Exception as e:
print("Exception", str(e))
pass
loop.run_forever()
# async def channel_send(channel, message):
# channel_log(channel, ">", message)
# if channel and message:
# channel.send(message)
def get_transcription(local_thread_id):
# Block 1
print("T-thread -> ", str(local_thread_id) , "created")
global frame_lock
def get_transcription():
while True:
with frame_lock:
with frame_lock.lock() as audio_buffer:
frames = audio_buffer.read_many(CHUNK_SIZE * 960, partial=False)
if not frames:
transcribe = False
@@ -89,7 +82,6 @@ def get_transcription(local_thread_id):
if transcribe:
try:
print("T-thread ", str(local_thread_id), "is transcribing")
sorted_message_queue[frames[0].time] = None
out_file = io.BytesIO()
wf = wave.open(out_file, "wb")
@@ -102,10 +94,6 @@ def get_transcription(local_thread_id):
wf.close()
whisper_result = pipeline(out_file.getvalue())
global total_bytes_handled
with total_bytes_handled_lock:
total_bytes_handled += sys.getsizeof(wf)
item = {'text': whisper_result["text"],
'start_time': str(frames[0].time),
'time': str(datetime.datetime.now())
@@ -115,6 +103,7 @@ def get_transcription(local_thread_id):
except Exception as e:
print("Exception -> ", str(e))
class AudioStreamTrack(MediaStreamTrack):
"""
A video stream track that transforms frames from an another track.
@@ -127,7 +116,6 @@ class AudioStreamTrack(MediaStreamTrack):
self.track = track
async def recv(self):
# print("Awaiting track in server")
frame = await self.track.recv()
audio_buffer.write(frame)
return frame
@@ -136,7 +124,7 @@ class AudioStreamTrack(MediaStreamTrack):
def start_messaging_thread():
message_thread = threading.Thread(target=thread_queue_channel_send)
message_thread.start()
# message_thread.join()
def start_transcription_thread(max_threads):
t_threads = []
@@ -145,12 +133,9 @@ def start_transcription_thread(max_threads):
t_threads.append(t_thread)
t_thread.start()
# for t_thread in t_threads:
# t_thread.join()
async def offer(request):
params = await request.json()
print("Request received")
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection()
@@ -185,11 +170,8 @@ async def offer(request):
@pc.on("track")
def on_track(track):
print("Track %s received", track.kind)
log_info("Track %s received", track.kind)
# Trials to listen to the correct track
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# pc.addTrack(AudioStreamTrack(track))
# handle offer
await pc.setRemoteDescription(offer)
@@ -197,7 +179,6 @@ async def offer(request):
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
print("Response sent")
return web.Response(
content_type="application/json",
text=json.dumps(
@@ -207,7 +188,6 @@ async def offer(request):
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
@@ -221,5 +201,3 @@ if __name__ == "__main__":
web.run_app(
app, access_log=None, host="127.0.0.1", port=1250
)

View File

@@ -1,16 +1,17 @@
import ast
import asyncio
import configparser
import logging
import time
import uuid
import threading
import configparser
import httpx
import pyaudio
import requests
import ast
import stamina
from aiortc import (RTCPeerConnection, RTCSessionDescription)
from aiortc.contrib.media import (MediaPlayer, MediaRelay)
from utils.server_utils import Mutex
logger = logging.getLogger("pc")
@@ -19,6 +20,7 @@ file_lock = Mutex(open("test_sm_6.txt", "a"))
config = configparser.ConfigParser()
config.read('config.ini')
class StreamClient:
def __init__(
self,
@@ -103,7 +105,6 @@ class StreamClient:
channel = pc.createDataChannel("data-channel")
self.channel_log(channel, "-", "created by local party")
async def send_pings():
while True:
self.channel_send(channel, "ping %d" % self.current_stamp())

View File

@@ -1,6 +1,7 @@
import configparser
import boto3
import botocore
import configparser
from loguru import logger
config = configparser.ConfigParser()
@@ -12,6 +13,7 @@ s3 = boto3.client('s3',
aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"],
aws_secret_access_key=config["DEFAULT"]["AWS_SECRET_KEY"])
def upload_files(files_to_upload):
"""
Upload a list of files to the configured S3 bucket
@@ -45,6 +47,7 @@ def download_files(files_to_download):
if __name__ == "__main__":
import sys
if sys.argv[1] == "download":
download_files([sys.argv[2]])
elif sys.argv[1] == "upload":

View File

@@ -1,9 +1,10 @@
import asyncio
from functools import partial
import contextlib
from functools import partial
from threading import Lock
from typing import ContextManager, Generic, TypeVar
def run_in_executor(func, *args, executor=None, **kwargs):
callback = partial(func, *args, **kwargs)
loop = asyncio.get_event_loop()
@@ -11,6 +12,8 @@ def run_in_executor(func, *args, executor=None, **kwargs):
T = TypeVar("T")
class Mutex(Generic[T]):
def __init__(self, value: T):
self.__value = value

View File

@@ -1,23 +1,27 @@
import torch
import configparser
import nltk
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
from loguru import logger
from nltk.corpus import stopwords
from sklearn.feature_extraction.text import TfidfVectorizer
from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import BartTokenizer, BartForConditionalGeneration
nltk.download('punkt', quiet=True)
config = configparser.ConfigParser()
config.read('config.ini')
def preprocess_sentence(sentence):
stop_words = set(stopwords.words('english'))
tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
return ' '.join(tokens)
def compute_similarity(sent1, sent2):
"""
Compute the similarity
@@ -28,6 +32,7 @@ def compute_similarity(sent1, sent2):
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
return 0.0
def remove_almost_alike_sentences(sentences, threshold=0.7):
num_sentences = len(sentences)
removed_indices = set()
@@ -55,12 +60,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
return filtered_sentences
def remove_outright_duplicate_sentences_from_chunk(chunk):
chunk_text = chunk["text"]
sentences = nltk.sent_tokenize(chunk_text)
nonduplicate_sentences = list(dict.fromkeys(sentences))
return nonduplicate_sentences
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
chunk_sentences = []
@@ -80,6 +87,7 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
chunk_sentences.append(temp_result)
return chunk_sentences
def post_process_transcription(whisper_result):
transcript_text = ""
for chunk in whisper_result["chunks"]:
@@ -113,6 +121,7 @@ def summarize_chunks(chunks, tokenizer, model):
summaries.append(summary)
return summaries
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
"""
Split text into smaller chunks.
@@ -132,6 +141,7 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])
chunks.append(current_chunk.strip())
return chunks
def summarize(transcript_text, timestamp,
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

View File

@@ -1,15 +1,16 @@
import matplotlib.pyplot as plt
from wordcloud import WordCloud, STOPWORDS
from nltk.corpus import stopwords
import collections
import spacy
import os
from pathlib import Path
import pickle
import ast
import collections
import configparser
import os
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import scattertext as st
import configparser
import spacy
from nltk.corpus import stopwords
from wordcloud import WordCloud, STOPWORDS
config = configparser.ConfigParser()
config.read('config.ini')

View File

@@ -6,25 +6,24 @@
import argparse
import configparser
import jax.numpy as jnp
import os
import re
import subprocess
import tempfile
from datetime import datetime
from urllib.parse import urlparse
import jax.numpy as jnp
import moviepy.editor
import moviepy.editor
import nltk
import os
import subprocess
import re
import tempfile
from loguru import logger
import yt_dlp as youtube_dl
from urllib.parse import urlparse
from loguru import logger
from whisper_jax import FlaxWhisperPipline
from datetime import datetime
from utils.file_utilities import upload_files, download_files
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
from utils.text_utilities import summarize, post_process_transcription
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
@@ -36,6 +35,7 @@ config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
NOW = datetime.now()
def init_argparse() -> argparse.ArgumentParser:
"""
Parse the CLI arguments
@@ -52,7 +52,6 @@ def init_argparse() -> argparse.ArgumentParser:
return parser
def main():
parser = init_argparse()
args = parser.parse_args()
@@ -140,10 +139,10 @@ def main():
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
transcript_file.write(transcript_text)
with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps:
with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt",
"w") as transcript_file_timestamps:
transcript_file_timestamps.write(str(whisper_result))
logger.info("Creating word cloud")
create_wordcloud(NOW)

View File

@@ -1,18 +1,20 @@
#!/usr/bin/env python3
import configparser
import pyaudio
from whisper_jax import FlaxWhisperPipline
from pynput import keyboard
import jax.numpy as jnp
import time
import wave
from datetime import datetime
from utils.file_utilities import upload_files
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
from utils.text_utilities import summarize, post_process_transcription
import jax.numpy as jnp
import pyaudio
from loguru import logger
import time
from pynput import keyboard
from termcolor import colored
from whisper_jax import FlaxWhisperPipline
from utils.file_utilities import upload_files
from utils.text_utilities import summarize, post_process_transcription
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
config = configparser.ConfigParser()
config.read('config.ini')