From 094ed696c4450f427a96a24028d14f5e0d08f897 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 27 Jul 2023 14:08:41 +0200 Subject: [PATCH] server: reformat whole project using black --- server/client.py | 36 ++++---- server/poetry.lock | 97 +++++++++++++++++++- server/pyproject.toml | 3 + server/reflector_dataclasses.py | 48 +++++----- server/server.py | 126 +++++++++++++------------- server/stream_client.py | 25 ++---- server/utils/file_utils.py | 10 ++- server/utils/format_output.py | 17 ++-- server/utils/log_utils.py | 1 + server/utils/run_utils.py | 5 +- server/utils/text_utils.py | 121 ++++++++++++++----------- server/utils/viz_utils.py | 154 ++++++++++++++++++++++---------- 12 files changed, 406 insertions(+), 237 deletions(-) diff --git a/server/client.py b/server/client.py index 519ccc26..aa89934e 100644 --- a/server/client.py +++ b/server/client.py @@ -2,13 +2,13 @@ import argparse import asyncio import signal -from aiortc.contrib.signaling import (add_signaling_arguments, - create_signaling) +from aiortc.contrib.signaling import add_signaling_arguments, create_signaling from utils.log_utils import LOGGER from stream_client import StreamClient from typing import NoReturn + async def main() -> NoReturn: """ Reflector's entry point to the python client for WebRTC streaming if not @@ -18,21 +18,21 @@ async def main() -> NoReturn: parser = argparse.ArgumentParser(description="Data channels ping/pong") parser.add_argument( - "--url", type=str, nargs="?", default="http://0.0.0.0:1250/offer" + "--url", type=str, nargs="?", default="http://0.0.0.0:1250/offer" ) parser.add_argument( - "--ping-pong", - help="Benchmark data channel with ping pong", - type=eval, - choices=[True, False], - default="False", + "--ping-pong", + help="Benchmark data channel with ping pong", + type=eval, + choices=[True, False], + default="False", ) parser.add_argument( - "--play-from", - type=str, - default="", + "--play-from", + type=str, + default="", ) add_signaling_arguments(parser) @@ -45,8 +45,7 @@ async def main() -> NoReturn: LOGGER.info(f"Received exit signal {signal.name}...") LOGGER.info("Closing database connections") LOGGER.info("Nacking outstanding messages") - tasks = [t for t in asyncio.all_tasks() if t is not - asyncio.current_task()] + tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()] [task.cancel() for task in tasks] @@ -58,15 +57,14 @@ async def main() -> NoReturn: signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) loop = asyncio.get_event_loop() for s in signals: - loop.add_signal_handler( - s, lambda s=s: asyncio.create_task(shutdown(s, loop))) + loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown(s, loop))) # Init client sc = StreamClient( - signaling=signaling, - url=args.url, - play_from=args.play_from, - ping_pong=args.ping_pong + signaling=signaling, + url=args.url, + play_from=args.play_from, + ping_pong=args.ping_pong, ) await sc.start() async for msg in sc.get_reader(): diff --git a/server/poetry.lock b/server/poetry.lock index 74420b26..56e0e183 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -325,6 +325,50 @@ files = [ {file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"}, ] +[[package]] +name = "black" +version = "23.7.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.8" +files = [ + {file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"}, + {file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"}, + {file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"}, + {file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"}, + {file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"}, + {file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"}, + {file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"}, + {file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"}, + {file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"}, + {file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"}, + {file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"}, + {file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"}, + {file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"}, + {file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"}, + {file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.7.4)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "certifi" version = "2023.7.22" @@ -496,6 +540,20 @@ files = [ {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, ] +[[package]] +name = "click" +version = "8.1.6" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +files = [ + {file = "click-8.1.6-py3-none-any.whl", hash = "sha256:fa244bb30b3b5ee2cae3da8f55c9e5e0c0e86093306301fb418eb9dc40fbded5"}, + {file = "click-8.1.6.tar.gz", hash = "sha256:48ee849951919527a045bfe3bf7baa8a959c423134e1a5b98c05c20ba75a1cbd"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [[package]] name = "colorama" version = "0.4.6" @@ -1080,6 +1138,17 @@ files = [ {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, ] +[[package]] +name = "mypy-extensions" +version = "1.0.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.5" +files = [ + {file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"}, + {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, +] + [[package]] name = "numpy" version = "1.25.1" @@ -1166,6 +1235,32 @@ files = [ {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, ] +[[package]] +name = "pathspec" +version = "0.11.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.7" +files = [ + {file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"}, + {file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"}, +] + +[[package]] +name = "platformdirs" +version = "3.9.1" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"." +optional = false +python-versions = ">=3.7" +files = [ + {file = "platformdirs-3.9.1-py3-none-any.whl", hash = "sha256:ad8291ae0ae5072f66c16945166cb11c63394c7a3ad1b1bc9828ca3162da8c2f"}, + {file = "platformdirs-3.9.1.tar.gz", hash = "sha256:1b42b450ad933e981d56e59f1b97495428c9bd60698baab9f3eb3d00d5822421"}, +] + +[package.extras] +docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)"] + [[package]] name = "protobuf" version = "4.23.4" @@ -1619,4 +1714,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "9b82606318ce1096923c0b25e5b3a6b07292f24465611d968e78f37a26e3d212" +content-hash = "e8eb6b4f81c090adb882a1b293d81f32167ea89f4636222d43fe0e9131cb97d6" diff --git a/server/pyproject.toml b/server/pyproject.toml index a7435ffe..25f4a9ff 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -18,6 +18,9 @@ sortedcontainers = "^2.4.0" loguru = "^0.7.0" +[tool.poetry.group.dev.dependencies] +black = "^23.7.0" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" diff --git a/server/reflector_dataclasses.py b/server/reflector_dataclasses.py index c417b857..d2e91c06 100644 --- a/server/reflector_dataclasses.py +++ b/server/reflector_dataclasses.py @@ -17,6 +17,7 @@ class TitleSummaryInput: Data class for the input to generate title and summaries. The outcome will be used to send query to the LLM for processing. """ + input_text = str transcribed_time = float prompt = str @@ -25,8 +26,7 @@ class TitleSummaryInput: def __init__(self, transcribed_time, input_text=""): self.input_text = input_text self.transcribed_time = transcribed_time - self.prompt = \ - f""" + self.prompt = f""" ### Human: Create a JSON object as response.The JSON object must have 2 fields: i) title and ii) summary.For the title field,generate a short title @@ -47,6 +47,7 @@ class IncrementalResult: Data class for the result of generating one title and summaries. Defines how a single "topic" looks like. """ + title = str description = str transcript = str @@ -65,6 +66,7 @@ class TitleSummaryOutput: Data class for the result of all generated titles and summaries. The result will be sent back to the client """ + cmd = str topics = List[IncrementalResult] @@ -77,10 +79,7 @@ class TitleSummaryOutput: Return the result dict for displaying the transcription :return: """ - return { - "cmd": self.cmd, - "topics": self.topics - } + return {"cmd": self.cmd, "topics": self.topics} @dataclass @@ -89,6 +88,7 @@ class ParseLLMResult: Data class to parse the result returned by the LLM while generating title and summaries. The result will be sent back to the client. """ + title = str description = str transcript = str @@ -98,8 +98,7 @@ class ParseLLMResult: self.title = output["title"] self.transcript = param.input_text self.description = output.pop("summary") - self.timestamp = \ - str(datetime.timedelta(seconds=round(param.transcribed_time))) + self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time))) def get_result(self) -> dict: """ @@ -107,10 +106,10 @@ class ParseLLMResult: :return: """ return { - "title": self.title, - "description": self.description, - "transcript": self.transcript, - "timestamp": self.timestamp + "title": self.title, + "description": self.description, + "transcript": self.transcript, + "timestamp": self.timestamp, } @@ -120,6 +119,7 @@ class TranscriptionInput: Data class to define the input to the transcription function AudioFrames -> input """ + frames = List[av.audio.frame.AudioFrame] def __init__(self, frames): @@ -132,6 +132,7 @@ class TranscriptionOutput: Dataclass to define the result of the transcription function. The result will be sent back to the client """ + cmd = str result_text = str @@ -144,10 +145,7 @@ class TranscriptionOutput: Return the result dict for displaying the transcription :return: """ - return { - "cmd": self.cmd, - "text": self.result_text - } + return {"cmd": self.cmd, "text": self.result_text} @dataclass @@ -156,6 +154,7 @@ class FinalSummaryResult: Dataclass to define the result of the final summary function. The result will be sent back to the client. """ + cmd = str final_summary = str duration = str @@ -171,9 +170,9 @@ class FinalSummaryResult: :return: """ return { - "cmd": self.cmd, - "duration": self.duration, - "summary": self.final_summary + "cmd": self.cmd, + "duration": self.duration, + "summary": self.final_summary, } @@ -182,9 +181,14 @@ class BlackListedMessages: Class to hold the blacklisted messages. These messages should be filtered out and not sent back to the client as part of the transcription. """ - messages = [" Thank you.", " See you next time!", - " Thank you for watching!", " Bye!", - " And that's what I'm talking about."] + + messages = [ + " Thank you.", + " See you next time!", + " Thank you for watching!", + " Bye!", + " And that's what I'm talking about.", + ] @dataclass diff --git a/server/server.py b/server/server.py index c72e28b6..b3939f43 100644 --- a/server/server.py +++ b/server/server.py @@ -17,8 +17,15 @@ from aiortc.contrib.media import MediaRelay from faster_whisper import WhisperModel from reflector_dataclasses import ( - BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput, - TitleSummaryOutput, TranscriptionInput, TranscriptionOutput, TranscriptionContext) + BlackListedMessages, + FinalSummaryResult, + ParseLLMResult, + TitleSummaryInput, + TitleSummaryOutput, + TranscriptionInput, + TranscriptionOutput, + TranscriptionContext, +) from utils.log_utils import LOGGER from utils.run_utils import CONFIG, run_in_executor, SECRETS @@ -28,9 +35,7 @@ relay = MediaRelay() executor = ThreadPoolExecutor() # Transcription model -model = WhisperModel("tiny", device="cpu", - compute_type="float32", - num_workers=12) +model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12) # Audio configurations CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"]) @@ -46,7 +51,10 @@ else: LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"] LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" -def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]: + +def parse_llm_output( + param: TitleSummaryInput, response: requests.Response +) -> Union[None, ParseLLMResult]: """ Function to parse the LLM response :param param: @@ -61,7 +69,9 @@ def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> U return None -def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]: +def get_title_and_summary( + ctx: TranscriptionContext, param: TitleSummaryInput +) -> Union[None, TitleSummaryOutput]: """ From the input provided (transcript), query the LLM to generate topics and summaries @@ -72,9 +82,7 @@ def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) - # TODO : Handle unexpected output formats from the model try: - response = requests.post(LLM_URL, - headers=param.headers, - json=param.data) + response = requests.post(LLM_URL, headers=param.headers, json=param.data) output = parse_llm_output(param, response) if output: result = output.get_result() @@ -107,7 +115,9 @@ def channel_send(channel, message: str) -> NoReturn: channel.send(message) -def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn: +def channel_send_increment( + channel, param: Union[FinalSummaryResult, TitleSummaryOutput] +) -> NoReturn: """ Send the incremental topics and summaries via the data channel :param channel: @@ -145,7 +155,9 @@ def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn: LOGGER.info("Exception", str(exception)) -def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]: +def get_transcription( + ctx: TranscriptionContext, input_frames: TranscriptionInput +) -> Union[None, TranscriptionOutput]: """ From the collected audio frames create transcription by inferring from the chosen transcription model @@ -173,12 +185,13 @@ def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInpu result_text = "" try: - segments, _ = \ - model.transcribe(audio_file, - language="en", - beam_size=5, - vad_filter=True, - vad_parameters={"min_silence_duration_ms": 500}) + segments, _ = model.transcribe( + audio_file, + language="en", + beam_size=5, + vad_filter=True, + vad_parameters={"min_silence_duration_ms": 500}, + ) os.remove(audio_file) segments = list(segments) result_text = "" @@ -191,7 +204,7 @@ def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInpu start_time = 0.0 if not segment.end: end_time = 5.5 - duration += (end_time - start_time) + duration += end_time - start_time ctx.last_transcribed_time += duration ctx.transcription_text += result_text @@ -218,8 +231,9 @@ def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult: response = FinalSummaryResult(final_summary, ctx.last_transcribed_time) - with open("./artefacts/meeting_titles_and_summaries.txt", "a", - encoding="utf-8") as file: + with open( + "./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8" + ) as file: file.write(json.dumps(ctx.incremental_responses)) return response @@ -243,33 +257,32 @@ class AudioStreamTrack(MediaStreamTrack): frame = await self.track.recv() self.audio_buffer.write(frame) - if local_frames := self.audio_buffer.read_many(AUDIO_BUFFER_SIZE, partial=False): + if local_frames := self.audio_buffer.read_many( + AUDIO_BUFFER_SIZE, partial=False + ): whisper_result = run_in_executor( - get_transcription, - ctx, - TranscriptionInput(local_frames), - executor=executor + get_transcription, + ctx, + TranscriptionInput(local_frames), + executor=executor, ) whisper_result.add_done_callback( - lambda f: channel_send_transcript(ctx) - if f.result() - else None + lambda f: channel_send_transcript(ctx) if f.result() else None ) if len(ctx.transcription_text) > 25: llm_input_text = ctx.transcription_text ctx.transcription_text = "" - param = TitleSummaryInput(input_text=llm_input_text, - transcribed_time=ctx.last_transcribed_time) - llm_result = run_in_executor(get_title_and_summary, - ctx, - param, - executor=executor) + param = TitleSummaryInput( + input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time + ) + llm_result = run_in_executor( + get_title_and_summary, ctx, param, executor=executor + ) llm_result.add_done_callback( - lambda f: channel_send_increment(ctx.data_channel, - llm_result.result()) - if f.result() - else None + lambda f: channel_send_increment(ctx.data_channel, llm_result.result()) + if f.result() + else None ) return frame @@ -328,13 +341,10 @@ async def offer(request: requests.Request) -> web.Response: answer = await pc.createAnswer() await pc.setLocalDescription(answer) return web.Response( - content_type="application/json", - text=json.dumps( - { - "sdp": pc.localDescription.sdp, - "type": pc.localDescription.type - } - ), + content_type="application/json", + text=json.dumps( + {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} + ), ) @@ -351,26 +361,22 @@ async def on_shutdown(application: web.Application) -> NoReturn: if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="WebRTC based server for Reflector" + parser = argparse.ArgumentParser(description="WebRTC based server for Reflector") + parser.add_argument( + "--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)" ) parser.add_argument( - "--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)" - ) - parser.add_argument( - "--port", type=int, default=1250, help="Server port (def: 1250)" + "--port", type=int, default=1250, help="Server port (def: 1250)" ) args = parser.parse_args() app = web.Application() cors = aiohttp_cors.setup( - app, - defaults={ - "*": aiohttp_cors.ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers="*" - ) - }, + app, + defaults={ + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, expose_headers="*", allow_headers="*" + ) + }, ) offer_resource = cors.add(app.router.add_resource("/offer")) diff --git a/server/stream_client.py b/server/stream_client.py index 57f88229..4ec688bc 100644 --- a/server/stream_client.py +++ b/server/stream_client.py @@ -6,8 +6,8 @@ import httpx import pyaudio import requests import stamina -from aiortc import (RTCPeerConnection, RTCSessionDescription) -from aiortc.contrib.media import (MediaPlayer, MediaRelay) +from aiortc import RTCPeerConnection, RTCSessionDescription +from aiortc.contrib.media import MediaPlayer, MediaRelay from utils.log_utils import LOGGER from utils.run_utils import CONFIG @@ -15,11 +15,7 @@ from utils.run_utils import CONFIG class StreamClient: def __init__( - self, - signaling, - url="http://0.0.0.0:1250", - play_from=None, - ping_pong=False + self, signaling, url="http://0.0.0.0:1250", play_from=None, ping_pong=False ): self.signaling = signaling self.server_url = url @@ -35,9 +31,10 @@ class StreamClient: self.time_start = None self.queue = asyncio.Queue() self.player = MediaPlayer( - ':' + str(CONFIG['AUDIO']["AV_FOUNDATION_DEVICE_ID"]), - format='avfoundation', - options={'channels': '2'}) + ":" + str(CONFIG["AUDIO"]["AV_FOUNDATION_DEVICE_ID"]), + format="avfoundation", + options={"channels": "2"}, + ) def stop(self): self.loop.run_until_complete(self.signaling.close()) @@ -114,16 +111,12 @@ class StreamClient: self.channel_log(channel, "<", message) if isinstance(message, str) and message.startswith("pong"): - elapsed_ms = (self.current_stamp() - int(message[5:])) \ - / 1000 + elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000 print(" RTT %.2f ms" % elapsed_ms) await pc.setLocalDescription(await pc.createOffer()) - sdp = { - "sdp": pc.localDescription.sdp, - "type": pc.localDescription.type - } + sdp = {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} @stamina.retry(on=httpx.HTTPError, attempts=5) def connect_to_server(): diff --git a/server/utils/file_utils.py b/server/utils/file_utils.py index 9c85ebdc..ba9e4fec 100644 --- a/server/utils/file_utils.py +++ b/server/utils/file_utils.py @@ -14,9 +14,11 @@ from .run_utils import SECRETS BUCKET_NAME = SECRETS["AWS-S3"]["BUCKET_NAME"] -s3 = boto3.client('s3', - aws_access_key_id=SECRETS["AWS-S3"]["AWS_ACCESS_KEY"], - aws_secret_access_key=SECRETS["AWS-S3"]["AWS_SECRET_KEY"]) +s3 = boto3.client( + "s3", + aws_access_key_id=SECRETS["AWS-S3"]["AWS_ACCESS_KEY"], + aws_secret_access_key=SECRETS["AWS-S3"]["AWS_SECRET_KEY"], +) def upload_files(files_to_upload: List[str]) -> NoReturn: @@ -44,7 +46,7 @@ def download_files(files_to_download: List[str]) -> NoReturn: try: s3.download_file(BUCKET_NAME, key, key) except botocore.exceptions.ClientError as exception: - if exception.response['Error']['Code'] == "404": + if exception.response["Error"]["Code"] == "404": print("The object does not exist.") else: raise diff --git a/server/utils/format_output.py b/server/utils/format_output.py index c46b90ba..adf2ff67 100644 --- a/server/utils/format_output.py +++ b/server/utils/format_output.py @@ -4,21 +4,16 @@ Utility function to format the artefacts created during Reflector run import json -with open("../artefacts/meeting_titles_and_summaries.txt", "r", - encoding='utf-8') as f: +with open("../artefacts/meeting_titles_and_summaries.txt", "r", encoding="utf-8") as f: outputs = f.read() outputs = json.loads(outputs) -transcript_file = open("../artefacts/meeting_transcript.txt", - "a", - encoding='utf-8') -title_desc_file = open("../artefacts/meeting_title_description.txt", - "a", - encoding='utf-8') -summary_file = open("../artefacts/meeting_summary.txt", - "a", - encoding='utf-8') +transcript_file = open("../artefacts/meeting_transcript.txt", "a", encoding="utf-8") +title_desc_file = open( + "../artefacts/meeting_title_description.txt", "a", encoding="utf-8" +) +summary_file = open("../artefacts/meeting_summary.txt", "a", encoding="utf-8") for item in outputs["topics"]: transcript_file.write(item["transcript"]) diff --git a/server/utils/log_utils.py b/server/utils/log_utils.py index 84cbe3fe..6d3056ba 100644 --- a/server/utils/log_utils.py +++ b/server/utils/log_utils.py @@ -10,6 +10,7 @@ class SingletonLogger: Use Singleton design pattern to create a logger object and share it across the entire project """ + __instance = None @staticmethod diff --git a/server/utils/run_utils.py b/server/utils/run_utils.py index 4a3dba30..3eac353b 100644 --- a/server/utils/run_utils.py +++ b/server/utils/run_utils.py @@ -14,6 +14,7 @@ class ReflectorConfig: """ Create a single config object to share across the project """ + __config = None __secrets = None @@ -25,7 +26,7 @@ class ReflectorConfig: """ if ReflectorConfig.__config is None: ReflectorConfig.__config = configparser.ConfigParser() - ReflectorConfig.__config.read('utils/config.ini') + ReflectorConfig.__config.read("utils/config.ini") return ReflectorConfig.__config @staticmethod @@ -36,7 +37,7 @@ class ReflectorConfig: """ if ReflectorConfig.__secrets is None: ReflectorConfig.__secrets = configparser.ConfigParser() - ReflectorConfig.__secrets.read('utils/secrets.ini') + ReflectorConfig.__secrets.read("utils/secrets.ini") return ReflectorConfig.__secrets diff --git a/server/utils/text_utils.py b/server/utils/text_utils.py index 5bde199a..01cb671d 100644 --- a/server/utils/text_utils.py +++ b/server/utils/text_utils.py @@ -15,7 +15,7 @@ from transformers import BartForConditionalGeneration, BartTokenizer from log_utils import LOGGER from run_utils import CONFIG -nltk.download('punkt', quiet=True) +nltk.download("punkt", quiet=True) def preprocess_sentence(sentence: str) -> str: @@ -24,11 +24,10 @@ def preprocess_sentence(sentence: str) -> str: :param sentence: :return: """ - stop_words = set(stopwords.words('english')) + stop_words = set(stopwords.words("english")) tokens = word_tokenize(sentence.lower()) - tokens = [token for token in tokens - if token.isalnum() and token not in stop_words] - return ' '.join(tokens) + tokens = [token for token in tokens if token.isalnum() and token not in stop_words] + return " ".join(tokens) def compute_similarity(sent1: str, sent2: str) -> float: @@ -67,14 +66,14 @@ def remove_almost_alike_sentences(sentences: List[str], threshold=0.7) -> List[s sentence1 = preprocess_sentence(sentences[i]) sentence2 = preprocess_sentence(sentences[j]) if len(sentence1) != 0 and len(sentence2) != 0: - similarity = compute_similarity(sentence1, - sentence2) + similarity = compute_similarity(sentence1, sentence2) if similarity >= threshold: removed_indices.add(max(i, j)) - filtered_sentences = [sentences[i] for i in range(num_sentences) - if i not in removed_indices] + filtered_sentences = [ + sentences[i] for i in range(num_sentences) if i not in removed_indices + ] return filtered_sentences @@ -90,7 +89,9 @@ def remove_outright_duplicate_sentences_from_chunk(chunk: str) -> List[str]: return nonduplicate_sentences -def remove_whisper_repetitive_hallucination(nonduplicate_sentences: List[str]) -> List[str]: +def remove_whisper_repetitive_hallucination( + nonduplicate_sentences: List[str], +) -> List[str]: """ Remove sentences that are repeated as a result of Whisper hallucinations @@ -105,13 +106,16 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences: List[str]) - words = nltk.word_tokenize(sent) n_gram_filter = 3 for i in range(len(words)): - if str(words[i:i + n_gram_filter]) in seen and \ - seen[str(words[i:i + n_gram_filter])] == \ - words[i + 1:i + n_gram_filter + 2]: + if ( + str(words[i : i + n_gram_filter]) in seen + and seen[str(words[i : i + n_gram_filter])] + == words[i + 1 : i + n_gram_filter + 2] + ): pass else: - seen[str(words[i:i + n_gram_filter])] = \ - words[i + 1:i + n_gram_filter + 2] + seen[str(words[i : i + n_gram_filter])] = words[ + i + 1 : i + n_gram_filter + 2 + ] temp_result += words[i] temp_result += " " chunk_sentences.append(temp_result) @@ -126,12 +130,11 @@ def post_process_transcription(whisper_result: dict) -> dict: """ transcript_text = "" for chunk in whisper_result["chunks"]: - nonduplicate_sentences = \ - remove_outright_duplicate_sentences_from_chunk(chunk) - chunk_sentences = \ - remove_whisper_repetitive_hallucination(nonduplicate_sentences) - similarity_matched_sentences = \ - remove_almost_alike_sentences(chunk_sentences) + nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk) + chunk_sentences = remove_whisper_repetitive_hallucination( + nonduplicate_sentences + ) + similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences) chunk["text"] = " ".join(similarity_matched_sentences) transcript_text += chunk["text"] whisper_result["text"] = transcript_text @@ -149,23 +152,24 @@ def summarize_chunks(chunks: List[str], tokenizer, model) -> List[str]: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") summaries = [] for c in chunks: - input_ids = tokenizer.encode(c, return_tensors='pt') + input_ids = tokenizer.encode(c, return_tensors="pt") input_ids = input_ids.to(device) with torch.no_grad(): - summary_ids = \ - model.generate(input_ids, - num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]), - length_penalty=2.0, - max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]), - early_stopping=True) - summary = tokenizer.decode(summary_ids[0], - skip_special_tokens=True) + summary_ids = model.generate( + input_ids, + num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]), + length_penalty=2.0, + max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]), + early_stopping=True, + ) + summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) summaries.append(summary) return summaries -def chunk_text(text: str, - max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])) -> List[str]: +def chunk_text( + text: str, max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"]) +) -> List[str]: """ Split text into smaller chunks. :param text: Text to be chunked @@ -185,9 +189,12 @@ def chunk_text(text: str, return chunks -def summarize(transcript_text: str, timestamp: datetime.datetime.timestamp, - real_time: bool = False, - chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]): +def summarize( + transcript_text: str, + timestamp: datetime.datetime.timestamp, + real_time: bool = False, + chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"], +): """ Summarize the given text either as a whole or as chunks as needed :param transcript_text: @@ -213,39 +220,45 @@ def summarize(transcript_text: str, timestamp: datetime.datetime.timestamp, if chunk_summarize != "YES": max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"]) - inputs = tokenizer. \ - batch_encode_plus([transcript_text], truncation=True, - padding='longest', - max_length=max_length, - return_tensors='pt') + inputs = tokenizer.batch_encode_plus( + [transcript_text], + truncation=True, + padding="longest", + max_length=max_length, + return_tensors="pt", + ) inputs = inputs.to(device) with torch.no_grad(): num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]) max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]) - summaries = model.generate(inputs['input_ids'], - num_beams=num_beans, - length_penalty=2.0, - max_length=max_length, - early_stopping=True) + summaries = model.generate( + inputs["input_ids"], + num_beams=num_beans, + length_penalty=2.0, + max_length=max_length, + early_stopping=True, + ) - decoded_summaries = \ - [tokenizer.decode(summary, - skip_special_tokens=True, - clean_up_tokenization_spaces=False) - for summary in summaries] + decoded_summaries = [ + tokenizer.decode( + summary, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + for summary in summaries + ] summary = " ".join(decoded_summaries) - with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file: + with open("./artefacts/" + output_file, "w", encoding="utf-8") as file: file.write(summary.strip() + "\n") else: LOGGER.info("Breaking transcript into smaller chunks") chunks = chunk_text(transcript_text) - LOGGER.info(f"Transcript broken into {len(chunks)} " - f"chunks of at most 500 words") + LOGGER.info( + f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words" + ) LOGGER.info(f"Writing summary text to: {output_file}") - with open(output_file, 'w') as f: + with open(output_file, "w") as f: summaries = summarize_chunks(chunks, tokenizer, model) for summary in summaries: f.write(summary.strip() + " ") diff --git a/server/utils/viz_utils.py b/server/utils/viz_utils.py index 22e2cc08..d26afdca 100644 --- a/server/utils/viz_utils.py +++ b/server/utils/viz_utils.py @@ -16,23 +16,30 @@ import spacy from nltk.corpus import stopwords from wordcloud import STOPWORDS, WordCloud -en = spacy.load('en_core_web_md') +en = spacy.load("en_core_web_md") spacy_stopwords = en.Defaults.stop_words -STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))). \ - union(set(spacy_stopwords)) +STOPWORDS = ( + set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords)) +) -def create_wordcloud(timestamp: datetime.datetime.timestamp, - real_time: bool = False) -> NoReturn: +def create_wordcloud( + timestamp: datetime.datetime.timestamp, real_time: bool = False +) -> NoReturn: """ Create a basic word cloud visualization of transcribed text :return: None. The wordcloud image is saved locally """ filename = "transcript" if real_time: - filename = "real_time_" + filename + "_" + \ - timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + filename = ( + "real_time_" + + filename + + "_" + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + + ".txt" + ) else: filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" @@ -41,10 +48,13 @@ def create_wordcloud(timestamp: datetime.datetime.timestamp, # python_mask = np.array(PIL.Image.open("download1.png")) - wordcloud = WordCloud(height=800, width=800, - background_color='white', - stopwords=STOPWORDS, - min_font_size=8).generate(transcription_text) + wordcloud = WordCloud( + height=800, + width=800, + background_color="white", + stopwords=STOPWORDS, + min_font_size=8, + ).generate(transcription_text) # Plot wordcloud and save image plt.figure(facecolor=None) @@ -54,16 +64,22 @@ def create_wordcloud(timestamp: datetime.datetime.timestamp, wordcloud = "wordcloud" if real_time: - wordcloud = "real_time_" + wordcloud + "_" + \ - timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" + wordcloud = ( + "real_time_" + + wordcloud + + "_" + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + + ".png" + ) else: wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" plt.savefig("./artefacts/" + wordcloud) -def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp, - real_time: bool = False) -> NoReturn: +def create_talk_diff_scatter_viz( + timestamp: datetime.datetime.timestamp, real_time: bool = False +) -> NoReturn: """ Perform agenda vs transcription diff to see covered topics. Create a scatter plot of words in topics. @@ -71,7 +87,7 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp, """ spacy_model = "en_core_web_md" nlp = spacy.load(spacy_model) - nlp.add_pipe('sentencizer') + nlp.add_pipe("sentencizer") agenda_topics = [] agenda = [] @@ -84,11 +100,17 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp, # Load the transcription with timestamp if real_time: - filename = "./artefacts/real_time_transcript_with_timestamp_" + \ - timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + filename = ( + "./artefacts/real_time_transcript_with_timestamp_" + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + + ".txt" + ) else: - filename = "./artefacts/transcript_with_timestamp_" + \ - timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" + filename = ( + "./artefacts/transcript_with_timestamp_" + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + + ".txt" + ) with open(filename) as file: transcription_timestamp_text = file.read() @@ -128,14 +150,20 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp, covered_items[agenda[topic_similarities[i][0]]] = True # top1 match if i == 0: - ts_to_topic_mapping_top_1[c["timestamp"]] = \ + ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[ + topic_similarities[i][0] + ] + topic_to_ts_mapping_top_1[ agenda_topics[topic_similarities[i][0]] - topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) + ].append(c["timestamp"]) # top2 match else: - ts_to_topic_mapping_top_2[c["timestamp"]] = \ + ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[ + topic_similarities[i][0] + ] + topic_to_ts_mapping_top_2[ agenda_topics[topic_similarities[i][0]] - topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) + ].append(c["timestamp"]) def create_new_columns(record: dict) -> dict: """ @@ -143,10 +171,12 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp, :param record: :return: """ - record["ts_to_topic_mapping_top_1"] = \ - ts_to_topic_mapping_top_1[record["timestamp"]] - record["ts_to_topic_mapping_top_2"] = \ - ts_to_topic_mapping_top_2[record["timestamp"]] + record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[ + record["timestamp"] + ] + record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[ + record["timestamp"] + ] return record df = df.apply(create_new_columns, axis=1) @@ -167,19 +197,33 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp, # Save df, mappings for further experimentation df_name = "df" if real_time: - df_name = "real_time_" + df_name + "_" + \ - timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + df_name = ( + "real_time_" + + df_name + + "_" + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + + ".pkl" + ) else: df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" df.to_pickle("./artefacts/" + df_name) - my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2, - topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2] + my_mappings = [ + ts_to_topic_mapping_top_1, + ts_to_topic_mapping_top_2, + topic_to_ts_mapping_top_1, + topic_to_ts_mapping_top_2, + ] mappings_name = "mappings" if real_time: - mappings_name = "real_time_" + mappings_name + "_" + \ - timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" + mappings_name = ( + "real_time_" + + mappings_name + + "_" + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + + ".pkl" + ) else: mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb")) @@ -203,23 +247,37 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp, # Scatter plot of topics df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences)) - corpus = st.CorpusFromParsedDocuments( - df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' - ).build().get_unigram_corpus().compact(st.AssociationCompactor(2000)) + corpus = ( + st.CorpusFromParsedDocuments( + df, category_col="ts_to_topic_mapping_top_1", parsed_col="parse" + ) + .build() + .get_unigram_corpus() + .compact(st.AssociationCompactor(2000)) + ) html = st.produce_scattertext_explorer( - corpus, - category=cat_1, - category_name=cat_1_name, - not_category_name=cat_2_name, - minimum_term_frequency=0, pmi_threshold_coefficient=0, - width_in_pixels=1000, - transform=st.Scalers.dense_rank + corpus, + category=cat_1, + category_name=cat_1_name, + not_category_name=cat_2_name, + minimum_term_frequency=0, + pmi_threshold_coefficient=0, + width_in_pixels=1000, + transform=st.Scalers.dense_rank, ) if real_time: - with open('./artefacts/real_time_scatter_' + - timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w') as file: + with open( + "./artefacts/real_time_scatter_" + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + + ".html", + "w", + ) as file: file.write(html) else: - with open('./artefacts/scatter_' + - timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w') as file: + with open( + "./artefacts/scatter_" + + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + + ".html", + "w", + ) as file: file.write(html)