mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Issues 44, 46, 47
This commit is contained in:
@@ -6,7 +6,7 @@ import os
|
||||
import uuid
|
||||
import wave
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Union, NoReturn
|
||||
from typing import NoReturn, Union
|
||||
|
||||
import aiohttp_cors
|
||||
import av
|
||||
@@ -17,33 +17,50 @@ from aiortc.contrib.media import MediaRelay
|
||||
from faster_whisper import WhisperModel
|
||||
from sortedcontainers import SortedDict
|
||||
|
||||
from reflector_dataclasses import FinalSummaryResult, ParseLLMResult,\
|
||||
TitleSummaryInput, TitleSummaryOutput, TranscriptionInput,\
|
||||
TranscriptionOutput, BlackListedMessages
|
||||
from utils.run_utils import CONFIG, run_in_executor
|
||||
from reflector_dataclasses import BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput, \
|
||||
TitleSummaryOutput, TranscriptionInput, TranscriptionOutput
|
||||
from utils.log_utils import LOGGER
|
||||
from utils.run_utils import CONFIG, run_in_executor
|
||||
|
||||
# WebRTC components
|
||||
pcs = set()
|
||||
relay = MediaRelay()
|
||||
data_channel = None
|
||||
audio_buffer = av.AudioFifo()
|
||||
executor = ThreadPoolExecutor()
|
||||
|
||||
# Transcription model
|
||||
model = WhisperModel("tiny", device="cpu",
|
||||
compute_type="float32",
|
||||
num_workers=12)
|
||||
|
||||
CHANNELS = 2
|
||||
RATE = 48000
|
||||
audio_buffer = av.AudioFifo()
|
||||
executor = ThreadPoolExecutor()
|
||||
# Audio configurations
|
||||
CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"])
|
||||
RATE = int(CONFIG["AUDIO"]["SAMPLING_RATE"])
|
||||
|
||||
# Global vars
|
||||
transcription_text = ""
|
||||
last_transcribed_time = 0.0
|
||||
|
||||
# LLM
|
||||
LLM_MACHINE_IP = CONFIG["LLM"]["LLM_MACHINE_IP"]
|
||||
LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"]
|
||||
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
|
||||
|
||||
# Topic and summary responses
|
||||
incremental_responses = []
|
||||
|
||||
# To synchronize the thread results before returning to the client
|
||||
sorted_transcripts = SortedDict()
|
||||
|
||||
|
||||
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]:
|
||||
"""
|
||||
Function to parse the LLM response
|
||||
:param param:
|
||||
:param response:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
output = json.loads(response.json()["results"][0]["text"])
|
||||
return ParseLLMResult(param, output)
|
||||
@@ -53,6 +70,12 @@ def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> U
|
||||
|
||||
|
||||
def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]:
|
||||
"""
|
||||
From the input provided (transcript), query the LLM to generate
|
||||
topics and summaries
|
||||
:param param:
|
||||
:return:
|
||||
"""
|
||||
LOGGER.info("Generating title and summary")
|
||||
|
||||
# TODO : Handle unexpected output formats from the model
|
||||
@@ -71,21 +94,45 @@ def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryO
|
||||
|
||||
|
||||
def channel_log(channel, t: str, message: str) -> NoReturn:
|
||||
"""
|
||||
Add logs
|
||||
:param channel:
|
||||
:param t:
|
||||
:param message:
|
||||
:return:
|
||||
"""
|
||||
LOGGER.info("channel(%s) %s %s" % (channel.label, t, message))
|
||||
|
||||
|
||||
def channel_send(channel, message: str) -> NoReturn:
|
||||
"""
|
||||
Send text messages via the data channel
|
||||
:param channel:
|
||||
:param message:
|
||||
:return:
|
||||
"""
|
||||
if channel:
|
||||
channel.send(message)
|
||||
|
||||
|
||||
def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn:
|
||||
"""
|
||||
Send the incremental topics and summaries via the data channel
|
||||
:param channel:
|
||||
:param param:
|
||||
:return:
|
||||
"""
|
||||
if channel and param:
|
||||
message = param.get_result()
|
||||
channel.send(json.dumps(message))
|
||||
|
||||
|
||||
def channel_send_transcript(channel) -> NoReturn:
|
||||
"""
|
||||
Send the transcription result via the data channel
|
||||
:param channel:
|
||||
:return:
|
||||
"""
|
||||
# channel_log(channel, ">", message)
|
||||
if channel:
|
||||
try:
|
||||
@@ -106,6 +153,12 @@ def channel_send_transcript(channel) -> NoReturn:
|
||||
|
||||
|
||||
def get_transcription(input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]:
|
||||
"""
|
||||
From the collected audio frames create transcription by inferring from
|
||||
the chosen transcription model
|
||||
:param input_frames:
|
||||
:return:
|
||||
"""
|
||||
LOGGER.info("Transcribing..")
|
||||
sorted_transcripts[input_frames.frames[0].time] = None
|
||||
|
||||
@@ -290,6 +343,12 @@ async def offer(request: requests.Request) -> web.Response:
|
||||
|
||||
|
||||
async def on_shutdown(application: web.Application) -> NoReturn:
|
||||
"""
|
||||
On shutdown, the coroutines that shutdown client connections are
|
||||
executed
|
||||
:param application:
|
||||
:return:
|
||||
"""
|
||||
coroutines = [pc.close() for pc in pcs]
|
||||
await asyncio.gather(*coroutines)
|
||||
pcs.clear()
|
||||
|
||||
Reference in New Issue
Block a user