server: refactor to reflector module

- replaced loguru to structlog, to get ability of having open tracing later
- moved configuration to pydantic-settings
- merged both secrets.ini and config.ini to .env (check reflector/settings.py)
This commit is contained in:
Mathieu Virbel
2023-07-27 15:29:41 +02:00
parent 094ed696c4
commit 69ba871481
24 changed files with 385 additions and 283 deletions

View File

@@ -0,0 +1,75 @@
import argparse
import asyncio
import signal
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
from reflector.logger import logger
from reflector.stream_client import StreamClient
from typing import NoReturn
async def main() -> NoReturn:
"""
Reflector's entry point to the python client for WebRTC streaming if not
using the browser based UI-application
:return:
"""
parser = argparse.ArgumentParser(description="Data channels ping/pong")
parser.add_argument(
"--url", type=str, nargs="?", default="http://0.0.0.0:1250/offer"
)
parser.add_argument(
"--ping-pong",
help="Benchmark data channel with ping pong",
type=eval,
choices=[True, False],
default="False",
)
parser.add_argument(
"--play-from",
type=str,
default="",
)
add_signaling_arguments(parser)
args = parser.parse_args()
signaling = create_signaling(args)
async def shutdown(signal, loop):
"""Cleanup tasks tied to the service's shutdown."""
logger.info(f"Received exit signal {signal.name}...")
logger.info("Closing database connections")
logger.info("Nacking outstanding messages")
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
[task.cancel() for task in tasks]
logger.info(f"Cancelling {len(tasks)} outstanding tasks")
await asyncio.gather(*tasks, return_exceptions=True)
logger.info(f'{"Flushing metrics"}')
loop.stop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
loop = asyncio.get_event_loop()
for s in signals:
loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
# Init client
sc = StreamClient(
signaling=signaling,
url=args.url,
play_from=args.play_from,
ping_pong=args.ping_pong,
)
await sc.start()
async for msg in sc.get_reader():
print(msg)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,3 @@
import structlog
logger = structlog.get_logger()

209
server/reflector/models.py Normal file
View File

@@ -0,0 +1,209 @@
"""
Collection of data classes for streamlining and rigidly structuring
the input and output parameters of functions
"""
import datetime
from dataclasses import dataclass
from typing import List
from sortedcontainers import SortedDict
import av
@dataclass
class TitleSummaryInput:
"""
Data class for the input to generate title and summaries.
The outcome will be used to send query to the LLM for processing.
"""
input_text = str
transcribed_time = float
prompt = str
data = dict
def __init__(self, transcribed_time, input_text=""):
self.input_text = input_text
self.transcribed_time = transcribed_time
self.prompt = f"""
### Human:
Create a JSON object as response.The JSON object must have 2 fields:
i) title and ii) summary.For the title field,generate a short title
for the given text. For the summary field, summarize the given text
in three sentences.
{self.input_text}
### Assistant:
"""
self.data = {"prompt": self.prompt}
self.headers = {"Content-Type": "application/json"}
@dataclass
class IncrementalResult:
"""
Data class for the result of generating one title and summaries.
Defines how a single "topic" looks like.
"""
title = str
description = str
transcript = str
timestamp = str
def __init__(self, title, desc, transcript, timestamp):
self.title = title
self.description = desc
self.transcript = transcript
self.timestamp = timestamp
@dataclass
class TitleSummaryOutput:
"""
Data class for the result of all generated titles and summaries.
The result will be sent back to the client
"""
cmd = str
topics = List[IncrementalResult]
def __init__(self, inc_responses):
self.topics = inc_responses
self.cmd = "UPDATE_TOPICS"
def get_result(self) -> dict:
"""
Return the result dict for displaying the transcription
:return:
"""
return {"cmd": self.cmd, "topics": self.topics}
@dataclass
class ParseLLMResult:
"""
Data class to parse the result returned by the LLM while generating title
and summaries. The result will be sent back to the client.
"""
title = str
description = str
transcript = str
timestamp = str
def __init__(self, param: TitleSummaryInput, output: dict):
self.title = output["title"]
self.transcript = param.input_text
self.description = output.pop("summary")
self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time)))
def get_result(self) -> dict:
"""
Return the result dict after parsing the response from LLM
:return:
"""
return {
"title": self.title,
"description": self.description,
"transcript": self.transcript,
"timestamp": self.timestamp,
}
@dataclass
class TranscriptionInput:
"""
Data class to define the input to the transcription function
AudioFrames -> input
"""
frames = List[av.audio.frame.AudioFrame]
def __init__(self, frames):
self.frames = frames
@dataclass
class TranscriptionOutput:
"""
Dataclass to define the result of the transcription function.
The result will be sent back to the client
"""
cmd = str
result_text = str
def __init__(self, result_text):
self.cmd = "SHOW_TRANSCRIPTION"
self.result_text = result_text
def get_result(self) -> dict:
"""
Return the result dict for displaying the transcription
:return:
"""
return {"cmd": self.cmd, "text": self.result_text}
@dataclass
class FinalSummaryResult:
"""
Dataclass to define the result of the final summary function.
The result will be sent back to the client.
"""
cmd = str
final_summary = str
duration = str
def __init__(self, final_summary, time):
self.duration = str(datetime.timedelta(seconds=round(time)))
self.final_summary = final_summary
self.cmd = "DISPLAY_FINAL_SUMMARY"
def get_result(self) -> dict:
"""
Return the result dict for displaying the final summary
:return:
"""
return {
"cmd": self.cmd,
"duration": self.duration,
"summary": self.final_summary,
}
class BlackListedMessages:
"""
Class to hold the blacklisted messages. These messages should be filtered
out and not sent back to the client as part of the transcription.
"""
messages = [
" Thank you.",
" See you next time!",
" Thank you for watching!",
" Bye!",
" And that's what I'm talking about.",
]
@dataclass
class TranscriptionContext:
transcription_text: str
last_transcribed_time: float
incremental_responses: List[IncrementalResult]
sorted_transcripts: dict
data_channel: None # FIXME
logger: None
def __init__(self, logger):
self.transcription_text = ""
self.last_transcribed_time = 0.0
self.incremental_responses = []
self.data_channel = None
self.sorted_transcripts = SortedDict()
self.logger = logger

373
server/reflector/server.py Normal file
View File

@@ -0,0 +1,373 @@
import argparse
import asyncio
import datetime
import json
import os
import wave
import uuid
from concurrent.futures import ThreadPoolExecutor
from typing import NoReturn, Union
import aiohttp_cors
import av
import requests
from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
from faster_whisper import WhisperModel
from reflector.models import (
BlackListedMessages,
FinalSummaryResult,
ParseLLMResult,
TitleSummaryInput,
TitleSummaryOutput,
TranscriptionInput,
TranscriptionOutput,
TranscriptionContext,
)
from reflector.logger import logger
from reflector.utils.run_utils import run_in_executor
from reflector.settings import settings
# WebRTC components
pcs = set()
relay = MediaRelay()
executor = ThreadPoolExecutor()
# Transcription model
model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12)
# LLM
LLM_URL = settings.LLM_URL
if not LLM_URL:
assert settings.LLM_BACKEND == "oobagooda"
LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate"
logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}")
def parse_llm_output(
param: TitleSummaryInput, response: requests.Response
) -> Union[None, ParseLLMResult]:
"""
Function to parse the LLM response
:param param:
:param response:
:return:
"""
try:
output = json.loads(response.json()["results"][0]["text"])
return ParseLLMResult(param, output)
except Exception:
logger.exception("Exception while parsing LLM output")
return None
def get_title_and_summary(
ctx: TranscriptionContext, param: TitleSummaryInput
) -> Union[None, TitleSummaryOutput]:
"""
From the input provided (transcript), query the LLM to generate
topics and summaries
:param param:
:return:
"""
logger.info("Generating title and summary")
# TODO : Handle unexpected output formats from the model
try:
response = requests.post(LLM_URL, headers=param.headers, json=param.data)
output = parse_llm_output(param, response)
if output:
result = output.get_result()
ctx.incremental_responses.append(result)
return TitleSummaryOutput(ctx.incremental_responses)
except Exception:
logger.exception("Exception while generating title and summary")
return None
def channel_send(channel, message: str) -> NoReturn:
"""
Send text messages via the data channel
:param channel:
:param message:
:return:
"""
if channel:
channel.send(message)
def channel_send_increment(
channel, param: Union[FinalSummaryResult, TitleSummaryOutput]
) -> NoReturn:
"""
Send the incremental topics and summaries via the data channel
:param channel:
:param param:
:return:
"""
if channel and param:
message = param.get_result()
channel.send(json.dumps(message))
def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn:
"""
Send the transcription result via the data channel
:param channel:
:return:
"""
if not ctx.data_channel:
return
try:
least_time = next(iter(ctx.sorted_transcripts))
message = ctx.sorted_transcripts[least_time].get_result()
if message:
del ctx.sorted_transcripts[least_time]
if message["text"] not in BlackListedMessages.messages:
ctx.data_channel.send(json.dumps(message))
# Due to exceptions if one of the earlier batches can't return
# a transcript, we don't want to be stuck waiting for the result
# With the threshold size of 3, we pop the first(lost) element
else:
if len(ctx.sorted_transcripts) >= 3:
del ctx.sorted_transcripts[least_time]
except Exception:
logger.exception("Exception while sending transcript")
def get_transcription(
ctx: TranscriptionContext, input_frames: TranscriptionInput
) -> Union[None, TranscriptionOutput]:
"""
From the collected audio frames create transcription by inferring from
the chosen transcription model
:param input_frames:
:return:
"""
ctx.logger.info("Transcribing..")
ctx.sorted_transcripts[input_frames.frames[0].time] = None
# TODO: Find cleaner way, watch "no transcription" issue below
# Passing IO objects instead of temporary files throws an error
# Passing ndarray (type casted with float) does not give any
# transcription. Refer issue,
# https://github.com/guillaumekln/faster-whisper/issues/369
audio_file = "test" + str(datetime.datetime.now())
wf = wave.open(audio_file, "wb")
wf.setnchannels(settings.AUDIO_CHANNELS)
wf.setframerate(settings.AUDIO_SAMPLING_RATE)
wf.setsampwidth(settings.AUDIO_SAMPLING_WIDTH)
for frame in input_frames.frames:
wf.writeframes(b"".join(frame.to_ndarray()))
wf.close()
result_text = ""
try:
segments, _ = model.transcribe(
audio_file,
language="en",
beam_size=5,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
os.remove(audio_file)
segments = list(segments)
result_text = ""
duration = 0.0
for segment in segments:
result_text += segment.text
start_time = segment.start
end_time = segment.end
if not segment.start:
start_time = 0.0
if not segment.end:
end_time = 5.5
duration += end_time - start_time
ctx.last_transcribed_time += duration
ctx.transcription_text += result_text
except Exception:
logger.exception("Exception while transcribing")
result = TranscriptionOutput(result_text)
ctx.sorted_transcripts[input_frames.frames[0].time] = result
return result
def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult:
"""
Collate the incremental summaries generated so far and return as the final
summary
:return:
"""
final_summary = ""
# Collate inc summaries
for topic in ctx.incremental_responses:
final_summary += topic["description"]
response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
with open(
"./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8"
) as file:
file.write(json.dumps(ctx.incremental_responses))
return response
class AudioStreamTrack(MediaStreamTrack):
"""
An audio stream track.
"""
kind = "audio"
def __init__(self, ctx: TranscriptionContext, track):
super().__init__()
self.ctx = ctx
self.track = track
self.audio_buffer = av.AudioFifo()
async def recv(self) -> av.audio.frame.AudioFrame:
ctx = self.ctx
frame = await self.track.recv()
self.audio_buffer.write(frame)
if local_frames := self.audio_buffer.read_many(
settings.AUDIO_BUFFER_SIZE, partial=False
):
whisper_result = run_in_executor(
get_transcription,
ctx,
TranscriptionInput(local_frames),
executor=executor,
)
whisper_result.add_done_callback(
lambda f: channel_send_transcript(ctx) if f.result() else None
)
if len(ctx.transcription_text) > 25:
llm_input_text = ctx.transcription_text
ctx.transcription_text = ""
param = TitleSummaryInput(
input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time
)
llm_result = run_in_executor(
get_title_and_summary, ctx, param, executor=executor
)
llm_result.add_done_callback(
lambda f: channel_send_increment(ctx.data_channel, llm_result.result())
if f.result()
else None
)
return frame
async def offer(request: requests.Request) -> web.Response:
"""
Establish the WebRTC connection with the client
:param request:
:return:
"""
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
# client identification
peername = request.transport.get_extra_info("peername")
if peername is not None:
clientid = f"{peername[0]}:{peername[1]}"
else:
clientid = uuid.uuid4()
# create a context for the whole rtc transaction
# add a customised logger to the context
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
# handle RTC peer connection
pc = RTCPeerConnection()
pcs.add(pc)
@pc.on("datachannel")
def on_datachannel(channel) -> NoReturn:
ctx.data_channel = channel
ctx.logger = ctx.logger.bind(channel=channel.label)
ctx.logger.info("Channel created by remote party")
@channel.on("message")
def on_message(message: str) -> NoReturn:
ctx.logger.info(f"Message: {message}")
if json.loads(message)["cmd"] == "STOP":
# Placeholder final summary
response = get_final_summary_response()
channel_send_increment(channel, response)
# To-do Add code to stop connection from server side here
# But have to handshake with client once
if isinstance(message, str) and message.startswith("ping"):
channel_send(channel, "pong" + message[4:])
@pc.on("connectionstatechange")
async def on_connectionstatechange() -> NoReturn:
ctx.logger.info(f"Connection state changed: {pc.connectionState}")
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
@pc.on("track")
def on_track(track) -> NoReturn:
ctx.logger.info(f"Track {track.kind} received")
pc.addTrack(AudioStreamTrack(ctx, relay.subscribe(track)))
await pc.setRemoteDescription(offer)
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
)
async def on_shutdown(application: web.Application) -> NoReturn:
"""
On shutdown, the coroutines that shutdown client connections are
executed
:param application:
:return:
"""
coroutines = [pc.close() for pc in pcs]
await asyncio.gather(*coroutines)
pcs.clear()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="WebRTC based server for Reflector")
parser.add_argument(
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
)
parser.add_argument(
"--port", type=int, default=1250, help="Server port (def: 1250)"
)
args = parser.parse_args()
app = web.Application()
cors = aiohttp_cors.setup(
app,
defaults={
"*": aiohttp_cors.ResourceOptions(
allow_credentials=True, expose_headers="*", allow_headers="*"
)
},
)
offer_resource = cors.add(app.router.add_resource("/offer"))
cors.add(offer_resource.add_route("POST", offer))
app.on_shutdown.append(on_shutdown)
web.run_app(app, access_log=None, host=args.host, port=args.port)

View File

@@ -0,0 +1,45 @@
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
OPENMP_KMP_DUPLICATE_LIB_OK: bool = False
# Whisper
WHISPER_MODEL_SIZE: str = "tiny"
whisper_real_time_model_size: str = "tiny"
# Summarizer
SUMMARIZER_MODEL: str = "facebook/bart-large-cnn"
SUMMARIZER_INPUT_ENCODING_MAX_LENGTH: int = 1024
SUMMARIZER_MAX_LENGTH: int = 2048
SUMMARIZER_BEAM_SIZE: int = 6
SUMMARIZER_MAX_CHUNK_LENGTH: int = 1024
SUMMARIZER_USING_CHUNKS: bool = True
# Audio
AUDIO_BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME: str = "aggregator"
AUDIO_AV_FOUNDATION_DEVICE_ID: int = 1
AUDIO_CHANNELS: int = 2
AUDIO_SAMPLING_RATE: int = 48000
AUDIO_SAMPLING_WIDTH: int = 2
AUDIO_BUFFER_SIZE: int = 256 * 960
# LLM
LLM_BACKEND: str = "oobagooda"
LLM_URL: str | None = None
LLM_HOST: str = "localhost"
LLM_PORT: int = 7860
# Storage
STORAGE_BACKEND: str = "aws"
STORAGE_AWS_ACCESS_KEY: str = ""
STORAGE_AWS_SECRET_KEY: str = ""
STORAGE_AWS_BUCKET: str = ""
# OpenAI
OPENAI_API_KEY: str = ""
settings = Settings()

View File

@@ -0,0 +1,145 @@
import asyncio
import time
import uuid
import httpx
import pyaudio
import requests
import stamina
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaPlayer, MediaRelay
from reflector.logger import logger
from reflector.settings import settings
class StreamClient:
def __init__(
self, signaling, url="http://0.0.0.0:1250", play_from=None, ping_pong=False
):
self.signaling = signaling
self.server_url = url
self.play_from = play_from
self.ping_pong = ping_pong
self.paudio = pyaudio.PyAudio()
self.pc = RTCPeerConnection()
self.loop = asyncio.get_event_loop()
self.relay = None
self.pcs = set()
self.time_start = None
self.queue = asyncio.Queue()
self.player = MediaPlayer(
f":{settings.AUDIO_AV_FOUNDATION_DEVICE_ID}",
format="avfoundation",
options={"channels": "2"},
)
def stop(self):
self.loop.run_until_complete(self.signaling.close())
self.loop.run_until_complete(self.pc.close())
# self.loop.close()
def create_local_tracks(self, play_from):
if play_from:
player = MediaPlayer(play_from)
return player.audio, player.video
else:
if self.relay is None:
self.relay = MediaRelay()
return self.relay.subscribe(self.player.audio), None
def channel_log(self, channel, t, message):
print("channel(%s) %s %s" % (channel.label, t, message))
def channel_send(self, channel, message):
# self.channel_log(channel, ">", message)
channel.send(message)
def current_stamp(self):
if self.time_start is None:
self.time_start = time.time()
return 0
else:
return int((time.time() - self.time_start) * 1000000)
async def run_offer(self, pc, signaling):
# microphone
audio, video = self.create_local_tracks(self.play_from)
pc_id = "PeerConnection(%s)" % uuid.uuid4()
self.pcs.add(pc)
def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args)
@pc.on("connectionstatechange")
async def on_connectionstatechange():
print("Connection state is %s" % pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
self.pcs.discard(pc)
@pc.on("track")
def on_track(track):
print("Sending %s" % track.kind)
self.pc.addTrack(track)
@track.on("ended")
async def on_ended():
log_info("Track %s ended", track.kind)
self.pc.addTrack(audio)
channel = pc.createDataChannel("data-channel")
self.channel_log(channel, "-", "created by local party")
async def send_pings():
while True:
self.channel_send(channel, "ping %d" % self.current_stamp())
await asyncio.sleep(1)
@channel.on("open")
def on_open():
if self.ping_pong:
asyncio.ensure_future(send_pings())
@channel.on("message")
def on_message(message):
self.queue.put_nowait(message)
if self.ping_pong:
self.channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("pong"):
elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000
print(" RTT %.2f ms" % elapsed_ms)
await pc.setLocalDescription(await pc.createOffer())
sdp = {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
@stamina.retry(on=httpx.HTTPError, attempts=5)
def connect_to_server():
response = requests.post(self.server_url, json=sdp, timeout=10)
response.raise_for_status()
return response
params = connect_to_server().json()
answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
await pc.setRemoteDescription(answer)
self.reader = self.worker(f'{"worker"}', self.queue)
def get_reader(self):
return self.reader
async def worker(self, name, queue):
while True:
msg = await self.queue.get()
yield msg
self.queue.task_done()
async def start(self):
coro = self.run_offer(self.pc, self.signaling)
task = asyncio.create_task(coro)
await task

View File

View File

@@ -0,0 +1,59 @@
"""
Utility file for file handling related functions, including file downloads and
uploads to cloud storage
"""
import sys
from typing import List, NoReturn
import boto3
import botocore
from .log_utils import LOGGER
from .run_utils import SECRETS
BUCKET_NAME = SECRETS["AWS-S3"]["BUCKET_NAME"]
s3 = boto3.client(
"s3",
aws_access_key_id=SECRETS["AWS-S3"]["AWS_ACCESS_KEY"],
aws_secret_access_key=SECRETS["AWS-S3"]["AWS_SECRET_KEY"],
)
def upload_files(files_to_upload: List[str]) -> NoReturn:
"""
Upload a list of files to the configured S3 bucket
:param files_to_upload: List of files to upload
:return: None
"""
for key in files_to_upload:
LOGGER.info("Uploading file " + key)
try:
s3.upload_file(key, BUCKET_NAME, key)
except botocore.exceptions.ClientError as exception:
print(exception.response)
def download_files(files_to_download: List[str]) -> NoReturn:
"""
Download a list of files from the configured S3 bucket
:param files_to_download: List of files to download
:return: None
"""
for key in files_to_download:
LOGGER.info("Downloading file " + key)
try:
s3.download_file(BUCKET_NAME, key, key)
except botocore.exceptions.ClientError as exception:
if exception.response["Error"]["Code"] == "404":
print("The object does not exist.")
else:
raise
if __name__ == "__main__":
if sys.argv[1] == "download":
download_files([sys.argv[2]])
elif sys.argv[1] == "upload":
upload_files([sys.argv[2]])

View File

@@ -0,0 +1,38 @@
"""
Utility function to format the artefacts created during Reflector run
"""
import json
with open("../artefacts/meeting_titles_and_summaries.txt", "r", encoding="utf-8") as f:
outputs = f.read()
outputs = json.loads(outputs)
transcript_file = open("../artefacts/meeting_transcript.txt", "a", encoding="utf-8")
title_desc_file = open(
"../artefacts/meeting_title_description.txt", "a", encoding="utf-8"
)
summary_file = open("../artefacts/meeting_summary.txt", "a", encoding="utf-8")
for item in outputs["topics"]:
transcript_file.write(item["transcript"])
summary_file.write(item["description"])
title_desc_file.write("TITLE: \n")
title_desc_file.write(item["title"])
title_desc_file.write("\n")
title_desc_file.write("DESCRIPTION: \n")
title_desc_file.write(item["description"])
title_desc_file.write("\n")
title_desc_file.write("TRANSCRIPT: \n")
title_desc_file.write(item["transcript"])
title_desc_file.write("\n")
title_desc_file.write("---------------------------------------- \n\n")
transcript_file.close()
title_desc_file.close()
summary_file.close()

View File

@@ -0,0 +1,55 @@
"""
Utility file for server side asynchronous task running and config objects
"""
import asyncio
import contextlib
from functools import partial
from threading import Lock
from typing import ContextManager, Generic, TypeVar
def run_in_executor(func, *args, executor=None, **kwargs):
"""
Run the function in an executor, unblocking the main loop
:param func: Function to be run in executor
:param args: function parameters
:param executor: executor instance [Thread | Process]
:param kwargs: Additional parameters
:return: Future of function result upon completion
"""
callback = partial(func, *args, **kwargs)
loop = asyncio.get_event_loop()
return loop.run_in_executor(executor, callback)
# Genetic type template
T = TypeVar("T")
class Mutex(Generic[T]):
"""
Mutex class to implement lock/release of a shared
protected variable
"""
def __init__(self, value: T):
"""
Create an instance of Mutex wrapper for the given resource
:param value: Shared resources to be thread protected
"""
self.__value = value
self.__lock = Lock()
@contextlib.contextmanager
def lock(self) -> ContextManager[T]:
"""
Lock the resource with a mutex to be used within a context block
The lock is automatically released on context exit
:return: Shared resource
"""
self.__lock.acquire()
try:
yield self.__value
finally:
self.__lock.release()

View File

@@ -0,0 +1,264 @@
"""
Utility file for all text processing related functionalities
"""
import datetime
from typing import List
import nltk
import torch
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import BartForConditionalGeneration, BartTokenizer
from log_utils import LOGGER
from run_utils import CONFIG
nltk.download("punkt", quiet=True)
def preprocess_sentence(sentence: str) -> str:
"""
Filter out undesirable tokens from thr sentence
:param sentence:
:return:
"""
stop_words = set(stopwords.words("english"))
tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
return " ".join(tokens)
def compute_similarity(sent1: str, sent2: str) -> float:
"""
Compute the similarity
"""
tfidf_vectorizer = TfidfVectorizer()
if sent1 is not None and sent2 is not None:
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
return 0.0
def remove_almost_alike_sentences(sentences: List[str], threshold=0.7) -> List[str]:
"""
Filter sentences that are similar beyond a set threshold
:param sentences:
:param threshold:
:return:
"""
num_sentences = len(sentences)
removed_indices = set()
for i in range(num_sentences):
if i not in removed_indices:
for j in range(i + 1, num_sentences):
if j not in removed_indices:
l_i = len(sentences[i])
l_j = len(sentences[j])
if l_i == 0 or l_j == 0:
if l_i == 0:
removed_indices.add(i)
if l_j == 0:
removed_indices.add(j)
else:
sentence1 = preprocess_sentence(sentences[i])
sentence2 = preprocess_sentence(sentences[j])
if len(sentence1) != 0 and len(sentence2) != 0:
similarity = compute_similarity(sentence1, sentence2)
if similarity >= threshold:
removed_indices.add(max(i, j))
filtered_sentences = [
sentences[i] for i in range(num_sentences) if i not in removed_indices
]
return filtered_sentences
def remove_outright_duplicate_sentences_from_chunk(chunk: str) -> List[str]:
"""
Remove repetitive sentences
:param chunk:
:return:
"""
chunk_text = chunk["text"]
sentences = nltk.sent_tokenize(chunk_text)
nonduplicate_sentences = list(dict.fromkeys(sentences))
return nonduplicate_sentences
def remove_whisper_repetitive_hallucination(
nonduplicate_sentences: List[str],
) -> List[str]:
"""
Remove sentences that are repeated as a result of Whisper
hallucinations
:param nonduplicate_sentences:
:return:
"""
chunk_sentences = []
for sent in nonduplicate_sentences:
temp_result = ""
seen = {}
words = nltk.word_tokenize(sent)
n_gram_filter = 3
for i in range(len(words)):
if (
str(words[i : i + n_gram_filter]) in seen
and seen[str(words[i : i + n_gram_filter])]
== words[i + 1 : i + n_gram_filter + 2]
):
pass
else:
seen[str(words[i : i + n_gram_filter])] = words[
i + 1 : i + n_gram_filter + 2
]
temp_result += words[i]
temp_result += " "
chunk_sentences.append(temp_result)
return chunk_sentences
def post_process_transcription(whisper_result: dict) -> dict:
"""
Parent function to perform post-processing on the transcription result
:param whisper_result:
:return:
"""
transcript_text = ""
for chunk in whisper_result["chunks"]:
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
chunk_sentences = remove_whisper_repetitive_hallucination(
nonduplicate_sentences
)
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
chunk["text"] = " ".join(similarity_matched_sentences)
transcript_text += chunk["text"]
whisper_result["text"] = transcript_text
return whisper_result
def summarize_chunks(chunks: List[str], tokenizer, model) -> List[str]:
"""
Summarize each chunk using a summarizer model
:param chunks:
:param tokenizer:
:param model:
:return:
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summaries = []
for c in chunks:
input_ids = tokenizer.encode(c, return_tensors="pt")
input_ids = input_ids.to(device)
with torch.no_grad():
summary_ids = model.generate(
input_ids,
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
length_penalty=2.0,
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
early_stopping=True,
)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(summary)
return summaries
def chunk_text(
text: str, max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])
) -> List[str]:
"""
Split text into smaller chunks.
:param text: Text to be chunked
:param max_chunk_length: length of chunk
:return: chunked texts
"""
sentences = nltk.sent_tokenize(text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) < max_chunk_length:
current_chunk += f" {sentence.strip()}"
else:
chunks.append(current_chunk.strip())
current_chunk = f"{sentence.strip()}"
chunks.append(current_chunk.strip())
return chunks
def summarize(
transcript_text: str,
timestamp: datetime.datetime.timestamp,
real_time: bool = False,
chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"],
):
"""
Summarize the given text either as a whole or as chunks as needed
:param transcript_text:
:param timestamp:
:param real_time:
:param chunk_summarize:
:return:
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = CONFIG["SUMMARIZER"]["SUMMARY_MODEL"]
if not summary_model:
summary_model = "facebook/bart-large-cnn"
# Summarize the generated transcript using the BART model
LOGGER.info(f"Loading BART model: {summary_model}")
tokenizer = BartTokenizer.from_pretrained(summary_model)
model = BartForConditionalGeneration.from_pretrained(summary_model)
model = model.to(device)
output_file = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
if real_time:
output_file = "real_time_" + output_file
if chunk_summarize != "YES":
max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer.batch_encode_plus(
[transcript_text],
truncation=True,
padding="longest",
max_length=max_length,
return_tensors="pt",
)
inputs = inputs.to(device)
with torch.no_grad():
num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
summaries = model.generate(
inputs["input_ids"],
num_beams=num_beans,
length_penalty=2.0,
max_length=max_length,
early_stopping=True,
)
decoded_summaries = [
tokenizer.decode(
summary, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
for summary in summaries
]
summary = " ".join(decoded_summaries)
with open("./artefacts/" + output_file, "w", encoding="utf-8") as file:
file.write(summary.strip() + "\n")
else:
LOGGER.info("Breaking transcript into smaller chunks")
chunks = chunk_text(transcript_text)
LOGGER.info(
f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words"
)
LOGGER.info(f"Writing summary text to: {output_file}")
with open(output_file, "w") as f:
summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries:
f.write(summary.strip() + " ")

View File

@@ -0,0 +1,283 @@
"""
Utility file for all visualization related functions
"""
import ast
import collections
import datetime
import os
import pickle
from typing import NoReturn
import matplotlib.pyplot as plt
import pandas as pd
import scattertext as st
import spacy
from nltk.corpus import stopwords
from wordcloud import STOPWORDS, WordCloud
en = spacy.load("en_core_web_md")
spacy_stopwords = en.Defaults.stop_words
STOPWORDS = (
set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
)
def create_wordcloud(
timestamp: datetime.datetime.timestamp, real_time: bool = False
) -> NoReturn:
"""
Create a basic word cloud visualization of transcribed text
:return: None. The wordcloud image is saved locally
"""
filename = "transcript"
if real_time:
filename = (
"real_time_"
+ filename
+ "_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".txt"
)
else:
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
with open("./artefacts/" + filename, "r") as f:
transcription_text = f.read()
# python_mask = np.array(PIL.Image.open("download1.png"))
wordcloud = WordCloud(
height=800,
width=800,
background_color="white",
stopwords=STOPWORDS,
min_font_size=8,
).generate(transcription_text)
# Plot wordcloud and save image
plt.figure(facecolor=None)
plt.imshow(wordcloud, interpolation="bilinear")
plt.axis("off")
plt.tight_layout(pad=0)
wordcloud = "wordcloud"
if real_time:
wordcloud = (
"real_time_"
+ wordcloud
+ "_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".png"
)
else:
wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
plt.savefig("./artefacts/" + wordcloud)
def create_talk_diff_scatter_viz(
timestamp: datetime.datetime.timestamp, real_time: bool = False
) -> NoReturn:
"""
Perform agenda vs transcription diff to see covered topics.
Create a scatter plot of words in topics.
:return: None. Saved locally.
"""
spacy_model = "en_core_web_md"
nlp = spacy.load(spacy_model)
nlp.add_pipe("sentencizer")
agenda_topics = []
agenda = []
# Load the agenda
with open(os.path.join(os.getcwd(), "agenda-headers.txt"), "r") as f:
for line in f.readlines():
if line.strip():
agenda.append(line.strip())
agenda_topics.append(line.split(":")[0])
# Load the transcription with timestamp
if real_time:
filename = (
"./artefacts/real_time_transcript_with_timestamp_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".txt"
)
else:
filename = (
"./artefacts/transcript_with_timestamp_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".txt"
)
with open(filename) as file:
transcription_timestamp_text = file.read()
res = ast.literal_eval(transcription_timestamp_text)
chunks = res["chunks"]
# create df for processing
df = pd.DataFrame.from_dict(res["chunks"])
covered_items = {}
# ts: timestamp
# Map each timestamped chunk with top1 and top2 matched agenda
ts_to_topic_mapping_top_1 = {}
ts_to_topic_mapping_top_2 = {}
# Also create a mapping of the different timestamps
# in which each topic was covered
topic_to_ts_mapping_top_1 = collections.defaultdict(list)
topic_to_ts_mapping_top_2 = collections.defaultdict(list)
similarity_threshold = 0.7
for c in chunks:
doc_transcription = nlp(c["text"])
topic_similarities = []
for item in range(len(agenda)):
item_doc = nlp(agenda[item])
# if not doc_transcription or not all
# (token.has_vector for token in doc_transcription):
if not doc_transcription:
continue
similarity = doc_transcription.similarity(item_doc)
topic_similarities.append((item, similarity))
topic_similarities.sort(key=lambda x: x[1], reverse=True)
for i in range(2):
if topic_similarities[i][1] >= similarity_threshold:
covered_items[agenda[topic_similarities[i][0]]] = True
# top1 match
if i == 0:
ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[
topic_similarities[i][0]
]
topic_to_ts_mapping_top_1[
agenda_topics[topic_similarities[i][0]]
].append(c["timestamp"])
# top2 match
else:
ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[
topic_similarities[i][0]
]
topic_to_ts_mapping_top_2[
agenda_topics[topic_similarities[i][0]]
].append(c["timestamp"])
def create_new_columns(record: dict) -> dict:
"""
Accumulate the mapping information into the df
:param record:
:return:
"""
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[
record["timestamp"]
]
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[
record["timestamp"]
]
return record
df = df.apply(create_new_columns, axis=1)
# Count the number of items covered and calculate the percentage
num_covered_items = sum(covered_items.values())
percentage_covered = num_covered_items / len(agenda) * 100
# Print the results
print("💬 Agenda items covered in the transcription:")
for item in agenda:
if item in covered_items and covered_items[item]:
print("", item)
else:
print("", item)
print("📊 Coverage: {:.2f}%".format(percentage_covered))
# Save df, mappings for further experimentation
df_name = "df"
if real_time:
df_name = (
"real_time_"
+ df_name
+ "_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".pkl"
)
else:
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
df.to_pickle("./artefacts/" + df_name)
my_mappings = [
ts_to_topic_mapping_top_1,
ts_to_topic_mapping_top_2,
topic_to_ts_mapping_top_1,
topic_to_ts_mapping_top_2,
]
mappings_name = "mappings"
if real_time:
mappings_name = (
"real_time_"
+ mappings_name
+ "_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".pkl"
)
else:
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb"))
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
# pick the 2 most matched topic to be used for plotting
topic_times = collections.defaultdict(int)
for key in ts_to_topic_mapping_top_1.keys():
if key[0] is None or key[1] is None:
continue
duration = key[1] - key[0]
topic_times[ts_to_topic_mapping_top_1[key]] += duration
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
if len(topic_times) > 1:
cat_1 = topic_times[0][0]
cat_1_name = topic_times[0][0]
cat_2_name = topic_times[1][0]
# Scatter plot of topics
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
corpus = (
st.CorpusFromParsedDocuments(
df, category_col="ts_to_topic_mapping_top_1", parsed_col="parse"
)
.build()
.get_unigram_corpus()
.compact(st.AssociationCompactor(2000))
)
html = st.produce_scattertext_explorer(
corpus,
category=cat_1,
category_name=cat_1_name,
not_category_name=cat_2_name,
minimum_term_frequency=0,
pmi_threshold_coefficient=0,
width_in_pixels=1000,
transform=st.Scalers.dense_rank,
)
if real_time:
with open(
"./artefacts/real_time_scatter_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".html",
"w",
) as file:
file.write(html)
else:
with open(
"./artefacts/scatter_"
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
+ ".html",
"w",
) as file:
file.write(html)