diff --git a/server_executor_cleaned.py b/server_executor_cleaned.py index e032da69..55909402 100644 --- a/server_executor_cleaned.py +++ b/server_executor_cleaned.py @@ -15,8 +15,8 @@ from aiortc.contrib.media import MediaRelay from av import AudioFifo from loguru import logger from whisper_jax import FlaxWhisperPipline - from utils.run_utils import run_in_executor +from sortedcontainers import SortedDict pcs = set() relay = MediaRelay() @@ -34,9 +34,11 @@ last_transcribed_time = 0.0 LLM_MACHINE_IP = "216.153.52.83" LLM_MACHINE_PORT = "5000" LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" +incremental_responses = [] +sorted_transcripts = SortedDict() -def get_title_and_summary(llm_input_text): +def get_title_and_summary(llm_input_text, last_timestamp): print("Generating title and summary") # output = llm.generate(prompt) @@ -47,8 +49,10 @@ def get_title_and_summary(llm_input_text): prompt = f""" ### Human: - Create a JSON object as response. The JSON object must have 2 fields: i) title and ii) summary. For the title field, - generate a short title for the given text. For the summary field, summarize the given text in three sentences. + Create a JSON object as response. The JSON object must have 2 fields: + i) title and ii) summary. For the title field,generate a short title + for the given text. For the summary field, summarize the given text + in three sentences. {llm_input_text} @@ -59,14 +63,22 @@ def get_title_and_summary(llm_input_text): "prompt": prompt } + # To-do: Handle unexpected output formats from the model try: response = requests.post(LLM_URL, headers=headers, json=data) output = json.loads(response.json()["results"][0]["text"]) output["description"] = output.pop("summary") + output["timestamp"] =\ + str(datetime.timedelta(seconds=round(last_timestamp))) + incremental_responses.append(output) + result = { + "cmd": "UPDATE_TOPICS", + "topics": incremental_responses, + } except Exception as e: - print(str(e)) - output = None - return output + print("Exception" + str(e)) + result = None + return result def channel_log(channel, t, message): @@ -74,13 +86,32 @@ def channel_log(channel, t, message): def channel_send(channel, message): - # channel_log(channel, ">", message) + if channel: + channel.send(message) + + +def channel_send_increment(channel, message): if channel and message: channel.send(json.dumps(message)) +def channel_send_transcript(channel): + # channel_log(channel, ">", message) + if channel: + try: + least_time = sorted_transcripts.keys()[0] + message = sorted_transcripts[least_time] + if message: + del sorted_transcripts[least_time] + channel.send(json.dumps(message)) + except Exception as e: + print("Exception", str(e)) + pass + + def get_transcription(frames): print("Transcribing..") + sorted_transcripts[frames[0].time] = None out_file = io.BytesIO() wf = wave.open(out_file, "wb") wf.setnchannels(CHANNELS) @@ -105,10 +136,10 @@ def get_transcription(frames): last_transcribed_time += duration result = { - "text": whisper_result["text"], - "timestamp": str(datetime.timedelta(seconds= - round(last_transcribed_time))) + "cmd": "SHOW_TRANSCRIPTION", + "text": whisper_result["text"] } + sorted_transcripts[frames[0].time] = result return result @@ -133,22 +164,24 @@ class AudioStreamTrack(MediaStreamTrack): get_transcription, local_frames, executor=executor ) whisper_result.add_done_callback( - lambda f: channel_send(data_channel, whisper_result.result()) + lambda f: channel_send_transcript(data_channel) if f.result() else None ) - if len(transcription_text) > 2000: - llm_input_text = transcription_text - transcription_text = "" - llm_result = run_in_executor(get_title_and_summary, - llm_input_text, - executor=executor) - llm_result.add_done_callback( - lambda f: channel_send(data_channel, llm_result.result()) - if f.result() - else None - ) + if len(transcription_text) > 100: + llm_input_text = transcription_text + transcription_text = "" + llm_result = run_in_executor(get_title_and_summary, + llm_input_text, + last_transcribed_time, + executor=executor) + llm_result.add_done_callback( + lambda f: channel_send_increment(data_channel, + llm_result.result()) + if f.result() + else None + ) return frame @@ -215,7 +248,9 @@ if __name__ == "__main__": app, defaults={ "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, expose_headers="*", allow_headers="*" + allow_credentials=True, + expose_headers="*", + allow_headers="*" ) }, )