mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: reformat whole project using black
This commit is contained in:
@@ -2,13 +2,13 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import signal
|
import signal
|
||||||
|
|
||||||
from aiortc.contrib.signaling import (add_signaling_arguments,
|
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
|
||||||
create_signaling)
|
|
||||||
|
|
||||||
from utils.log_utils import LOGGER
|
from utils.log_utils import LOGGER
|
||||||
from stream_client import StreamClient
|
from stream_client import StreamClient
|
||||||
from typing import NoReturn
|
from typing import NoReturn
|
||||||
|
|
||||||
|
|
||||||
async def main() -> NoReturn:
|
async def main() -> NoReturn:
|
||||||
"""
|
"""
|
||||||
Reflector's entry point to the python client for WebRTC streaming if not
|
Reflector's entry point to the python client for WebRTC streaming if not
|
||||||
@@ -45,8 +45,7 @@ async def main() -> NoReturn:
|
|||||||
LOGGER.info(f"Received exit signal {signal.name}...")
|
LOGGER.info(f"Received exit signal {signal.name}...")
|
||||||
LOGGER.info("Closing database connections")
|
LOGGER.info("Closing database connections")
|
||||||
LOGGER.info("Nacking outstanding messages")
|
LOGGER.info("Nacking outstanding messages")
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not
|
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||||
asyncio.current_task()]
|
|
||||||
|
|
||||||
[task.cancel() for task in tasks]
|
[task.cancel() for task in tasks]
|
||||||
|
|
||||||
@@ -58,15 +57,14 @@ async def main() -> NoReturn:
|
|||||||
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
for s in signals:
|
for s in signals:
|
||||||
loop.add_signal_handler(
|
loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
|
||||||
s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
|
|
||||||
|
|
||||||
# Init client
|
# Init client
|
||||||
sc = StreamClient(
|
sc = StreamClient(
|
||||||
signaling=signaling,
|
signaling=signaling,
|
||||||
url=args.url,
|
url=args.url,
|
||||||
play_from=args.play_from,
|
play_from=args.play_from,
|
||||||
ping_pong=args.ping_pong
|
ping_pong=args.ping_pong,
|
||||||
)
|
)
|
||||||
await sc.start()
|
await sc.start()
|
||||||
async for msg in sc.get_reader():
|
async for msg in sc.get_reader():
|
||||||
|
|||||||
97
server/poetry.lock
generated
97
server/poetry.lock
generated
@@ -325,6 +325,50 @@ files = [
|
|||||||
{file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"},
|
{file = "av-10.0.0.tar.gz", hash = "sha256:8afd3d5610e1086f3b2d8389d66672ea78624516912c93612de64dcaa4c67e05"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "black"
|
||||||
|
version = "23.7.0"
|
||||||
|
description = "The uncompromising code formatter."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "black-23.7.0-cp310-cp310-macosx_10_16_arm64.whl", hash = "sha256:5c4bc552ab52f6c1c506ccae05681fab58c3f72d59ae6e6639e8885e94fe2587"},
|
||||||
|
{file = "black-23.7.0-cp310-cp310-macosx_10_16_universal2.whl", hash = "sha256:552513d5cd5694590d7ef6f46e1767a4df9af168d449ff767b13b084c020e63f"},
|
||||||
|
{file = "black-23.7.0-cp310-cp310-macosx_10_16_x86_64.whl", hash = "sha256:86cee259349b4448adb4ef9b204bb4467aae74a386bce85d56ba4f5dc0da27be"},
|
||||||
|
{file = "black-23.7.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:501387a9edcb75d7ae8a4412bb8749900386eaef258f1aefab18adddea1936bc"},
|
||||||
|
{file = "black-23.7.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb074d8b213749fa1d077d630db0d5f8cc3b2ae63587ad4116e8a436e9bbe995"},
|
||||||
|
{file = "black-23.7.0-cp311-cp311-macosx_10_16_arm64.whl", hash = "sha256:b5b0ee6d96b345a8b420100b7d71ebfdd19fab5e8301aff48ec270042cd40ac2"},
|
||||||
|
{file = "black-23.7.0-cp311-cp311-macosx_10_16_universal2.whl", hash = "sha256:893695a76b140881531062d48476ebe4a48f5d1e9388177e175d76234ca247cd"},
|
||||||
|
{file = "black-23.7.0-cp311-cp311-macosx_10_16_x86_64.whl", hash = "sha256:c333286dc3ddca6fdff74670b911cccedacb4ef0a60b34e491b8a67c833b343a"},
|
||||||
|
{file = "black-23.7.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:831d8f54c3a8c8cf55f64d0422ee875eecac26f5f649fb6c1df65316b67c8926"},
|
||||||
|
{file = "black-23.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:7f3bf2dec7d541b4619b8ce526bda74a6b0bffc480a163fed32eb8b3c9aed8ad"},
|
||||||
|
{file = "black-23.7.0-cp38-cp38-macosx_10_16_arm64.whl", hash = "sha256:f9062af71c59c004cd519e2fb8f5d25d39e46d3af011b41ab43b9c74e27e236f"},
|
||||||
|
{file = "black-23.7.0-cp38-cp38-macosx_10_16_universal2.whl", hash = "sha256:01ede61aac8c154b55f35301fac3e730baf0c9cf8120f65a9cd61a81cfb4a0c3"},
|
||||||
|
{file = "black-23.7.0-cp38-cp38-macosx_10_16_x86_64.whl", hash = "sha256:327a8c2550ddc573b51e2c352adb88143464bb9d92c10416feb86b0f5aee5ff6"},
|
||||||
|
{file = "black-23.7.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d1c6022b86f83b632d06f2b02774134def5d4d4f1dac8bef16d90cda18ba28a"},
|
||||||
|
{file = "black-23.7.0-cp38-cp38-win_amd64.whl", hash = "sha256:27eb7a0c71604d5de083757fbdb245b1a4fae60e9596514c6ec497eb63f95320"},
|
||||||
|
{file = "black-23.7.0-cp39-cp39-macosx_10_16_arm64.whl", hash = "sha256:8417dbd2f57b5701492cd46edcecc4f9208dc75529bcf76c514864e48da867d9"},
|
||||||
|
{file = "black-23.7.0-cp39-cp39-macosx_10_16_universal2.whl", hash = "sha256:47e56d83aad53ca140da0af87678fb38e44fd6bc0af71eebab2d1f59b1acf1d3"},
|
||||||
|
{file = "black-23.7.0-cp39-cp39-macosx_10_16_x86_64.whl", hash = "sha256:25cc308838fe71f7065df53aedd20327969d05671bac95b38fdf37ebe70ac087"},
|
||||||
|
{file = "black-23.7.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:642496b675095d423f9b8448243336f8ec71c9d4d57ec17bf795b67f08132a91"},
|
||||||
|
{file = "black-23.7.0-cp39-cp39-win_amd64.whl", hash = "sha256:ad0014efc7acf0bd745792bd0d8857413652979200ab924fbf239062adc12491"},
|
||||||
|
{file = "black-23.7.0-py3-none-any.whl", hash = "sha256:9fd59d418c60c0348505f2ddf9609c1e1de8e7493eab96198fc89d9f865e7a96"},
|
||||||
|
{file = "black-23.7.0.tar.gz", hash = "sha256:022a582720b0d9480ed82576c920a8c1dde97cc38ff11d8d8859b3bd6ca9eedb"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
click = ">=8.0.0"
|
||||||
|
mypy-extensions = ">=0.4.3"
|
||||||
|
packaging = ">=22.0"
|
||||||
|
pathspec = ">=0.9.0"
|
||||||
|
platformdirs = ">=2"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
colorama = ["colorama (>=0.4.3)"]
|
||||||
|
d = ["aiohttp (>=3.7.4)"]
|
||||||
|
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
||||||
|
uvloop = ["uvloop (>=0.15.2)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
version = "2023.7.22"
|
version = "2023.7.22"
|
||||||
@@ -496,6 +540,20 @@ files = [
|
|||||||
{file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"},
|
{file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "click"
|
||||||
|
version = "8.1.6"
|
||||||
|
description = "Composable command line interface toolkit"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "click-8.1.6-py3-none-any.whl", hash = "sha256:fa244bb30b3b5ee2cae3da8f55c9e5e0c0e86093306301fb418eb9dc40fbded5"},
|
||||||
|
{file = "click-8.1.6.tar.gz", hash = "sha256:48ee849951919527a045bfe3bf7baa8a959c423134e1a5b98c05c20ba75a1cbd"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "colorama"
|
name = "colorama"
|
||||||
version = "0.4.6"
|
version = "0.4.6"
|
||||||
@@ -1080,6 +1138,17 @@ files = [
|
|||||||
{file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
|
{file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mypy-extensions"
|
||||||
|
version = "1.0.0"
|
||||||
|
description = "Type system extensions for programs checked with the mypy type checker."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.5"
|
||||||
|
files = [
|
||||||
|
{file = "mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d"},
|
||||||
|
{file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "numpy"
|
name = "numpy"
|
||||||
version = "1.25.1"
|
version = "1.25.1"
|
||||||
@@ -1166,6 +1235,32 @@ files = [
|
|||||||
{file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"},
|
{file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pathspec"
|
||||||
|
version = "0.11.1"
|
||||||
|
description = "Utility library for gitignore style pattern matching of file paths."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "pathspec-0.11.1-py3-none-any.whl", hash = "sha256:d8af70af76652554bd134c22b3e8a1cc46ed7d91edcdd721ef1a0c51a84a5293"},
|
||||||
|
{file = "pathspec-0.11.1.tar.gz", hash = "sha256:2798de800fa92780e33acca925945e9a19a133b715067cf165b8866c15a31687"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "platformdirs"
|
||||||
|
version = "3.9.1"
|
||||||
|
description = "A small Python package for determining appropriate platform-specific dirs, e.g. a \"user data dir\"."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "platformdirs-3.9.1-py3-none-any.whl", hash = "sha256:ad8291ae0ae5072f66c16945166cb11c63394c7a3ad1b1bc9828ca3162da8c2f"},
|
||||||
|
{file = "platformdirs-3.9.1.tar.gz", hash = "sha256:1b42b450ad933e981d56e59f1b97495428c9bd60698baab9f3eb3d00d5822421"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
docs = ["furo (>=2023.5.20)", "proselint (>=0.13)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"]
|
||||||
|
test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "protobuf"
|
name = "protobuf"
|
||||||
version = "4.23.4"
|
version = "4.23.4"
|
||||||
@@ -1619,4 +1714,4 @@ multidict = ">=4.0"
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.11"
|
python-versions = "^3.11"
|
||||||
content-hash = "9b82606318ce1096923c0b25e5b3a6b07292f24465611d968e78f37a26e3d212"
|
content-hash = "e8eb6b4f81c090adb882a1b293d81f32167ea89f4636222d43fe0e9131cb97d6"
|
||||||
|
|||||||
@@ -18,6 +18,9 @@ sortedcontainers = "^2.4.0"
|
|||||||
loguru = "^0.7.0"
|
loguru = "^0.7.0"
|
||||||
|
|
||||||
|
|
||||||
|
[tool.poetry.group.dev.dependencies]
|
||||||
|
black = "^23.7.0"
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["poetry-core"]
|
requires = ["poetry-core"]
|
||||||
build-backend = "poetry.core.masonry.api"
|
build-backend = "poetry.core.masonry.api"
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class TitleSummaryInput:
|
|||||||
Data class for the input to generate title and summaries.
|
Data class for the input to generate title and summaries.
|
||||||
The outcome will be used to send query to the LLM for processing.
|
The outcome will be used to send query to the LLM for processing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
input_text = str
|
input_text = str
|
||||||
transcribed_time = float
|
transcribed_time = float
|
||||||
prompt = str
|
prompt = str
|
||||||
@@ -25,8 +26,7 @@ class TitleSummaryInput:
|
|||||||
def __init__(self, transcribed_time, input_text=""):
|
def __init__(self, transcribed_time, input_text=""):
|
||||||
self.input_text = input_text
|
self.input_text = input_text
|
||||||
self.transcribed_time = transcribed_time
|
self.transcribed_time = transcribed_time
|
||||||
self.prompt = \
|
self.prompt = f"""
|
||||||
f"""
|
|
||||||
### Human:
|
### Human:
|
||||||
Create a JSON object as response.The JSON object must have 2 fields:
|
Create a JSON object as response.The JSON object must have 2 fields:
|
||||||
i) title and ii) summary.For the title field,generate a short title
|
i) title and ii) summary.For the title field,generate a short title
|
||||||
@@ -47,6 +47,7 @@ class IncrementalResult:
|
|||||||
Data class for the result of generating one title and summaries.
|
Data class for the result of generating one title and summaries.
|
||||||
Defines how a single "topic" looks like.
|
Defines how a single "topic" looks like.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
title = str
|
title = str
|
||||||
description = str
|
description = str
|
||||||
transcript = str
|
transcript = str
|
||||||
@@ -65,6 +66,7 @@ class TitleSummaryOutput:
|
|||||||
Data class for the result of all generated titles and summaries.
|
Data class for the result of all generated titles and summaries.
|
||||||
The result will be sent back to the client
|
The result will be sent back to the client
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cmd = str
|
cmd = str
|
||||||
topics = List[IncrementalResult]
|
topics = List[IncrementalResult]
|
||||||
|
|
||||||
@@ -77,10 +79,7 @@ class TitleSummaryOutput:
|
|||||||
Return the result dict for displaying the transcription
|
Return the result dict for displaying the transcription
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return {
|
return {"cmd": self.cmd, "topics": self.topics}
|
||||||
"cmd": self.cmd,
|
|
||||||
"topics": self.topics
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -89,6 +88,7 @@ class ParseLLMResult:
|
|||||||
Data class to parse the result returned by the LLM while generating title
|
Data class to parse the result returned by the LLM while generating title
|
||||||
and summaries. The result will be sent back to the client.
|
and summaries. The result will be sent back to the client.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
title = str
|
title = str
|
||||||
description = str
|
description = str
|
||||||
transcript = str
|
transcript = str
|
||||||
@@ -98,8 +98,7 @@ class ParseLLMResult:
|
|||||||
self.title = output["title"]
|
self.title = output["title"]
|
||||||
self.transcript = param.input_text
|
self.transcript = param.input_text
|
||||||
self.description = output.pop("summary")
|
self.description = output.pop("summary")
|
||||||
self.timestamp = \
|
self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time)))
|
||||||
str(datetime.timedelta(seconds=round(param.transcribed_time)))
|
|
||||||
|
|
||||||
def get_result(self) -> dict:
|
def get_result(self) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -110,7 +109,7 @@ class ParseLLMResult:
|
|||||||
"title": self.title,
|
"title": self.title,
|
||||||
"description": self.description,
|
"description": self.description,
|
||||||
"transcript": self.transcript,
|
"transcript": self.transcript,
|
||||||
"timestamp": self.timestamp
|
"timestamp": self.timestamp,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -120,6 +119,7 @@ class TranscriptionInput:
|
|||||||
Data class to define the input to the transcription function
|
Data class to define the input to the transcription function
|
||||||
AudioFrames -> input
|
AudioFrames -> input
|
||||||
"""
|
"""
|
||||||
|
|
||||||
frames = List[av.audio.frame.AudioFrame]
|
frames = List[av.audio.frame.AudioFrame]
|
||||||
|
|
||||||
def __init__(self, frames):
|
def __init__(self, frames):
|
||||||
@@ -132,6 +132,7 @@ class TranscriptionOutput:
|
|||||||
Dataclass to define the result of the transcription function.
|
Dataclass to define the result of the transcription function.
|
||||||
The result will be sent back to the client
|
The result will be sent back to the client
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cmd = str
|
cmd = str
|
||||||
result_text = str
|
result_text = str
|
||||||
|
|
||||||
@@ -144,10 +145,7 @@ class TranscriptionOutput:
|
|||||||
Return the result dict for displaying the transcription
|
Return the result dict for displaying the transcription
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return {
|
return {"cmd": self.cmd, "text": self.result_text}
|
||||||
"cmd": self.cmd,
|
|
||||||
"text": self.result_text
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -156,6 +154,7 @@ class FinalSummaryResult:
|
|||||||
Dataclass to define the result of the final summary function.
|
Dataclass to define the result of the final summary function.
|
||||||
The result will be sent back to the client.
|
The result will be sent back to the client.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cmd = str
|
cmd = str
|
||||||
final_summary = str
|
final_summary = str
|
||||||
duration = str
|
duration = str
|
||||||
@@ -173,7 +172,7 @@ class FinalSummaryResult:
|
|||||||
return {
|
return {
|
||||||
"cmd": self.cmd,
|
"cmd": self.cmd,
|
||||||
"duration": self.duration,
|
"duration": self.duration,
|
||||||
"summary": self.final_summary
|
"summary": self.final_summary,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -182,9 +181,14 @@ class BlackListedMessages:
|
|||||||
Class to hold the blacklisted messages. These messages should be filtered
|
Class to hold the blacklisted messages. These messages should be filtered
|
||||||
out and not sent back to the client as part of the transcription.
|
out and not sent back to the client as part of the transcription.
|
||||||
"""
|
"""
|
||||||
messages = [" Thank you.", " See you next time!",
|
|
||||||
" Thank you for watching!", " Bye!",
|
messages = [
|
||||||
" And that's what I'm talking about."]
|
" Thank you.",
|
||||||
|
" See you next time!",
|
||||||
|
" Thank you for watching!",
|
||||||
|
" Bye!",
|
||||||
|
" And that's what I'm talking about.",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -17,8 +17,15 @@ from aiortc.contrib.media import MediaRelay
|
|||||||
from faster_whisper import WhisperModel
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
from reflector_dataclasses import (
|
from reflector_dataclasses import (
|
||||||
BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput,
|
BlackListedMessages,
|
||||||
TitleSummaryOutput, TranscriptionInput, TranscriptionOutput, TranscriptionContext)
|
FinalSummaryResult,
|
||||||
|
ParseLLMResult,
|
||||||
|
TitleSummaryInput,
|
||||||
|
TitleSummaryOutput,
|
||||||
|
TranscriptionInput,
|
||||||
|
TranscriptionOutput,
|
||||||
|
TranscriptionContext,
|
||||||
|
)
|
||||||
from utils.log_utils import LOGGER
|
from utils.log_utils import LOGGER
|
||||||
from utils.run_utils import CONFIG, run_in_executor, SECRETS
|
from utils.run_utils import CONFIG, run_in_executor, SECRETS
|
||||||
|
|
||||||
@@ -28,9 +35,7 @@ relay = MediaRelay()
|
|||||||
executor = ThreadPoolExecutor()
|
executor = ThreadPoolExecutor()
|
||||||
|
|
||||||
# Transcription model
|
# Transcription model
|
||||||
model = WhisperModel("tiny", device="cpu",
|
model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12)
|
||||||
compute_type="float32",
|
|
||||||
num_workers=12)
|
|
||||||
|
|
||||||
# Audio configurations
|
# Audio configurations
|
||||||
CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"])
|
CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"])
|
||||||
@@ -46,7 +51,10 @@ else:
|
|||||||
LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"]
|
LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"]
|
||||||
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
|
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
|
||||||
|
|
||||||
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]:
|
|
||||||
|
def parse_llm_output(
|
||||||
|
param: TitleSummaryInput, response: requests.Response
|
||||||
|
) -> Union[None, ParseLLMResult]:
|
||||||
"""
|
"""
|
||||||
Function to parse the LLM response
|
Function to parse the LLM response
|
||||||
:param param:
|
:param param:
|
||||||
@@ -61,7 +69,9 @@ def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> U
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]:
|
def get_title_and_summary(
|
||||||
|
ctx: TranscriptionContext, param: TitleSummaryInput
|
||||||
|
) -> Union[None, TitleSummaryOutput]:
|
||||||
"""
|
"""
|
||||||
From the input provided (transcript), query the LLM to generate
|
From the input provided (transcript), query the LLM to generate
|
||||||
topics and summaries
|
topics and summaries
|
||||||
@@ -72,9 +82,7 @@ def get_title_and_summary(ctx: TranscriptionContext, param: TitleSummaryInput) -
|
|||||||
|
|
||||||
# TODO : Handle unexpected output formats from the model
|
# TODO : Handle unexpected output formats from the model
|
||||||
try:
|
try:
|
||||||
response = requests.post(LLM_URL,
|
response = requests.post(LLM_URL, headers=param.headers, json=param.data)
|
||||||
headers=param.headers,
|
|
||||||
json=param.data)
|
|
||||||
output = parse_llm_output(param, response)
|
output = parse_llm_output(param, response)
|
||||||
if output:
|
if output:
|
||||||
result = output.get_result()
|
result = output.get_result()
|
||||||
@@ -107,7 +115,9 @@ def channel_send(channel, message: str) -> NoReturn:
|
|||||||
channel.send(message)
|
channel.send(message)
|
||||||
|
|
||||||
|
|
||||||
def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn:
|
def channel_send_increment(
|
||||||
|
channel, param: Union[FinalSummaryResult, TitleSummaryOutput]
|
||||||
|
) -> NoReturn:
|
||||||
"""
|
"""
|
||||||
Send the incremental topics and summaries via the data channel
|
Send the incremental topics and summaries via the data channel
|
||||||
:param channel:
|
:param channel:
|
||||||
@@ -145,7 +155,9 @@ def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn:
|
|||||||
LOGGER.info("Exception", str(exception))
|
LOGGER.info("Exception", str(exception))
|
||||||
|
|
||||||
|
|
||||||
def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]:
|
def get_transcription(
|
||||||
|
ctx: TranscriptionContext, input_frames: TranscriptionInput
|
||||||
|
) -> Union[None, TranscriptionOutput]:
|
||||||
"""
|
"""
|
||||||
From the collected audio frames create transcription by inferring from
|
From the collected audio frames create transcription by inferring from
|
||||||
the chosen transcription model
|
the chosen transcription model
|
||||||
@@ -173,12 +185,13 @@ def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInpu
|
|||||||
result_text = ""
|
result_text = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
segments, _ = \
|
segments, _ = model.transcribe(
|
||||||
model.transcribe(audio_file,
|
audio_file,
|
||||||
language="en",
|
language="en",
|
||||||
beam_size=5,
|
beam_size=5,
|
||||||
vad_filter=True,
|
vad_filter=True,
|
||||||
vad_parameters={"min_silence_duration_ms": 500})
|
vad_parameters={"min_silence_duration_ms": 500},
|
||||||
|
)
|
||||||
os.remove(audio_file)
|
os.remove(audio_file)
|
||||||
segments = list(segments)
|
segments = list(segments)
|
||||||
result_text = ""
|
result_text = ""
|
||||||
@@ -191,7 +204,7 @@ def get_transcription(ctx: TranscriptionContext, input_frames: TranscriptionInpu
|
|||||||
start_time = 0.0
|
start_time = 0.0
|
||||||
if not segment.end:
|
if not segment.end:
|
||||||
end_time = 5.5
|
end_time = 5.5
|
||||||
duration += (end_time - start_time)
|
duration += end_time - start_time
|
||||||
|
|
||||||
ctx.last_transcribed_time += duration
|
ctx.last_transcribed_time += duration
|
||||||
ctx.transcription_text += result_text
|
ctx.transcription_text += result_text
|
||||||
@@ -218,8 +231,9 @@ def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult:
|
|||||||
|
|
||||||
response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
|
response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
|
||||||
|
|
||||||
with open("./artefacts/meeting_titles_and_summaries.txt", "a",
|
with open(
|
||||||
encoding="utf-8") as file:
|
"./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8"
|
||||||
|
) as file:
|
||||||
file.write(json.dumps(ctx.incremental_responses))
|
file.write(json.dumps(ctx.incremental_responses))
|
||||||
|
|
||||||
return response
|
return response
|
||||||
@@ -243,31 +257,30 @@ class AudioStreamTrack(MediaStreamTrack):
|
|||||||
frame = await self.track.recv()
|
frame = await self.track.recv()
|
||||||
self.audio_buffer.write(frame)
|
self.audio_buffer.write(frame)
|
||||||
|
|
||||||
if local_frames := self.audio_buffer.read_many(AUDIO_BUFFER_SIZE, partial=False):
|
if local_frames := self.audio_buffer.read_many(
|
||||||
|
AUDIO_BUFFER_SIZE, partial=False
|
||||||
|
):
|
||||||
whisper_result = run_in_executor(
|
whisper_result = run_in_executor(
|
||||||
get_transcription,
|
get_transcription,
|
||||||
ctx,
|
ctx,
|
||||||
TranscriptionInput(local_frames),
|
TranscriptionInput(local_frames),
|
||||||
executor=executor
|
executor=executor,
|
||||||
)
|
)
|
||||||
whisper_result.add_done_callback(
|
whisper_result.add_done_callback(
|
||||||
lambda f: channel_send_transcript(ctx)
|
lambda f: channel_send_transcript(ctx) if f.result() else None
|
||||||
if f.result()
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(ctx.transcription_text) > 25:
|
if len(ctx.transcription_text) > 25:
|
||||||
llm_input_text = ctx.transcription_text
|
llm_input_text = ctx.transcription_text
|
||||||
ctx.transcription_text = ""
|
ctx.transcription_text = ""
|
||||||
param = TitleSummaryInput(input_text=llm_input_text,
|
param = TitleSummaryInput(
|
||||||
transcribed_time=ctx.last_transcribed_time)
|
input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time
|
||||||
llm_result = run_in_executor(get_title_and_summary,
|
)
|
||||||
ctx,
|
llm_result = run_in_executor(
|
||||||
param,
|
get_title_and_summary, ctx, param, executor=executor
|
||||||
executor=executor)
|
)
|
||||||
llm_result.add_done_callback(
|
llm_result.add_done_callback(
|
||||||
lambda f: channel_send_increment(ctx.data_channel,
|
lambda f: channel_send_increment(ctx.data_channel, llm_result.result())
|
||||||
llm_result.result())
|
|
||||||
if f.result()
|
if f.result()
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
@@ -330,10 +343,7 @@ async def offer(request: requests.Request) -> web.Response:
|
|||||||
return web.Response(
|
return web.Response(
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
text=json.dumps(
|
text=json.dumps(
|
||||||
{
|
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
||||||
"sdp": pc.localDescription.sdp,
|
|
||||||
"type": pc.localDescription.type
|
|
||||||
}
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -351,9 +361,7 @@ async def on_shutdown(application: web.Application) -> NoReturn:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(description="WebRTC based server for Reflector")
|
||||||
description="WebRTC based server for Reflector"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
|
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
|
||||||
)
|
)
|
||||||
@@ -366,9 +374,7 @@ if __name__ == "__main__":
|
|||||||
app,
|
app,
|
||||||
defaults={
|
defaults={
|
||||||
"*": aiohttp_cors.ResourceOptions(
|
"*": aiohttp_cors.ResourceOptions(
|
||||||
allow_credentials=True,
|
allow_credentials=True, expose_headers="*", allow_headers="*"
|
||||||
expose_headers="*",
|
|
||||||
allow_headers="*"
|
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import httpx
|
|||||||
import pyaudio
|
import pyaudio
|
||||||
import requests
|
import requests
|
||||||
import stamina
|
import stamina
|
||||||
from aiortc import (RTCPeerConnection, RTCSessionDescription)
|
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||||
from aiortc.contrib.media import (MediaPlayer, MediaRelay)
|
from aiortc.contrib.media import MediaPlayer, MediaRelay
|
||||||
|
|
||||||
from utils.log_utils import LOGGER
|
from utils.log_utils import LOGGER
|
||||||
from utils.run_utils import CONFIG
|
from utils.run_utils import CONFIG
|
||||||
@@ -15,11 +15,7 @@ from utils.run_utils import CONFIG
|
|||||||
|
|
||||||
class StreamClient:
|
class StreamClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self, signaling, url="http://0.0.0.0:1250", play_from=None, ping_pong=False
|
||||||
signaling,
|
|
||||||
url="http://0.0.0.0:1250",
|
|
||||||
play_from=None,
|
|
||||||
ping_pong=False
|
|
||||||
):
|
):
|
||||||
self.signaling = signaling
|
self.signaling = signaling
|
||||||
self.server_url = url
|
self.server_url = url
|
||||||
@@ -35,9 +31,10 @@ class StreamClient:
|
|||||||
self.time_start = None
|
self.time_start = None
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
self.player = MediaPlayer(
|
self.player = MediaPlayer(
|
||||||
':' + str(CONFIG['AUDIO']["AV_FOUNDATION_DEVICE_ID"]),
|
":" + str(CONFIG["AUDIO"]["AV_FOUNDATION_DEVICE_ID"]),
|
||||||
format='avfoundation',
|
format="avfoundation",
|
||||||
options={'channels': '2'})
|
options={"channels": "2"},
|
||||||
|
)
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.loop.run_until_complete(self.signaling.close())
|
self.loop.run_until_complete(self.signaling.close())
|
||||||
@@ -114,16 +111,12 @@ class StreamClient:
|
|||||||
self.channel_log(channel, "<", message)
|
self.channel_log(channel, "<", message)
|
||||||
|
|
||||||
if isinstance(message, str) and message.startswith("pong"):
|
if isinstance(message, str) and message.startswith("pong"):
|
||||||
elapsed_ms = (self.current_stamp() - int(message[5:])) \
|
elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000
|
||||||
/ 1000
|
|
||||||
print(" RTT %.2f ms" % elapsed_ms)
|
print(" RTT %.2f ms" % elapsed_ms)
|
||||||
|
|
||||||
await pc.setLocalDescription(await pc.createOffer())
|
await pc.setLocalDescription(await pc.createOffer())
|
||||||
|
|
||||||
sdp = {
|
sdp = {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
||||||
"sdp": pc.localDescription.sdp,
|
|
||||||
"type": pc.localDescription.type
|
|
||||||
}
|
|
||||||
|
|
||||||
@stamina.retry(on=httpx.HTTPError, attempts=5)
|
@stamina.retry(on=httpx.HTTPError, attempts=5)
|
||||||
def connect_to_server():
|
def connect_to_server():
|
||||||
|
|||||||
@@ -14,9 +14,11 @@ from .run_utils import SECRETS
|
|||||||
|
|
||||||
BUCKET_NAME = SECRETS["AWS-S3"]["BUCKET_NAME"]
|
BUCKET_NAME = SECRETS["AWS-S3"]["BUCKET_NAME"]
|
||||||
|
|
||||||
s3 = boto3.client('s3',
|
s3 = boto3.client(
|
||||||
|
"s3",
|
||||||
aws_access_key_id=SECRETS["AWS-S3"]["AWS_ACCESS_KEY"],
|
aws_access_key_id=SECRETS["AWS-S3"]["AWS_ACCESS_KEY"],
|
||||||
aws_secret_access_key=SECRETS["AWS-S3"]["AWS_SECRET_KEY"])
|
aws_secret_access_key=SECRETS["AWS-S3"]["AWS_SECRET_KEY"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def upload_files(files_to_upload: List[str]) -> NoReturn:
|
def upload_files(files_to_upload: List[str]) -> NoReturn:
|
||||||
@@ -44,7 +46,7 @@ def download_files(files_to_download: List[str]) -> NoReturn:
|
|||||||
try:
|
try:
|
||||||
s3.download_file(BUCKET_NAME, key, key)
|
s3.download_file(BUCKET_NAME, key, key)
|
||||||
except botocore.exceptions.ClientError as exception:
|
except botocore.exceptions.ClientError as exception:
|
||||||
if exception.response['Error']['Code'] == "404":
|
if exception.response["Error"]["Code"] == "404":
|
||||||
print("The object does not exist.")
|
print("The object does not exist.")
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -4,21 +4,16 @@ Utility function to format the artefacts created during Reflector run
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
with open("../artefacts/meeting_titles_and_summaries.txt", "r",
|
with open("../artefacts/meeting_titles_and_summaries.txt", "r", encoding="utf-8") as f:
|
||||||
encoding='utf-8') as f:
|
|
||||||
outputs = f.read()
|
outputs = f.read()
|
||||||
|
|
||||||
outputs = json.loads(outputs)
|
outputs = json.loads(outputs)
|
||||||
|
|
||||||
transcript_file = open("../artefacts/meeting_transcript.txt",
|
transcript_file = open("../artefacts/meeting_transcript.txt", "a", encoding="utf-8")
|
||||||
"a",
|
title_desc_file = open(
|
||||||
encoding='utf-8')
|
"../artefacts/meeting_title_description.txt", "a", encoding="utf-8"
|
||||||
title_desc_file = open("../artefacts/meeting_title_description.txt",
|
)
|
||||||
"a",
|
summary_file = open("../artefacts/meeting_summary.txt", "a", encoding="utf-8")
|
||||||
encoding='utf-8')
|
|
||||||
summary_file = open("../artefacts/meeting_summary.txt",
|
|
||||||
"a",
|
|
||||||
encoding='utf-8')
|
|
||||||
|
|
||||||
for item in outputs["topics"]:
|
for item in outputs["topics"]:
|
||||||
transcript_file.write(item["transcript"])
|
transcript_file.write(item["transcript"])
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ class SingletonLogger:
|
|||||||
Use Singleton design pattern to create a logger object and share it
|
Use Singleton design pattern to create a logger object and share it
|
||||||
across the entire project
|
across the entire project
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__instance = None
|
__instance = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ class ReflectorConfig:
|
|||||||
"""
|
"""
|
||||||
Create a single config object to share across the project
|
Create a single config object to share across the project
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__config = None
|
__config = None
|
||||||
__secrets = None
|
__secrets = None
|
||||||
|
|
||||||
@@ -25,7 +26,7 @@ class ReflectorConfig:
|
|||||||
"""
|
"""
|
||||||
if ReflectorConfig.__config is None:
|
if ReflectorConfig.__config is None:
|
||||||
ReflectorConfig.__config = configparser.ConfigParser()
|
ReflectorConfig.__config = configparser.ConfigParser()
|
||||||
ReflectorConfig.__config.read('utils/config.ini')
|
ReflectorConfig.__config.read("utils/config.ini")
|
||||||
return ReflectorConfig.__config
|
return ReflectorConfig.__config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -36,7 +37,7 @@ class ReflectorConfig:
|
|||||||
"""
|
"""
|
||||||
if ReflectorConfig.__secrets is None:
|
if ReflectorConfig.__secrets is None:
|
||||||
ReflectorConfig.__secrets = configparser.ConfigParser()
|
ReflectorConfig.__secrets = configparser.ConfigParser()
|
||||||
ReflectorConfig.__secrets.read('utils/secrets.ini')
|
ReflectorConfig.__secrets.read("utils/secrets.ini")
|
||||||
return ReflectorConfig.__secrets
|
return ReflectorConfig.__secrets
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from transformers import BartForConditionalGeneration, BartTokenizer
|
|||||||
from log_utils import LOGGER
|
from log_utils import LOGGER
|
||||||
from run_utils import CONFIG
|
from run_utils import CONFIG
|
||||||
|
|
||||||
nltk.download('punkt', quiet=True)
|
nltk.download("punkt", quiet=True)
|
||||||
|
|
||||||
|
|
||||||
def preprocess_sentence(sentence: str) -> str:
|
def preprocess_sentence(sentence: str) -> str:
|
||||||
@@ -24,11 +24,10 @@ def preprocess_sentence(sentence: str) -> str:
|
|||||||
:param sentence:
|
:param sentence:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
stop_words = set(stopwords.words('english'))
|
stop_words = set(stopwords.words("english"))
|
||||||
tokens = word_tokenize(sentence.lower())
|
tokens = word_tokenize(sentence.lower())
|
||||||
tokens = [token for token in tokens
|
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
||||||
if token.isalnum() and token not in stop_words]
|
return " ".join(tokens)
|
||||||
return ' '.join(tokens)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_similarity(sent1: str, sent2: str) -> float:
|
def compute_similarity(sent1: str, sent2: str) -> float:
|
||||||
@@ -67,14 +66,14 @@ def remove_almost_alike_sentences(sentences: List[str], threshold=0.7) -> List[s
|
|||||||
sentence1 = preprocess_sentence(sentences[i])
|
sentence1 = preprocess_sentence(sentences[i])
|
||||||
sentence2 = preprocess_sentence(sentences[j])
|
sentence2 = preprocess_sentence(sentences[j])
|
||||||
if len(sentence1) != 0 and len(sentence2) != 0:
|
if len(sentence1) != 0 and len(sentence2) != 0:
|
||||||
similarity = compute_similarity(sentence1,
|
similarity = compute_similarity(sentence1, sentence2)
|
||||||
sentence2)
|
|
||||||
|
|
||||||
if similarity >= threshold:
|
if similarity >= threshold:
|
||||||
removed_indices.add(max(i, j))
|
removed_indices.add(max(i, j))
|
||||||
|
|
||||||
filtered_sentences = [sentences[i] for i in range(num_sentences)
|
filtered_sentences = [
|
||||||
if i not in removed_indices]
|
sentences[i] for i in range(num_sentences) if i not in removed_indices
|
||||||
|
]
|
||||||
return filtered_sentences
|
return filtered_sentences
|
||||||
|
|
||||||
|
|
||||||
@@ -90,7 +89,9 @@ def remove_outright_duplicate_sentences_from_chunk(chunk: str) -> List[str]:
|
|||||||
return nonduplicate_sentences
|
return nonduplicate_sentences
|
||||||
|
|
||||||
|
|
||||||
def remove_whisper_repetitive_hallucination(nonduplicate_sentences: List[str]) -> List[str]:
|
def remove_whisper_repetitive_hallucination(
|
||||||
|
nonduplicate_sentences: List[str],
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Remove sentences that are repeated as a result of Whisper
|
Remove sentences that are repeated as a result of Whisper
|
||||||
hallucinations
|
hallucinations
|
||||||
@@ -105,13 +106,16 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences: List[str]) -
|
|||||||
words = nltk.word_tokenize(sent)
|
words = nltk.word_tokenize(sent)
|
||||||
n_gram_filter = 3
|
n_gram_filter = 3
|
||||||
for i in range(len(words)):
|
for i in range(len(words)):
|
||||||
if str(words[i:i + n_gram_filter]) in seen and \
|
if (
|
||||||
seen[str(words[i:i + n_gram_filter])] == \
|
str(words[i : i + n_gram_filter]) in seen
|
||||||
words[i + 1:i + n_gram_filter + 2]:
|
and seen[str(words[i : i + n_gram_filter])]
|
||||||
|
== words[i + 1 : i + n_gram_filter + 2]
|
||||||
|
):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
seen[str(words[i:i + n_gram_filter])] = \
|
seen[str(words[i : i + n_gram_filter])] = words[
|
||||||
words[i + 1:i + n_gram_filter + 2]
|
i + 1 : i + n_gram_filter + 2
|
||||||
|
]
|
||||||
temp_result += words[i]
|
temp_result += words[i]
|
||||||
temp_result += " "
|
temp_result += " "
|
||||||
chunk_sentences.append(temp_result)
|
chunk_sentences.append(temp_result)
|
||||||
@@ -126,12 +130,11 @@ def post_process_transcription(whisper_result: dict) -> dict:
|
|||||||
"""
|
"""
|
||||||
transcript_text = ""
|
transcript_text = ""
|
||||||
for chunk in whisper_result["chunks"]:
|
for chunk in whisper_result["chunks"]:
|
||||||
nonduplicate_sentences = \
|
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
|
||||||
remove_outright_duplicate_sentences_from_chunk(chunk)
|
chunk_sentences = remove_whisper_repetitive_hallucination(
|
||||||
chunk_sentences = \
|
nonduplicate_sentences
|
||||||
remove_whisper_repetitive_hallucination(nonduplicate_sentences)
|
)
|
||||||
similarity_matched_sentences = \
|
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
|
||||||
remove_almost_alike_sentences(chunk_sentences)
|
|
||||||
chunk["text"] = " ".join(similarity_matched_sentences)
|
chunk["text"] = " ".join(similarity_matched_sentences)
|
||||||
transcript_text += chunk["text"]
|
transcript_text += chunk["text"]
|
||||||
whisper_result["text"] = transcript_text
|
whisper_result["text"] = transcript_text
|
||||||
@@ -149,23 +152,24 @@ def summarize_chunks(chunks: List[str], tokenizer, model) -> List[str]:
|
|||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
summaries = []
|
summaries = []
|
||||||
for c in chunks:
|
for c in chunks:
|
||||||
input_ids = tokenizer.encode(c, return_tensors='pt')
|
input_ids = tokenizer.encode(c, return_tensors="pt")
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
summary_ids = \
|
summary_ids = model.generate(
|
||||||
model.generate(input_ids,
|
input_ids,
|
||||||
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
|
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
|
||||||
length_penalty=2.0,
|
length_penalty=2.0,
|
||||||
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
|
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
|
||||||
early_stopping=True)
|
early_stopping=True,
|
||||||
summary = tokenizer.decode(summary_ids[0],
|
)
|
||||||
skip_special_tokens=True)
|
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
||||||
summaries.append(summary)
|
summaries.append(summary)
|
||||||
return summaries
|
return summaries
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(text: str,
|
def chunk_text(
|
||||||
max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])) -> List[str]:
|
text: str, max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])
|
||||||
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Split text into smaller chunks.
|
Split text into smaller chunks.
|
||||||
:param text: Text to be chunked
|
:param text: Text to be chunked
|
||||||
@@ -185,9 +189,12 @@ def chunk_text(text: str,
|
|||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
def summarize(transcript_text: str, timestamp: datetime.datetime.timestamp,
|
def summarize(
|
||||||
|
transcript_text: str,
|
||||||
|
timestamp: datetime.datetime.timestamp,
|
||||||
real_time: bool = False,
|
real_time: bool = False,
|
||||||
chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
|
chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"],
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Summarize the given text either as a whole or as chunks as needed
|
Summarize the given text either as a whole or as chunks as needed
|
||||||
:param transcript_text:
|
:param transcript_text:
|
||||||
@@ -213,39 +220,45 @@ def summarize(transcript_text: str, timestamp: datetime.datetime.timestamp,
|
|||||||
|
|
||||||
if chunk_summarize != "YES":
|
if chunk_summarize != "YES":
|
||||||
max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
|
max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
|
||||||
inputs = tokenizer. \
|
inputs = tokenizer.batch_encode_plus(
|
||||||
batch_encode_plus([transcript_text], truncation=True,
|
[transcript_text],
|
||||||
padding='longest',
|
truncation=True,
|
||||||
|
padding="longest",
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
return_tensors='pt')
|
return_tensors="pt",
|
||||||
|
)
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
|
num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
|
||||||
max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
|
max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
|
||||||
summaries = model.generate(inputs['input_ids'],
|
summaries = model.generate(
|
||||||
|
inputs["input_ids"],
|
||||||
num_beams=num_beans,
|
num_beams=num_beans,
|
||||||
length_penalty=2.0,
|
length_penalty=2.0,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
early_stopping=True)
|
early_stopping=True,
|
||||||
|
)
|
||||||
|
|
||||||
decoded_summaries = \
|
decoded_summaries = [
|
||||||
[tokenizer.decode(summary,
|
tokenizer.decode(
|
||||||
skip_special_tokens=True,
|
summary, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||||
clean_up_tokenization_spaces=False)
|
)
|
||||||
for summary in summaries]
|
for summary in summaries
|
||||||
|
]
|
||||||
summary = " ".join(decoded_summaries)
|
summary = " ".join(decoded_summaries)
|
||||||
with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file:
|
with open("./artefacts/" + output_file, "w", encoding="utf-8") as file:
|
||||||
file.write(summary.strip() + "\n")
|
file.write(summary.strip() + "\n")
|
||||||
else:
|
else:
|
||||||
LOGGER.info("Breaking transcript into smaller chunks")
|
LOGGER.info("Breaking transcript into smaller chunks")
|
||||||
chunks = chunk_text(transcript_text)
|
chunks = chunk_text(transcript_text)
|
||||||
|
|
||||||
LOGGER.info(f"Transcript broken into {len(chunks)} "
|
LOGGER.info(
|
||||||
f"chunks of at most 500 words")
|
f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words"
|
||||||
|
)
|
||||||
|
|
||||||
LOGGER.info(f"Writing summary text to: {output_file}")
|
LOGGER.info(f"Writing summary text to: {output_file}")
|
||||||
with open(output_file, 'w') as f:
|
with open(output_file, "w") as f:
|
||||||
summaries = summarize_chunks(chunks, tokenizer, model)
|
summaries = summarize_chunks(chunks, tokenizer, model)
|
||||||
for summary in summaries:
|
for summary in summaries:
|
||||||
f.write(summary.strip() + " ")
|
f.write(summary.strip() + " ")
|
||||||
|
|||||||
@@ -16,23 +16,30 @@ import spacy
|
|||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
from wordcloud import STOPWORDS, WordCloud
|
from wordcloud import STOPWORDS, WordCloud
|
||||||
|
|
||||||
en = spacy.load('en_core_web_md')
|
en = spacy.load("en_core_web_md")
|
||||||
spacy_stopwords = en.Defaults.stop_words
|
spacy_stopwords = en.Defaults.stop_words
|
||||||
|
|
||||||
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))). \
|
STOPWORDS = (
|
||||||
union(set(spacy_stopwords))
|
set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_wordcloud(timestamp: datetime.datetime.timestamp,
|
def create_wordcloud(
|
||||||
real_time: bool = False) -> NoReturn:
|
timestamp: datetime.datetime.timestamp, real_time: bool = False
|
||||||
|
) -> NoReturn:
|
||||||
"""
|
"""
|
||||||
Create a basic word cloud visualization of transcribed text
|
Create a basic word cloud visualization of transcribed text
|
||||||
:return: None. The wordcloud image is saved locally
|
:return: None. The wordcloud image is saved locally
|
||||||
"""
|
"""
|
||||||
filename = "transcript"
|
filename = "transcript"
|
||||||
if real_time:
|
if real_time:
|
||||||
filename = "real_time_" + filename + "_" + \
|
filename = (
|
||||||
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
"real_time_"
|
||||||
|
+ filename
|
||||||
|
+ "_"
|
||||||
|
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
|
+ ".txt"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||||
|
|
||||||
@@ -41,10 +48,13 @@ def create_wordcloud(timestamp: datetime.datetime.timestamp,
|
|||||||
|
|
||||||
# python_mask = np.array(PIL.Image.open("download1.png"))
|
# python_mask = np.array(PIL.Image.open("download1.png"))
|
||||||
|
|
||||||
wordcloud = WordCloud(height=800, width=800,
|
wordcloud = WordCloud(
|
||||||
background_color='white',
|
height=800,
|
||||||
|
width=800,
|
||||||
|
background_color="white",
|
||||||
stopwords=STOPWORDS,
|
stopwords=STOPWORDS,
|
||||||
min_font_size=8).generate(transcription_text)
|
min_font_size=8,
|
||||||
|
).generate(transcription_text)
|
||||||
|
|
||||||
# Plot wordcloud and save image
|
# Plot wordcloud and save image
|
||||||
plt.figure(facecolor=None)
|
plt.figure(facecolor=None)
|
||||||
@@ -54,16 +64,22 @@ def create_wordcloud(timestamp: datetime.datetime.timestamp,
|
|||||||
|
|
||||||
wordcloud = "wordcloud"
|
wordcloud = "wordcloud"
|
||||||
if real_time:
|
if real_time:
|
||||||
wordcloud = "real_time_" + wordcloud + "_" + \
|
wordcloud = (
|
||||||
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
"real_time_"
|
||||||
|
+ wordcloud
|
||||||
|
+ "_"
|
||||||
|
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
|
+ ".png"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
||||||
|
|
||||||
plt.savefig("./artefacts/" + wordcloud)
|
plt.savefig("./artefacts/" + wordcloud)
|
||||||
|
|
||||||
|
|
||||||
def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
|
def create_talk_diff_scatter_viz(
|
||||||
real_time: bool = False) -> NoReturn:
|
timestamp: datetime.datetime.timestamp, real_time: bool = False
|
||||||
|
) -> NoReturn:
|
||||||
"""
|
"""
|
||||||
Perform agenda vs transcription diff to see covered topics.
|
Perform agenda vs transcription diff to see covered topics.
|
||||||
Create a scatter plot of words in topics.
|
Create a scatter plot of words in topics.
|
||||||
@@ -71,7 +87,7 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
|
|||||||
"""
|
"""
|
||||||
spacy_model = "en_core_web_md"
|
spacy_model = "en_core_web_md"
|
||||||
nlp = spacy.load(spacy_model)
|
nlp = spacy.load(spacy_model)
|
||||||
nlp.add_pipe('sentencizer')
|
nlp.add_pipe("sentencizer")
|
||||||
|
|
||||||
agenda_topics = []
|
agenda_topics = []
|
||||||
agenda = []
|
agenda = []
|
||||||
@@ -84,11 +100,17 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
|
|||||||
|
|
||||||
# Load the transcription with timestamp
|
# Load the transcription with timestamp
|
||||||
if real_time:
|
if real_time:
|
||||||
filename = "./artefacts/real_time_transcript_with_timestamp_" + \
|
filename = (
|
||||||
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
"./artefacts/real_time_transcript_with_timestamp_"
|
||||||
|
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
|
+ ".txt"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
filename = "./artefacts/transcript_with_timestamp_" + \
|
filename = (
|
||||||
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
"./artefacts/transcript_with_timestamp_"
|
||||||
|
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
|
+ ".txt"
|
||||||
|
)
|
||||||
with open(filename) as file:
|
with open(filename) as file:
|
||||||
transcription_timestamp_text = file.read()
|
transcription_timestamp_text = file.read()
|
||||||
|
|
||||||
@@ -128,14 +150,20 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
|
|||||||
covered_items[agenda[topic_similarities[i][0]]] = True
|
covered_items[agenda[topic_similarities[i][0]]] = True
|
||||||
# top1 match
|
# top1 match
|
||||||
if i == 0:
|
if i == 0:
|
||||||
ts_to_topic_mapping_top_1[c["timestamp"]] = \
|
ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[
|
||||||
|
topic_similarities[i][0]
|
||||||
|
]
|
||||||
|
topic_to_ts_mapping_top_1[
|
||||||
agenda_topics[topic_similarities[i][0]]
|
agenda_topics[topic_similarities[i][0]]
|
||||||
topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
|
].append(c["timestamp"])
|
||||||
# top2 match
|
# top2 match
|
||||||
else:
|
else:
|
||||||
ts_to_topic_mapping_top_2[c["timestamp"]] = \
|
ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[
|
||||||
|
topic_similarities[i][0]
|
||||||
|
]
|
||||||
|
topic_to_ts_mapping_top_2[
|
||||||
agenda_topics[topic_similarities[i][0]]
|
agenda_topics[topic_similarities[i][0]]
|
||||||
topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
|
].append(c["timestamp"])
|
||||||
|
|
||||||
def create_new_columns(record: dict) -> dict:
|
def create_new_columns(record: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
@@ -143,10 +171,12 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
|
|||||||
:param record:
|
:param record:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
record["ts_to_topic_mapping_top_1"] = \
|
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[
|
||||||
ts_to_topic_mapping_top_1[record["timestamp"]]
|
record["timestamp"]
|
||||||
record["ts_to_topic_mapping_top_2"] = \
|
]
|
||||||
ts_to_topic_mapping_top_2[record["timestamp"]]
|
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[
|
||||||
|
record["timestamp"]
|
||||||
|
]
|
||||||
return record
|
return record
|
||||||
|
|
||||||
df = df.apply(create_new_columns, axis=1)
|
df = df.apply(create_new_columns, axis=1)
|
||||||
@@ -167,19 +197,33 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
|
|||||||
# Save df, mappings for further experimentation
|
# Save df, mappings for further experimentation
|
||||||
df_name = "df"
|
df_name = "df"
|
||||||
if real_time:
|
if real_time:
|
||||||
df_name = "real_time_" + df_name + "_" + \
|
df_name = (
|
||||||
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
"real_time_"
|
||||||
|
+ df_name
|
||||||
|
+ "_"
|
||||||
|
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
|
+ ".pkl"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||||
df.to_pickle("./artefacts/" + df_name)
|
df.to_pickle("./artefacts/" + df_name)
|
||||||
|
|
||||||
my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2,
|
my_mappings = [
|
||||||
topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2]
|
ts_to_topic_mapping_top_1,
|
||||||
|
ts_to_topic_mapping_top_2,
|
||||||
|
topic_to_ts_mapping_top_1,
|
||||||
|
topic_to_ts_mapping_top_2,
|
||||||
|
]
|
||||||
|
|
||||||
mappings_name = "mappings"
|
mappings_name = "mappings"
|
||||||
if real_time:
|
if real_time:
|
||||||
mappings_name = "real_time_" + mappings_name + "_" + \
|
mappings_name = (
|
||||||
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
"real_time_"
|
||||||
|
+ mappings_name
|
||||||
|
+ "_"
|
||||||
|
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
|
+ ".pkl"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||||
pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb"))
|
pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb"))
|
||||||
@@ -203,23 +247,37 @@ def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
|
|||||||
|
|
||||||
# Scatter plot of topics
|
# Scatter plot of topics
|
||||||
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
|
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
|
||||||
corpus = st.CorpusFromParsedDocuments(
|
corpus = (
|
||||||
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
|
st.CorpusFromParsedDocuments(
|
||||||
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
|
df, category_col="ts_to_topic_mapping_top_1", parsed_col="parse"
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
.get_unigram_corpus()
|
||||||
|
.compact(st.AssociationCompactor(2000))
|
||||||
|
)
|
||||||
html = st.produce_scattertext_explorer(
|
html = st.produce_scattertext_explorer(
|
||||||
corpus,
|
corpus,
|
||||||
category=cat_1,
|
category=cat_1,
|
||||||
category_name=cat_1_name,
|
category_name=cat_1_name,
|
||||||
not_category_name=cat_2_name,
|
not_category_name=cat_2_name,
|
||||||
minimum_term_frequency=0, pmi_threshold_coefficient=0,
|
minimum_term_frequency=0,
|
||||||
|
pmi_threshold_coefficient=0,
|
||||||
width_in_pixels=1000,
|
width_in_pixels=1000,
|
||||||
transform=st.Scalers.dense_rank
|
transform=st.Scalers.dense_rank,
|
||||||
)
|
)
|
||||||
if real_time:
|
if real_time:
|
||||||
with open('./artefacts/real_time_scatter_' +
|
with open(
|
||||||
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w') as file:
|
"./artefacts/real_time_scatter_"
|
||||||
|
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
|
+ ".html",
|
||||||
|
"w",
|
||||||
|
) as file:
|
||||||
file.write(html)
|
file.write(html)
|
||||||
else:
|
else:
|
||||||
with open('./artefacts/scatter_' +
|
with open(
|
||||||
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w') as file:
|
"./artefacts/scatter_"
|
||||||
|
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
|
+ ".html",
|
||||||
|
"w",
|
||||||
|
) as file:
|
||||||
file.write(html)
|
file.write(html)
|
||||||
|
|||||||
Reference in New Issue
Block a user