Merge pull request #35 from Monadical-SAS/feat/gokul

Front end integrations and small updates to return types
This commit is contained in:
projects-g
2023-07-19 13:19:13 +05:30
committed by GitHub
3 changed files with 80 additions and 30 deletions

View File

@@ -58,3 +58,4 @@ httpx==0.24.1
sortedcontainers==2.4.0 sortedcontainers==2.4.0
https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz
gpt4all==1.0.5 gpt4all==1.0.5
aiohttp_cors==0.7.0

View File

@@ -1,20 +1,22 @@
import asyncio import asyncio
import datetime
import io import io
import json import json
import uuid import uuid
import wave import wave
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import aiohttp_cors
import jax.numpy as jnp import jax.numpy as jnp
import requests
from aiohttp import web from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay from aiortc.contrib.media import MediaRelay
from av import AudioFifo from av import AudioFifo
from gpt4all import GPT4All
from loguru import logger from loguru import logger
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from utils.run_utils import run_in_executor, config from utils.run_utils import run_in_executor
pcs = set() pcs = set()
relay = MediaRelay() relay = MediaRelay()
@@ -28,27 +30,42 @@ RATE = 48000
audio_buffer = AudioFifo() audio_buffer = AudioFifo()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
transcription_text = "" transcription_text = ""
# Load your locally downloaded Vicuna model and load it here. Set this path in the config.ini file last_transcribed_time = 0.0
llm = GPT4All(config["DEFAULT"]["LLM_PATH"]) LLM_MACHINE_IP = "216.153.52.83"
LLM_MACHINE_PORT = "5000"
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
def get_title_and_summary(): def get_title_and_summary(llm_input_text):
global transcription_text print("Generating title and summary")
output = None # output = llm.generate(prompt)
if len(transcription_text) > 1000:
print("Generating title and summary") # Use monadical-ml to fire this query to an LLM and get result
prompt = f""" headers = {
"Content-Type": "application/json"
}
prompt = f"""
### Human: ### Human:
Create a JSON object having 2 fields: title and summary. For the title field generate a short title for the given Create a JSON object as response. The JSON object must have 2 fields: i) title and ii) summary. For the title field,
text and for the summary field, summarize the given text by creating 3 key points. generate a short title for the given text. For the summary field, summarize the given text in three sentences.
{transcription_text}
{llm_input_text}
### Assistant: ### Assistant:
""" """
transcription_text = ""
output = llm.generate(prompt) data = {
return str(output) "prompt": prompt
}
try:
response = requests.post(LLM_URL, headers=headers, json=data)
output = json.loads(response.json()["results"][0]["text"])
output["description"] = output.pop("summary")
except Exception as e:
print(str(e))
output = None
return output return output
@@ -59,7 +76,7 @@ def channel_log(channel, t, message):
def channel_send(channel, message): def channel_send(channel, message):
# channel_log(channel, ">", message) # channel_log(channel, ">", message)
if channel and message: if channel and message:
channel.send(str(message)) channel.send(json.dumps(message))
def get_transcription(frames): def get_transcription(frames):
@@ -73,11 +90,26 @@ def get_transcription(frames):
for frame in frames: for frame in frames:
wf.writeframes(b"".join(frame.to_ndarray())) wf.writeframes(b"".join(frame.to_ndarray()))
wf.close() wf.close()
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
# whisper_result['start_time'] = [f.time for f in frames] # To-Do: Look into WhisperTimeStampLogitsProcessor exception
global transcription_text try:
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
except Exception as e:
return
global transcription_text, last_transcribed_time
transcription_text += whisper_result["text"] transcription_text += whisper_result["text"]
return whisper_result duration = whisper_result["chunks"][0]["timestamp"][1]
if not duration:
duration = 5.0
last_transcribed_time += duration
result = {
"text": whisper_result["text"],
"timestamp": str(datetime.timedelta(seconds=
round(last_transcribed_time)))
}
return result
class AudioStreamTrack(MediaStreamTrack): class AudioStreamTrack(MediaStreamTrack):
@@ -92,8 +124,10 @@ class AudioStreamTrack(MediaStreamTrack):
self.track = track self.track = track
async def recv(self): async def recv(self):
global transcription_text
frame = await self.track.recv() frame = await self.track.recv()
audio_buffer.write(frame) audio_buffer.write(frame)
if local_frames := audio_buffer.read_many(256 * 960, partial=False): if local_frames := audio_buffer.read_many(256 * 960, partial=False):
whisper_result = run_in_executor( whisper_result = run_in_executor(
get_transcription, local_frames, executor=executor get_transcription, local_frames, executor=executor
@@ -103,13 +137,18 @@ class AudioStreamTrack(MediaStreamTrack):
if f.result() if f.result()
else None else None
) )
llm_result = run_in_executor(get_title_and_summary,
executor=executor) if len(transcription_text) > 2000:
llm_result.add_done_callback( llm_input_text = transcription_text
lambda f: channel_send(data_channel, llm_result.result()) transcription_text = ""
if f.result() llm_result = run_in_executor(get_title_and_summary,
else None llm_input_text,
) executor=executor)
llm_result.add_done_callback(
lambda f: channel_send(data_channel, llm_result.result())
if f.result()
else None
)
return frame return frame
@@ -172,6 +211,16 @@ async def on_shutdown(app):
if __name__ == "__main__": if __name__ == "__main__":
app = web.Application() app = web.Application()
cors = aiohttp_cors.setup(
app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True, expose_headers="*", allow_headers="*"
)
},
)
offer_resource = cors.add(app.router.add_resource("/offer"))
cors.add(offer_resource.add_route("POST", offer))
app.on_shutdown.append(on_shutdown) app.on_shutdown.append(on_shutdown)
app.router.add_post("/offer", offer)
web.run_app(app, access_log=None, host="127.0.0.1", port=1250) web.run_app(app, access_log=None, host="127.0.0.1", port=1250)

0
trials/api.py Normal file
View File