mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: reformat whole project using black
This commit is contained in:
126
server/server.py
126
server/server.py
@@ -17,8 +17,15 @@ from aiortc.contrib.media import MediaRelay
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
from reflector_dataclasses import (
|
||||
BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput,
|
||||
TitleSummaryOutput, TranscriptionInput, TranscriptionOutput, TranscriptionContext)
|
||||
BlackListedMessages,
|
||||
FinalSummaryResult,
|
||||
ParseLLMResult,
|
||||
TitleSummaryInput,
|
||||
TitleSummaryOutput,
|
||||
TranscriptionInput,
|
||||
TranscriptionOutput,
|
||||
TranscriptionContext,
|
||||
)
|
||||
from utils.log_utils import LOGGER
|
||||
from utils.run_utils import CONFIG, run_in_executor, SECRETS
|
||||
|
||||
@@ -28,9 +35,7 @@ relay = MediaRelay()
|
||||
executor = ThreadPoolExecutor()
|
||||
|
||||
# Transcription model
|
||||
model = WhisperModel("tiny", device="cpu",
|
||||
compute_type="float32",
|
||||
num_workers=12)
|
||||
model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12)
|
||||
|
||||
# Audio configurations
|
||||
CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"])
|
||||
@@ -46,7 +51,10 @@ else:
|
||||
LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"]
|
||||
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
|
||||
|
||||
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]:
|
||||
|
||||
def parse_llm_output(
|
||||
param: TitleSummaryInput, response: requests.Response
|
||||
) -> Union[None, ParseLLMResult]:
|
||||
"""
|
||||
Function to parse the LLM response
|
||||
:param param:
|
||||
@@ -61,7 +69,9 @@ def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> U
|
||||
return None
|
||||
|
||||
|
||||
def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]:
|
||||
def get_title_and_summary(
|
||||
ctx: TranscriptionContext, param: TitleSummaryInput
|
||||
) -> Union[None, TitleSummaryOutput]:
|
||||
"""
|
||||
From the input provided (transcript), query the LLM to generate
|
||||
topics and summaries
|
||||
@@ -72,9 +82,7 @@ def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -
|
||||
|
||||
# TODO : Handle unexpected output formats from the model
|
||||
try:
|
||||
response = requests.post(LLM_URL,
|
||||
headers=param.headers,
|
||||
json=param.data)
|
||||
response = requests.post(LLM_URL, headers=param.headers, json=param.data)
|
||||
output = parse_llm_output(param, response)
|
||||
if output:
|
||||
result = output.get_result()
|
||||
@@ -107,7 +115,9 @@ def channel_send(channel, message: str) -> NoReturn:
|
||||
channel.send(message)
|
||||
|
||||
|
||||
def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn:
|
||||
def channel_send_increment(
|
||||
channel, param: Union[FinalSummaryResult, TitleSummaryOutput]
|
||||
) -> NoReturn:
|
||||
"""
|
||||
Send the incremental topics and summaries via the data channel
|
||||
:param channel:
|
||||
@@ -145,7 +155,9 @@ def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn:
|
||||
LOGGER.info("Exception", str(exception))
|
||||
|
||||
|
||||
def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]:
|
||||
def get_transcription(
|
||||
ctx: TranscriptionContext, input_frames: TranscriptionInput
|
||||
) -> Union[None, TranscriptionOutput]:
|
||||
"""
|
||||
From the collected audio frames create transcription by inferring from
|
||||
the chosen transcription model
|
||||
@@ -173,12 +185,13 @@ def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInpu
|
||||
result_text = ""
|
||||
|
||||
try:
|
||||
segments, _ = \
|
||||
model.transcribe(audio_file,
|
||||
language="en",
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500})
|
||||
segments, _ = model.transcribe(
|
||||
audio_file,
|
||||
language="en",
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
os.remove(audio_file)
|
||||
segments = list(segments)
|
||||
result_text = ""
|
||||
@@ -191,7 +204,7 @@ def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInpu
|
||||
start_time = 0.0
|
||||
if not segment.end:
|
||||
end_time = 5.5
|
||||
duration += (end_time - start_time)
|
||||
duration += end_time - start_time
|
||||
|
||||
ctx.last_transcribed_time += duration
|
||||
ctx.transcription_text += result_text
|
||||
@@ -218,8 +231,9 @@ def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult:
|
||||
|
||||
response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
|
||||
|
||||
with open("./artefacts/meeting_titles_and_summaries.txt", "a",
|
||||
encoding="utf-8") as file:
|
||||
with open(
|
||||
"./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8"
|
||||
) as file:
|
||||
file.write(json.dumps(ctx.incremental_responses))
|
||||
|
||||
return response
|
||||
@@ -243,33 +257,32 @@ class AudioStreamTrack(MediaStreamTrack):
|
||||
frame = await self.track.recv()
|
||||
self.audio_buffer.write(frame)
|
||||
|
||||
if local_frames := self.audio_buffer.read_many(AUDIO_BUFFER_SIZE, partial=False):
|
||||
if local_frames := self.audio_buffer.read_many(
|
||||
AUDIO_BUFFER_SIZE, partial=False
|
||||
):
|
||||
whisper_result = run_in_executor(
|
||||
get_transcription,
|
||||
ctx,
|
||||
TranscriptionInput(local_frames),
|
||||
executor=executor
|
||||
get_transcription,
|
||||
ctx,
|
||||
TranscriptionInput(local_frames),
|
||||
executor=executor,
|
||||
)
|
||||
whisper_result.add_done_callback(
|
||||
lambda f: channel_send_transcript(ctx)
|
||||
if f.result()
|
||||
else None
|
||||
lambda f: channel_send_transcript(ctx) if f.result() else None
|
||||
)
|
||||
|
||||
if len(ctx.transcription_text) > 25:
|
||||
llm_input_text = ctx.transcription_text
|
||||
ctx.transcription_text = ""
|
||||
param = TitleSummaryInput(input_text=llm_input_text,
|
||||
transcribed_time=ctx.last_transcribed_time)
|
||||
llm_result = run_in_executor(get_title_and_summary,
|
||||
ctx,
|
||||
param,
|
||||
executor=executor)
|
||||
param = TitleSummaryInput(
|
||||
input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time
|
||||
)
|
||||
llm_result = run_in_executor(
|
||||
get_title_and_summary, ctx, param, executor=executor
|
||||
)
|
||||
llm_result.add_done_callback(
|
||||
lambda f: channel_send_increment(ctx.data_channel,
|
||||
llm_result.result())
|
||||
if f.result()
|
||||
else None
|
||||
lambda f: channel_send_increment(ctx.data_channel, llm_result.result())
|
||||
if f.result()
|
||||
else None
|
||||
)
|
||||
return frame
|
||||
|
||||
@@ -328,13 +341,10 @@ async def offer(request: requests.Request) -> web.Response:
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
return web.Response(
|
||||
content_type="application/json",
|
||||
text=json.dumps(
|
||||
{
|
||||
"sdp": pc.localDescription.sdp,
|
||||
"type": pc.localDescription.type
|
||||
}
|
||||
),
|
||||
content_type="application/json",
|
||||
text=json.dumps(
|
||||
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -351,26 +361,22 @@ async def on_shutdown(application: web.Application) -> NoReturn:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="WebRTC based server for Reflector"
|
||||
parser = argparse.ArgumentParser(description="WebRTC based server for Reflector")
|
||||
parser.add_argument(
|
||||
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=1250, help="Server port (def: 1250)"
|
||||
"--port", type=int, default=1250, help="Server port (def: 1250)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
app = web.Application()
|
||||
cors = aiohttp_cors.setup(
|
||||
app,
|
||||
defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True,
|
||||
expose_headers="*",
|
||||
allow_headers="*"
|
||||
)
|
||||
},
|
||||
app,
|
||||
defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True, expose_headers="*", allow_headers="*"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
offer_resource = cors.add(app.router.add_resource("/offer"))
|
||||
|
||||
Reference in New Issue
Block a user