Merge pull request #35 from Monadical-SAS/feat/gokul

Front end integrations and small updates to return types
This commit is contained in:
projects-g
2023-07-19 13:19:13 +05:30
committed by GitHub
3 changed files with 80 additions and 30 deletions

View File

@@ -58,3 +58,4 @@ httpx==0.24.1
sortedcontainers==2.4.0
https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz
gpt4all==1.0.5
aiohttp_cors==0.7.0

View File

@@ -1,20 +1,22 @@
import asyncio
import datetime
import io
import json
import uuid
import wave
from concurrent.futures import ThreadPoolExecutor
import aiohttp_cors
import jax.numpy as jnp
import requests
from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
from av import AudioFifo
from gpt4all import GPT4All
from loguru import logger
from whisper_jax import FlaxWhisperPipline
from utils.run_utils import run_in_executor, config
from utils.run_utils import run_in_executor
pcs = set()
relay = MediaRelay()
@@ -28,27 +30,42 @@ RATE = 48000
audio_buffer = AudioFifo()
executor = ThreadPoolExecutor()
transcription_text = ""
# Load your locally downloaded Vicuna model and load it here. Set this path in the config.ini file
llm = GPT4All(config["DEFAULT"]["LLM_PATH"])
last_transcribed_time = 0.0
LLM_MACHINE_IP = "216.153.52.83"
LLM_MACHINE_PORT = "5000"
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
def get_title_and_summary():
global transcription_text
output = None
if len(transcription_text) > 1000:
print("Generating title and summary")
prompt = f"""
def get_title_and_summary(llm_input_text):
print("Generating title and summary")
# output = llm.generate(prompt)
# Use monadical-ml to fire this query to an LLM and get result
headers = {
"Content-Type": "application/json"
}
prompt = f"""
### Human:
Create a JSON object having 2 fields: title and summary. For the title field generate a short title for the given
text and for the summary field, summarize the given text by creating 3 key points.
{transcription_text}
Create a JSON object as response. The JSON object must have 2 fields: i) title and ii) summary. For the title field,
generate a short title for the given text. For the summary field, summarize the given text in three sentences.
{llm_input_text}
### Assistant:
"""
transcription_text = ""
output = llm.generate(prompt)
return str(output)
data = {
"prompt": prompt
}
try:
response = requests.post(LLM_URL, headers=headers, json=data)
output = json.loads(response.json()["results"][0]["text"])
output["description"] = output.pop("summary")
except Exception as e:
print(str(e))
output = None
return output
@@ -59,7 +76,7 @@ def channel_log(channel, t, message):
def channel_send(channel, message):
# channel_log(channel, ">", message)
if channel and message:
channel.send(str(message))
channel.send(json.dumps(message))
def get_transcription(frames):
@@ -73,11 +90,26 @@ def get_transcription(frames):
for frame in frames:
wf.writeframes(b"".join(frame.to_ndarray()))
wf.close()
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
# whisper_result['start_time'] = [f.time for f in frames]
global transcription_text
# To-Do: Look into WhisperTimeStampLogitsProcessor exception
try:
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
except Exception as e:
return
global transcription_text, last_transcribed_time
transcription_text += whisper_result["text"]
return whisper_result
duration = whisper_result["chunks"][0]["timestamp"][1]
if not duration:
duration = 5.0
last_transcribed_time += duration
result = {
"text": whisper_result["text"],
"timestamp": str(datetime.timedelta(seconds=
round(last_transcribed_time)))
}
return result
class AudioStreamTrack(MediaStreamTrack):
@@ -92,8 +124,10 @@ class AudioStreamTrack(MediaStreamTrack):
self.track = track
async def recv(self):
global transcription_text
frame = await self.track.recv()
audio_buffer.write(frame)
if local_frames := audio_buffer.read_many(256 * 960, partial=False):
whisper_result = run_in_executor(
get_transcription, local_frames, executor=executor
@@ -103,13 +137,18 @@ class AudioStreamTrack(MediaStreamTrack):
if f.result()
else None
)
llm_result = run_in_executor(get_title_and_summary,
executor=executor)
llm_result.add_done_callback(
lambda f: channel_send(data_channel, llm_result.result())
if f.result()
else None
)
if len(transcription_text) > 2000:
llm_input_text = transcription_text
transcription_text = ""
llm_result = run_in_executor(get_title_and_summary,
llm_input_text,
executor=executor)
llm_result.add_done_callback(
lambda f: channel_send(data_channel, llm_result.result())
if f.result()
else None
)
return frame
@@ -172,6 +211,16 @@ async def on_shutdown(app):
if __name__ == "__main__":
app = web.Application()
cors = aiohttp_cors.setup(
app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True, expose_headers="*", allow_headers="*"
)
},
)
offer_resource = cors.add(app.router.add_resource("/offer"))
cors.add(offer_resource.add_route("POST", offer))
app.on_shutdown.append(on_shutdown)
app.router.add_post("/offer", offer)
web.run_app(app, access_log=None, host="127.0.0.1", port=1250)

0
trials/api.py Normal file
View File