server: reformat whole project using black

This commit is contained in:
Mathieu Virbel
2023-07-27 14:08:41 +02:00
parent 314321c603
commit 094ed696c4
12 changed files with 406 additions and 237 deletions

View File

@@ -17,8 +17,15 @@ from aiortc.contrib.media import MediaRelay
from faster_whisper import WhisperModel
from reflector_dataclasses import (
BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput,
TitleSummaryOutput, TranscriptionInput, TranscriptionOutput, TranscriptionContext)
BlackListedMessages,
FinalSummaryResult,
ParseLLMResult,
TitleSummaryInput,
TitleSummaryOutput,
TranscriptionInput,
TranscriptionOutput,
TranscriptionContext,
)
from utils.log_utils import LOGGER
from utils.run_utils import CONFIG, run_in_executor, SECRETS
@@ -28,9 +35,7 @@ relay = MediaRelay()
executor = ThreadPoolExecutor()
# Transcription model
model = WhisperModel("tiny", device="cpu",
compute_type="float32",
num_workers=12)
model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12)
# Audio configurations
CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"])
@@ -46,7 +51,10 @@ else:
LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"]
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]:
def parse_llm_output(
param: TitleSummaryInput, response: requests.Response
) -> Union[None, ParseLLMResult]:
"""
Function to parse the LLM response
:param param:
@@ -61,7 +69,9 @@ def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> U
return None
def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]:
def get_title_and_summary(
ctx: TranscriptionContext, param: TitleSummaryInput
) -> Union[None, TitleSummaryOutput]:
"""
From the input provided (transcript), query the LLM to generate
topics and summaries
@@ -72,9 +82,7 @@ def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -
# TODO : Handle unexpected output formats from the model
try:
response = requests.post(LLM_URL,
headers=param.headers,
json=param.data)
response = requests.post(LLM_URL, headers=param.headers, json=param.data)
output = parse_llm_output(param, response)
if output:
result = output.get_result()
@@ -107,7 +115,9 @@ def channel_send(channel, message: str) -> NoReturn:
channel.send(message)
def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn:
def channel_send_increment(
channel, param: Union[FinalSummaryResult, TitleSummaryOutput]
) -> NoReturn:
"""
Send the incremental topics and summaries via the data channel
:param channel:
@@ -145,7 +155,9 @@ def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn:
LOGGER.info("Exception", str(exception))
def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]:
def get_transcription(
ctx: TranscriptionContext, input_frames: TranscriptionInput
) -> Union[None, TranscriptionOutput]:
"""
From the collected audio frames create transcription by inferring from
the chosen transcription model
@@ -173,12 +185,13 @@ def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInpu
result_text = ""
try:
segments, _ = \
model.transcribe(audio_file,
language="en",
beam_size=5,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500})
segments, _ = model.transcribe(
audio_file,
language="en",
beam_size=5,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
os.remove(audio_file)
segments = list(segments)
result_text = ""
@@ -191,7 +204,7 @@ def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInpu
start_time = 0.0
if not segment.end:
end_time = 5.5
duration += (end_time - start_time)
duration += end_time - start_time
ctx.last_transcribed_time += duration
ctx.transcription_text += result_text
@@ -218,8 +231,9 @@ def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult:
response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
with open("./artefacts/meeting_titles_and_summaries.txt", "a",
encoding="utf-8") as file:
with open(
"./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8"
) as file:
file.write(json.dumps(ctx.incremental_responses))
return response
@@ -243,33 +257,32 @@ class AudioStreamTrack(MediaStreamTrack):
frame = await self.track.recv()
self.audio_buffer.write(frame)
if local_frames := self.audio_buffer.read_many(AUDIO_BUFFER_SIZE, partial=False):
if local_frames := self.audio_buffer.read_many(
AUDIO_BUFFER_SIZE, partial=False
):
whisper_result = run_in_executor(
get_transcription,
ctx,
TranscriptionInput(local_frames),
executor=executor
get_transcription,
ctx,
TranscriptionInput(local_frames),
executor=executor,
)
whisper_result.add_done_callback(
lambda f: channel_send_transcript(ctx)
if f.result()
else None
lambda f: channel_send_transcript(ctx) if f.result() else None
)
if len(ctx.transcription_text) > 25:
llm_input_text = ctx.transcription_text
ctx.transcription_text = ""
param = TitleSummaryInput(input_text=llm_input_text,
transcribed_time=ctx.last_transcribed_time)
llm_result = run_in_executor(get_title_and_summary,
ctx,
param,
executor=executor)
param = TitleSummaryInput(
input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time
)
llm_result = run_in_executor(
get_title_and_summary, ctx, param, executor=executor
)
llm_result.add_done_callback(
lambda f: channel_send_increment(ctx.data_channel,
llm_result.result())
if f.result()
else None
lambda f: channel_send_increment(ctx.data_channel, llm_result.result())
if f.result()
else None
)
return frame
@@ -328,13 +341,10 @@ async def offer(request: requests.Request) -> web.Response:
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return web.Response(
content_type="application/json",
text=json.dumps(
{
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type
}
),
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
)
@@ -351,26 +361,22 @@ async def on_shutdown(application: web.Application) -> NoReturn:
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="WebRTC based server for Reflector"
parser = argparse.ArgumentParser(description="WebRTC based server for Reflector")
parser.add_argument(
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
)
parser.add_argument(
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=1250, help="Server port (def: 1250)"
"--port", type=int, default=1250, help="Server port (def: 1250)"
)
args = parser.parse_args()
app = web.Application()
cors = aiohttp_cors.setup(
app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True,
expose_headers="*",
allow_headers="*"
)
},
app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True, expose_headers="*", allow_headers="*"
)
},
)
offer_resource = cors.add(app.router.add_resource("/offer"))