minor refactor

This commit is contained in:
Gokul Mohanarangan
2023-07-10 22:48:22 +05:30
parent 73c4270764
commit 3128813ca3
8 changed files with 82 additions and 85 deletions

View File

@@ -1,23 +1,26 @@
import asyncio
import configparser
import datetime
import io
import json
import logging
import os
import sys
import threading
import uuid
import wave
from concurrent.futures import ThreadPoolExecutor
from sortedcontainers import SortedDict
import configparser
import jax.numpy as jnp
from aiohttp import web
from aiohttp import webq
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import (MediaRelay)
from av import AudioFifo
from sortedcontainers import SortedDict
from whisper_jax import FlaxWhisperPipline
from utils.server_utils import Mutex
ROOT = os.path.dirname(__file__)
config = configparser.ConfigParser()
@@ -46,14 +49,14 @@ total_bytes_handled = 0
executor = ThreadPoolExecutor()
frame_lock = threading.Lock()
total_bytes_handled_lock = threading.Lock()
frame_lock = Mutex(audio_buffer)
def channel_log(channel, t, message):
print("channel(%s) %s %s" % (channel.label, t, message))
def thread_queue_channel_send():
print("M-thread created")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
@@ -62,25 +65,15 @@ def thread_queue_channel_send():
if message:
del sorted_message_queue[least_time]
data_channel.send(message)
print("M-thread sent message to client")
with total_bytes_handled_lock:
print("Bytes handled :", total_bytes_handled, " Time : ", datetime.datetime.now() - start_time)
except Exception as e:
print("Exception", str(e))
pass
loop.run_forever()
# async def channel_send(channel, message):
# channel_log(channel, ">", message)
# if channel and message:
# channel.send(message)
def get_transcription(local_thread_id):
# Block 1
print("T-thread -> ", str(local_thread_id) , "created")
global frame_lock
def get_transcription():
while True:
with frame_lock:
with frame_lock.lock() as audio_buffer:
frames = audio_buffer.read_many(CHUNK_SIZE * 960, partial=False)
if not frames:
transcribe = False
@@ -89,7 +82,6 @@ def get_transcription(local_thread_id):
if transcribe:
try:
print("T-thread ", str(local_thread_id), "is transcribing")
sorted_message_queue[frames[0].time] = None
out_file = io.BytesIO()
wf = wave.open(out_file, "wb")
@@ -102,10 +94,6 @@ def get_transcription(local_thread_id):
wf.close()
whisper_result = pipeline(out_file.getvalue())
global total_bytes_handled
with total_bytes_handled_lock:
total_bytes_handled += sys.getsizeof(wf)
item = {'text': whisper_result["text"],
'start_time': str(frames[0].time),
'time': str(datetime.datetime.now())
@@ -115,6 +103,7 @@ def get_transcription(local_thread_id):
except Exception as e:
print("Exception -> ", str(e))
class AudioStreamTrack(MediaStreamTrack):
"""
A video stream track that transforms frames from an another track.
@@ -127,7 +116,6 @@ class AudioStreamTrack(MediaStreamTrack):
self.track = track
async def recv(self):
# print("Awaiting track in server")
frame = await self.track.recv()
audio_buffer.write(frame)
return frame
@@ -136,7 +124,7 @@ class AudioStreamTrack(MediaStreamTrack):
def start_messaging_thread():
message_thread = threading.Thread(target=thread_queue_channel_send)
message_thread.start()
# message_thread.join()
def start_transcription_thread(max_threads):
t_threads = []
@@ -145,12 +133,9 @@ def start_transcription_thread(max_threads):
t_threads.append(t_thread)
t_thread.start()
# for t_thread in t_threads:
# t_thread.join()
async def offer(request):
params = await request.json()
print("Request received")
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection()
@@ -185,11 +170,8 @@ async def offer(request):
@pc.on("track")
def on_track(track):
print("Track %s received", track.kind)
log_info("Track %s received", track.kind)
# Trials to listen to the correct track
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# pc.addTrack(AudioStreamTrack(track))
# handle offer
await pc.setRemoteDescription(offer)
@@ -197,7 +179,6 @@ async def offer(request):
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
print("Response sent")
return web.Response(
content_type="application/json",
text=json.dumps(
@@ -207,7 +188,6 @@ async def offer(request):
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
@@ -221,5 +201,3 @@ if __name__ == "__main__":
web.run_app(
app, access_log=None, host="127.0.0.1", port=1250
)