Files
reflector/server_multithreaded.py
2023-07-10 14:06:04 +05:30

228 lines
6.5 KiB
Python

import asyncio
import datetime
import io
import json
import logging
import os
import sys
import threading
import uuid
import wave
from concurrent.futures import ThreadPoolExecutor
from sortedcontainers import SortedDict
import configparser
import jax.numpy as jnp
from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import (MediaRelay)
from av import AudioFifo
from whisper_jax import FlaxWhisperPipline
ROOT = os.path.dirname(__file__)
config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
logger = logging.getLogger("pc")
pcs = set()
relay = MediaRelay()
data_channel = None
sorted_message_queue = SortedDict()
CHANNELS = 2
RATE = 44100
CHUNK_SIZE = 256
audio_buffer = AudioFifo()
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
dtype=jnp.float16,
batch_size=16)
transcription = ""
start_time = datetime.datetime.now()
total_bytes_handled = 0
executor = ThreadPoolExecutor()
frame_lock = threading.Lock()
file_lock = threading.Lock()
total_bytes_handled_lock = threading.Lock()
def channel_log(channel, t, message):
print("channel(%s) %s %s" % (channel.label, t, message))
def thread_queue_channel_send():
print("M-thread created")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
least_time = sorted_message_queue.keys()[0]
message = sorted_message_queue[least_time]
if message:
del sorted_message_queue[least_time]
data_channel.send(message)
print("M-thread sent message to client")
with total_bytes_handled_lock:
print("Bytes handled :", total_bytes_handled, " Time : ", datetime.datetime.now() - start_time)
except Exception as e:
print("Exception", str(e))
pass
loop.run_forever()
# async def channel_send(channel, message):
# channel_log(channel, ">", message)
# if channel and message:
# channel.send(message)
def get_transcription(local_thread_id):
# Block 1
print("T-thread -> ", str(local_thread_id) , "created")
global frame_lock
while True:
with frame_lock:
frames = audio_buffer.read_many(CHUNK_SIZE * 960, partial=False)
if not frames:
transcribe = False
else:
transcribe = True
if transcribe:
try:
print("T-thread ", str(local_thread_id), "is transcribing")
sorted_message_queue[frames[0].time] = None
out_file = io.BytesIO()
wf = wave.open(out_file, "wb")
wf.setnchannels(CHANNELS)
wf.setframerate(RATE)
wf.setsampwidth(2)
for frame in frames:
wf.writeframes(b''.join(frame.to_ndarray()))
wf.close()
whisper_result = pipeline(out_file.getvalue())
global total_bytes_handled
with total_bytes_handled_lock:
total_bytes_handled += sys.getsizeof(wf)
item = {'text': whisper_result["text"],
'start_time': str(frames[0].time),
'time': str(datetime.datetime.now())
}
sorted_message_queue[frames[0].time] = str(item)
start_messaging_thread()
except Exception as e:
print("Exception -> ", str(e))
class AudioStreamTrack(MediaStreamTrack):
"""
A video stream track that transforms frames from an another track.
"""
kind = "audio"
def __init__(self, track):
super().__init__() # don't forget this!
self.track = track
async def recv(self):
# print("Awaiting track in server")
frame = await self.track.recv()
audio_buffer.write(frame)
return frame
def start_messaging_thread():
message_thread = threading.Thread(target=thread_queue_channel_send)
message_thread.start()
# message_thread.join()
def start_transcription_thread(max_threads):
t_threads = []
for i in range(max_threads):
t_thread = threading.Thread(target=get_transcription, args=(i,))
t_threads.append(t_thread)
t_thread.start()
# for t_thread in t_threads:
# t_thread.join()
async def offer(request):
params = await request.json()
print("Request received")
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection()
pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc)
def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args)
log_info("Created for %s", request.remote)
@pc.on("datachannel")
def on_datachannel(channel):
global data_channel, start_time
data_channel = channel
channel_log(channel, "-", "created by remote party")
start_time = datetime.datetime.now()
@channel.on("message")
def on_message(message):
channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("ping"):
# reply
channel.send("pong" + message[4:])
@pc.on("connectionstatechange")
async def on_connectionstatechange():
log_info("Connection state is %s", pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
@pc.on("track")
def on_track(track):
print("Track %s received", track.kind)
log_info("Track %s received", track.kind)
# Trials to listen to the correct track
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# pc.addTrack(AudioStreamTrack(track))
# handle offer
await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
print("Response sent")
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
if __name__ == "__main__":
app = web.Application()
app.on_shutdown.append(on_shutdown)
start_transcription_thread(6)
app.router.add_post("/offer", offer)
web.run_app(
app, access_log=None, host="127.0.0.1", port=1250
)