add data classes and typing

This commit is contained in:
Gokul Mohanarangan
2023-07-25 22:55:17 +05:30
parent 81680796cd
commit b892fc0562
2 changed files with 158 additions and 59 deletions

119
reflector_dataclasses.py Normal file
View File

@@ -0,0 +1,119 @@
import datetime
from dataclasses import dataclass
from typing import List
import av
@dataclass
class TitleSummaryInput:
input_text = str
transcribed_time = float
prompt = str
data = dict
def __init__(self, transcribed_time, input_text=""):
self.input_text = input_text
self.transcribed_time = transcribed_time
self.prompt = f"""
### Human:
Create a JSON object as response. The JSON object must have 2 fields:
i) title and ii) summary. For the title field,generate a short title
for the given text. For the summary field, summarize the given text
in three sentences.
{self.input_text}
### Assistant:
"""
self.data = {"data": self.prompt}
self.headers = {"Content-Type": "application/json"}
@dataclass
class IncrementalResponse:
title = str
description = str
transcript = str
def __init__(self, title, desc, transcript):
self.title = title
self.description = desc
self.transcript = transcript
@dataclass
class TitleSummaryOutput:
cmd = str
topics = List[IncrementalResponse]
def __init__(self, inc_responses):
self.topics = inc_responses
def get_response(self):
return {
"cmd": self.cmd,
"topics": self.topics
}
@dataclass
class ParseLLMResult:
description = str
transcript = str
timestamp = str
def __init__(self, param: TitleSummaryInput, output: dict):
self.transcript = param.input_text
self.description = output.pop("summary")
self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time)))
def get_result(self):
return {
"description": self.description,
"transcript": self.transcript,
"timestamp": self.timestamp
}
@dataclass
class TranscriptionInput:
frames = List[av.audio.frame.AudioFrame]
def __init__(self, frames):
self.frames = frames
@dataclass
class TranscriptionOutput:
cmd = str
result_text = str
def __init__(self, result_text):
self.cmd = "SHOW_TRANSCRIPTION"
self.result_text = result_text
def get_response(self):
return {
"cmd": self.cmd,
"text": self.result_text
}
@dataclass
class FinalSummaryResponse:
cmd = str
final_summary = str
duration = str
def __init__(self, final_summary, time):
self.duration = str(datetime.timedelta(seconds=round(time)))
self.final_summary = final_summary
self.cmd = ""
def get_response(self):
return {
"cmd": self.cmd,
"duration": self.duration,
"summary": self.final_summary
}

View File

@@ -6,6 +6,7 @@ import os
import uuid
import wave
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import aiohttp_cors
import requests
@@ -17,7 +18,9 @@ from faster_whisper import WhisperModel
from loguru import logger
from sortedcontainers import SortedDict
from utils.run_utils import run_in_executor, config
from reflector_dataclasses import FinalSummaryResponse, ParseLLMResult, TitleSummaryInput, TitleSummaryOutput, \
TranscriptionInput, TranscriptionOutput
from utils.run_utils import config, run_in_executor
pcs = set()
relay = MediaRelay()
@@ -43,49 +46,30 @@ blacklisted_messages = [" Thank you.", " See you next time!",
" And that's what I'm talking about."]
def get_title_and_summary(llm_input_text, last_timestamp):
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Any[None, ParseLLMResult]:
try:
output = json.loads(response.json()["results"][0]["text"])
return ParseLLMResult(param, output).get_result()
except Exception as e:
logger.info("Exception" + str(e))
return None
def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOutput]:
logger.info("Generating title and summary")
# output = llm.generate(prompt)
# Use monadical-ml to fire this query to an LLM and get result
headers = {
"Content-Type": "application/json"
}
prompt = f"""
### Human:
Create a JSON object as response. The JSON object must have 2 fields:
i) title and ii) summary. For the title field,generate a short title
for the given text. For the summary field, summarize the given text
in three sentences.
{llm_input_text}
### Assistant:
"""
data = {
"prompt": prompt
}
# TODO : Handle unexpected output formats from the model
try:
response = requests.post(LLM_URL, headers=headers, json=data)
output = json.loads(response.json()["results"][0]["text"])
output["description"] = output.pop("summary")
output["transcript"] = llm_input_text
output["timestamp"] = \
str(datetime.timedelta(seconds=round(last_timestamp)))
incremental_responses.append(output)
result = {
"cmd": "UPDATE_TOPICS",
"topics": incremental_responses,
}
response = requests.post(LLM_URL,
headers=param.headers,
json=param.data)
output = parse_llm_output(param, response)
if output:
incremental_responses.append(output)
return TitleSummaryOutput(incremental_responses).get_response()
except Exception as e:
logger.info("Exception" + str(e))
result = None
return result
return None
def channel_log(channel, t, message):
@@ -123,11 +107,11 @@ def channel_send_transcript(channel):
pass
def get_transcription(frames):
def get_transcription(input_frames: TranscriptionInput) -> Any[None, TranscriptionOutput]:
logger.info("Transcribing..")
sorted_transcripts[frames[0].time] = None
sorted_transcripts[input_frames[0].time] = None
# TODO:
# TODO: Find cleaner way, watch "no transcription" issue below
# Passing IO objects instead of temporary files throws an error
# Passing ndarrays (typecasted with float) does not give any
# transcription. Refer issue,
@@ -138,7 +122,7 @@ def get_transcription(frames):
wf.setframerate(RATE)
wf.setsampwidth(2)
for frame in frames:
for frame in input_frames.frames:
wf.writeframes(b"".join(frame.to_ndarray()))
wf.close()
@@ -173,30 +157,23 @@ def get_transcription(frames):
logger.info("Exception" + str(e))
pass
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": result_text
}
sorted_transcripts[frames[0].time] = result
result = TranscriptionOutput(result_text).get_response()
sorted_transcripts[input_frames.frames[0].time] = result
return result
def get_final_summary_response():
def get_final_summary_response() -> Any[None, FinalSummaryResponse]:
final_summary = ""
# Collate inc summaries
for topic in incremental_responses:
final_summary += topic["description"]
response = {
"cmd": "DISPLAY_FINAL_SUMMARY",
"duration": str(datetime.timedelta(
seconds=round(last_transcribed_time))),
"summary": final_summary
}
response = FinalSummaryResponse(final_summary, last_transcribed_time).get_response()
with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f:
f.write(json.dumps(incremental_responses))
return response
@@ -218,7 +195,9 @@ class AudioStreamTrack(MediaStreamTrack):
if local_frames := audio_buffer.read_many(256 * 960, partial=False):
whisper_result = run_in_executor(
get_transcription, local_frames, executor=executor
get_transcription,
TranscriptionInput(local_frames),
executor=executor
)
whisper_result.add_done_callback(
lambda f: channel_send_transcript(data_channel)
@@ -226,12 +205,13 @@ class AudioStreamTrack(MediaStreamTrack):
else None
)
if len(transcription_text) > 750:
if len(transcription_text) > 25:
llm_input_text = transcription_text
transcription_text = ""
param = TitleSummaryInput(input_text=llm_input_text,
transcribed_time=last_transcribed_time)
llm_result = run_in_executor(get_title_and_summary,
llm_input_text,
last_transcribed_time,
param,
executor=executor)
llm_result.add_done_callback(
lambda f: channel_send_increment(data_channel,
@@ -332,4 +312,4 @@ if __name__ == "__main__":
offer_resource = cors.add(app.router.add_resource("/offer"))
cors.add(offer_resource.add_route("POST", offer))
app.on_shutdown.append(on_shutdown)
web.run_app(app, access_log=None, host=args.host, port=args.port)
web.run_app(app, access_log=None, host=args.host, port=args.port)