add front end integration

This commit is contained in:
Gokul Mohanarangan
2023-07-19 13:14:59 +05:30
parent 44e2f0c7b7
commit 317a113384
3 changed files with 81 additions and 30 deletions

View File

@@ -58,3 +58,4 @@ httpx==0.24.1
sortedcontainers==2.4.0 sortedcontainers==2.4.0
https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz
gpt4all==1.0.5 gpt4all==1.0.5
aiohttp_cors==0.7.0

View File

@@ -3,18 +3,19 @@ import io
import json import json
import uuid import uuid
import wave import wave
import requests
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import aiohttp_cors
import jax.numpy as jnp import jax.numpy as jnp
from aiohttp import web from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay from aiortc.contrib.media import MediaRelay
from av import AudioFifo from av import AudioFifo
from gpt4all import GPT4All import datetime
from loguru import logger from loguru import logger
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from utils.run_utils import run_in_executor, config from utils.run_utils import run_in_executor
pcs = set() pcs = set()
relay = MediaRelay() relay = MediaRelay()
@@ -28,26 +29,43 @@ RATE = 48000
audio_buffer = AudioFifo() audio_buffer = AudioFifo()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
transcription_text = "" transcription_text = ""
llm = GPT4All(config["DEFAULT"]["LLM_PATH"]) last_transcribed_time = 0.0
LLM_MACHINE_IP = "216.153.52.83"
LLM_MACHINE_PORT = "5000"
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
def get_title_and_summary(): def get_title_and_summary(llm_input_text):
global transcription_text
output = None
if len(transcription_text) > 1000:
print("Generating title and summary") print("Generating title and summary")
# output = llm.generate(prompt)
# Use monadical-ml to fire this query to an LLM and get result
headers = {
"Content-Type": "application/json"
}
prompt = f""" prompt = f"""
### Human: ### Human:
Create a JSON object having 2 fields: title and summary. For the title field generate a short title for the given Create a JSON object as response. The JSON object must have 2 fields: i) title and ii) summary. For the title field,
text and for the summary field, summarize the given text by creating 3 key points. generate a short title for the given text. For the summary field, summarize the given text in three sentences.
{transcription_text} {llm_input_text}
### Assistant: ### Assistant:
""" """
transcription_text = ""
output = llm.generate(prompt) data = {
return str(output) "prompt": prompt
}
try:
response = requests.post(LLM_URL, headers=headers, json=data)
output = json.loads(response.json()["results"][0]["text"])
output["description"] = output.pop("summary")
except Exception as e:
print(str(e))
output = None
return output return output
@@ -58,7 +76,7 @@ def channel_log(channel, t, message):
def channel_send(channel, message): def channel_send(channel, message):
# channel_log(channel, ">", message) # channel_log(channel, ">", message)
if channel and message: if channel and message:
channel.send(str(message)) channel.send(json.dumps(message))
def get_transcription(frames): def get_transcription(frames):
@@ -72,11 +90,26 @@ def get_transcription(frames):
for frame in frames: for frame in frames:
wf.writeframes(b"".join(frame.to_ndarray())) wf.writeframes(b"".join(frame.to_ndarray()))
wf.close() wf.close()
# To-Do: Look into WhisperTimeStampLogitsProcessor exception
try:
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True) whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
# whisper_result['start_time'] = [f.time for f in frames] except Exception as e:
global transcription_text return
global transcription_text, last_transcribed_time
transcription_text += whisper_result["text"] transcription_text += whisper_result["text"]
return whisper_result duration = whisper_result["chunks"][0]["timestamp"][1]
if not duration:
duration = 5.0
last_transcribed_time += duration
result = {
"text": whisper_result["text"],
"timestamp": str(datetime.timedelta(seconds=
round(last_transcribed_time)))
}
return result
class AudioStreamTrack(MediaStreamTrack): class AudioStreamTrack(MediaStreamTrack):
@@ -91,8 +124,10 @@ class AudioStreamTrack(MediaStreamTrack):
self.track = track self.track = track
async def recv(self): async def recv(self):
global transcription_text
frame = await self.track.recv() frame = await self.track.recv()
audio_buffer.write(frame) audio_buffer.write(frame)
if local_frames := audio_buffer.read_many(256 * 960, partial=False): if local_frames := audio_buffer.read_many(256 * 960, partial=False):
whisper_result = run_in_executor( whisper_result = run_in_executor(
get_transcription, local_frames, executor=executor get_transcription, local_frames, executor=executor
@@ -102,7 +137,12 @@ class AudioStreamTrack(MediaStreamTrack):
if f.result() if f.result()
else None else None
) )
if len(transcription_text) > 2000:
llm_input_text = transcription_text
transcription_text = ""
llm_result = run_in_executor(get_title_and_summary, llm_result = run_in_executor(get_title_and_summary,
llm_input_text,
executor=executor) executor=executor)
llm_result.add_done_callback( llm_result.add_done_callback(
lambda f: channel_send(data_channel, llm_result.result()) lambda f: channel_send(data_channel, llm_result.result())
@@ -171,6 +211,16 @@ async def on_shutdown(app):
if __name__ == "__main__": if __name__ == "__main__":
app = web.Application() app = web.Application()
cors = aiohttp_cors.setup(
app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True, expose_headers="*", allow_headers="*"
)
},
)
offer_resource = cors.add(app.router.add_resource("/offer"))
cors.add(offer_resource.add_route("POST", offer))
app.on_shutdown.append(on_shutdown) app.on_shutdown.append(on_shutdown)
app.router.add_post("/offer", offer)
web.run_app(app, access_log=None, host="127.0.0.1", port=1250) web.run_app(app, access_log=None, host="127.0.0.1", port=1250)

0
trials/api.py Normal file
View File