update server to run summarizer and captions

This commit is contained in:
Gokul Mohanarangan
2023-07-18 22:36:35 +05:30
parent 20eaeee46b
commit e5aa943998
3 changed files with 41 additions and 18 deletions

View File

@@ -1,4 +1,3 @@
pyaudio==0.2.13
keyboard==0.13.5 keyboard==0.13.5
pynput==1.7.6 pynput==1.7.6
wave==0.0.2 wave==0.0.2
@@ -57,3 +56,4 @@ stamina==23.1.0
httpx==0.24.1 httpx==0.24.1
sortedcontainers==2.4.0 sortedcontainers==2.4.0
https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz
gpt4all==1.0.5

View File

@@ -1,6 +1,7 @@
import asyncio import asyncio
import io import io
import json import json
import time
import uuid import uuid
import wave import wave
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@@ -12,7 +13,7 @@ from aiortc.contrib.media import MediaRelay
from av import AudioFifo from av import AudioFifo
from loguru import logger from loguru import logger
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from gpt4all import GPT4All
from utils.run_utils import run_in_executor from utils.run_utils import run_in_executor
pcs = set() pcs = set()
@@ -26,6 +27,28 @@ CHANNELS = 2
RATE = 48000 RATE = 48000
audio_buffer = AudioFifo() audio_buffer = AudioFifo()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
transcription_text = ""
llm = GPT4All("/Users/gokulmohanarangan/Library/Application Support/nomic.ai/GPT4All/ggml-vicuna-13b-1.1-q4_2.bin")
def get_title_and_summary():
global transcription_text
output = None
if len(transcription_text) > 1000:
print("Generating title and summary")
prompt = f"""
### Human:
Create a JSON object having 2 fields: title and summary. For the title field generate a short title for the given
text and for the summary field, summarize the given text by creating 3 key points.
{transcription_text}
### Assistant:
"""
transcription_text = ""
output = llm.generate(prompt)
return str(output)
return output
def channel_log(channel, t, message): def channel_log(channel, t, message):
@@ -34,8 +57,8 @@ def channel_log(channel, t, message):
def channel_send(channel, message): def channel_send(channel, message):
# channel_log(channel, ">", message) # channel_log(channel, ">", message)
if channel: if channel and message:
channel.send(message) channel.send(str(message))
def get_transcription(frames): def get_transcription(frames):
@@ -50,9 +73,9 @@ def get_transcription(frames):
wf.writeframes(b"".join(frame.to_ndarray())) wf.writeframes(b"".join(frame.to_ndarray()))
wf.close() wf.close()
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True) whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
with open("test_exec.txt", "a") as f: # whisper_result['start_time'] = [f.time for f in frames]
f.write(whisper_result["text"]) global transcription_text
whisper_result['start_time'] = [f.time for f in frames] transcription_text += whisper_result["text"]
return whisper_result return whisper_result
@@ -75,9 +98,15 @@ class AudioStreamTrack(MediaStreamTrack):
get_transcription, local_frames, executor=executor get_transcription, local_frames, executor=executor
) )
whisper_result.add_done_callback( whisper_result.add_done_callback(
lambda f: channel_send(data_channel, lambda f: channel_send(data_channel, whisper_result.result())
str(whisper_result.result())) if f.result()
if (f.result()) else None
)
llm_result = run_in_executor(get_title_and_summary,
executor=executor)
llm_result.add_done_callback(
lambda f: channel_send(data_channel, llm_result.result())
if f.result()
else None else None
) )
return frame return frame

View File

@@ -11,10 +11,7 @@ from aiortc import (RTCPeerConnection, RTCSessionDescription)
from aiortc.contrib.media import (MediaPlayer, MediaRelay) from aiortc.contrib.media import (MediaPlayer, MediaRelay)
from utils.log_utils import logger from utils.log_utils import logger
from utils.run_utils import config, Mutex from utils.run_utils import config
file_lock = Mutex(open("test_sm_6.txt", "a"))
class StreamClient: class StreamClient:
def __init__( def __init__(
@@ -146,10 +143,7 @@ class StreamClient:
async def worker(self, name, queue): async def worker(self, name, queue):
while True: while True:
msg = await self.queue.get() msg = await self.queue.get()
msg = ast.literal_eval(msg) yield msg
with file_lock.lock() as file:
file.write(msg["text"])
yield msg["text"]
self.queue.task_done() self.queue.task_done()
async def start(self): async def start(self):