mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
add data classes and typing
This commit is contained in:
119
reflector_dataclasses.py
Normal file
119
reflector_dataclasses.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import av
|
||||
|
||||
|
||||
@dataclass
|
||||
class TitleSummaryInput:
|
||||
input_text = str
|
||||
transcribed_time = float
|
||||
prompt = str
|
||||
data = dict
|
||||
|
||||
def __init__(self, transcribed_time, input_text=""):
|
||||
self.input_text = input_text
|
||||
self.transcribed_time = transcribed_time
|
||||
self.prompt = f"""
|
||||
### Human:
|
||||
Create a JSON object as response. The JSON object must have 2 fields:
|
||||
i) title and ii) summary. For the title field,generate a short title
|
||||
for the given text. For the summary field, summarize the given text
|
||||
in three sentences.
|
||||
|
||||
{self.input_text}
|
||||
|
||||
### Assistant:
|
||||
"""
|
||||
self.data = {"data": self.prompt}
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class IncrementalResponse:
|
||||
title = str
|
||||
description = str
|
||||
transcript = str
|
||||
|
||||
def __init__(self, title, desc, transcript):
|
||||
self.title = title
|
||||
self.description = desc
|
||||
self.transcript = transcript
|
||||
|
||||
|
||||
@dataclass
|
||||
class TitleSummaryOutput:
|
||||
cmd = str
|
||||
topics = List[IncrementalResponse]
|
||||
|
||||
def __init__(self, inc_responses):
|
||||
self.topics = inc_responses
|
||||
|
||||
def get_response(self):
|
||||
return {
|
||||
"cmd": self.cmd,
|
||||
"topics": self.topics
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParseLLMResult:
|
||||
description = str
|
||||
transcript = str
|
||||
timestamp = str
|
||||
|
||||
def __init__(self, param: TitleSummaryInput, output: dict):
|
||||
self.transcript = param.input_text
|
||||
self.description = output.pop("summary")
|
||||
self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time)))
|
||||
|
||||
def get_result(self):
|
||||
return {
|
||||
"description": self.description,
|
||||
"transcript": self.transcript,
|
||||
"timestamp": self.timestamp
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionInput:
|
||||
frames = List[av.audio.frame.AudioFrame]
|
||||
|
||||
def __init__(self, frames):
|
||||
self.frames = frames
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionOutput:
|
||||
cmd = str
|
||||
result_text = str
|
||||
|
||||
def __init__(self, result_text):
|
||||
self.cmd = "SHOW_TRANSCRIPTION"
|
||||
self.result_text = result_text
|
||||
|
||||
def get_response(self):
|
||||
return {
|
||||
"cmd": self.cmd,
|
||||
"text": self.result_text
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinalSummaryResponse:
|
||||
cmd = str
|
||||
final_summary = str
|
||||
duration = str
|
||||
|
||||
def __init__(self, final_summary, time):
|
||||
self.duration = str(datetime.timedelta(seconds=round(time)))
|
||||
self.final_summary = final_summary
|
||||
self.cmd = ""
|
||||
|
||||
def get_response(self):
|
||||
return {
|
||||
"cmd": self.cmd,
|
||||
"duration": self.duration,
|
||||
"summary": self.final_summary
|
||||
}
|
||||
98
server.py
98
server.py
@@ -6,6 +6,7 @@ import os
|
||||
import uuid
|
||||
import wave
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import aiohttp_cors
|
||||
import requests
|
||||
@@ -17,7 +18,9 @@ from faster_whisper import WhisperModel
|
||||
from loguru import logger
|
||||
from sortedcontainers import SortedDict
|
||||
|
||||
from utils.run_utils import run_in_executor, config
|
||||
from reflector_dataclasses import FinalSummaryResponse, ParseLLMResult, TitleSummaryInput, TitleSummaryOutput, \
|
||||
TranscriptionInput, TranscriptionOutput
|
||||
from utils.run_utils import config, run_in_executor
|
||||
|
||||
pcs = set()
|
||||
relay = MediaRelay()
|
||||
@@ -43,49 +46,30 @@ blacklisted_messages = [" Thank you.", " See you next time!",
|
||||
" And that's what I'm talking about."]
|
||||
|
||||
|
||||
def get_title_and_summary(llm_input_text, last_timestamp):
|
||||
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Any[None, ParseLLMResult]:
|
||||
try:
|
||||
output = json.loads(response.json()["results"][0]["text"])
|
||||
return ParseLLMResult(param, output).get_result()
|
||||
except Exception as e:
|
||||
logger.info("Exception" + str(e))
|
||||
return None
|
||||
|
||||
|
||||
def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOutput]:
|
||||
logger.info("Generating title and summary")
|
||||
# output = llm.generate(prompt)
|
||||
|
||||
# Use monadical-ml to fire this query to an LLM and get result
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
prompt = f"""
|
||||
### Human:
|
||||
Create a JSON object as response. The JSON object must have 2 fields:
|
||||
i) title and ii) summary. For the title field,generate a short title
|
||||
for the given text. For the summary field, summarize the given text
|
||||
in three sentences.
|
||||
|
||||
{llm_input_text}
|
||||
|
||||
### Assistant:
|
||||
"""
|
||||
|
||||
data = {
|
||||
"prompt": prompt
|
||||
}
|
||||
|
||||
# TODO : Handle unexpected output formats from the model
|
||||
try:
|
||||
response = requests.post(LLM_URL, headers=headers, json=data)
|
||||
output = json.loads(response.json()["results"][0]["text"])
|
||||
output["description"] = output.pop("summary")
|
||||
output["transcript"] = llm_input_text
|
||||
output["timestamp"] = \
|
||||
str(datetime.timedelta(seconds=round(last_timestamp)))
|
||||
incremental_responses.append(output)
|
||||
result = {
|
||||
"cmd": "UPDATE_TOPICS",
|
||||
"topics": incremental_responses,
|
||||
}
|
||||
|
||||
response = requests.post(LLM_URL,
|
||||
headers=param.headers,
|
||||
json=param.data)
|
||||
output = parse_llm_output(param, response)
|
||||
if output:
|
||||
incremental_responses.append(output)
|
||||
return TitleSummaryOutput(incremental_responses).get_response()
|
||||
except Exception as e:
|
||||
logger.info("Exception" + str(e))
|
||||
result = None
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def channel_log(channel, t, message):
|
||||
@@ -123,11 +107,11 @@ def channel_send_transcript(channel):
|
||||
pass
|
||||
|
||||
|
||||
def get_transcription(frames):
|
||||
def get_transcription(input_frames: TranscriptionInput) -> Any[None, TranscriptionOutput]:
|
||||
logger.info("Transcribing..")
|
||||
sorted_transcripts[frames[0].time] = None
|
||||
sorted_transcripts[input_frames[0].time] = None
|
||||
|
||||
# TODO:
|
||||
# TODO: Find cleaner way, watch "no transcription" issue below
|
||||
# Passing IO objects instead of temporary files throws an error
|
||||
# Passing ndarrays (typecasted with float) does not give any
|
||||
# transcription. Refer issue,
|
||||
@@ -138,7 +122,7 @@ def get_transcription(frames):
|
||||
wf.setframerate(RATE)
|
||||
wf.setsampwidth(2)
|
||||
|
||||
for frame in frames:
|
||||
for frame in input_frames.frames:
|
||||
wf.writeframes(b"".join(frame.to_ndarray()))
|
||||
wf.close()
|
||||
|
||||
@@ -173,30 +157,23 @@ def get_transcription(frames):
|
||||
logger.info("Exception" + str(e))
|
||||
pass
|
||||
|
||||
result = {
|
||||
"cmd": "SHOW_TRANSCRIPTION",
|
||||
"text": result_text
|
||||
}
|
||||
sorted_transcripts[frames[0].time] = result
|
||||
result = TranscriptionOutput(result_text).get_response()
|
||||
sorted_transcripts[input_frames.frames[0].time] = result
|
||||
return result
|
||||
|
||||
|
||||
def get_final_summary_response():
|
||||
def get_final_summary_response() -> Any[None, FinalSummaryResponse]:
|
||||
final_summary = ""
|
||||
|
||||
# Collate inc summaries
|
||||
for topic in incremental_responses:
|
||||
final_summary += topic["description"]
|
||||
|
||||
response = {
|
||||
"cmd": "DISPLAY_FINAL_SUMMARY",
|
||||
"duration": str(datetime.timedelta(
|
||||
seconds=round(last_transcribed_time))),
|
||||
"summary": final_summary
|
||||
}
|
||||
response = FinalSummaryResponse(final_summary, last_transcribed_time).get_response()
|
||||
|
||||
with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f:
|
||||
f.write(json.dumps(incremental_responses))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -218,7 +195,9 @@ class AudioStreamTrack(MediaStreamTrack):
|
||||
|
||||
if local_frames := audio_buffer.read_many(256 * 960, partial=False):
|
||||
whisper_result = run_in_executor(
|
||||
get_transcription, local_frames, executor=executor
|
||||
get_transcription,
|
||||
TranscriptionInput(local_frames),
|
||||
executor=executor
|
||||
)
|
||||
whisper_result.add_done_callback(
|
||||
lambda f: channel_send_transcript(data_channel)
|
||||
@@ -226,12 +205,13 @@ class AudioStreamTrack(MediaStreamTrack):
|
||||
else None
|
||||
)
|
||||
|
||||
if len(transcription_text) > 750:
|
||||
if len(transcription_text) > 25:
|
||||
llm_input_text = transcription_text
|
||||
transcription_text = ""
|
||||
param = TitleSummaryInput(input_text=llm_input_text,
|
||||
transcribed_time=last_transcribed_time)
|
||||
llm_result = run_in_executor(get_title_and_summary,
|
||||
llm_input_text,
|
||||
last_transcribed_time,
|
||||
param,
|
||||
executor=executor)
|
||||
llm_result.add_done_callback(
|
||||
lambda f: channel_send_increment(data_channel,
|
||||
@@ -332,4 +312,4 @@ if __name__ == "__main__":
|
||||
offer_resource = cors.add(app.router.add_resource("/offer"))
|
||||
cors.add(offer_resource.add_route("POST", offer))
|
||||
app.on_shutdown.append(on_shutdown)
|
||||
web.run_app(app, access_log=None, host=args.host, port=args.port)
|
||||
web.run_app(app, access_log=None, host=args.host, port=args.port)
|
||||
|
||||
Reference in New Issue
Block a user