This commit is contained in:
Gokul Mohanarangan
2023-07-11 11:06:27 +05:30
parent 58c9cdf676
commit b7fbfb2a54
13 changed files with 54 additions and 44 deletions

View File

@@ -2,7 +2,6 @@ import asyncio
import datetime
import io
import json
from loguru import logger
import sys
import uuid
import wave
@@ -13,6 +12,7 @@ from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
from av import AudioFifo
from loguru import logger
from whisper_jax import FlaxWhisperPipline
from utils.server_utils import run_in_executor
@@ -23,7 +23,9 @@ pcs = set()
relay = MediaRelay()
data_channel = None
total_bytes_handled = 0
pipeline = FlaxWhisperPipline("openai/whisper-tiny", dtype=jnp.float16, batch_size=16)
pipeline = FlaxWhisperPipline("openai/whisper-tiny",
dtype=jnp.float16,
batch_size=16)
CHANNELS = 2
RATE = 48000
@@ -50,18 +52,6 @@ def channel_send(channel, message):
def get_transcription(frames):
print("Transcribing..")
# samples = np.ndarray(
# np.concatenate([f.to_ndarray() for f in frames], axis=None),
# dtype=np.float32,
# )
# whisper_result = pipeline(
# {
# "array": samples,
# "sampling_rate": 48000,
# },
# return_timestamps=True,
# )
out_file = io.BytesIO()
wf = wave.open(out_file, "wb")
wf.setnchannels(CHANNELS)
@@ -108,7 +98,6 @@ class AudioStreamTrack(MediaStreamTrack):
async def offer(request):
params = await request.json()
print("Request received")
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection()
@@ -132,7 +121,6 @@ async def offer(request):
channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("ping"):
# reply
channel_send(channel, "pong" + message[4:])
@pc.on("connectionstatechange")
@@ -144,19 +132,13 @@ async def offer(request):
@pc.on("track")
def on_track(track):
print("Track %s received" % track.kind)
log_info("Track %s received", track.kind)
# Trials to listen to the correct track
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# pc.addTrack(AudioStreamTrack(track))
# handle offer
await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
print("Response sent")
return web.Response(
content_type="application/json",
text=json.dumps(
@@ -166,7 +148,6 @@ async def offer(request):
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()