mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Merge pull request #24 from Monadical-SAS/whisper-jax-gokul
Minor refactor
This commit is contained in:
@@ -1,23 +1,26 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
import configparser
|
||||||
import datetime
|
import datetime
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
import wave
|
import wave
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from sortedcontainers import SortedDict
|
|
||||||
import configparser
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from aiohttp import web
|
from aiohttp import webq
|
||||||
|
|
||||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||||
from aiortc.contrib.media import (MediaRelay)
|
from aiortc.contrib.media import (MediaRelay)
|
||||||
from av import AudioFifo
|
from av import AudioFifo
|
||||||
|
from sortedcontainers import SortedDict
|
||||||
from whisper_jax import FlaxWhisperPipline
|
from whisper_jax import FlaxWhisperPipline
|
||||||
|
|
||||||
|
from utils.server_utils import Mutex
|
||||||
|
|
||||||
ROOT = os.path.dirname(__file__)
|
ROOT = os.path.dirname(__file__)
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -46,14 +49,14 @@ total_bytes_handled = 0
|
|||||||
|
|
||||||
executor = ThreadPoolExecutor()
|
executor = ThreadPoolExecutor()
|
||||||
|
|
||||||
frame_lock = threading.Lock()
|
frame_lock = Mutex(audio_buffer)
|
||||||
total_bytes_handled_lock = threading.Lock()
|
|
||||||
|
|
||||||
def channel_log(channel, t, message):
|
def channel_log(channel, t, message):
|
||||||
print("channel(%s) %s %s" % (channel.label, t, message))
|
print("channel(%s) %s %s" % (channel.label, t, message))
|
||||||
|
|
||||||
|
|
||||||
def thread_queue_channel_send():
|
def thread_queue_channel_send():
|
||||||
print("M-thread created")
|
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
try:
|
try:
|
||||||
@@ -62,25 +65,15 @@ def thread_queue_channel_send():
|
|||||||
if message:
|
if message:
|
||||||
del sorted_message_queue[least_time]
|
del sorted_message_queue[least_time]
|
||||||
data_channel.send(message)
|
data_channel.send(message)
|
||||||
print("M-thread sent message to client")
|
|
||||||
with total_bytes_handled_lock:
|
|
||||||
print("Bytes handled :", total_bytes_handled, " Time : ", datetime.datetime.now() - start_time)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception", str(e))
|
print("Exception", str(e))
|
||||||
pass
|
pass
|
||||||
loop.run_forever()
|
loop.run_forever()
|
||||||
|
|
||||||
# async def channel_send(channel, message):
|
|
||||||
# channel_log(channel, ">", message)
|
|
||||||
# if channel and message:
|
|
||||||
# channel.send(message)
|
|
||||||
|
|
||||||
def get_transcription(local_thread_id):
|
def get_transcription():
|
||||||
# Block 1
|
|
||||||
print("T-thread -> ", str(local_thread_id) , "created")
|
|
||||||
global frame_lock
|
|
||||||
while True:
|
while True:
|
||||||
with frame_lock:
|
with frame_lock.lock() as audio_buffer:
|
||||||
frames = audio_buffer.read_many(CHUNK_SIZE * 960, partial=False)
|
frames = audio_buffer.read_many(CHUNK_SIZE * 960, partial=False)
|
||||||
if not frames:
|
if not frames:
|
||||||
transcribe = False
|
transcribe = False
|
||||||
@@ -89,7 +82,6 @@ def get_transcription(local_thread_id):
|
|||||||
|
|
||||||
if transcribe:
|
if transcribe:
|
||||||
try:
|
try:
|
||||||
print("T-thread ", str(local_thread_id), "is transcribing")
|
|
||||||
sorted_message_queue[frames[0].time] = None
|
sorted_message_queue[frames[0].time] = None
|
||||||
out_file = io.BytesIO()
|
out_file = io.BytesIO()
|
||||||
wf = wave.open(out_file, "wb")
|
wf = wave.open(out_file, "wb")
|
||||||
@@ -102,10 +94,6 @@ def get_transcription(local_thread_id):
|
|||||||
wf.close()
|
wf.close()
|
||||||
|
|
||||||
whisper_result = pipeline(out_file.getvalue())
|
whisper_result = pipeline(out_file.getvalue())
|
||||||
|
|
||||||
global total_bytes_handled
|
|
||||||
with total_bytes_handled_lock:
|
|
||||||
total_bytes_handled += sys.getsizeof(wf)
|
|
||||||
item = {'text': whisper_result["text"],
|
item = {'text': whisper_result["text"],
|
||||||
'start_time': str(frames[0].time),
|
'start_time': str(frames[0].time),
|
||||||
'time': str(datetime.datetime.now())
|
'time': str(datetime.datetime.now())
|
||||||
@@ -115,6 +103,7 @@ def get_transcription(local_thread_id):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Exception -> ", str(e))
|
print("Exception -> ", str(e))
|
||||||
|
|
||||||
|
|
||||||
class AudioStreamTrack(MediaStreamTrack):
|
class AudioStreamTrack(MediaStreamTrack):
|
||||||
"""
|
"""
|
||||||
A video stream track that transforms frames from an another track.
|
A video stream track that transforms frames from an another track.
|
||||||
@@ -127,7 +116,6 @@ class AudioStreamTrack(MediaStreamTrack):
|
|||||||
self.track = track
|
self.track = track
|
||||||
|
|
||||||
async def recv(self):
|
async def recv(self):
|
||||||
# print("Awaiting track in server")
|
|
||||||
frame = await self.track.recv()
|
frame = await self.track.recv()
|
||||||
audio_buffer.write(frame)
|
audio_buffer.write(frame)
|
||||||
return frame
|
return frame
|
||||||
@@ -136,7 +124,7 @@ class AudioStreamTrack(MediaStreamTrack):
|
|||||||
def start_messaging_thread():
|
def start_messaging_thread():
|
||||||
message_thread = threading.Thread(target=thread_queue_channel_send)
|
message_thread = threading.Thread(target=thread_queue_channel_send)
|
||||||
message_thread.start()
|
message_thread.start()
|
||||||
# message_thread.join()
|
|
||||||
|
|
||||||
def start_transcription_thread(max_threads):
|
def start_transcription_thread(max_threads):
|
||||||
t_threads = []
|
t_threads = []
|
||||||
@@ -145,12 +133,9 @@ def start_transcription_thread(max_threads):
|
|||||||
t_threads.append(t_thread)
|
t_threads.append(t_thread)
|
||||||
t_thread.start()
|
t_thread.start()
|
||||||
|
|
||||||
# for t_thread in t_threads:
|
|
||||||
# t_thread.join()
|
|
||||||
|
|
||||||
async def offer(request):
|
async def offer(request):
|
||||||
params = await request.json()
|
params = await request.json()
|
||||||
print("Request received")
|
|
||||||
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||||
|
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
@@ -185,11 +170,8 @@ async def offer(request):
|
|||||||
|
|
||||||
@pc.on("track")
|
@pc.on("track")
|
||||||
def on_track(track):
|
def on_track(track):
|
||||||
print("Track %s received", track.kind)
|
|
||||||
log_info("Track %s received", track.kind)
|
log_info("Track %s received", track.kind)
|
||||||
# Trials to listen to the correct track
|
|
||||||
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
|
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
|
||||||
# pc.addTrack(AudioStreamTrack(track))
|
|
||||||
|
|
||||||
# handle offer
|
# handle offer
|
||||||
await pc.setRemoteDescription(offer)
|
await pc.setRemoteDescription(offer)
|
||||||
@@ -197,7 +179,6 @@ async def offer(request):
|
|||||||
# send answer
|
# send answer
|
||||||
answer = await pc.createAnswer()
|
answer = await pc.createAnswer()
|
||||||
await pc.setLocalDescription(answer)
|
await pc.setLocalDescription(answer)
|
||||||
print("Response sent")
|
|
||||||
return web.Response(
|
return web.Response(
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
text=json.dumps(
|
text=json.dumps(
|
||||||
@@ -207,7 +188,6 @@ async def offer(request):
|
|||||||
|
|
||||||
|
|
||||||
async def on_shutdown(app):
|
async def on_shutdown(app):
|
||||||
# close peer connections
|
|
||||||
coros = [pc.close() for pc in pcs]
|
coros = [pc.close() for pc in pcs]
|
||||||
await asyncio.gather(*coros)
|
await asyncio.gather(*coros)
|
||||||
pcs.clear()
|
pcs.clear()
|
||||||
@@ -221,5 +201,3 @@ if __name__ == "__main__":
|
|||||||
web.run_app(
|
web.run_app(
|
||||||
app, access_log=None, host="127.0.0.1", port=1250
|
app, access_log=None, host="127.0.0.1", port=1250
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import configparser
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import threading
|
|
||||||
import configparser
|
|
||||||
import httpx
|
import httpx
|
||||||
import pyaudio
|
import pyaudio
|
||||||
import requests
|
import requests
|
||||||
import ast
|
|
||||||
import stamina
|
import stamina
|
||||||
from aiortc import (RTCPeerConnection, RTCSessionDescription)
|
from aiortc import (RTCPeerConnection, RTCSessionDescription)
|
||||||
from aiortc.contrib.media import (MediaPlayer, MediaRelay)
|
from aiortc.contrib.media import (MediaPlayer, MediaRelay)
|
||||||
|
|
||||||
from utils.server_utils import Mutex
|
from utils.server_utils import Mutex
|
||||||
|
|
||||||
logger = logging.getLogger("pc")
|
logger = logging.getLogger("pc")
|
||||||
@@ -19,13 +20,14 @@ file_lock = Mutex(open("test_sm_6.txt", "a"))
|
|||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read('config.ini')
|
config.read('config.ini')
|
||||||
|
|
||||||
|
|
||||||
class StreamClient:
|
class StreamClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
signaling,
|
signaling,
|
||||||
url="http://127.0.0.1:1250",
|
url="http://127.0.0.1:1250",
|
||||||
play_from=None,
|
play_from=None,
|
||||||
ping_pong=False
|
ping_pong=False
|
||||||
):
|
):
|
||||||
self.signaling = signaling
|
self.signaling = signaling
|
||||||
self.server_url = url
|
self.server_url = url
|
||||||
@@ -103,7 +105,6 @@ class StreamClient:
|
|||||||
channel = pc.createDataChannel("data-channel")
|
channel = pc.createDataChannel("data-channel")
|
||||||
self.channel_log(channel, "-", "created by local party")
|
self.channel_log(channel, "-", "created by local party")
|
||||||
|
|
||||||
|
|
||||||
async def send_pings():
|
async def send_pings():
|
||||||
while True:
|
while True:
|
||||||
self.channel_send(channel, "ping %d" % self.current_stamp())
|
self.channel_send(channel, "ping %d" % self.current_stamp())
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
|
import configparser
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import botocore
|
import botocore
|
||||||
import configparser
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
@@ -12,6 +13,7 @@ s3 = boto3.client('s3',
|
|||||||
aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"],
|
aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"],
|
||||||
aws_secret_access_key=config["DEFAULT"]["AWS_SECRET_KEY"])
|
aws_secret_access_key=config["DEFAULT"]["AWS_SECRET_KEY"])
|
||||||
|
|
||||||
|
|
||||||
def upload_files(files_to_upload):
|
def upload_files(files_to_upload):
|
||||||
"""
|
"""
|
||||||
Upload a list of files to the configured S3 bucket
|
Upload a list of files to the configured S3 bucket
|
||||||
@@ -45,6 +47,7 @@ def download_files(files_to_download):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
if sys.argv[1] == "download":
|
if sys.argv[1] == "download":
|
||||||
download_files([sys.argv[2]])
|
download_files([sys.argv[2]])
|
||||||
elif sys.argv[1] == "upload":
|
elif sys.argv[1] == "upload":
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
from functools import partial
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
from functools import partial
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
from typing import ContextManager, Generic, TypeVar
|
from typing import ContextManager, Generic, TypeVar
|
||||||
|
|
||||||
|
|
||||||
def run_in_executor(func, *args, executor=None, **kwargs):
|
def run_in_executor(func, *args, executor=None, **kwargs):
|
||||||
callback = partial(func, *args, **kwargs)
|
callback = partial(func, *args, **kwargs)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
@@ -11,6 +12,8 @@ def run_in_executor(func, *args, executor=None, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Mutex(Generic[T]):
|
class Mutex(Generic[T]):
|
||||||
def __init__(self, value: T):
|
def __init__(self, value: T):
|
||||||
self.__value = value
|
self.__value = value
|
||||||
@@ -22,4 +25,4 @@ class Mutex(Generic[T]):
|
|||||||
try:
|
try:
|
||||||
yield self.__value
|
yield self.__value
|
||||||
finally:
|
finally:
|
||||||
self.__lock.release()
|
self.__lock.release()
|
||||||
|
|||||||
@@ -1,23 +1,27 @@
|
|||||||
import torch
|
|
||||||
import configparser
|
import configparser
|
||||||
|
|
||||||
import nltk
|
import nltk
|
||||||
from transformers import BartTokenizer, BartForConditionalGeneration
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
from nltk.tokenize import word_tokenize
|
from nltk.tokenize import word_tokenize
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
|
from transformers import BartTokenizer, BartForConditionalGeneration
|
||||||
|
|
||||||
nltk.download('punkt', quiet=True)
|
nltk.download('punkt', quiet=True)
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read('config.ini')
|
config.read('config.ini')
|
||||||
|
|
||||||
|
|
||||||
def preprocess_sentence(sentence):
|
def preprocess_sentence(sentence):
|
||||||
stop_words = set(stopwords.words('english'))
|
stop_words = set(stopwords.words('english'))
|
||||||
tokens = word_tokenize(sentence.lower())
|
tokens = word_tokenize(sentence.lower())
|
||||||
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
||||||
return ' '.join(tokens)
|
return ' '.join(tokens)
|
||||||
|
|
||||||
|
|
||||||
def compute_similarity(sent1, sent2):
|
def compute_similarity(sent1, sent2):
|
||||||
"""
|
"""
|
||||||
Compute the similarity
|
Compute the similarity
|
||||||
@@ -28,6 +32,7 @@ def compute_similarity(sent1, sent2):
|
|||||||
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
def remove_almost_alike_sentences(sentences, threshold=0.7):
|
def remove_almost_alike_sentences(sentences, threshold=0.7):
|
||||||
num_sentences = len(sentences)
|
num_sentences = len(sentences)
|
||||||
removed_indices = set()
|
removed_indices = set()
|
||||||
@@ -55,12 +60,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
|
|||||||
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
|
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
|
||||||
return filtered_sentences
|
return filtered_sentences
|
||||||
|
|
||||||
|
|
||||||
def remove_outright_duplicate_sentences_from_chunk(chunk):
|
def remove_outright_duplicate_sentences_from_chunk(chunk):
|
||||||
chunk_text = chunk["text"]
|
chunk_text = chunk["text"]
|
||||||
sentences = nltk.sent_tokenize(chunk_text)
|
sentences = nltk.sent_tokenize(chunk_text)
|
||||||
nonduplicate_sentences = list(dict.fromkeys(sentences))
|
nonduplicate_sentences = list(dict.fromkeys(sentences))
|
||||||
return nonduplicate_sentences
|
return nonduplicate_sentences
|
||||||
|
|
||||||
|
|
||||||
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
||||||
chunk_sentences = []
|
chunk_sentences = []
|
||||||
|
|
||||||
@@ -80,6 +87,7 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
|||||||
chunk_sentences.append(temp_result)
|
chunk_sentences.append(temp_result)
|
||||||
return chunk_sentences
|
return chunk_sentences
|
||||||
|
|
||||||
|
|
||||||
def post_process_transcription(whisper_result):
|
def post_process_transcription(whisper_result):
|
||||||
transcript_text = ""
|
transcript_text = ""
|
||||||
for chunk in whisper_result["chunks"]:
|
for chunk in whisper_result["chunks"]:
|
||||||
@@ -107,12 +115,13 @@ def summarize_chunks(chunks, tokenizer, model):
|
|||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
summary_ids = model.generate(input_ids,
|
summary_ids = model.generate(input_ids,
|
||||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||||
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
||||||
summaries.append(summary)
|
summaries.append(summary)
|
||||||
return summaries
|
return summaries
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
||||||
"""
|
"""
|
||||||
Split text into smaller chunks.
|
Split text into smaller chunks.
|
||||||
@@ -132,6 +141,7 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])
|
|||||||
chunks.append(current_chunk.strip())
|
chunks.append(current_chunk.strip())
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def summarize(transcript_text, timestamp,
|
def summarize(transcript_text, timestamp,
|
||||||
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|||||||
@@ -1,15 +1,16 @@
|
|||||||
import matplotlib.pyplot as plt
|
|
||||||
from wordcloud import WordCloud, STOPWORDS
|
|
||||||
from nltk.corpus import stopwords
|
|
||||||
import collections
|
|
||||||
import spacy
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
import pickle
|
|
||||||
import ast
|
import ast
|
||||||
|
import collections
|
||||||
|
import configparser
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import scattertext as st
|
import scattertext as st
|
||||||
import configparser
|
import spacy
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
from wordcloud import WordCloud, STOPWORDS
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read('config.ini')
|
config.read('config.ini')
|
||||||
@@ -29,7 +30,7 @@ def create_wordcloud(timestamp, real_time=False):
|
|||||||
if real_time:
|
if real_time:
|
||||||
filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||||
else:
|
else:
|
||||||
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||||
|
|
||||||
with open("./artefacts/" + filename, "r") as f:
|
with open("./artefacts/" + filename, "r") as f:
|
||||||
transcription_text = f.read()
|
transcription_text = f.read()
|
||||||
@@ -202,4 +203,4 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
|||||||
if real_time:
|
if real_time:
|
||||||
open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||||
else:
|
else:
|
||||||
open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||||
|
|||||||
25
whisjax.py
25
whisjax.py
@@ -6,25 +6,24 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import configparser
|
import configparser
|
||||||
import jax.numpy as jnp
|
import os
|
||||||
|
import re
|
||||||
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
from datetime import datetime
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import jax.numpy as jnp
|
||||||
import moviepy.editor
|
import moviepy.editor
|
||||||
import moviepy.editor
|
import moviepy.editor
|
||||||
import nltk
|
import nltk
|
||||||
import os
|
|
||||||
import subprocess
|
|
||||||
import re
|
|
||||||
import tempfile
|
|
||||||
from loguru import logger
|
|
||||||
import yt_dlp as youtube_dl
|
import yt_dlp as youtube_dl
|
||||||
|
from loguru import logger
|
||||||
from urllib.parse import urlparse
|
|
||||||
from whisper_jax import FlaxWhisperPipline
|
from whisper_jax import FlaxWhisperPipline
|
||||||
|
|
||||||
from datetime import datetime
|
|
||||||
from utils.file_utilities import upload_files, download_files
|
from utils.file_utilities import upload_files, download_files
|
||||||
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
|
||||||
from utils.text_utilities import summarize, post_process_transcription
|
from utils.text_utilities import summarize, post_process_transcription
|
||||||
|
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
||||||
|
|
||||||
nltk.download('punkt', quiet=True)
|
nltk.download('punkt', quiet=True)
|
||||||
nltk.download('stopwords', quiet=True)
|
nltk.download('stopwords', quiet=True)
|
||||||
@@ -36,6 +35,7 @@ config.read('config.ini')
|
|||||||
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
||||||
NOW = datetime.now()
|
NOW = datetime.now()
|
||||||
|
|
||||||
|
|
||||||
def init_argparse() -> argparse.ArgumentParser:
|
def init_argparse() -> argparse.ArgumentParser:
|
||||||
"""
|
"""
|
||||||
Parse the CLI arguments
|
Parse the CLI arguments
|
||||||
@@ -52,7 +52,6 @@ def init_argparse() -> argparse.ArgumentParser:
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = init_argparse()
|
parser = init_argparse()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -140,10 +139,10 @@ def main():
|
|||||||
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
|
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
|
||||||
transcript_file.write(transcript_text)
|
transcript_file.write(transcript_text)
|
||||||
|
|
||||||
with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file_timestamps:
|
with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt",
|
||||||
|
"w") as transcript_file_timestamps:
|
||||||
transcript_file_timestamps.write(str(whisper_result))
|
transcript_file_timestamps.write(str(whisper_result))
|
||||||
|
|
||||||
|
|
||||||
logger.info("Creating word cloud")
|
logger.info("Creating word cloud")
|
||||||
create_wordcloud(NOW)
|
create_wordcloud(NOW)
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import configparser
|
import configparser
|
||||||
import pyaudio
|
import time
|
||||||
from whisper_jax import FlaxWhisperPipline
|
|
||||||
from pynput import keyboard
|
|
||||||
import jax.numpy as jnp
|
|
||||||
import wave
|
import wave
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from utils.file_utilities import upload_files
|
|
||||||
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
import jax.numpy as jnp
|
||||||
from utils.text_utilities import summarize, post_process_transcription
|
import pyaudio
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
import time
|
from pynput import keyboard
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
from whisper_jax import FlaxWhisperPipline
|
||||||
|
|
||||||
|
from utils.file_utilities import upload_files
|
||||||
|
from utils.text_utilities import summarize, post_process_transcription
|
||||||
|
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
config = configparser.ConfigParser()
|
||||||
config.read('config.ini')
|
config.read('config.ini')
|
||||||
|
|||||||
Reference in New Issue
Block a user