integrate webrtc logic with whisper-jax-gokul branch

This commit is contained in:
Gokul Mohanarangan
2023-07-10 14:06:04 +05:30
parent 3f1c59abc6
commit 0d3f2c9072
10 changed files with 678 additions and 13 deletions

75
client.py Normal file
View File

@@ -0,0 +1,75 @@
import argparse
import asyncio
import logging
import signal
from aiortc.contrib.signaling import (add_signaling_arguments,
create_signaling)
from streamclient import StreamClient
logger = logging.getLogger("pc")
async def main():
parser = argparse.ArgumentParser(description="Data channels ping/pong")
parser.add_argument(
"--url", type=str, nargs="?", default="http://127.0.0.1:1250/offer"
)
parser.add_argument(
"--ping-pong",
help="Benchmark data channel with ping pong",
type=eval,
choices=[True, False],
default="False",
)
parser.add_argument(
"--play-from",
type=str,
default="",
)
add_signaling_arguments(parser)
args = parser.parse_args()
signaling = create_signaling(args)
async def shutdown(signal, loop):
"""Cleanup tasks tied to the service's shutdown."""
logging.info(f"Received exit signal {signal.name}...")
logging.info("Closing database connections")
logging.info("Nacking outstanding messages")
tasks = [t for t in asyncio.all_tasks() if t is not
asyncio.current_task()]
[task.cancel() for task in tasks]
logging.info(f"Cancelling {len(tasks)} outstanding tasks")
await asyncio.gather(*tasks, return_exceptions=True)
logging.info(f"Flushing metrics")
loop.stop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
loop = asyncio.get_event_loop()
for s in signals:
loop.add_signal_handler(
s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
# Init client
sc = StreamClient(
signaling=signaling,
url=args.url,
play_from=args.play_from,
ping_pong=args.ping_pong
)
await sc.start()
print("Stream client started")
async for msg in sc.get_reader():
print(msg)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -4,7 +4,7 @@ KMP_DUPLICATE_LIB_OK=TRUE
# Export OpenAI API Key
OPENAI_APIKEY=
# Export Whisper Model Size
WHISPER_MODEL_SIZE=medium
WHISPER_MODEL_SIZE=tiny
WHISPER_REAL_TIME_MODEL_SIZE=tiny
# AWS config
AWS_ACCESS_KEY=***REMOVED***
@@ -19,3 +19,4 @@ MAX_CHUNK_LENGTH=1024
SUMMARIZE_USING_CHUNKS=YES
# Audio device
BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME=aggregator
AV_FOUNDATION_DEVICE_ID=2

File diff suppressed because one or more lines are too long

View File

@@ -41,12 +41,20 @@ urllib3
yarl==1.9.2
boto3==1.26.151
nltk==3.8.1
wordcloud
spacy
scattertext
pandas
jupyter
seaborn
matplotlib
termcolor
ffmpeg
wordcloud==1.9.2
spacy==3.5.4
scattertext==0.1.19
pandas==2.0.3
jupyter==1.0.0
seaborn==0.12.2
matplotlib==3.7.2
matplotlib-inline==0.1.6
termcolor==2.3.0
ffmpeg==1.4
aiortc==1.5.0
cached_property==1.5.2
stamina==23.1.0
httpx==0.24.1
sortedcontainers==2.4.0
openai-whisper @ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8
https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz

181
server_executor_cleaned.py Normal file
View File

@@ -0,0 +1,181 @@
import asyncio
import datetime
import io
import json
import logging
import sys
import uuid
import wave
from concurrent.futures import ThreadPoolExecutor
import jax.numpy as jnp
from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
from av import AudioFifo
from whisper_jax import FlaxWhisperPipline
from utils import run_in_executor
logger = logging.getLogger(__name__)
transcription = ""
pcs = set()
relay = MediaRelay()
data_channel = None
total_bytes_handled = 0
pipeline = FlaxWhisperPipline("openai/whisper-tiny", dtype=jnp.float16, batch_size=16)
CHANNELS = 2
RATE = 48000
audio_buffer = AudioFifo()
start_time = datetime.datetime.now()
executor = ThreadPoolExecutor()
def channel_log(channel, t, message):
print("channel(%s) %s %s" % (channel.label, t, message))
def channel_send(channel, message):
# channel_log(channel, ">", message)
global start_time
if channel:
channel.send(message)
print(
"Bytes handled :",
total_bytes_handled,
" Time : ",
datetime.datetime.now() - start_time,
)
def get_transcription(frames):
print("Transcribing..")
# samples = np.ndarray(
# np.concatenate([f.to_ndarray() for f in frames], axis=None),
# dtype=np.float32,
# )
# whisper_result = pipeline(
# {
# "array": samples,
# "sampling_rate": 48000,
# },
# return_timestamps=True,
# )
out_file = io.BytesIO()
wf = wave.open(out_file, "wb")
wf.setnchannels(CHANNELS)
wf.setframerate(RATE)
wf.setsampwidth(2)
for frame in frames:
wf.writeframes(b"".join(frame.to_ndarray()))
wf.close()
global total_bytes_handled
total_bytes_handled += sys.getsizeof(wf)
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
with open("test_exec.txt", "a") as f:
f.write(whisper_result["text"])
whisper_result['start_time'] = [f.time for f in frames]
return whisper_result
class AudioStreamTrack(MediaStreamTrack):
"""
An audio stream track.
"""
kind = "audio"
def __init__(self, track):
super().__init__()
self.track = track
async def recv(self):
frame = await self.track.recv()
audio_buffer.write(frame)
if local_frames := audio_buffer.read_many(256 * 960, partial=False):
whisper_result = run_in_executor(
get_transcription, local_frames, executor=executor
)
whisper_result.add_done_callback(
lambda f: channel_send(data_channel, str(whisper_result.result()))
if (f.result())
else None
)
return frame
async def offer(request):
params = await request.json()
print("Request received")
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection()
pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc)
def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args)
log_info("Created for %s", request.remote)
@pc.on("datachannel")
def on_datachannel(channel):
global data_channel, start_time
data_channel = channel
channel_log(channel, "-", "created by remote party")
start_time = datetime.datetime.now()
@channel.on("message")
def on_message(message):
channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("ping"):
# reply
channel_send(channel, "pong" + message[4:])
@pc.on("connectionstatechange")
async def on_connectionstatechange():
log_info("Connection state is %s", pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
@pc.on("track")
def on_track(track):
print("Track %s received" % track.kind)
log_info("Track %s received", track.kind)
# Trials to listen to the correct track
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# pc.addTrack(AudioStreamTrack(track))
# handle offer
await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
print("Response sent")
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
if __name__ == "__main__":
app = web.Application()
app.on_shutdown.append(on_shutdown)
app.router.add_post("/offer", offer)
web.run_app(app, access_log=None, host="127.0.0.1", port=1250)

227
server_multithreaded.py Normal file
View File

@@ -0,0 +1,227 @@
import asyncio
import datetime
import io
import json
import logging
import os
import sys
import threading
import uuid
import wave
from concurrent.futures import ThreadPoolExecutor
from sortedcontainers import SortedDict
import configparser
import jax.numpy as jnp
from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import (MediaRelay)
from av import AudioFifo
from whisper_jax import FlaxWhisperPipline
ROOT = os.path.dirname(__file__)
config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
logger = logging.getLogger("pc")
pcs = set()
relay = MediaRelay()
data_channel = None
sorted_message_queue = SortedDict()
CHANNELS = 2
RATE = 44100
CHUNK_SIZE = 256
audio_buffer = AudioFifo()
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
dtype=jnp.float16,
batch_size=16)
transcription = ""
start_time = datetime.datetime.now()
total_bytes_handled = 0
executor = ThreadPoolExecutor()
frame_lock = threading.Lock()
file_lock = threading.Lock()
total_bytes_handled_lock = threading.Lock()
def channel_log(channel, t, message):
print("channel(%s) %s %s" % (channel.label, t, message))
def thread_queue_channel_send():
print("M-thread created")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
least_time = sorted_message_queue.keys()[0]
message = sorted_message_queue[least_time]
if message:
del sorted_message_queue[least_time]
data_channel.send(message)
print("M-thread sent message to client")
with total_bytes_handled_lock:
print("Bytes handled :", total_bytes_handled, " Time : ", datetime.datetime.now() - start_time)
except Exception as e:
print("Exception", str(e))
pass
loop.run_forever()
# async def channel_send(channel, message):
# channel_log(channel, ">", message)
# if channel and message:
# channel.send(message)
def get_transcription(local_thread_id):
# Block 1
print("T-thread -> ", str(local_thread_id) , "created")
global frame_lock
while True:
with frame_lock:
frames = audio_buffer.read_many(CHUNK_SIZE * 960, partial=False)
if not frames:
transcribe = False
else:
transcribe = True
if transcribe:
try:
print("T-thread ", str(local_thread_id), "is transcribing")
sorted_message_queue[frames[0].time] = None
out_file = io.BytesIO()
wf = wave.open(out_file, "wb")
wf.setnchannels(CHANNELS)
wf.setframerate(RATE)
wf.setsampwidth(2)
for frame in frames:
wf.writeframes(b''.join(frame.to_ndarray()))
wf.close()
whisper_result = pipeline(out_file.getvalue())
global total_bytes_handled
with total_bytes_handled_lock:
total_bytes_handled += sys.getsizeof(wf)
item = {'text': whisper_result["text"],
'start_time': str(frames[0].time),
'time': str(datetime.datetime.now())
}
sorted_message_queue[frames[0].time] = str(item)
start_messaging_thread()
except Exception as e:
print("Exception -> ", str(e))
class AudioStreamTrack(MediaStreamTrack):
"""
A video stream track that transforms frames from an another track.
"""
kind = "audio"
def __init__(self, track):
super().__init__() # don't forget this!
self.track = track
async def recv(self):
# print("Awaiting track in server")
frame = await self.track.recv()
audio_buffer.write(frame)
return frame
def start_messaging_thread():
message_thread = threading.Thread(target=thread_queue_channel_send)
message_thread.start()
# message_thread.join()
def start_transcription_thread(max_threads):
t_threads = []
for i in range(max_threads):
t_thread = threading.Thread(target=get_transcription, args=(i,))
t_threads.append(t_thread)
t_thread.start()
# for t_thread in t_threads:
# t_thread.join()
async def offer(request):
params = await request.json()
print("Request received")
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection()
pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc)
def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args)
log_info("Created for %s", request.remote)
@pc.on("datachannel")
def on_datachannel(channel):
global data_channel, start_time
data_channel = channel
channel_log(channel, "-", "created by remote party")
start_time = datetime.datetime.now()
@channel.on("message")
def on_message(message):
channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("ping"):
# reply
channel.send("pong" + message[4:])
@pc.on("connectionstatechange")
async def on_connectionstatechange():
log_info("Connection state is %s", pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
@pc.on("track")
def on_track(track):
print("Track %s received", track.kind)
log_info("Track %s received", track.kind)
# Trials to listen to the correct track
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# pc.addTrack(AudioStreamTrack(track))
# handle offer
await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
print("Response sent")
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
if __name__ == "__main__":
app = web.Application()
app.on_shutdown.append(on_shutdown)
start_transcription_thread(6)
app.router.add_post("/offer", offer)
web.run_app(
app, access_log=None, host="127.0.0.1", port=1250
)

168
streamclient.py Normal file
View File

@@ -0,0 +1,168 @@
import asyncio
import logging
import time
import uuid
import threading
import configparser
import httpx
import pyaudio
import requests
import ast
import stamina
from aiortc import (RTCPeerConnection, RTCSessionDescription)
from aiortc.contrib.media import (MediaPlayer, MediaRelay)
logger = logging.getLogger("pc")
file_lock = threading.Lock()
config = configparser.ConfigParser()
config.read('config.ini')
class StreamClient:
def __init__(
self,
signaling,
url="http://127.0.0.1:1250",
play_from=None,
ping_pong=False,
audio_stream=None
):
self.signaling = signaling
self.server_url = url
self.play_from = play_from
self.ping_pong = ping_pong
self.paudio = pyaudio.PyAudio()
self.pc = RTCPeerConnection()
self.loop = asyncio.get_event_loop()
# self.loop = asyncio.new_event_loop()
self.relay = None
self.pcs = set()
self.time_start = None
self.queue = asyncio.Queue()
self.player = MediaPlayer(':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]),
format='avfoundation', options={'channels': '2'})
def stop(self):
self.loop.run_until_complete(self.signaling.close())
self.loop.run_until_complete(self.pc.close())
# self.loop.close()
print("ended")
def create_local_tracks(self, play_from):
if play_from:
player = MediaPlayer(play_from)
return player.audio, player.video
else:
if self.relay is None:
self.relay = MediaRelay()
print("Created local track from microphone stream")
return self.relay.subscribe(self.player.audio), None
def channel_log(self, channel, t, message):
print("channel(%s) %s %s" % (channel.label, t, message))
def channel_send(self, channel, message):
# self.channel_log(channel, ">", message)
channel.send(message)
def current_stamp(self):
if self.time_start is None:
self.time_start = time.time()
return 0
else:
return int((time.time() - self.time_start) * 1000000)
async def run_offer(self, pc, signaling):
# microphone
audio, video = self.create_local_tracks(self.play_from)
pc_id = "PeerConnection(%s)" % uuid.uuid4()
self.pcs.add(pc)
def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
self.pcs.discard(pc)
@pc.on("track")
def on_track(track):
print("Sending %s" % track.kind)
# Trials
self.pc.addTrack(track)
# self.pc.addTrack(self.microphone)
@track.on("ended")
async def on_ended():
log_info("Track %s ended", track.kind)
self.pc.addTrack(audio)
# DataChannel
channel = pc.createDataChannel("data-channel")
self.channel_log(channel, "-", "created by local party")
async def send_pings():
while True:
self.channel_send(channel, "ping %d" % self.current_stamp())
await asyncio.sleep(1)
@channel.on("open")
def on_open():
if self.ping_pong:
asyncio.ensure_future(send_pings())
@channel.on("message")
def on_message(message):
self.queue.put_nowait(message)
if self.ping_pong:
self.channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("pong"):
elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000
print(" RTT %.2f ms" % elapsed_ms)
await pc.setLocalDescription(await pc.createOffer())
sdp = {
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type
}
@stamina.retry(on=httpx.HTTPError, attempts=5)
def connect_to_server():
response = requests.post(self.server_url, json=sdp, timeout=10)
response.raise_for_status()
return response
params = connect_to_server().json()
answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
await pc.setRemoteDescription(answer)
self.reader = self.worker(f"worker", self.queue)
def get_reader(self):
return self.reader
async def worker(self, name, queue):
while True:
msg = await self.queue.get()
msg = ast.literal_eval(msg)
with file_lock:
with open("test_sm_6.txt", "a") as f:
f.write(msg["text"])
yield msg["text"]
self.queue.task_done()
async def start(self):
print("Starting stream client")
coro = self.run_offer(self.pc, self.signaling)
task = asyncio.create_task(coro)
await task

File diff suppressed because one or more lines are too long

7
utils.py Normal file
View File

@@ -0,0 +1,7 @@
import asyncio
from functools import partial
def run_in_executor(func, *args, executor=None, **kwargs):
callback = partial(func, *args, **kwargs)
loop = asyncio.get_event_loop()
return asyncio.get_event_loop().run_in_executor(executor, callback)