mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 12:49:06 +00:00
server: refactor to reflector module
- replaced loguru to structlog, to get ability of having open tracing later - moved configuration to pydantic-settings - merged both secrets.ini and config.ini to .env (check reflector/settings.py)
This commit is contained in:
75
server/reflector/client.py
Normal file
75
server/reflector/client.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import signal
|
||||
|
||||
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.stream_client import StreamClient
|
||||
from typing import NoReturn
|
||||
|
||||
|
||||
async def main() -> NoReturn:
|
||||
"""
|
||||
Reflector's entry point to the python client for WebRTC streaming if not
|
||||
using the browser based UI-application
|
||||
:return:
|
||||
"""
|
||||
parser = argparse.ArgumentParser(description="Data channels ping/pong")
|
||||
|
||||
parser.add_argument(
|
||||
"--url", type=str, nargs="?", default="http://0.0.0.0:1250/offer"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ping-pong",
|
||||
help="Benchmark data channel with ping pong",
|
||||
type=eval,
|
||||
choices=[True, False],
|
||||
default="False",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--play-from",
|
||||
type=str,
|
||||
default="",
|
||||
)
|
||||
add_signaling_arguments(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
signaling = create_signaling(args)
|
||||
|
||||
async def shutdown(signal, loop):
|
||||
"""Cleanup tasks tied to the service's shutdown."""
|
||||
logger.info(f"Received exit signal {signal.name}...")
|
||||
logger.info("Closing database connections")
|
||||
logger.info("Nacking outstanding messages")
|
||||
tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
|
||||
|
||||
[task.cancel() for task in tasks]
|
||||
|
||||
logger.info(f"Cancelling {len(tasks)} outstanding tasks")
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
logger.info(f'{"Flushing metrics"}')
|
||||
loop.stop()
|
||||
|
||||
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
||||
loop = asyncio.get_event_loop()
|
||||
for s in signals:
|
||||
loop.add_signal_handler(s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
|
||||
|
||||
# Init client
|
||||
sc = StreamClient(
|
||||
signaling=signaling,
|
||||
url=args.url,
|
||||
play_from=args.play_from,
|
||||
ping_pong=args.ping_pong,
|
||||
)
|
||||
await sc.start()
|
||||
async for msg in sc.get_reader():
|
||||
print(msg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
3
server/reflector/logger.py
Normal file
3
server/reflector/logger.py
Normal file
@@ -0,0 +1,3 @@
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
209
server/reflector/models.py
Normal file
209
server/reflector/models.py
Normal file
@@ -0,0 +1,209 @@
|
||||
"""
|
||||
Collection of data classes for streamlining and rigidly structuring
|
||||
the input and output parameters of functions
|
||||
"""
|
||||
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from sortedcontainers import SortedDict
|
||||
|
||||
import av
|
||||
|
||||
|
||||
@dataclass
|
||||
class TitleSummaryInput:
|
||||
"""
|
||||
Data class for the input to generate title and summaries.
|
||||
The outcome will be used to send query to the LLM for processing.
|
||||
"""
|
||||
|
||||
input_text = str
|
||||
transcribed_time = float
|
||||
prompt = str
|
||||
data = dict
|
||||
|
||||
def __init__(self, transcribed_time, input_text=""):
|
||||
self.input_text = input_text
|
||||
self.transcribed_time = transcribed_time
|
||||
self.prompt = f"""
|
||||
### Human:
|
||||
Create a JSON object as response.The JSON object must have 2 fields:
|
||||
i) title and ii) summary.For the title field,generate a short title
|
||||
for the given text. For the summary field, summarize the given text
|
||||
in three sentences.
|
||||
|
||||
{self.input_text}
|
||||
|
||||
### Assistant:
|
||||
"""
|
||||
self.data = {"prompt": self.prompt}
|
||||
self.headers = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
@dataclass
|
||||
class IncrementalResult:
|
||||
"""
|
||||
Data class for the result of generating one title and summaries.
|
||||
Defines how a single "topic" looks like.
|
||||
"""
|
||||
|
||||
title = str
|
||||
description = str
|
||||
transcript = str
|
||||
timestamp = str
|
||||
|
||||
def __init__(self, title, desc, transcript, timestamp):
|
||||
self.title = title
|
||||
self.description = desc
|
||||
self.transcript = transcript
|
||||
self.timestamp = timestamp
|
||||
|
||||
|
||||
@dataclass
|
||||
class TitleSummaryOutput:
|
||||
"""
|
||||
Data class for the result of all generated titles and summaries.
|
||||
The result will be sent back to the client
|
||||
"""
|
||||
|
||||
cmd = str
|
||||
topics = List[IncrementalResult]
|
||||
|
||||
def __init__(self, inc_responses):
|
||||
self.topics = inc_responses
|
||||
self.cmd = "UPDATE_TOPICS"
|
||||
|
||||
def get_result(self) -> dict:
|
||||
"""
|
||||
Return the result dict for displaying the transcription
|
||||
:return:
|
||||
"""
|
||||
return {"cmd": self.cmd, "topics": self.topics}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParseLLMResult:
|
||||
"""
|
||||
Data class to parse the result returned by the LLM while generating title
|
||||
and summaries. The result will be sent back to the client.
|
||||
"""
|
||||
|
||||
title = str
|
||||
description = str
|
||||
transcript = str
|
||||
timestamp = str
|
||||
|
||||
def __init__(self, param: TitleSummaryInput, output: dict):
|
||||
self.title = output["title"]
|
||||
self.transcript = param.input_text
|
||||
self.description = output.pop("summary")
|
||||
self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time)))
|
||||
|
||||
def get_result(self) -> dict:
|
||||
"""
|
||||
Return the result dict after parsing the response from LLM
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"title": self.title,
|
||||
"description": self.description,
|
||||
"transcript": self.transcript,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionInput:
|
||||
"""
|
||||
Data class to define the input to the transcription function
|
||||
AudioFrames -> input
|
||||
"""
|
||||
|
||||
frames = List[av.audio.frame.AudioFrame]
|
||||
|
||||
def __init__(self, frames):
|
||||
self.frames = frames
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionOutput:
|
||||
"""
|
||||
Dataclass to define the result of the transcription function.
|
||||
The result will be sent back to the client
|
||||
"""
|
||||
|
||||
cmd = str
|
||||
result_text = str
|
||||
|
||||
def __init__(self, result_text):
|
||||
self.cmd = "SHOW_TRANSCRIPTION"
|
||||
self.result_text = result_text
|
||||
|
||||
def get_result(self) -> dict:
|
||||
"""
|
||||
Return the result dict for displaying the transcription
|
||||
:return:
|
||||
"""
|
||||
return {"cmd": self.cmd, "text": self.result_text}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FinalSummaryResult:
|
||||
"""
|
||||
Dataclass to define the result of the final summary function.
|
||||
The result will be sent back to the client.
|
||||
"""
|
||||
|
||||
cmd = str
|
||||
final_summary = str
|
||||
duration = str
|
||||
|
||||
def __init__(self, final_summary, time):
|
||||
self.duration = str(datetime.timedelta(seconds=round(time)))
|
||||
self.final_summary = final_summary
|
||||
self.cmd = "DISPLAY_FINAL_SUMMARY"
|
||||
|
||||
def get_result(self) -> dict:
|
||||
"""
|
||||
Return the result dict for displaying the final summary
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
"cmd": self.cmd,
|
||||
"duration": self.duration,
|
||||
"summary": self.final_summary,
|
||||
}
|
||||
|
||||
|
||||
class BlackListedMessages:
|
||||
"""
|
||||
Class to hold the blacklisted messages. These messages should be filtered
|
||||
out and not sent back to the client as part of the transcription.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
" Thank you.",
|
||||
" See you next time!",
|
||||
" Thank you for watching!",
|
||||
" Bye!",
|
||||
" And that's what I'm talking about.",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionContext:
|
||||
transcription_text: str
|
||||
last_transcribed_time: float
|
||||
incremental_responses: List[IncrementalResult]
|
||||
sorted_transcripts: dict
|
||||
data_channel: None # FIXME
|
||||
logger: None
|
||||
|
||||
def __init__(self, logger):
|
||||
self.transcription_text = ""
|
||||
self.last_transcribed_time = 0.0
|
||||
self.incremental_responses = []
|
||||
self.data_channel = None
|
||||
self.sorted_transcripts = SortedDict()
|
||||
self.logger = logger
|
||||
373
server/reflector/server.py
Normal file
373
server/reflector/server.py
Normal file
@@ -0,0 +1,373 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import wave
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import NoReturn, Union
|
||||
|
||||
import aiohttp_cors
|
||||
import av
|
||||
import requests
|
||||
from aiohttp import web
|
||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||
from aiortc.contrib.media import MediaRelay
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
from reflector.models import (
|
||||
BlackListedMessages,
|
||||
FinalSummaryResult,
|
||||
ParseLLMResult,
|
||||
TitleSummaryInput,
|
||||
TitleSummaryOutput,
|
||||
TranscriptionInput,
|
||||
TranscriptionOutput,
|
||||
TranscriptionContext,
|
||||
)
|
||||
from reflector.logger import logger
|
||||
from reflector.utils.run_utils import run_in_executor
|
||||
from reflector.settings import settings
|
||||
|
||||
# WebRTC components
|
||||
pcs = set()
|
||||
relay = MediaRelay()
|
||||
executor = ThreadPoolExecutor()
|
||||
|
||||
# Transcription model
|
||||
model = WhisperModel("tiny", device="cpu", compute_type="float32", num_workers=12)
|
||||
|
||||
# LLM
|
||||
LLM_URL = settings.LLM_URL
|
||||
if not LLM_URL:
|
||||
assert settings.LLM_BACKEND == "oobagooda"
|
||||
LLM_URL = f"http://{settings.LLM_HOST}:{settings.LLM_PORT}/api/v1/generate"
|
||||
logger.info(f"Using LLM [{settings.LLM_BACKEND}]: {LLM_URL}")
|
||||
|
||||
|
||||
def parse_llm_output(
|
||||
param: TitleSummaryInput, response: requests.Response
|
||||
) -> Union[None, ParseLLMResult]:
|
||||
"""
|
||||
Function to parse the LLM response
|
||||
:param param:
|
||||
:param response:
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
output = json.loads(response.json()["results"][0]["text"])
|
||||
return ParseLLMResult(param, output)
|
||||
except Exception:
|
||||
logger.exception("Exception while parsing LLM output")
|
||||
return None
|
||||
|
||||
|
||||
def get_title_and_summary(
|
||||
ctx: TranscriptionContext, param: TitleSummaryInput
|
||||
) -> Union[None, TitleSummaryOutput]:
|
||||
"""
|
||||
From the input provided (transcript), query the LLM to generate
|
||||
topics and summaries
|
||||
:param param:
|
||||
:return:
|
||||
"""
|
||||
logger.info("Generating title and summary")
|
||||
|
||||
# TODO : Handle unexpected output formats from the model
|
||||
try:
|
||||
response = requests.post(LLM_URL, headers=param.headers, json=param.data)
|
||||
output = parse_llm_output(param, response)
|
||||
if output:
|
||||
result = output.get_result()
|
||||
ctx.incremental_responses.append(result)
|
||||
return TitleSummaryOutput(ctx.incremental_responses)
|
||||
except Exception:
|
||||
logger.exception("Exception while generating title and summary")
|
||||
return None
|
||||
|
||||
|
||||
def channel_send(channel, message: str) -> NoReturn:
|
||||
"""
|
||||
Send text messages via the data channel
|
||||
:param channel:
|
||||
:param message:
|
||||
:return:
|
||||
"""
|
||||
if channel:
|
||||
channel.send(message)
|
||||
|
||||
|
||||
def channel_send_increment(
|
||||
channel, param: Union[FinalSummaryResult, TitleSummaryOutput]
|
||||
) -> NoReturn:
|
||||
"""
|
||||
Send the incremental topics and summaries via the data channel
|
||||
:param channel:
|
||||
:param param:
|
||||
:return:
|
||||
"""
|
||||
if channel and param:
|
||||
message = param.get_result()
|
||||
channel.send(json.dumps(message))
|
||||
|
||||
|
||||
def channel_send_transcript(ctx: TranscriptionContext) -> NoReturn:
|
||||
"""
|
||||
Send the transcription result via the data channel
|
||||
:param channel:
|
||||
:return:
|
||||
"""
|
||||
if not ctx.data_channel:
|
||||
return
|
||||
try:
|
||||
least_time = next(iter(ctx.sorted_transcripts))
|
||||
message = ctx.sorted_transcripts[least_time].get_result()
|
||||
if message:
|
||||
del ctx.sorted_transcripts[least_time]
|
||||
if message["text"] not in BlackListedMessages.messages:
|
||||
ctx.data_channel.send(json.dumps(message))
|
||||
# Due to exceptions if one of the earlier batches can't return
|
||||
# a transcript, we don't want to be stuck waiting for the result
|
||||
# With the threshold size of 3, we pop the first(lost) element
|
||||
else:
|
||||
if len(ctx.sorted_transcripts) >= 3:
|
||||
del ctx.sorted_transcripts[least_time]
|
||||
except Exception:
|
||||
logger.exception("Exception while sending transcript")
|
||||
|
||||
|
||||
def get_transcription(
|
||||
ctx: TranscriptionContext, input_frames: TranscriptionInput
|
||||
) -> Union[None, TranscriptionOutput]:
|
||||
"""
|
||||
From the collected audio frames create transcription by inferring from
|
||||
the chosen transcription model
|
||||
:param input_frames:
|
||||
:return:
|
||||
"""
|
||||
ctx.logger.info("Transcribing..")
|
||||
ctx.sorted_transcripts[input_frames.frames[0].time] = None
|
||||
|
||||
# TODO: Find cleaner way, watch "no transcription" issue below
|
||||
# Passing IO objects instead of temporary files throws an error
|
||||
# Passing ndarray (type casted with float) does not give any
|
||||
# transcription. Refer issue,
|
||||
# https://github.com/guillaumekln/faster-whisper/issues/369
|
||||
audio_file = "test" + str(datetime.datetime.now())
|
||||
wf = wave.open(audio_file, "wb")
|
||||
wf.setnchannels(settings.AUDIO_CHANNELS)
|
||||
wf.setframerate(settings.AUDIO_SAMPLING_RATE)
|
||||
wf.setsampwidth(settings.AUDIO_SAMPLING_WIDTH)
|
||||
|
||||
for frame in input_frames.frames:
|
||||
wf.writeframes(b"".join(frame.to_ndarray()))
|
||||
wf.close()
|
||||
|
||||
result_text = ""
|
||||
|
||||
try:
|
||||
segments, _ = model.transcribe(
|
||||
audio_file,
|
||||
language="en",
|
||||
beam_size=5,
|
||||
vad_filter=True,
|
||||
vad_parameters={"min_silence_duration_ms": 500},
|
||||
)
|
||||
os.remove(audio_file)
|
||||
segments = list(segments)
|
||||
result_text = ""
|
||||
duration = 0.0
|
||||
for segment in segments:
|
||||
result_text += segment.text
|
||||
start_time = segment.start
|
||||
end_time = segment.end
|
||||
if not segment.start:
|
||||
start_time = 0.0
|
||||
if not segment.end:
|
||||
end_time = 5.5
|
||||
duration += end_time - start_time
|
||||
|
||||
ctx.last_transcribed_time += duration
|
||||
ctx.transcription_text += result_text
|
||||
|
||||
except Exception:
|
||||
logger.exception("Exception while transcribing")
|
||||
|
||||
result = TranscriptionOutput(result_text)
|
||||
ctx.sorted_transcripts[input_frames.frames[0].time] = result
|
||||
return result
|
||||
|
||||
|
||||
def get_final_summary_response(ctx: TranscriptionContext) -> FinalSummaryResult:
|
||||
"""
|
||||
Collate the incremental summaries generated so far and return as the final
|
||||
summary
|
||||
:return:
|
||||
"""
|
||||
final_summary = ""
|
||||
|
||||
# Collate inc summaries
|
||||
for topic in ctx.incremental_responses:
|
||||
final_summary += topic["description"]
|
||||
|
||||
response = FinalSummaryResult(final_summary, ctx.last_transcribed_time)
|
||||
|
||||
with open(
|
||||
"./artefacts/meeting_titles_and_summaries.txt", "a", encoding="utf-8"
|
||||
) as file:
|
||||
file.write(json.dumps(ctx.incremental_responses))
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class AudioStreamTrack(MediaStreamTrack):
|
||||
"""
|
||||
An audio stream track.
|
||||
"""
|
||||
|
||||
kind = "audio"
|
||||
|
||||
def __init__(self, ctx: TranscriptionContext, track):
|
||||
super().__init__()
|
||||
self.ctx = ctx
|
||||
self.track = track
|
||||
self.audio_buffer = av.AudioFifo()
|
||||
|
||||
async def recv(self) -> av.audio.frame.AudioFrame:
|
||||
ctx = self.ctx
|
||||
frame = await self.track.recv()
|
||||
self.audio_buffer.write(frame)
|
||||
|
||||
if local_frames := self.audio_buffer.read_many(
|
||||
settings.AUDIO_BUFFER_SIZE, partial=False
|
||||
):
|
||||
whisper_result = run_in_executor(
|
||||
get_transcription,
|
||||
ctx,
|
||||
TranscriptionInput(local_frames),
|
||||
executor=executor,
|
||||
)
|
||||
whisper_result.add_done_callback(
|
||||
lambda f: channel_send_transcript(ctx) if f.result() else None
|
||||
)
|
||||
|
||||
if len(ctx.transcription_text) > 25:
|
||||
llm_input_text = ctx.transcription_text
|
||||
ctx.transcription_text = ""
|
||||
param = TitleSummaryInput(
|
||||
input_text=llm_input_text, transcribed_time=ctx.last_transcribed_time
|
||||
)
|
||||
llm_result = run_in_executor(
|
||||
get_title_and_summary, ctx, param, executor=executor
|
||||
)
|
||||
llm_result.add_done_callback(
|
||||
lambda f: channel_send_increment(ctx.data_channel, llm_result.result())
|
||||
if f.result()
|
||||
else None
|
||||
)
|
||||
return frame
|
||||
|
||||
|
||||
async def offer(request: requests.Request) -> web.Response:
|
||||
"""
|
||||
Establish the WebRTC connection with the client
|
||||
:param request:
|
||||
:return:
|
||||
"""
|
||||
params = await request.json()
|
||||
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||
|
||||
# client identification
|
||||
peername = request.transport.get_extra_info("peername")
|
||||
if peername is not None:
|
||||
clientid = f"{peername[0]}:{peername[1]}"
|
||||
else:
|
||||
clientid = uuid.uuid4()
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
ctx = TranscriptionContext(logger=logger.bind(client=clientid))
|
||||
|
||||
# handle RTC peer connection
|
||||
pc = RTCPeerConnection()
|
||||
pcs.add(pc)
|
||||
|
||||
@pc.on("datachannel")
|
||||
def on_datachannel(channel) -> NoReturn:
|
||||
ctx.data_channel = channel
|
||||
ctx.logger = ctx.logger.bind(channel=channel.label)
|
||||
ctx.logger.info("Channel created by remote party")
|
||||
|
||||
@channel.on("message")
|
||||
def on_message(message: str) -> NoReturn:
|
||||
ctx.logger.info(f"Message: {message}")
|
||||
if json.loads(message)["cmd"] == "STOP":
|
||||
# Placeholder final summary
|
||||
response = get_final_summary_response()
|
||||
channel_send_increment(channel, response)
|
||||
# To-do Add code to stop connection from server side here
|
||||
# But have to handshake with client once
|
||||
|
||||
if isinstance(message, str) and message.startswith("ping"):
|
||||
channel_send(channel, "pong" + message[4:])
|
||||
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange() -> NoReturn:
|
||||
ctx.logger.info(f"Connection state changed: {pc.connectionState}")
|
||||
if pc.connectionState == "failed":
|
||||
await pc.close()
|
||||
pcs.discard(pc)
|
||||
|
||||
@pc.on("track")
|
||||
def on_track(track) -> NoReturn:
|
||||
ctx.logger.info(f"Track {track.kind} received")
|
||||
pc.addTrack(AudioStreamTrack(ctx, relay.subscribe(track)))
|
||||
|
||||
await pc.setRemoteDescription(offer)
|
||||
|
||||
answer = await pc.createAnswer()
|
||||
await pc.setLocalDescription(answer)
|
||||
return web.Response(
|
||||
content_type="application/json",
|
||||
text=json.dumps(
|
||||
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def on_shutdown(application: web.Application) -> NoReturn:
|
||||
"""
|
||||
On shutdown, the coroutines that shutdown client connections are
|
||||
executed
|
||||
:param application:
|
||||
:return:
|
||||
"""
|
||||
coroutines = [pc.close() for pc in pcs]
|
||||
await asyncio.gather(*coroutines)
|
||||
pcs.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="WebRTC based server for Reflector")
|
||||
parser.add_argument(
|
||||
"--host", default="0.0.0.0", help="Server host IP (def: 0.0.0.0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=1250, help="Server port (def: 1250)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
app = web.Application()
|
||||
cors = aiohttp_cors.setup(
|
||||
app,
|
||||
defaults={
|
||||
"*": aiohttp_cors.ResourceOptions(
|
||||
allow_credentials=True, expose_headers="*", allow_headers="*"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
offer_resource = cors.add(app.router.add_resource("/offer"))
|
||||
cors.add(offer_resource.add_route("POST", offer))
|
||||
app.on_shutdown.append(on_shutdown)
|
||||
web.run_app(app, access_log=None, host=args.host, port=args.port)
|
||||
45
server/reflector/settings.py
Normal file
45
server/reflector/settings.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
OPENMP_KMP_DUPLICATE_LIB_OK: bool = False
|
||||
|
||||
# Whisper
|
||||
WHISPER_MODEL_SIZE: str = "tiny"
|
||||
whisper_real_time_model_size: str = "tiny"
|
||||
|
||||
# Summarizer
|
||||
SUMMARIZER_MODEL: str = "facebook/bart-large-cnn"
|
||||
SUMMARIZER_INPUT_ENCODING_MAX_LENGTH: int = 1024
|
||||
SUMMARIZER_MAX_LENGTH: int = 2048
|
||||
SUMMARIZER_BEAM_SIZE: int = 6
|
||||
SUMMARIZER_MAX_CHUNK_LENGTH: int = 1024
|
||||
SUMMARIZER_USING_CHUNKS: bool = True
|
||||
|
||||
# Audio
|
||||
AUDIO_BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME: str = "aggregator"
|
||||
AUDIO_AV_FOUNDATION_DEVICE_ID: int = 1
|
||||
AUDIO_CHANNELS: int = 2
|
||||
AUDIO_SAMPLING_RATE: int = 48000
|
||||
AUDIO_SAMPLING_WIDTH: int = 2
|
||||
AUDIO_BUFFER_SIZE: int = 256 * 960
|
||||
|
||||
# LLM
|
||||
LLM_BACKEND: str = "oobagooda"
|
||||
LLM_URL: str | None = None
|
||||
LLM_HOST: str = "localhost"
|
||||
LLM_PORT: int = 7860
|
||||
|
||||
# Storage
|
||||
STORAGE_BACKEND: str = "aws"
|
||||
STORAGE_AWS_ACCESS_KEY: str = ""
|
||||
STORAGE_AWS_SECRET_KEY: str = ""
|
||||
STORAGE_AWS_BUCKET: str = ""
|
||||
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: str = ""
|
||||
|
||||
|
||||
settings = Settings()
|
||||
145
server/reflector/stream_client.py
Normal file
145
server/reflector/stream_client.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
import pyaudio
|
||||
import requests
|
||||
import stamina
|
||||
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||||
from aiortc.contrib.media import MediaPlayer, MediaRelay
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class StreamClient:
|
||||
def __init__(
|
||||
self, signaling, url="http://0.0.0.0:1250", play_from=None, ping_pong=False
|
||||
):
|
||||
self.signaling = signaling
|
||||
self.server_url = url
|
||||
self.play_from = play_from
|
||||
self.ping_pong = ping_pong
|
||||
self.paudio = pyaudio.PyAudio()
|
||||
|
||||
self.pc = RTCPeerConnection()
|
||||
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.relay = None
|
||||
self.pcs = set()
|
||||
self.time_start = None
|
||||
self.queue = asyncio.Queue()
|
||||
self.player = MediaPlayer(
|
||||
f":{settings.AUDIO_AV_FOUNDATION_DEVICE_ID}",
|
||||
format="avfoundation",
|
||||
options={"channels": "2"},
|
||||
)
|
||||
|
||||
def stop(self):
|
||||
self.loop.run_until_complete(self.signaling.close())
|
||||
self.loop.run_until_complete(self.pc.close())
|
||||
# self.loop.close()
|
||||
|
||||
def create_local_tracks(self, play_from):
|
||||
if play_from:
|
||||
player = MediaPlayer(play_from)
|
||||
return player.audio, player.video
|
||||
else:
|
||||
if self.relay is None:
|
||||
self.relay = MediaRelay()
|
||||
return self.relay.subscribe(self.player.audio), None
|
||||
|
||||
def channel_log(self, channel, t, message):
|
||||
print("channel(%s) %s %s" % (channel.label, t, message))
|
||||
|
||||
def channel_send(self, channel, message):
|
||||
# self.channel_log(channel, ">", message)
|
||||
channel.send(message)
|
||||
|
||||
def current_stamp(self):
|
||||
if self.time_start is None:
|
||||
self.time_start = time.time()
|
||||
return 0
|
||||
else:
|
||||
return int((time.time() - self.time_start) * 1000000)
|
||||
|
||||
async def run_offer(self, pc, signaling):
|
||||
# microphone
|
||||
audio, video = self.create_local_tracks(self.play_from)
|
||||
pc_id = "PeerConnection(%s)" % uuid.uuid4()
|
||||
self.pcs.add(pc)
|
||||
|
||||
def log_info(msg, *args):
|
||||
logger.info(pc_id + " " + msg, *args)
|
||||
|
||||
@pc.on("connectionstatechange")
|
||||
async def on_connectionstatechange():
|
||||
print("Connection state is %s" % pc.connectionState)
|
||||
if pc.connectionState == "failed":
|
||||
await pc.close()
|
||||
self.pcs.discard(pc)
|
||||
|
||||
@pc.on("track")
|
||||
def on_track(track):
|
||||
print("Sending %s" % track.kind)
|
||||
self.pc.addTrack(track)
|
||||
|
||||
@track.on("ended")
|
||||
async def on_ended():
|
||||
log_info("Track %s ended", track.kind)
|
||||
|
||||
self.pc.addTrack(audio)
|
||||
|
||||
channel = pc.createDataChannel("data-channel")
|
||||
self.channel_log(channel, "-", "created by local party")
|
||||
|
||||
async def send_pings():
|
||||
while True:
|
||||
self.channel_send(channel, "ping %d" % self.current_stamp())
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@channel.on("open")
|
||||
def on_open():
|
||||
if self.ping_pong:
|
||||
asyncio.ensure_future(send_pings())
|
||||
|
||||
@channel.on("message")
|
||||
def on_message(message):
|
||||
self.queue.put_nowait(message)
|
||||
if self.ping_pong:
|
||||
self.channel_log(channel, "<", message)
|
||||
|
||||
if isinstance(message, str) and message.startswith("pong"):
|
||||
elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000
|
||||
print(" RTT %.2f ms" % elapsed_ms)
|
||||
|
||||
await pc.setLocalDescription(await pc.createOffer())
|
||||
|
||||
sdp = {"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
||||
|
||||
@stamina.retry(on=httpx.HTTPError, attempts=5)
|
||||
def connect_to_server():
|
||||
response = requests.post(self.server_url, json=sdp, timeout=10)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
params = connect_to_server().json()
|
||||
answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||
await pc.setRemoteDescription(answer)
|
||||
|
||||
self.reader = self.worker(f'{"worker"}', self.queue)
|
||||
|
||||
def get_reader(self):
|
||||
return self.reader
|
||||
|
||||
async def worker(self, name, queue):
|
||||
while True:
|
||||
msg = await self.queue.get()
|
||||
yield msg
|
||||
self.queue.task_done()
|
||||
|
||||
async def start(self):
|
||||
coro = self.run_offer(self.pc, self.signaling)
|
||||
task = asyncio.create_task(coro)
|
||||
await task
|
||||
0
server/reflector/utils/__init__.py
Normal file
0
server/reflector/utils/__init__.py
Normal file
59
server/reflector/utils/file_utils.py
Normal file
59
server/reflector/utils/file_utils.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Utility file for file handling related functions, including file downloads and
|
||||
uploads to cloud storage
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import List, NoReturn
|
||||
|
||||
import boto3
|
||||
import botocore
|
||||
|
||||
from .log_utils import LOGGER
|
||||
from .run_utils import SECRETS
|
||||
|
||||
BUCKET_NAME = SECRETS["AWS-S3"]["BUCKET_NAME"]
|
||||
|
||||
s3 = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=SECRETS["AWS-S3"]["AWS_ACCESS_KEY"],
|
||||
aws_secret_access_key=SECRETS["AWS-S3"]["AWS_SECRET_KEY"],
|
||||
)
|
||||
|
||||
|
||||
def upload_files(files_to_upload: List[str]) -> NoReturn:
|
||||
"""
|
||||
Upload a list of files to the configured S3 bucket
|
||||
:param files_to_upload: List of files to upload
|
||||
:return: None
|
||||
"""
|
||||
for key in files_to_upload:
|
||||
LOGGER.info("Uploading file " + key)
|
||||
try:
|
||||
s3.upload_file(key, BUCKET_NAME, key)
|
||||
except botocore.exceptions.ClientError as exception:
|
||||
print(exception.response)
|
||||
|
||||
|
||||
def download_files(files_to_download: List[str]) -> NoReturn:
|
||||
"""
|
||||
Download a list of files from the configured S3 bucket
|
||||
:param files_to_download: List of files to download
|
||||
:return: None
|
||||
"""
|
||||
for key in files_to_download:
|
||||
LOGGER.info("Downloading file " + key)
|
||||
try:
|
||||
s3.download_file(BUCKET_NAME, key, key)
|
||||
except botocore.exceptions.ClientError as exception:
|
||||
if exception.response["Error"]["Code"] == "404":
|
||||
print("The object does not exist.")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if sys.argv[1] == "download":
|
||||
download_files([sys.argv[2]])
|
||||
elif sys.argv[1] == "upload":
|
||||
upload_files([sys.argv[2]])
|
||||
38
server/reflector/utils/format_output.py
Normal file
38
server/reflector/utils/format_output.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Utility function to format the artefacts created during Reflector run
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
with open("../artefacts/meeting_titles_and_summaries.txt", "r", encoding="utf-8") as f:
|
||||
outputs = f.read()
|
||||
|
||||
outputs = json.loads(outputs)
|
||||
|
||||
transcript_file = open("../artefacts/meeting_transcript.txt", "a", encoding="utf-8")
|
||||
title_desc_file = open(
|
||||
"../artefacts/meeting_title_description.txt", "a", encoding="utf-8"
|
||||
)
|
||||
summary_file = open("../artefacts/meeting_summary.txt", "a", encoding="utf-8")
|
||||
|
||||
for item in outputs["topics"]:
|
||||
transcript_file.write(item["transcript"])
|
||||
summary_file.write(item["description"])
|
||||
|
||||
title_desc_file.write("TITLE: \n")
|
||||
title_desc_file.write(item["title"])
|
||||
title_desc_file.write("\n")
|
||||
|
||||
title_desc_file.write("DESCRIPTION: \n")
|
||||
title_desc_file.write(item["description"])
|
||||
title_desc_file.write("\n")
|
||||
|
||||
title_desc_file.write("TRANSCRIPT: \n")
|
||||
title_desc_file.write(item["transcript"])
|
||||
title_desc_file.write("\n")
|
||||
|
||||
title_desc_file.write("---------------------------------------- \n\n")
|
||||
|
||||
transcript_file.close()
|
||||
title_desc_file.close()
|
||||
summary_file.close()
|
||||
55
server/reflector/utils/run_utils.py
Normal file
55
server/reflector/utils/run_utils.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Utility file for server side asynchronous task running and config objects
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from functools import partial
|
||||
from threading import Lock
|
||||
from typing import ContextManager, Generic, TypeVar
|
||||
|
||||
|
||||
def run_in_executor(func, *args, executor=None, **kwargs):
|
||||
"""
|
||||
Run the function in an executor, unblocking the main loop
|
||||
:param func: Function to be run in executor
|
||||
:param args: function parameters
|
||||
:param executor: executor instance [Thread | Process]
|
||||
:param kwargs: Additional parameters
|
||||
:return: Future of function result upon completion
|
||||
"""
|
||||
callback = partial(func, *args, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_in_executor(executor, callback)
|
||||
|
||||
|
||||
# Genetic type template
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Mutex(Generic[T]):
|
||||
"""
|
||||
Mutex class to implement lock/release of a shared
|
||||
protected variable
|
||||
"""
|
||||
|
||||
def __init__(self, value: T):
|
||||
"""
|
||||
Create an instance of Mutex wrapper for the given resource
|
||||
:param value: Shared resources to be thread protected
|
||||
"""
|
||||
self.__value = value
|
||||
self.__lock = Lock()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def lock(self) -> ContextManager[T]:
|
||||
"""
|
||||
Lock the resource with a mutex to be used within a context block
|
||||
The lock is automatically released on context exit
|
||||
:return: Shared resource
|
||||
"""
|
||||
self.__lock.acquire()
|
||||
try:
|
||||
yield self.__value
|
||||
finally:
|
||||
self.__lock.release()
|
||||
264
server/reflector/utils/text_utils.py
Normal file
264
server/reflector/utils/text_utils.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Utility file for all text processing related functionalities
|
||||
"""
|
||||
import datetime
|
||||
from typing import List
|
||||
|
||||
import nltk
|
||||
import torch
|
||||
from nltk.corpus import stopwords
|
||||
from nltk.tokenize import word_tokenize
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||
|
||||
from log_utils import LOGGER
|
||||
from run_utils import CONFIG
|
||||
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
|
||||
def preprocess_sentence(sentence: str) -> str:
|
||||
"""
|
||||
Filter out undesirable tokens from thr sentence
|
||||
:param sentence:
|
||||
:return:
|
||||
"""
|
||||
stop_words = set(stopwords.words("english"))
|
||||
tokens = word_tokenize(sentence.lower())
|
||||
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
||||
return " ".join(tokens)
|
||||
|
||||
|
||||
def compute_similarity(sent1: str, sent2: str) -> float:
|
||||
"""
|
||||
Compute the similarity
|
||||
"""
|
||||
tfidf_vectorizer = TfidfVectorizer()
|
||||
if sent1 is not None and sent2 is not None:
|
||||
tfidf_matrix = tfidf_vectorizer.fit_transform([sent1, sent2])
|
||||
return cosine_similarity(tfidf_matrix[0], tfidf_matrix[1])[0][0]
|
||||
return 0.0
|
||||
|
||||
|
||||
def remove_almost_alike_sentences(sentences: List[str], threshold=0.7) -> List[str]:
|
||||
"""
|
||||
Filter sentences that are similar beyond a set threshold
|
||||
:param sentences:
|
||||
:param threshold:
|
||||
:return:
|
||||
"""
|
||||
num_sentences = len(sentences)
|
||||
removed_indices = set()
|
||||
|
||||
for i in range(num_sentences):
|
||||
if i not in removed_indices:
|
||||
for j in range(i + 1, num_sentences):
|
||||
if j not in removed_indices:
|
||||
l_i = len(sentences[i])
|
||||
l_j = len(sentences[j])
|
||||
if l_i == 0 or l_j == 0:
|
||||
if l_i == 0:
|
||||
removed_indices.add(i)
|
||||
if l_j == 0:
|
||||
removed_indices.add(j)
|
||||
else:
|
||||
sentence1 = preprocess_sentence(sentences[i])
|
||||
sentence2 = preprocess_sentence(sentences[j])
|
||||
if len(sentence1) != 0 and len(sentence2) != 0:
|
||||
similarity = compute_similarity(sentence1, sentence2)
|
||||
|
||||
if similarity >= threshold:
|
||||
removed_indices.add(max(i, j))
|
||||
|
||||
filtered_sentences = [
|
||||
sentences[i] for i in range(num_sentences) if i not in removed_indices
|
||||
]
|
||||
return filtered_sentences
|
||||
|
||||
|
||||
def remove_outright_duplicate_sentences_from_chunk(chunk: str) -> List[str]:
|
||||
"""
|
||||
Remove repetitive sentences
|
||||
:param chunk:
|
||||
:return:
|
||||
"""
|
||||
chunk_text = chunk["text"]
|
||||
sentences = nltk.sent_tokenize(chunk_text)
|
||||
nonduplicate_sentences = list(dict.fromkeys(sentences))
|
||||
return nonduplicate_sentences
|
||||
|
||||
|
||||
def remove_whisper_repetitive_hallucination(
|
||||
nonduplicate_sentences: List[str],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Remove sentences that are repeated as a result of Whisper
|
||||
hallucinations
|
||||
:param nonduplicate_sentences:
|
||||
:return:
|
||||
"""
|
||||
chunk_sentences = []
|
||||
|
||||
for sent in nonduplicate_sentences:
|
||||
temp_result = ""
|
||||
seen = {}
|
||||
words = nltk.word_tokenize(sent)
|
||||
n_gram_filter = 3
|
||||
for i in range(len(words)):
|
||||
if (
|
||||
str(words[i : i + n_gram_filter]) in seen
|
||||
and seen[str(words[i : i + n_gram_filter])]
|
||||
== words[i + 1 : i + n_gram_filter + 2]
|
||||
):
|
||||
pass
|
||||
else:
|
||||
seen[str(words[i : i + n_gram_filter])] = words[
|
||||
i + 1 : i + n_gram_filter + 2
|
||||
]
|
||||
temp_result += words[i]
|
||||
temp_result += " "
|
||||
chunk_sentences.append(temp_result)
|
||||
return chunk_sentences
|
||||
|
||||
|
||||
def post_process_transcription(whisper_result: dict) -> dict:
|
||||
"""
|
||||
Parent function to perform post-processing on the transcription result
|
||||
:param whisper_result:
|
||||
:return:
|
||||
"""
|
||||
transcript_text = ""
|
||||
for chunk in whisper_result["chunks"]:
|
||||
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
|
||||
chunk_sentences = remove_whisper_repetitive_hallucination(
|
||||
nonduplicate_sentences
|
||||
)
|
||||
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
|
||||
chunk["text"] = " ".join(similarity_matched_sentences)
|
||||
transcript_text += chunk["text"]
|
||||
whisper_result["text"] = transcript_text
|
||||
return whisper_result
|
||||
|
||||
|
||||
def summarize_chunks(chunks: List[str], tokenizer, model) -> List[str]:
|
||||
"""
|
||||
Summarize each chunk using a summarizer model
|
||||
:param chunks:
|
||||
:param tokenizer:
|
||||
:param model:
|
||||
:return:
|
||||
"""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
summaries = []
|
||||
for c in chunks:
|
||||
input_ids = tokenizer.encode(c, return_tensors="pt")
|
||||
input_ids = input_ids.to(device)
|
||||
with torch.no_grad():
|
||||
summary_ids = model.generate(
|
||||
input_ids,
|
||||
num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
|
||||
length_penalty=2.0,
|
||||
max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
|
||||
early_stopping=True,
|
||||
)
|
||||
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
||||
summaries.append(summary)
|
||||
return summaries
|
||||
|
||||
|
||||
def chunk_text(
|
||||
text: str, max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])
|
||||
) -> List[str]:
|
||||
"""
|
||||
Split text into smaller chunks.
|
||||
:param text: Text to be chunked
|
||||
:param max_chunk_length: length of chunk
|
||||
:return: chunked texts
|
||||
"""
|
||||
sentences = nltk.sent_tokenize(text)
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
for sentence in sentences:
|
||||
if len(current_chunk) + len(sentence) < max_chunk_length:
|
||||
current_chunk += f" {sentence.strip()}"
|
||||
else:
|
||||
chunks.append(current_chunk.strip())
|
||||
current_chunk = f"{sentence.strip()}"
|
||||
chunks.append(current_chunk.strip())
|
||||
return chunks
|
||||
|
||||
|
||||
def summarize(
|
||||
transcript_text: str,
|
||||
timestamp: datetime.datetime.timestamp,
|
||||
real_time: bool = False,
|
||||
chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"],
|
||||
):
|
||||
"""
|
||||
Summarize the given text either as a whole or as chunks as needed
|
||||
:param transcript_text:
|
||||
:param timestamp:
|
||||
:param real_time:
|
||||
:param chunk_summarize:
|
||||
:return:
|
||||
"""
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
summary_model = CONFIG["SUMMARIZER"]["SUMMARY_MODEL"]
|
||||
if not summary_model:
|
||||
summary_model = "facebook/bart-large-cnn"
|
||||
|
||||
# Summarize the generated transcript using the BART model
|
||||
LOGGER.info(f"Loading BART model: {summary_model}")
|
||||
tokenizer = BartTokenizer.from_pretrained(summary_model)
|
||||
model = BartForConditionalGeneration.from_pretrained(summary_model)
|
||||
model = model.to(device)
|
||||
|
||||
output_file = "summary_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||
if real_time:
|
||||
output_file = "real_time_" + output_file
|
||||
|
||||
if chunk_summarize != "YES":
|
||||
max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
|
||||
inputs = tokenizer.batch_encode_plus(
|
||||
[transcript_text],
|
||||
truncation=True,
|
||||
padding="longest",
|
||||
max_length=max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
|
||||
max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
|
||||
summaries = model.generate(
|
||||
inputs["input_ids"],
|
||||
num_beams=num_beans,
|
||||
length_penalty=2.0,
|
||||
max_length=max_length,
|
||||
early_stopping=True,
|
||||
)
|
||||
|
||||
decoded_summaries = [
|
||||
tokenizer.decode(
|
||||
summary, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
for summary in summaries
|
||||
]
|
||||
summary = " ".join(decoded_summaries)
|
||||
with open("./artefacts/" + output_file, "w", encoding="utf-8") as file:
|
||||
file.write(summary.strip() + "\n")
|
||||
else:
|
||||
LOGGER.info("Breaking transcript into smaller chunks")
|
||||
chunks = chunk_text(transcript_text)
|
||||
|
||||
LOGGER.info(
|
||||
f"Transcript broken into {len(chunks)} " f"chunks of at most 500 words"
|
||||
)
|
||||
|
||||
LOGGER.info(f"Writing summary text to: {output_file}")
|
||||
with open(output_file, "w") as f:
|
||||
summaries = summarize_chunks(chunks, tokenizer, model)
|
||||
for summary in summaries:
|
||||
f.write(summary.strip() + " ")
|
||||
283
server/reflector/utils/viz_utils.py
Normal file
283
server/reflector/utils/viz_utils.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Utility file for all visualization related functions
|
||||
"""
|
||||
|
||||
import ast
|
||||
import collections
|
||||
import datetime
|
||||
import os
|
||||
import pickle
|
||||
from typing import NoReturn
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import scattertext as st
|
||||
import spacy
|
||||
from nltk.corpus import stopwords
|
||||
from wordcloud import STOPWORDS, WordCloud
|
||||
|
||||
en = spacy.load("en_core_web_md")
|
||||
spacy_stopwords = en.Defaults.stop_words
|
||||
|
||||
STOPWORDS = (
|
||||
set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
|
||||
)
|
||||
|
||||
|
||||
def create_wordcloud(
|
||||
timestamp: datetime.datetime.timestamp, real_time: bool = False
|
||||
) -> NoReturn:
|
||||
"""
|
||||
Create a basic word cloud visualization of transcribed text
|
||||
:return: None. The wordcloud image is saved locally
|
||||
"""
|
||||
filename = "transcript"
|
||||
if real_time:
|
||||
filename = (
|
||||
"real_time_"
|
||||
+ filename
|
||||
+ "_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".txt"
|
||||
)
|
||||
else:
|
||||
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||
|
||||
with open("./artefacts/" + filename, "r") as f:
|
||||
transcription_text = f.read()
|
||||
|
||||
# python_mask = np.array(PIL.Image.open("download1.png"))
|
||||
|
||||
wordcloud = WordCloud(
|
||||
height=800,
|
||||
width=800,
|
||||
background_color="white",
|
||||
stopwords=STOPWORDS,
|
||||
min_font_size=8,
|
||||
).generate(transcription_text)
|
||||
|
||||
# Plot wordcloud and save image
|
||||
plt.figure(facecolor=None)
|
||||
plt.imshow(wordcloud, interpolation="bilinear")
|
||||
plt.axis("off")
|
||||
plt.tight_layout(pad=0)
|
||||
|
||||
wordcloud = "wordcloud"
|
||||
if real_time:
|
||||
wordcloud = (
|
||||
"real_time_"
|
||||
+ wordcloud
|
||||
+ "_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".png"
|
||||
)
|
||||
else:
|
||||
wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
||||
|
||||
plt.savefig("./artefacts/" + wordcloud)
|
||||
|
||||
|
||||
def create_talk_diff_scatter_viz(
|
||||
timestamp: datetime.datetime.timestamp, real_time: bool = False
|
||||
) -> NoReturn:
|
||||
"""
|
||||
Perform agenda vs transcription diff to see covered topics.
|
||||
Create a scatter plot of words in topics.
|
||||
:return: None. Saved locally.
|
||||
"""
|
||||
spacy_model = "en_core_web_md"
|
||||
nlp = spacy.load(spacy_model)
|
||||
nlp.add_pipe("sentencizer")
|
||||
|
||||
agenda_topics = []
|
||||
agenda = []
|
||||
# Load the agenda
|
||||
with open(os.path.join(os.getcwd(), "agenda-headers.txt"), "r") as f:
|
||||
for line in f.readlines():
|
||||
if line.strip():
|
||||
agenda.append(line.strip())
|
||||
agenda_topics.append(line.split(":")[0])
|
||||
|
||||
# Load the transcription with timestamp
|
||||
if real_time:
|
||||
filename = (
|
||||
"./artefacts/real_time_transcript_with_timestamp_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".txt"
|
||||
)
|
||||
else:
|
||||
filename = (
|
||||
"./artefacts/transcript_with_timestamp_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".txt"
|
||||
)
|
||||
with open(filename) as file:
|
||||
transcription_timestamp_text = file.read()
|
||||
|
||||
res = ast.literal_eval(transcription_timestamp_text)
|
||||
chunks = res["chunks"]
|
||||
|
||||
# create df for processing
|
||||
df = pd.DataFrame.from_dict(res["chunks"])
|
||||
|
||||
covered_items = {}
|
||||
# ts: timestamp
|
||||
# Map each timestamped chunk with top1 and top2 matched agenda
|
||||
ts_to_topic_mapping_top_1 = {}
|
||||
ts_to_topic_mapping_top_2 = {}
|
||||
|
||||
# Also create a mapping of the different timestamps
|
||||
# in which each topic was covered
|
||||
topic_to_ts_mapping_top_1 = collections.defaultdict(list)
|
||||
topic_to_ts_mapping_top_2 = collections.defaultdict(list)
|
||||
|
||||
similarity_threshold = 0.7
|
||||
|
||||
for c in chunks:
|
||||
doc_transcription = nlp(c["text"])
|
||||
topic_similarities = []
|
||||
for item in range(len(agenda)):
|
||||
item_doc = nlp(agenda[item])
|
||||
# if not doc_transcription or not all
|
||||
# (token.has_vector for token in doc_transcription):
|
||||
if not doc_transcription:
|
||||
continue
|
||||
similarity = doc_transcription.similarity(item_doc)
|
||||
topic_similarities.append((item, similarity))
|
||||
topic_similarities.sort(key=lambda x: x[1], reverse=True)
|
||||
for i in range(2):
|
||||
if topic_similarities[i][1] >= similarity_threshold:
|
||||
covered_items[agenda[topic_similarities[i][0]]] = True
|
||||
# top1 match
|
||||
if i == 0:
|
||||
ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[
|
||||
topic_similarities[i][0]
|
||||
]
|
||||
topic_to_ts_mapping_top_1[
|
||||
agenda_topics[topic_similarities[i][0]]
|
||||
].append(c["timestamp"])
|
||||
# top2 match
|
||||
else:
|
||||
ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[
|
||||
topic_similarities[i][0]
|
||||
]
|
||||
topic_to_ts_mapping_top_2[
|
||||
agenda_topics[topic_similarities[i][0]]
|
||||
].append(c["timestamp"])
|
||||
|
||||
def create_new_columns(record: dict) -> dict:
|
||||
"""
|
||||
Accumulate the mapping information into the df
|
||||
:param record:
|
||||
:return:
|
||||
"""
|
||||
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[
|
||||
record["timestamp"]
|
||||
]
|
||||
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[
|
||||
record["timestamp"]
|
||||
]
|
||||
return record
|
||||
|
||||
df = df.apply(create_new_columns, axis=1)
|
||||
|
||||
# Count the number of items covered and calculate the percentage
|
||||
num_covered_items = sum(covered_items.values())
|
||||
percentage_covered = num_covered_items / len(agenda) * 100
|
||||
|
||||
# Print the results
|
||||
print("💬 Agenda items covered in the transcription:")
|
||||
for item in agenda:
|
||||
if item in covered_items and covered_items[item]:
|
||||
print("✅ ", item)
|
||||
else:
|
||||
print("❌ ", item)
|
||||
print("📊 Coverage: {:.2f}%".format(percentage_covered))
|
||||
|
||||
# Save df, mappings for further experimentation
|
||||
df_name = "df"
|
||||
if real_time:
|
||||
df_name = (
|
||||
"real_time_"
|
||||
+ df_name
|
||||
+ "_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".pkl"
|
||||
)
|
||||
else:
|
||||
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||
df.to_pickle("./artefacts/" + df_name)
|
||||
|
||||
my_mappings = [
|
||||
ts_to_topic_mapping_top_1,
|
||||
ts_to_topic_mapping_top_2,
|
||||
topic_to_ts_mapping_top_1,
|
||||
topic_to_ts_mapping_top_2,
|
||||
]
|
||||
|
||||
mappings_name = "mappings"
|
||||
if real_time:
|
||||
mappings_name = (
|
||||
"real_time_"
|
||||
+ mappings_name
|
||||
+ "_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".pkl"
|
||||
)
|
||||
else:
|
||||
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||
pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb"))
|
||||
|
||||
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
|
||||
|
||||
# pick the 2 most matched topic to be used for plotting
|
||||
topic_times = collections.defaultdict(int)
|
||||
for key in ts_to_topic_mapping_top_1.keys():
|
||||
if key[0] is None or key[1] is None:
|
||||
continue
|
||||
duration = key[1] - key[0]
|
||||
topic_times[ts_to_topic_mapping_top_1[key]] += duration
|
||||
|
||||
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
if len(topic_times) > 1:
|
||||
cat_1 = topic_times[0][0]
|
||||
cat_1_name = topic_times[0][0]
|
||||
cat_2_name = topic_times[1][0]
|
||||
|
||||
# Scatter plot of topics
|
||||
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
|
||||
corpus = (
|
||||
st.CorpusFromParsedDocuments(
|
||||
df, category_col="ts_to_topic_mapping_top_1", parsed_col="parse"
|
||||
)
|
||||
.build()
|
||||
.get_unigram_corpus()
|
||||
.compact(st.AssociationCompactor(2000))
|
||||
)
|
||||
html = st.produce_scattertext_explorer(
|
||||
corpus,
|
||||
category=cat_1,
|
||||
category_name=cat_1_name,
|
||||
not_category_name=cat_2_name,
|
||||
minimum_term_frequency=0,
|
||||
pmi_threshold_coefficient=0,
|
||||
width_in_pixels=1000,
|
||||
transform=st.Scalers.dense_rank,
|
||||
)
|
||||
if real_time:
|
||||
with open(
|
||||
"./artefacts/real_time_scatter_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".html",
|
||||
"w",
|
||||
) as file:
|
||||
file.write(html)
|
||||
else:
|
||||
with open(
|
||||
"./artefacts/scatter_"
|
||||
+ timestamp.strftime("%m-%d-%Y_%H:%M:%S")
|
||||
+ ".html",
|
||||
"w",
|
||||
) as file:
|
||||
file.write(html)
|
||||
Reference in New Issue
Block a user