server: reformat whole project using black

This commit is contained in:
Mathieu Virbel
2023-07-27 14:08:41 +02:00
parent 314321c603
commit 094ed696c4
12 changed files with 406 additions and 237 deletions

View File

@@ -2,13 +2,13 @@ import argparse
import asyncio import asyncio
import signal import signal
from aiortc.contrib.signaling import (add_signaling_arguments, from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
create_signaling)
from utils.log_utils import LOGGER from utils.log_utils import LOGGER
from stream_client import StreamClient from stream_client import StreamClient
from typing import NoReturn from typing import NoReturn
async def main() -> NoReturn: async def main() -> NoReturn:
""" """
Reflector's entry point to the python client for WebRTC streaming if not Reflector's entry point to the python client for WebRTC streaming if not
@@ -45,8 +45,7 @@ async def main() -> NoReturn:
LOGGER.info(f"Received exit signal {signal.name}...") LOGGER.info(f"Received exit signal {signal.name}...")
LOGGER.info("Closing database connections") LOGGER.info("Closing database connections")
LOGGER.info("Nacking outstanding messages") LOGGER.info("Nacking outstanding messages")
tasks = [t for t in asyncio.all_tasks() if t is not tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
asyncio.current_task()]
[task.cancel() for task in tasks] [task.cancel() for task in tasks]
@@ -58,15 +57,14 @@ async def main() -> NoReturn:
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
for s in signals: for s in signals:
loop.add_signal_handler( loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
# Init client # Init client
sc = StreamClient( sc = StreamClient(
signaling=signaling, signaling=signaling,
url=args.url, url=args.url,
play_from=args.play_from, play_from=args.play_from,
ping_pong=args.ping_pong ping_pong=args.ping_pong,
) )
await sc.start() await sc.start()
async for msg in sc.get_reader(): async for msg in sc.get_reader():

97
server/poetry.lock generated
View File

@@ -325,6 +325,50 @@ files = [
{file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"}, {file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"},
] ]
[[package]]
name = "black"
version = "23.7.0"
description = "The uncompromising code formatter."
optional = false
python-versions = ">=3.8"
files = [
{file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"},
{file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"},
{file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"},
{file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"},
{file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"},
{file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"},
{file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"},
{file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"},
{file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"},
{file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"},
{file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"},
{file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"},
{file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"},
{file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"},
{file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"},
{file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"},
{file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"},
{file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"},
{file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"},
{file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"},
{file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"},
{file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"},
]
[package.dependencies]
click = ">=8.0.0"
mypy-extensions = ">=0.4.3"
packaging = ">=22.0"
pathspec = ">=0.9.0"
platformdirs = ">=2"
[package.extras]
colorama = ["colorama (>=0.4.3)"]
d = ["aiohttp (>=3.7.4)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]] [[package]]
name = "certifi" name = "certifi"
version = "2023.7.22" version = "2023.7.22"
@@ -496,6 +540,20 @@ files = [
{file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"},
] ]
[[package]]
name = "click"
version = "8.1.6"
description = "Composable command line interface toolkit"
optional = false
python-versions = ">=3.7"
files = [
{file = "click-8.1.6-py3-none-any.whl", hash = "sha256:fa244bb30b3b5ee2cae3da8f55c9e5e0c0e86093306301fb418eb9dc40fbded5"},
{file = "click-8.1.6.tar.gz", hash = "sha256:48ee849951919527a045bfe3bf7baa8a959c423134e1a5b98c05c20ba75a1cbd"},
]
[package.dependencies]
colorama = {version = "*", markers = "platform_system == \"Windows\""}
[[package]] [[package]]
name = "colorama" name = "colorama"
version = "0.4.6" version = "0.4.6"
@@ -1080,6 +1138,17 @@ files = [
{file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
] ]
[[package]]
name = "mypy-extensions"
version = "1.0.0"
description = "Type system extensions for programs checked with the mypy type checker."
optional = false
python-versions = ">=3.5"
files = [
{file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"},
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
]
[[package]] [[package]]
name = "numpy" name = "numpy"
version = "1.25.1" version = "1.25.1"
@@ -1166,6 +1235,32 @@ files = [
{file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"},
] ]
[[package]]
name = "pathspec"
version = "0.11.1"
description = "Utility library for gitignore style pattern matching of file paths."
optional = false
python-versions = ">=3.7"
files = [
{file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"},
{file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"},
]
[[package]]
name = "platformdirs"
version = "3.9.1"
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
optional = false
python-versions = ">=3.7"
files = [
{file = "platformdirs-3.9.1-py3-none-any.whl", hash = "sha256:ad8291ae0ae5072f66c16945166cb11c63394c7a3ad1b1bc9828ca3162da8c2f"},
{file = "platformdirs-3.9.1.tar.gz", hash = "sha256:1b42b450ad933e981d56e59f1b97495428c9bd60698baab9f3eb3d00d5822421"},
]
[package.extras]
docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"]
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)"]
[[package]] [[package]]
name = "protobuf" name = "protobuf"
version = "4.23.4" version = "4.23.4"
@@ -1619,4 +1714,4 @@ multidict = ">=4.0"
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.11" python-versions = "^3.11"
content-hash = "9b82606318ce1096923c0b25e5b3a6b07292f24465611d968e78f37a26e3d212" content-hash = "e8eb6b4f81c090adb882a1b293d81f32167ea89f4636222d43fe0e9131cb97d6"

View File

@@ -18,6 +18,9 @@ sortedcontainers = "^2.4.0"
loguru = "^0.7.0" loguru = "^0.7.0"
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
[build-system] [build-system]
requires = ["poetry-core"] requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api" build-backend = "poetry.core.masonry.api"

View File

@@ -17,6 +17,7 @@ class TitleSummaryInput:
Data class for the input to generate title and summaries. Data class for the input to generate title and summaries.
The outcome will be used to send query to the LLM for processing. The outcome will be used to send query to the LLM for processing.
""" """
input_text = str input_text = str
transcribed_time = float transcribed_time = float
prompt = str prompt = str
@@ -25,8 +26,7 @@ class TitleSummaryInput:
def __init__(self, transcribed_time, input_text=""): def __init__(self, transcribed_time, input_text=""):
self.input_text = input_text self.input_text = input_text
self.transcribed_time = transcribed_time self.transcribed_time = transcribed_time
self.prompt = \ self.prompt = f"""
f"""
### Human: ### Human:
Create a JSON object as response.The JSON object must have 2 fields: Create a JSON object as response.The JSON object must have 2 fields:
i) title and ii) summary.For the title field,generate a short title i) title and ii) summary.For the title field,generate a short title
@@ -47,6 +47,7 @@ class IncrementalResult:
Data class for the result of generating one title and summaries. Data class for the result of generating one title and summaries.
Defines how a single "topic" looks like. Defines how a single "topic" looks like.
""" """
title = str title = str
description = str description = str
transcript = str transcript = str
@@ -65,6 +66,7 @@ class TitleSummaryOutput:
Data class for the result of all generated titles and summaries. Data class for the result of all generated titles and summaries.
The result will be sent back to the client The result will be sent back to the client
""" """
cmd = str cmd = str
topics = List[IncrementalResult] topics = List[IncrementalResult]
@@ -77,10 +79,7 @@ class TitleSummaryOutput:
Return the result dict for displaying the transcription Return the result dict for displaying the transcription
:return: :return:
""" """
return { return {"cmd": self.cmd, "topics": self.topics}
"cmd": self.cmd,
"topics": self.topics
}
@dataclass @dataclass
@@ -89,6 +88,7 @@ class ParseLLMResult:
Data class to parse the result returned by the LLM while generating title Data class to parse the result returned by the LLM while generating title
and summaries. The result will be sent back to the client. and summaries. The result will be sent back to the client.
""" """
title = str title = str
description = str description = str
transcript = str transcript = str
@@ -98,8 +98,7 @@ class ParseLLMResult:
self.title = output["title"] self.title = output["title"]
self.transcript = param.input_text self.transcript = param.input_text
self.description = output.pop("summary") self.description = output.pop("summary")
self.timestamp = \ self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time)))
str(datetime.timedelta(seconds=round(param.transcribed_time)))
def get_result(self) -> dict: def get_result(self) -> dict:
""" """
@@ -110,7 +109,7 @@ class ParseLLMResult:
"title": self.title, "title": self.title,
"description": self.description, "description": self.description,
"transcript": self.transcript, "transcript": self.transcript,
"timestamp": self.timestamp "timestamp": self.timestamp,
} }
@@ -120,6 +119,7 @@ class TranscriptionInput:
Data class to define the input to the transcription function Data class to define the input to the transcription function
AudioFrames -> input AudioFrames -> input
""" """
frames = List[av.audio.frame.AudioFrame] frames = List[av.audio.frame.AudioFrame]
def __init__(self, frames): def __init__(self, frames):
@@ -132,6 +132,7 @@ class TranscriptionOutput:
Dataclass to define the result of the transcription function. Dataclass to define the result of the transcription function.
The result will be sent back to the client The result will be sent back to the client
""" """
cmd = str cmd = str
result_text = str result_text = str
@@ -144,10 +145,7 @@ class TranscriptionOutput:
Return the result dict for displaying the transcription Return the result dict for displaying the transcription
:return: :return:
""" """
return { return {"cmd": self.cmd, "text": self.result_text}
"cmd": self.cmd,
"text": self.result_text
}
@dataclass @dataclass
@@ -156,6 +154,7 @@ class FinalSummaryResult:
Dataclass to define the result of the final summary function. Dataclass to define the result of the final summary function.
The result will be sent back to the client. The result will be sent back to the client.
""" """
cmd = str cmd = str
final_summary = str final_summary = str
duration = str duration = str
@@ -173,7 +172,7 @@ class FinalSummaryResult:
return { return {
"cmd": self.cmd, "cmd": self.cmd,
"duration": self.duration, "duration": self.duration,
"summary": self.final_summary "summary": self.final_summary,
} }
@@ -182,9 +181,14 @@ class BlackListedMessages:
Class to hold the blacklisted messages. These messages should be filtered Class to hold the blacklisted messages. These messages should be filtered
out and not sent back to the client as part of the transcription. out and not sent back to the client as part of the transcription.
""" """
messages = [" Thank you.", " See you next time!",
" Thank you for watching!", " Bye!", messages = [
" And that's what I'm talking about."] " Thank you.",
" See you next time!",
" Thank you for watching!",
" Bye!",
" And that's what I'm talking about.",
]
@dataclass @dataclass

View File

@@ -17,8 +17,15 @@ from aiortc.contrib.media import MediaRelay
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from reflector_dataclasses import ( from reflector_dataclasses import (
BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput, BlackListedMessages,
TitleSummaryOutput, TranscriptionInput, TranscriptionOutput, TranscriptionContext) FinalSummaryResult,
ParseLLMResult,
TitleSummaryInput,
TitleSummaryOutput,
TranscriptionInput,
TranscriptionOutput,
TranscriptionContext,
)
from utils.log_utils import LOGGER from utils.log_utils import LOGGER
from utils.run_utils import CONFIG, run_in_executor, SECRETS from utils.run_utils import CONFIG, run_in_executor, SECRETS
@@ -28,9 +35,7 @@ relay = MediaRelay()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
# Transcription model # Transcription model
model = WhisperModel("tiny", device="cpu", model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12)
compute_type="float32",
num_workers=12)
# Audio configurations # Audio configurations
CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"]) CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"])
@@ -46,7 +51,10 @@ else:
LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"] LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"]
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]:
def parse_llm_output(
param: TitleSummaryInput, response: requests.Response
) -> Union[None, ParseLLMResult]:
""" """
Function to parse the LLM response Function to parse the LLM response
:param param: :param param:
@@ -61,7 +69,9 @@ def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> U
return None return None
def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]: def get_title_and_summary(
ctx: TranscriptionContext, param: TitleSummaryInput
) -> Union[None, TitleSummaryOutput]:
""" """
From the input provided (transcript), query the LLM to generate From the input provided (transcript), query the LLM to generate
topics and summaries topics and summaries
@@ -72,9 +82,7 @@ def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -
# TODO : Handle unexpected output formats from the model # TODO : Handle unexpected output formats from the model
try: try:
response = requests.post(LLM_URL, response = requests.post(LLM_URL, headers=param.headers, json=param.data)
headers=param.headers,
json=param.data)
output = parse_llm_output(param, response) output = parse_llm_output(param, response)
if output: if output:
result = output.get_result() result = output.get_result()
@@ -107,7 +115,9 @@ def channel_send(channel, message: str) -> NoReturn:
channel.send(message) channel.send(message)
def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn: def channel_send_increment(
channel, param: Union[FinalSummaryResult, TitleSummaryOutput]
) -> NoReturn:
""" """
Send the incremental topics and summaries via the data channel Send the incremental topics and summaries via the data channel
:param channel: :param channel:
@@ -145,7 +155,9 @@ def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn:
LOGGER.info("Exception", str(exception)) LOGGER.info("Exception", str(exception))
def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]: def get_transcription(
ctx: TranscriptionContext, input_frames: TranscriptionInput
) -> Union[None, TranscriptionOutput]:
""" """
From the collected audio frames create transcription by inferring from From the collected audio frames create transcription by inferring from
the chosen transcription model the chosen transcription model
@@ -173,12 +185,13 @@ def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInpu
result_text = "" result_text = ""
try: try:
segments, _ = \ segments, _ = model.transcribe(
model.transcribe(audio_file, audio_file,
language="en", language="en",
beam_size=5, beam_size=5,
vad_filter=True, vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500}) vad_parameters={"min_silence_duration_ms": 500},
)
os.remove(audio_file) os.remove(audio_file)
segments = list(segments) segments = list(segments)
result_text = "" result_text = ""
@@ -191,7 +204,7 @@ def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInpu
start_time = 0.0 start_time = 0.0
if not segment.end: if not segment.end:
end_time = 5.5 end_time = 5.5
duration += (end_time - start_time) duration += end_time - start_time
ctx.last_transcribed_time += duration ctx.last_transcribed_time += duration
ctx.transcription_text += result_text ctx.transcription_text += result_text
@@ -218,8 +231,9 @@ def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult:
response = FinalSummaryResult(final_summary, ctx.last_transcribed_time) response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
with open("./artefacts/meeting_titles_and_summaries.txt", "a", with open(
encoding="utf-8") as file: "./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8"
) as file:
file.write(json.dumps(ctx.incremental_responses)) file.write(json.dumps(ctx.incremental_responses))
return response return response
@@ -243,31 +257,30 @@ class AudioStreamTrack(MediaStreamTrack):
frame = await self.track.recv() frame = await self.track.recv()
self.audio_buffer.write(frame) self.audio_buffer.write(frame)
if local_frames := self.audio_buffer.read_many(AUDIO_BUFFER_SIZE, partial=False): if local_frames := self.audio_buffer.read_many(
AUDIO_BUFFER_SIZE, partial=False
):
whisper_result = run_in_executor( whisper_result = run_in_executor(
get_transcription, get_transcription,
ctx, ctx,
TranscriptionInput(local_frames), TranscriptionInput(local_frames),
executor=executor executor=executor,
) )
whisper_result.add_done_callback( whisper_result.add_done_callback(
lambda f: channel_send_transcript(ctx) lambda f: channel_send_transcript(ctx) if f.result() else None
if f.result()
else None
) )
if len(ctx.transcription_text) > 25: if len(ctx.transcription_text) > 25:
llm_input_text = ctx.transcription_text llm_input_text = ctx.transcription_text
ctx.transcription_text = "" ctx.transcription_text = ""
param = TitleSummaryInput(input_text=llm_input_text, param = TitleSummaryInput(
transcribed_time=ctx.last_transcribed_time) input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time
llm_result = run_in_executor(get_title_and_summary, )
ctx, llm_result = run_in_executor(
param, get_title_and_summary, ctx, param, executor=executor
executor=executor) )
llm_result.add_done_callback( llm_result.add_done_callback(
lambda f: channel_send_increment(ctx.data_channel, lambda f: channel_send_increment(ctx.data_channel, llm_result.result())
llm_result.result())
if f.result() if f.result()
else None else None
) )
@@ -330,10 +343,7 @@ async def offer(request: requests.Request) -> web.Response:
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
text=json.dumps( text=json.dumps(
{ {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type
}
), ),
) )
@@ -351,9 +361,7 @@ async def on_shutdown(application: web.Application) -> NoReturn:
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(description="WebRTC based server for Reflector")
description="WebRTC based server for Reflector"
)
parser.add_argument( parser.add_argument(
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)" "--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
) )
@@ -366,9 +374,7 @@ if __name__ == "__main__":
app, app,
defaults={ defaults={
"*": aiohttp_cors.ResourceOptions( "*": aiohttp_cors.ResourceOptions(
allow_credentials=True, allow_credentials=True, expose_headers="*", allow_headers="*"
expose_headers="*",
allow_headers="*"
) )
}, },
) )

View File

@@ -6,8 +6,8 @@ import httpx
import pyaudio import pyaudio
import requests import requests
import stamina import stamina
from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import (MediaPlayer, MediaRelay) from aiortc.contrib.media import MediaPlayer, MediaRelay
from utils.log_utils import LOGGER from utils.log_utils import LOGGER
from utils.run_utils import CONFIG from utils.run_utils import CONFIG
@@ -15,11 +15,7 @@ from utils.run_utils import CONFIG
class StreamClient: class StreamClient:
def __init__( def __init__(
self, self, signaling, url="http://0.0.0.0:1250", play_from=None, ping_pong=False
signaling,
url="http://0.0.0.0:1250",
play_from=None,
ping_pong=False
): ):
self.signaling = signaling self.signaling = signaling
self.server_url = url self.server_url = url
@@ -35,9 +31,10 @@ class StreamClient:
self.time_start = None self.time_start = None
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.player = MediaPlayer( self.player = MediaPlayer(
':' + str(CONFIG['AUDIO']["AV_FOUNDATION_DEVICE_ID"]), ":" + str(CONFIG["AUDIO"]["AV_FOUNDATION_DEVICE_ID"]),
format='avfoundation', format="avfoundation",
options={'channels': '2'}) options={"channels": "2"},
)
def stop(self): def stop(self):
self.loop.run_until_complete(self.signaling.close()) self.loop.run_until_complete(self.signaling.close())
@@ -114,16 +111,12 @@ class StreamClient:
self.channel_log(channel, "<", message) self.channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("pong"): if isinstance(message, str) and message.startswith("pong"):
elapsed_ms = (self.current_stamp() - int(message[5:])) \ elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000
/ 1000
print(" RTT %.2f ms" % elapsed_ms) print(" RTT %.2f ms" % elapsed_ms)
await pc.setLocalDescription(await pc.createOffer()) await pc.setLocalDescription(await pc.createOffer())
sdp = { sdp = {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type
}
@stamina.retry(on=httpx.HTTPError, attempts=5) @stamina.retry(on=httpx.HTTPError, attempts=5)
def connect_to_server(): def connect_to_server():

View File

@@ -14,9 +14,11 @@ from .run_utils import SECRETS
BUCKET_NAME = SECRETS["AWS-S3"]["BUCKET_NAME"] BUCKET_NAME = SECRETS["AWS-S3"]["BUCKET_NAME"]
s3 = boto3.client('s3', s3 = boto3.client(
"s3",
aws_access_key_id=SECRETS["AWS-S3"]["AWS_ACCESS_KEY"], aws_access_key_id=SECRETS["AWS-S3"]["AWS_ACCESS_KEY"],
aws_secret_access_key=SECRETS["AWS-S3"]["AWS_SECRET_KEY"]) aws_secret_access_key=SECRETS["AWS-S3"]["AWS_SECRET_KEY"],
)
def upload_files(files_to_upload: List[str]) -> NoReturn: def upload_files(files_to_upload: List[str]) -> NoReturn:
@@ -44,7 +46,7 @@ def download_files(files_to_download: List[str]) -> NoReturn:
try: try:
s3.download_file(BUCKET_NAME, key, key) s3.download_file(BUCKET_NAME, key, key)
except botocore.exceptions.ClientError as exception: except botocore.exceptions.ClientError as exception:
if exception.response['Error']['Code'] == "404": if exception.response["Error"]["Code"] == "404":
print("The object does not exist.") print("The object does not exist.")
else: else:
raise raise

View File

@@ -4,21 +4,16 @@ Utility function to format the artefacts created during Reflector run
import json import json
with open("../artefacts/meeting_titles_and_summaries.txt", "r", with open("../artefacts/meeting_titles_and_summaries.txt", "r", encoding="utf-8") as f:
encoding='utf-8') as f:
outputs = f.read() outputs = f.read()
outputs = json.loads(outputs) outputs = json.loads(outputs)
transcript_file = open("../artefacts/meeting_transcript.txt", transcript_file = open("../artefacts/meeting_transcript.txt", "a", encoding="utf-8")
"a", title_desc_file = open(
encoding='utf-8') "../artefacts/meeting_title_description.txt", "a", encoding="utf-8"
title_desc_file = open("../artefacts/meeting_title_description.txt", )
"a", summary_file = open("../artefacts/meeting_summary.txt", "a", encoding="utf-8")
encoding='utf-8')
summary_file = open("../artefacts/meeting_summary.txt",
"a",
encoding='utf-8')
for item in outputs["topics"]: for item in outputs["topics"]:
transcript_file.write(item["transcript"]) transcript_file.write(item["transcript"])

View File

@@ -10,6 +10,7 @@ class SingletonLogger:
Use Singleton design pattern to create a logger object and share it Use Singleton design pattern to create a logger object and share it
across the entire project across the entire project
""" """
__instance = None __instance = None
@staticmethod @staticmethod

View File

@@ -14,6 +14,7 @@ class ReflectorConfig:
""" """
Create a single config object to share across the project Create a single config object to share across the project
""" """
__config = None __config = None
__secrets = None __secrets = None
@@ -25,7 +26,7 @@ class ReflectorConfig:
""" """
if ReflectorConfig.__config is None: if ReflectorConfig.__config is None:
ReflectorConfig.__config = configparser.ConfigParser() ReflectorConfig.__config = configparser.ConfigParser()
ReflectorConfig.__config.read('utils/config.ini') ReflectorConfig.__config.read("utils/config.ini")
return ReflectorConfig.__config return ReflectorConfig.__config
@staticmethod @staticmethod
@@ -36,7 +37,7 @@ class ReflectorConfig:
""" """
if ReflectorConfig.__secrets is None: if ReflectorConfig.__secrets is None:
ReflectorConfig.__secrets = configparser.ConfigParser() ReflectorConfig.__secrets = configparser.ConfigParser()
ReflectorConfig.__secrets.read('utils/secrets.ini') ReflectorConfig.__secrets.read("utils/secrets.ini")
return ReflectorConfig.__secrets return ReflectorConfig.__secrets

View File

@@ -15,7 +15,7 @@ from transformers import BartForConditionalGeneration, BartTokenizer
from log_utils import LOGGER from log_utils import LOGGER
from run_utils import CONFIG from run_utils import CONFIG
nltk.download('punkt', quiet=True) nltk.download("punkt", quiet=True)
def preprocess_sentence(sentence: str) -> str: def preprocess_sentence(sentence: str) -> str:
@@ -24,11 +24,10 @@ def preprocess_sentence(sentence: str) -> str:
:param sentence: :param sentence:
:return: :return:
""" """
stop_words = set(stopwords.words('english')) stop_words = set(stopwords.words("english"))
tokens = word_tokenize(sentence.lower()) tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
if token.isalnum() and token not in stop_words] return " ".join(tokens)
return ' '.join(tokens)
def compute_similarity(sent1: str, sent2: str) -> float: def compute_similarity(sent1: str, sent2: str) -> float:
@@ -67,14 +66,14 @@ def remove_almost_alike_sentences(sentences: List[str], threshold=0.7) -> List[s
sentence1 = preprocess_sentence(sentences[i]) sentence1 = preprocess_sentence(sentences[i])
sentence2 = preprocess_sentence(sentences[j]) sentence2 = preprocess_sentence(sentences[j])
if len(sentence1) != 0 and len(sentence2) != 0: if len(sentence1) != 0 and len(sentence2) != 0:
similarity = compute_similarity(sentence1, similarity = compute_similarity(sentence1, sentence2)
sentence2)
if similarity >= threshold: if similarity >= threshold:
removed_indices.add(max(i, j)) removed_indices.add(max(i, j))
filtered_sentences = [sentences[i] for i in range(num_sentences) filtered_sentences = [
if i not in removed_indices] sentences[i] for i in range(num_sentences) if i not in removed_indices
]
return filtered_sentences return filtered_sentences
@@ -90,7 +89,9 @@ def remove_outright_duplicate_sentences_from_chunk(chunk: str) -> List[str]:
return nonduplicate_sentences return nonduplicate_sentences
def remove_whisper_repetitive_hallucination(nonduplicate_sentences: List[str]) -> List[str]: def remove_whisper_repetitive_hallucination(
nonduplicate_sentences: List[str],
) -> List[str]:
""" """
Remove sentences that are repeated as a result of Whisper Remove sentences that are repeated as a result of Whisper
hallucinations hallucinations
@@ -105,13 +106,16 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences: List[str]) -
words = nltk.word_tokenize(sent) words = nltk.word_tokenize(sent)
n_gram_filter = 3 n_gram_filter = 3
for i in range(len(words)): for i in range(len(words)):
if str(words[i:i + n_gram_filter]) in seen and \ if (
seen[str(words[i:i + n_gram_filter])] == \ str(words[i : i + n_gram_filter]) in seen
words[i + 1:i + n_gram_filter + 2]: and seen[str(words[i : i + n_gram_filter])]
== words[i + 1 : i + n_gram_filter + 2]
):
pass pass
else: else:
seen[str(words[i:i + n_gram_filter])] = \ seen[str(words[i : i + n_gram_filter])] = words[
words[i + 1:i + n_gram_filter + 2] i + 1 : i + n_gram_filter + 2
]
temp_result += words[i] temp_result += words[i]
temp_result += " " temp_result += " "
chunk_sentences.append(temp_result) chunk_sentences.append(temp_result)
@@ -126,12 +130,11 @@ def post_process_transcription(whisper_result: dict) -> dict:
""" """
transcript_text = "" transcript_text = ""
for chunk in whisper_result["chunks"]: for chunk in whisper_result["chunks"]:
nonduplicate_sentences = \ nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
remove_outright_duplicate_sentences_from_chunk(chunk) chunk_sentences = remove_whisper_repetitive_hallucination(
chunk_sentences = \ nonduplicate_sentences
remove_whisper_repetitive_hallucination(nonduplicate_sentences) )
similarity_matched_sentences = \ similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
remove_almost_alike_sentences(chunk_sentences)
chunk["text"] = " ".join(similarity_matched_sentences) chunk["text"] = " ".join(similarity_matched_sentences)
transcript_text += chunk["text"] transcript_text += chunk["text"]
whisper_result["text"] = transcript_text whisper_result["text"] = transcript_text
@@ -149,23 +152,24 @@ def summarize_chunks(chunks: List[str], tokenizer, model) -> List[str]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summaries = [] summaries = []
for c in chunks: for c in chunks:
input_ids = tokenizer.encode(c, return_tensors='pt') input_ids = tokenizer.encode(c, return_tensors="pt")
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
with torch.no_grad(): with torch.no_grad():
summary_ids = \ summary_ids = model.generate(
model.generate(input_ids, input_ids,
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]), num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
length_penalty=2.0, length_penalty=2.0,
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]), max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
early_stopping=True) early_stopping=True,
summary = tokenizer.decode(summary_ids[0], )
skip_special_tokens=True) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(summary) summaries.append(summary)
return summaries return summaries
def chunk_text(text: str, def chunk_text(
max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])) -> List[str]: text: str, max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])
) -> List[str]:
""" """
Split text into smaller chunks. Split text into smaller chunks.
:param text: Text to be chunked :param text: Text to be chunked
@@ -185,9 +189,12 @@ def chunk_text(text: str,
return chunks return chunks
def summarize(transcript_text: str, timestamp: datetime.datetime.timestamp, def summarize(
transcript_text: str,
timestamp: datetime.datetime.timestamp,
real_time: bool = False, real_time: bool = False,
chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]): chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"],
):
""" """
Summarize the given text either as a whole or as chunks as needed Summarize the given text either as a whole or as chunks as needed
:param transcript_text: :param transcript_text:
@@ -213,39 +220,45 @@ def summarize(transcript_text: str, timestamp: datetime.datetime.timestamp,
if chunk_summarize != "YES": if chunk_summarize != "YES":
max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"]) max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer. \ inputs = tokenizer.batch_encode_plus(
batch_encode_plus([transcript_text], truncation=True, [transcript_text],
padding='longest', truncation=True,
padding="longest",
max_length=max_length, max_length=max_length,
return_tensors='pt') return_tensors="pt",
)
inputs = inputs.to(device) inputs = inputs.to(device)
with torch.no_grad(): with torch.no_grad():
num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]) num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]) max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
summaries = model.generate(inputs['input_ids'], summaries = model.generate(
inputs["input_ids"],
num_beams=num_beans, num_beams=num_beans,
length_penalty=2.0, length_penalty=2.0,
max_length=max_length, max_length=max_length,
early_stopping=True) early_stopping=True,
)
decoded_summaries = \ decoded_summaries = [
[tokenizer.decode(summary, tokenizer.decode(
skip_special_tokens=True, summary, skip_special_tokens=True, clean_up_tokenization_spaces=False
clean_up_tokenization_spaces=False) )
for summary in summaries] for summary in summaries
]
summary = " ".join(decoded_summaries) summary = " ".join(decoded_summaries)
with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file: with open("./artefacts/" + output_file, "w", encoding="utf-8") as file:
file.write(summary.strip() + "\n") file.write(summary.strip() + "\n")
else: else:
LOGGER.info("Breaking transcript into smaller chunks") LOGGER.info("Breaking transcript into smaller chunks")
chunks = chunk_text(transcript_text) chunks = chunk_text(transcript_text)
LOGGER.info(f"Transcript broken into {len(chunks)} " LOGGER.info(
f"chunks of at most 500 words") f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words"
)
LOGGER.info(f"Writing summary text to: {output_file}") LOGGER.info(f"Writing summary text to: {output_file}")
with open(output_file, 'w') as f: with open(output_file, "w") as f:
summaries = summarize_chunks(chunks, tokenizer, model) summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries: for summary in summaries:
f.write(summary.strip() + " ") f.write(summary.strip() + " ")

View File

@@ -16,23 +16,30 @@ import spacy
from nltk.corpus import stopwords from nltk.corpus import stopwords
from wordcloud import STOPWORDS, WordCloud from wordcloud import STOPWORDS, WordCloud
en = spacy.load('en_core_web_md') en = spacy.load("en_core_web_md")
spacy_stopwords = en.Defaults.stop_words spacy_stopwords = en.Defaults.stop_words
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))). \ STOPWORDS = (
union(set(spacy_stopwords)) set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
)
def create_wordcloud(timestamp: datetime.datetime.timestamp, def create_wordcloud(
real_time: bool = False) -> NoReturn: timestamp: datetime.datetime.timestamp, real_time: bool = False
) -> NoReturn:
""" """
Create a basic word cloud visualization of transcribed text Create a basic word cloud visualization of transcribed text
:return: None. The wordcloud image is saved locally :return: None. The wordcloud image is saved locally
""" """
filename = "transcript" filename = "transcript"
if real_time: if real_time:
filename = "real_time_" + filename + "_" + \ filename = (
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" "real_time_"
+ filename
+ "_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".txt"
)
else: else:
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
@@ -41,10 +48,13 @@ def create_wordcloud(timestamp: datetime.datetime.timestamp,
# python_mask = np.array(PIL.Image.open("download1.png")) # python_mask = np.array(PIL.Image.open("download1.png"))
wordcloud = WordCloud(height=800, width=800, wordcloud = WordCloud(
background_color='white', height=800,
width=800,
background_color="white",
stopwords=STOPWORDS, stopwords=STOPWORDS,
min_font_size=8).generate(transcription_text) min_font_size=8,
).generate(transcription_text)
# Plot wordcloud and save image # Plot wordcloud and save image
plt.figure(facecolor=None) plt.figure(facecolor=None)
@@ -54,16 +64,22 @@ def create_wordcloud(timestamp: datetime.datetime.timestamp,
wordcloud = "wordcloud" wordcloud = "wordcloud"
if real_time: if real_time:
wordcloud = "real_time_" + wordcloud + "_" + \ wordcloud = (
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" "real_time_"
+ wordcloud
+ "_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".png"
)
else: else:
wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
plt.savefig("./artefacts/" + wordcloud) plt.savefig("./artefacts/" + wordcloud)
def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp, def create_talk_diff_scatter_viz(
real_time: bool = False) -> NoReturn: timestamp: datetime.datetime.timestamp, real_time: bool = False
) -> NoReturn:
""" """
Perform agenda vs transcription diff to see covered topics. Perform agenda vs transcription diff to see covered topics.
Create a scatter plot of words in topics. Create a scatter plot of words in topics.
@@ -71,7 +87,7 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
""" """
spacy_model = "en_core_web_md" spacy_model = "en_core_web_md"
nlp = spacy.load(spacy_model) nlp = spacy.load(spacy_model)
nlp.add_pipe('sentencizer') nlp.add_pipe("sentencizer")
agenda_topics = [] agenda_topics = []
agenda = [] agenda = []
@@ -84,11 +100,17 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
# Load the transcription with timestamp # Load the transcription with timestamp
if real_time: if real_time:
filename = "./artefacts/real_time_transcript_with_timestamp_" + \ filename = (
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" "./artefacts/real_time_transcript_with_timestamp_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".txt"
)
else: else:
filename = "./artefacts/transcript_with_timestamp_" + \ filename = (
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" "./artefacts/transcript_with_timestamp_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".txt"
)
with open(filename) as file: with open(filename) as file:
transcription_timestamp_text = file.read() transcription_timestamp_text = file.read()
@@ -128,14 +150,20 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
covered_items[agenda[topic_similarities[i][0]]] = True covered_items[agenda[topic_similarities[i][0]]] = True
# top1 match # top1 match
if i == 0: if i == 0:
ts_to_topic_mapping_top_1[c["timestamp"]] = \ ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[
topic_similarities[i][0]
]
topic_to_ts_mapping_top_1[
agenda_topics[topic_similarities[i][0]] agenda_topics[topic_similarities[i][0]]
topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) ].append(c["timestamp"])
# top2 match # top2 match
else: else:
ts_to_topic_mapping_top_2[c["timestamp"]] = \ ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[
topic_similarities[i][0]
]
topic_to_ts_mapping_top_2[
agenda_topics[topic_similarities[i][0]] agenda_topics[topic_similarities[i][0]]
topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"]) ].append(c["timestamp"])
def create_new_columns(record: dict) -> dict: def create_new_columns(record: dict) -> dict:
""" """
@@ -143,10 +171,12 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
:param record: :param record:
:return: :return:
""" """
record["ts_to_topic_mapping_top_1"] = \ record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[
ts_to_topic_mapping_top_1[record["timestamp"]] record["timestamp"]
record["ts_to_topic_mapping_top_2"] = \ ]
ts_to_topic_mapping_top_2[record["timestamp"]] record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[
record["timestamp"]
]
return record return record
df = df.apply(create_new_columns, axis=1) df = df.apply(create_new_columns, axis=1)
@@ -167,19 +197,33 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
# Save df, mappings for further experimentation # Save df, mappings for further experimentation
df_name = "df" df_name = "df"
if real_time: if real_time:
df_name = "real_time_" + df_name + "_" + \ df_name = (
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" "real_time_"
+ df_name
+ "_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".pkl"
)
else: else:
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
df.to_pickle("./artefacts/" + df_name) df.to_pickle("./artefacts/" + df_name)
my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2, my_mappings = [
topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2] ts_to_topic_mapping_top_1,
ts_to_topic_mapping_top_2,
topic_to_ts_mapping_top_1,
topic_to_ts_mapping_top_2,
]
mappings_name = "mappings" mappings_name = "mappings"
if real_time: if real_time:
mappings_name = "real_time_" + mappings_name + "_" + \ mappings_name = (
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" "real_time_"
+ mappings_name
+ "_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".pkl"
)
else: else:
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb")) pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb"))
@@ -203,23 +247,37 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
# Scatter plot of topics # Scatter plot of topics
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences)) df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
corpus = st.CorpusFromParsedDocuments( corpus = (
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' st.CorpusFromParsedDocuments(
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000)) df, category_col="ts_to_topic_mapping_top_1", parsed_col="parse"
)
.build()
.get_unigram_corpus()
.compact(st.AssociationCompactor(2000))
)
html = st.produce_scattertext_explorer( html = st.produce_scattertext_explorer(
corpus, corpus,
category=cat_1, category=cat_1,
category_name=cat_1_name, category_name=cat_1_name,
not_category_name=cat_2_name, not_category_name=cat_2_name,
minimum_term_frequency=0, pmi_threshold_coefficient=0, minimum_term_frequency=0,
pmi_threshold_coefficient=0,
width_in_pixels=1000, width_in_pixels=1000,
transform=st.Scalers.dense_rank transform=st.Scalers.dense_rank,
) )
if real_time: if real_time:
with open('./artefacts/real_time_scatter_' + with open(
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w') as file: "./artefacts/real_time_scatter_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".html",
"w",
) as file:
file.write(html) file.write(html)
else: else:
with open('./artefacts/scatter_' + with open(
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w') as file: "./artefacts/scatter_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".html",
"w",
) as file:
file.write(html) file.write(html)