Merge pull request #42 from Monadical-SAS/feat/gokul

Code clean up and refactoring
This commit is contained in:
projects-g
2023-07-26 12:11:41 +05:30
committed by GitHub
15 changed files with 436 additions and 206 deletions

View File

@@ -5,11 +5,16 @@ import signal
from aiortc.contrib.signaling import (add_signaling_arguments, from aiortc.contrib.signaling import (add_signaling_arguments,
create_signaling) create_signaling)
from utils.log_utils import logger from utils.log_utils import LOGGER
from stream_client import StreamClient from stream_client import StreamClient
from typing import NoReturn
async def main() -> NoReturn:
async def main(): """
Reflector's entry point to the python client for WebRTC streaming if not
using the browser based UI-application
:return:
"""
parser = argparse.ArgumentParser(description="Data channels ping/pong") parser = argparse.ArgumentParser(description="Data channels ping/pong")
parser.add_argument( parser.add_argument(
@@ -37,17 +42,17 @@ async def main():
async def shutdown(signal, loop): async def shutdown(signal, loop):
"""Cleanup tasks tied to the service's shutdown.""" """Cleanup tasks tied to the service's shutdown."""
logger.info(f"Received exit signal {signal.name}...") LOGGER.info(f"Received exit signal {signal.name}...")
logger.info("Closing database connections") LOGGER.info("Closing database connections")
logger.info("Nacking outstanding messages") LOGGER.info("Nacking outstanding messages")
tasks = [t for t in asyncio.all_tasks() if t is not tasks = [t for t in asyncio.all_tasks() if t is not
asyncio.current_task()] asyncio.current_task()]
[task.cancel() for task in tasks] [task.cancel() for task in tasks]
logger.info(f"Cancelling {len(tasks)} outstanding tasks") LOGGER.info(f"Cancelling {len(tasks)} outstanding tasks")
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
logger.info(f'{"Flushing metrics"}') LOGGER.info(f'{"Flushing metrics"}')
loop.stop() loop.stop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)

164
reflector_dataclasses.py Normal file
View File

@@ -0,0 +1,164 @@
"""
Collection of data classes for streamlining and rigidly structuring
the input and output parameters of functions
"""
import datetime
from dataclasses import dataclass
from typing import List
import av
@dataclass
class TitleSummaryInput:
"""
Data class for the input to generate title and summaries.
The outcome will be used to send query to the LLM for processing.
"""
input_text = str
transcribed_time = float
prompt = str
data = dict
def __init__(self, transcribed_time, input_text=""):
self.input_text = input_text
self.transcribed_time = transcribed_time
self.prompt = \
f"""
### Human:
Create a JSON object as response.The JSON object must have 2 fields:
i) title and ii) summary.For the title field,generate a short title
for the given text. For the summary field, summarize the given text
in three sentences.
{self.input_text}
### Assistant:
"""
self.data = {"data": self.prompt}
self.headers = {"Content-Type": "application/json"}
@dataclass
class IncrementalResult:
"""
Data class for the result of generating one title and summaries.
Defines how a single "topic" looks like.
"""
title = str
description = str
transcript = str
def __init__(self, title, desc, transcript):
self.title = title
self.description = desc
self.transcript = transcript
@dataclass
class TitleSummaryOutput:
"""
Data class for the result of all generated titles and summaries.
The result will be sent back to the client
"""
cmd = str
topics = List[IncrementalResult]
def __init__(self, inc_responses):
self.topics = inc_responses
def get_result(self):
return {
"cmd": self.cmd,
"topics": self.topics
}
@dataclass
class ParseLLMResult:
"""
Data class to parse the result returned by the LLM while generating title
and summaries. The result will be sent back to the client.
"""
description = str
transcript = str
timestamp = str
def __init__(self, param: TitleSummaryInput, output: dict):
self.transcript = param.input_text
self.description = output.pop("summary")
self.timestamp = \
str(datetime.timedelta(seconds=round(param.transcribed_time)))
def get_result(self):
return {
"description": self.description,
"transcript": self.transcript,
"timestamp": self.timestamp
}
@dataclass
class TranscriptionInput:
"""
Data class to define the input to the transcription function
AudioFrames -> input
"""
frames = List[av.audio.frame.AudioFrame]
def __init__(self, frames):
self.frames = frames
@dataclass
class TranscriptionOutput:
"""
Dataclass to define the result of the transcription function.
The result will be sent back to the client
"""
cmd = str
result_text = str
def __init__(self, result_text):
self.cmd = "SHOW_TRANSCRIPTION"
self.result_text = result_text
def get_result(self):
return {
"cmd": self.cmd,
"text": self.result_text
}
@dataclass
class FinalSummaryResult:
"""
Dataclass to define the result of the final summary function.
The result will be sent back to the client.
"""
cmd = str
final_summary = str
duration = str
def __init__(self, final_summary, time):
self.duration = str(datetime.timedelta(seconds=round(time)))
self.final_summary = final_summary
self.cmd = ""
def get_result(self):
return {
"cmd": self.cmd,
"duration": self.duration,
"summary": self.final_summary
}
class BlackListedMessages:
"""
Class to hold the blacklisted messages. These messages should be filtered
out and not sent back to the client as part of the transcription.
"""
messages = [" Thank you.", " See you next time!",
" Thank you for watching!", " Bye!",
" And that's what I'm talking about."]

203
server.py
View File

@@ -6,18 +6,22 @@ import os
import uuid import uuid
import wave import wave
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Union, NoReturn
import aiohttp_cors import aiohttp_cors
import av
import requests import requests
from aiohttp import web from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay from aiortc.contrib.media import MediaRelay
from av import AudioFifo
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from loguru import logger
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from utils.run_utils import run_in_executor, config from reflector_dataclasses import FinalSummaryResult, ParseLLMResult,\
TitleSummaryInput, TitleSummaryOutput, TranscriptionInput,\
TranscriptionOutput, BlackListedMessages
from utils.run_utils import CONFIG, run_in_executor
from utils.log_utils import LOGGER
pcs = set() pcs = set()
relay = MediaRelay() relay = MediaRelay()
@@ -28,89 +32,68 @@ model = WhisperModel("tiny", device="cpu",
CHANNELS = 2 CHANNELS = 2
RATE = 48000 RATE = 48000
audio_buffer = AudioFifo() audio_buffer = av.AudioFifo()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
transcription_text = "" transcription_text = ""
last_transcribed_time = 0.0 last_transcribed_time = 0.0
LLM_MACHINE_IP = config["DEFAULT"]["LLM_MACHINE_IP"] LLM_MACHINE_IP = CONFIG["LLM"]["LLM_MACHINE_IP"]
LLM_MACHINE_PORT = config["DEFAULT"]["LLM_MACHINE_PORT"] LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"]
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
incremental_responses = [] incremental_responses = []
sorted_transcripts = SortedDict() sorted_transcripts = SortedDict()
blacklisted_messages = [" Thank you.", " See you next time!",
" Thank you for watching!", " Bye!", def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]:
" And that's what I'm talking about."] try:
output = json.loads(response.json()["results"][0]["text"])
return ParseLLMResult(param, output)
except Exception as e:
LOGGER.info("Exception" + str(e))
return None
def get_title_and_summary(llm_input_text, last_timestamp): def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]:
logger.info("Generating title and summary") LOGGER.info("Generating title and summary")
# output = llm.generate(prompt)
# Use monadical-ml to fire this query to an LLM and get result
headers = {
"Content-Type": "application/json"
}
prompt = f"""
### Human:
Create a JSON object as response. The JSON object must have 2 fields:
i) title and ii) summary. For the title field,generate a short title
for the given text. For the summary field, summarize the given text
in three sentences.
{llm_input_text}
### Assistant:
"""
data = {
"prompt": prompt
}
# TODO : Handle unexpected output formats from the model # TODO : Handle unexpected output formats from the model
try: try:
response = requests.post(LLM_URL, headers=headers, json=data) response = requests.post(LLM_URL,
output = json.loads(response.json()["results"][0]["text"]) headers=param.headers,
output["description"] = output.pop("summary") json=param.data)
output["transcript"] = llm_input_text output = parse_llm_output(param, response)
output["timestamp"] = \ if output:
str(datetime.timedelta(seconds=round(last_timestamp))) result = output.get_result()
incremental_responses.append(output) incremental_responses.append(result)
result = { return TitleSummaryOutput(incremental_responses)
"cmd": "UPDATE_TOPICS",
"topics": incremental_responses,
}
except Exception as e: except Exception as e:
logger.info("Exception" + str(e)) LOGGER.info("Exception" + str(e))
result = None return None
return result
def channel_log(channel, t, message): def channel_log(channel, t: str, message: str) -> NoReturn:
logger.info("channel(%s) %s %s" % (channel.label, t, message)) LOGGER.info("channel(%s) %s %s" % (channel.label, t, message))
def channel_send(channel, message): def channel_send(channel, message: str) -> NoReturn:
if channel: if channel:
channel.send(message) channel.send(message)
def channel_send_increment(channel, message): def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn:
if channel and message: if channel and param:
message = param.get_result()
channel.send(json.dumps(message)) channel.send(json.dumps(message))
def channel_send_transcript(channel): def channel_send_transcript(channel) -> NoReturn:
# channel_log(channel, ">", message) # channel_log(channel, ">", message)
if channel: if channel:
try: try:
least_time = sorted_transcripts.keys()[0] least_time = next(iter(sorted_transcripts))
message = sorted_transcripts[least_time] message = sorted_transcripts[least_time].get_result()
if message: if message:
del sorted_transcripts[least_time] del sorted_transcripts[least_time]
if message["text"] not in blacklisted_messages: if message["text"] not in BlackListedMessages.messages:
channel.send(json.dumps(message)) channel.send(json.dumps(message))
# Due to exceptions if one of the earlier batches can't return # Due to exceptions if one of the earlier batches can't return
# a transcript, we don't want to be stuck waiting for the result # a transcript, we don't want to be stuck waiting for the result
@@ -118,27 +101,26 @@ def channel_send_transcript(channel):
else: else:
if len(sorted_transcripts) >= 3: if len(sorted_transcripts) >= 3:
del sorted_transcripts[least_time] del sorted_transcripts[least_time]
except Exception as e: except Exception as exception:
logger.info("Exception", str(e)) LOGGER.info("Exception", str(exception))
pass
def get_transcription(frames): def get_transcription(input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]:
logger.info("Transcribing..") LOGGER.info("Transcribing..")
sorted_transcripts[frames[0].time] = None sorted_transcripts[input_frames.frames[0].time] = None
# TODO: # TODO: Find cleaner way, watch "no transcription" issue below
# Passing IO objects instead of temporary files throws an error # Passing IO objects instead of temporary files throws an error
# Passing ndarrays (typecasted with float) does not give any # Passing ndarray (type casted with float) does not give any
# transcription. Refer issue, # transcription. Refer issue,
# https://github.com/guillaumekln/faster-whisper/issues/369 # https://github.com/guillaumekln/faster-whisper/issues/369
audiofilename = "test" + str(datetime.datetime.now()) audio_file = "test" + str(datetime.datetime.now())
wf = wave.open(audiofilename, "wb") wf = wave.open(audio_file, "wb")
wf.setnchannels(CHANNELS) wf.setnchannels(CHANNELS)
wf.setframerate(RATE) wf.setframerate(RATE)
wf.setsampwidth(2) wf.setsampwidth(2)
for frame in frames: for frame in input_frames.frames:
wf.writeframes(b"".join(frame.to_ndarray())) wf.writeframes(b"".join(frame.to_ndarray()))
wf.close() wf.close()
@@ -146,12 +128,12 @@ def get_transcription(frames):
try: try:
segments, _ = \ segments, _ = \
model.transcribe(audiofilename, model.transcribe(audio_file,
language="en", language="en",
beam_size=5, beam_size=5,
vad_filter=True, vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500)) vad_parameters={"min_silence_duration_ms": 500})
os.remove(audiofilename) os.remove(audio_file)
segments = list(segments) segments = list(segments)
result_text = "" result_text = ""
duration = 0.0 duration = 0.0
@@ -169,34 +151,32 @@ def get_transcription(frames):
last_transcribed_time += duration last_transcribed_time += duration
transcription_text += result_text transcription_text += result_text
except Exception as e: except Exception as exception:
logger.info("Exception" + str(e)) LOGGER.info("Exception" + str(exception))
pass
result = { result = TranscriptionOutput(result_text)
"cmd": "SHOW_TRANSCRIPTION", sorted_transcripts[input_frames.frames[0].time] = result
"text": result_text
}
sorted_transcripts[frames[0].time] = result
return result return result
def get_final_summary_response(): def get_final_summary_response() -> FinalSummaryResult:
"""
Collate the incremental summaries generated so far and return as the final
summary
:return:
"""
final_summary = "" final_summary = ""
# Collate inc summaries # Collate inc summaries
for topic in incremental_responses: for topic in incremental_responses:
final_summary += topic["description"] final_summary += topic["description"]
response = { response = FinalSummaryResult(final_summary, last_transcribed_time)
"cmd": "DISPLAY_FINAL_SUMMARY",
"duration": str(datetime.timedelta( with open("./artefacts/meeting_titles_and_summaries.txt", "a",
seconds=round(last_transcribed_time))), encoding="utf-8") as file:
"summary": final_summary file.write(json.dumps(incremental_responses))
}
with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f:
f.write(json.dumps(incremental_responses))
return response return response
@@ -211,14 +191,16 @@ class AudioStreamTrack(MediaStreamTrack):
super().__init__() super().__init__()
self.track = track self.track = track
async def recv(self): async def recv(self) -> av.audio.frame.AudioFrame:
global transcription_text global transcription_text
frame = await self.track.recv() frame = await self.track.recv()
audio_buffer.write(frame) audio_buffer.write(frame)
if local_frames := audio_buffer.read_many(256 * 960, partial=False): if local_frames := audio_buffer.read_many(256 * 960, partial=False):
whisper_result = run_in_executor( whisper_result = run_in_executor(
get_transcription, local_frames, executor=executor get_transcription,
TranscriptionInput(local_frames),
executor=executor
) )
whisper_result.add_done_callback( whisper_result.add_done_callback(
lambda f: channel_send_transcript(data_channel) lambda f: channel_send_transcript(data_channel)
@@ -226,12 +208,13 @@ class AudioStreamTrack(MediaStreamTrack):
else None else None
) )
if len(transcription_text) > 750: if len(transcription_text) > 25:
llm_input_text = transcription_text llm_input_text = transcription_text
transcription_text = "" transcription_text = ""
param = TitleSummaryInput(input_text=llm_input_text,
transcribed_time=last_transcribed_time)
llm_result = run_in_executor(get_title_and_summary, llm_result = run_in_executor(get_title_and_summary,
llm_input_text, param,
last_transcribed_time,
executor=executor) executor=executor)
llm_result.add_done_callback( llm_result.add_done_callback(
lambda f: channel_send_increment(data_channel, lambda f: channel_send_increment(data_channel,
@@ -242,7 +225,12 @@ class AudioStreamTrack(MediaStreamTrack):
return frame return frame
async def offer(request): async def offer(request: requests.Request) -> web.Response:
"""
Establish the WebRTC connection with the client
:param request:
:return:
"""
params = await request.json() params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
@@ -250,40 +238,39 @@ async def offer(request):
pc_id = "PeerConnection(%s)" % uuid.uuid4() pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc) pcs.add(pc)
def log_info(msg, *args): def log_info(msg, *args) -> NoReturn:
logger.info(pc_id + " " + msg, *args) LOGGER.info(pc_id + " " + msg, *args)
log_info("Created for " + request.remote) log_info("Created for " + request.remote)
@pc.on("datachannel") @pc.on("datachannel")
def on_datachannel(channel): def on_datachannel(channel) -> NoReturn:
global data_channel global data_channel
data_channel = channel data_channel = channel
channel_log(channel, "-", "created by remote party") channel_log(channel, "-", "created by remote party")
@channel.on("message") @channel.on("message")
def on_message(message): def on_message(message: str) -> NoReturn:
channel_log(channel, "<", message) channel_log(channel, "<", message)
if json.loads(message)["cmd"] == "STOP": if json.loads(message)["cmd"] == "STOP":
# Place holder final summary # Placeholder final summary
response = get_final_summary_response() response = get_final_summary_response()
channel_send_increment(data_channel, response) channel_send_increment(data_channel, response)
# To-do Add code to stop connection from server side here # To-do Add code to stop connection from server side here
# But have to handshake with client once # But have to handshake with client once
# pc.close()
if isinstance(message, str) and message.startswith("ping"): if isinstance(message, str) and message.startswith("ping"):
channel_send(channel, "pong" + message[4:]) channel_send(channel, "pong" + message[4:])
@pc.on("connectionstatechange") @pc.on("connectionstatechange")
async def on_connectionstatechange(): async def on_connectionstatechange() -> NoReturn:
log_info("Connection state is " + pc.connectionState) log_info("Connection state is " + pc.connectionState)
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track) -> NoReturn:
log_info("Track " + track.kind + " received") log_info("Track " + track.kind + " received")
pc.addTrack(AudioStreamTrack(relay.subscribe(track))) pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
@@ -294,15 +281,17 @@ async def offer(request):
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
text=json.dumps( text=json.dumps(
{"sdp": pc.localDescription.sdp, {
"type": pc.localDescription.type} "sdp": pc.localDescription.sdp,
"type": pc.localDescription.type
}
), ),
) )
async def on_shutdown(app): async def on_shutdown(application: web.Application) -> NoReturn:
coros = [pc.close() for pc in pcs] coroutines = [pc.close() for pc in pcs]
await asyncio.gather(*coros) await asyncio.gather(*coroutines)
pcs.clear() pcs.clear()

View File

@@ -9,8 +9,8 @@ import stamina
from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc import (RTCPeerConnection, RTCSessionDescription)
from aiortc.contrib.media import (MediaPlayer, MediaRelay) from aiortc.contrib.media import (MediaPlayer, MediaRelay)
from utils.log_utils import logger from utils.log_utils import LOGGER
from utils.run_utils import config from utils.run_utils import CONFIG
class StreamClient: class StreamClient:
@@ -35,7 +35,7 @@ class StreamClient:
self.time_start = None self.time_start = None
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.player = MediaPlayer( self.player = MediaPlayer(
':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]), ':' + str(CONFIG['AUDIO']["AV_FOUNDATION_DEVICE_ID"]),
format='avfoundation', format='avfoundation',
options={'channels': '2'}) options={'channels': '2'})
@@ -74,7 +74,7 @@ class StreamClient:
self.pcs.add(pc) self.pcs.add(pc)
def log_info(msg, *args): def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args) LOGGER.info(pc_id + " " + msg, *args)
@pc.on("connectionstatechange") @pc.on("connectionstatechange")
async def on_connectionstatechange(): async def on_connectionstatechange():

View File

@@ -93,6 +93,6 @@ def generate_finetuning_dataset(video_ids):
video_ids = ["yTnSEZIwnkU"] video_ids = ["yTnSEZIwnkU"]
dataset = generate_finetuning_dataset(video_ids) dataset = generate_finetuning_dataset(video_ids)
with open("finetuning_dataset.jsonl", "w") as f: with open("finetuning_dataset.jsonl", "w", encoding="utf-8") as file:
for example in dataset: for example in dataset:
f.write(json.dumps(example) + "\n") file.write(json.dumps(example) + "\n")

View File

@@ -16,10 +16,10 @@ from av import AudioFifo
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from reflector.utils.log_utils import logger from reflector.utils.log_utils import LOGGER
from reflector.utils.run_utils import config, Mutex from reflector.utils.run_utils import CONFIG, Mutex
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"] WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_REAL_TIME_MODEL_SIZE"]
pcs = set() pcs = set()
relay = MediaRelay() relay = MediaRelay()
data_channel = None data_channel = None
@@ -127,7 +127,7 @@ async def offer(request: requests.Request):
pcs.add(pc) pcs.add(pc)
def log_info(msg: str, *args): def log_info(msg: str, *args):
logger.info(pc_id + " " + msg, *args) LOGGER.info(pc_id + " " + msg, *args)
log_info("Created for " + request.remote) log_info("Created for " + request.remote)

View File

@@ -3,14 +3,14 @@ import sys
# Observe the incremental summaries by performing summaries in chunks # Observe the incremental summaries by performing summaries in chunks
with open("transcript.txt") as f: with open("transcript.txt", "r", encoding="utf-8") as file:
transcription = f.read() transcription = file.read()
def split_text_file(filename, token_count): def split_text_file(filename, token_count):
nlp = spacy.load('en_core_web_md') nlp = spacy.load('en_core_web_md')
with open(filename, 'r') as file: with open(filename, 'r', encoding="utf-8") as file:
text = file.read() text = file.read()
doc = nlp(text) doc = nlp(text)
@@ -36,9 +36,9 @@ chunks = split_text_file("transcript.txt", MAX_CHUNK_LENGTH)
print("Number of chunks", len(chunks)) print("Number of chunks", len(chunks))
# Write chunks to file to refer to input vs output, separated by blank lines # Write chunks to file to refer to input vs output, separated by blank lines
with open("chunks" + str(MAX_CHUNK_LENGTH) + ".txt", "a") as f: with open("chunks" + str(MAX_CHUNK_LENGTH) + ".txt", "a", encoding="utf-8") as file:
for c in chunks: for c in chunks:
f.write(c + "\n\n") file.write(c + "\n\n")
# If we want to run only a certain model, type the option while running # If we want to run only a certain model, type the option while running
# ex. python incsum.py 1 => will run approach 1 # ex. python incsum.py 1 => will run approach 1
@@ -78,9 +78,9 @@ if index == "1" or index is None:
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(summary) summaries.append(summary)
with open("bart-summaries.txt", "a") as f: with open("bart-summaries.txt", "a", encoding="utf-8") as file:
for summary in summaries: for summary in summaries:
f.write(summary + "\n\n") file.write(summary + "\n\n")
# Approach 2 # Approach 2
if index == "2" or index is None: if index == "2" or index is None:
@@ -114,8 +114,8 @@ if index == "2" or index is None:
summary_ids = output[0, input_length:] summary_ids = output[0, input_length:]
summary = tokenizer.decode(summary_ids, skip_special_tokens=True) summary = tokenizer.decode(summary_ids, skip_special_tokens=True)
summaries.append(summary) summaries.append(summary)
with open("gptneo1.3B-summaries.txt", "a") as f: with open("gptneo1.3B-summaries.txt", "a", encoding="utf-8") as file:
f.write(summary + "\n\n") file.write(summary + "\n\n")
# Approach 3 # Approach 3
if index == "3" or index is None: if index == "3" or index is None:
@@ -152,6 +152,6 @@ if index == "3" or index is None:
skip_special_tokens=True) skip_special_tokens=True)
summaries.append(summary) summaries.append(summary)
with open("mpt-7b-summaries.txt", "a") as f: with open("mpt-7b-summaries.txt", "a", encoding="utf-8") as file:
for summary in summaries: for summary in summaries:
f.write(summary + "\n\n") file.write(summary + "\n\n")

View File

@@ -19,15 +19,15 @@ import yt_dlp as youtube_dl
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from ...utils.file_utils import download_files, upload_files from ...utils.file_utils import download_files, upload_files
from ...utils.log_utils import logger from ...utils.log_utils import LOGGER
from ...utils.run_utils import config from ...utils.run_utils import CONFIG
from ...utils.text_utils import post_process_transcription, summarize from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True) nltk.download('stopwords', quiet=True)
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_MODEL_SIZE"]
NOW = datetime.now() NOW = datetime.now()
if not os.path.exists('../../artefacts'): if not os.path.exists('../../artefacts'):
@@ -75,7 +75,7 @@ def main():
# Download the lowest resolution YouTube video # Download the lowest resolution YouTube video
# (since we're just interested in the audio). # (since we're just interested in the audio).
# It will be saved to the current directory. # It will be saved to the current directory.
logger.info("Downloading YouTube video at url: " + args.location) LOGGER.info("Downloading YouTube video at url: " + args.location)
# Create options for the download # Create options for the download
ydl_opts = { ydl_opts = {
@@ -93,12 +93,12 @@ def main():
ydl.download([args.location]) ydl.download([args.location])
media_file = "../artefacts/audio.mp3" media_file = "../artefacts/audio.mp3"
logger.info("Saved downloaded YouTube video to: " + media_file) LOGGER.info("Saved downloaded YouTube video to: " + media_file)
else: else:
# XXX - Download file using urllib, check if file is # XXX - Download file using urllib, check if file is
# audio/video using python-magic # audio/video using python-magic
logger.info(f"Downloading file at url: {args.location}") LOGGER.info(f"Downloading file at url: {args.location}")
logger.info(" XXX - This method hasn't been implemented yet.") LOGGER.info(" XXX - This method hasn't been implemented yet.")
elif url.scheme == '': elif url.scheme == '':
media_file = url.path media_file = url.path
# If file is not present locally, take it from S3 bucket # If file is not present locally, take it from S3 bucket
@@ -119,7 +119,7 @@ def main():
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3",
delete=False).name delete=False).name
video.audio.write_audiofile(audio_filename, logger=None) video.audio.write_audiofile(audio_filename, logger=None)
logger.info(f"Extracting audio to: {audio_filename}") LOGGER.info(f"Extracting audio to: {audio_filename}")
# Handle audio only file # Handle audio only file
except Exception: except Exception:
audio = moviepy.editor.AudioFileClip(media_file) audio = moviepy.editor.AudioFileClip(media_file)
@@ -129,14 +129,14 @@ def main():
else: else:
audio_filename = media_file audio_filename = media_file
logger.info("Finished extracting audio") LOGGER.info("Finished extracting audio")
logger.info("Transcribing") LOGGER.info("Transcribing")
# Convert the audio to text using the OpenAI Whisper model # Convert the audio to text using the OpenAI Whisper model
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
dtype=jnp.float16, dtype=jnp.float16,
batch_size=16) batch_size=16)
whisper_result = pipeline(audio_filename, return_timestamps=True) whisper_result = pipeline(audio_filename, return_timestamps=True)
logger.info("Finished transcribing file") LOGGER.info("Finished transcribing file")
whisper_result = post_process_transcription(whisper_result) whisper_result = post_process_transcription(whisper_result)
@@ -153,10 +153,10 @@ def main():
"w") as transcript_file_timestamps: "w") as transcript_file_timestamps:
transcript_file_timestamps.write(str(whisper_result)) transcript_file_timestamps.write(str(whisper_result))
logger.info("Creating word cloud") LOGGER.info("Creating word cloud")
create_wordcloud(NOW) create_wordcloud(NOW)
logger.info("Performing talk-diff and talk-diff visualization") LOGGER.info("Performing talk-diff and talk-diff visualization")
create_talk_diff_scatter_viz(NOW) create_talk_diff_scatter_viz(NOW)
# S3 : Push artefacts to S3 bucket # S3 : Push artefacts to S3 bucket
@@ -172,7 +172,7 @@ def main():
summarize(transcript_text, NOW, False, False) summarize(transcript_text, NOW, False, False)
logger.info("Summarization completed") LOGGER.info("Summarization completed")
# Summarization takes a lot of time, so do this separately at the end # Summarization takes a lot of time, so do this separately at the end
files_to_upload = [prefix + "summary_" + suffix + ".txt"] files_to_upload = [prefix + "summary_" + suffix + ".txt"]

View File

@@ -11,12 +11,12 @@ from termcolor import colored
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from ...utils.file_utils import upload_files from ...utils.file_utils import upload_files
from ...utils.log_utils import logger from ...utils.log_utils import LOGGER
from ...utils.run_utils import config from ...utils.run_utils import CONFIG
from ...utils.text_utils import post_process_transcription, summarize from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_MODEL_SIZE"]
FRAMES_PER_BUFFER = 8000 FRAMES_PER_BUFFER = 8000
FORMAT = pyaudio.paInt16 FORMAT = pyaudio.paInt16
@@ -31,7 +31,7 @@ def main():
AUDIO_DEVICE_ID = -1 AUDIO_DEVICE_ID = -1
for i in range(p.get_device_count()): for i in range(p.get_device_count()):
if p.get_device_info_by_index(i)["name"] == \ if p.get_device_info_by_index(i)["name"] == \
config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: CONFIG["AUDIO"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]:
AUDIO_DEVICE_ID = i AUDIO_DEVICE_ID = i
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
stream = p.open( stream = p.open(
@@ -44,7 +44,7 @@ def main():
) )
pipeline = FlaxWhisperPipline("openai/whisper-" + pipeline = FlaxWhisperPipline("openai/whisper-" +
config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], CONFIG["WHISPER"]["WHISPER_REAL_TIME_MODEL_SIZE"],
dtype=jnp.float16, dtype=jnp.float16,
batch_size=16) batch_size=16)
@@ -106,23 +106,26 @@ def main():
" | Transcribed duration: " + " | Transcribed duration: " +
str(duration), "yellow")) str(duration), "yellow"))
except Exception as e: except Exception as exception:
print(e) print(str(exception))
finally: finally:
with open("real_time_transcript_" + with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S")
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: + ".txt", "w", encoding="utf-8") as file:
f.write(transcription) file.write(transcription)
with open("real_time_transcript_with_timestamp_" + with open("real_time_transcript_with_timestamp_" +
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w",
encoding="utf-8") as file:
transcript_with_timestamp["text"] = transcription transcript_with_timestamp["text"] = transcription
f.write(str(transcript_with_timestamp)) file.write(str(transcript_with_timestamp))
transcript_with_timestamp = post_process_transcription(transcript_with_timestamp) transcript_with_timestamp = \
post_process_transcription(transcript_with_timestamp)
logger.info("Creating word cloud") LOGGER.info("Creating word cloud")
create_wordcloud(NOW, True) create_wordcloud(NOW, True)
logger.info("Performing talk-diff and talk-diff visualization") LOGGER.info("Performing talk-diff and talk-diff visualization")
create_talk_diff_scatter_viz(NOW, True) create_talk_diff_scatter_viz(NOW, True)
# S3 : Push artefacts to S3 bucket # S3 : Push artefacts to S3 bucket
@@ -137,7 +140,7 @@ def main():
summarize(transcript_with_timestamp["text"], NOW, True, True) summarize(transcript_with_timestamp["text"], NOW, True, True)
logger.info("Summarization completed") LOGGER.info("Summarization completed")
# Summarization takes a lot of time, so do this separately at the end # Summarization takes a lot of time, so do this separately at the end
files_to_upload = ["real_time_summary_" + suffix + ".txt"] files_to_upload = ["real_time_summary_" + suffix + ".txt"]

View File

@@ -1,16 +1,21 @@
"""
Utility file for file handling related functions, including file downloads and
uploads to cloud storage
"""
import sys import sys
import boto3 import boto3
import botocore import botocore
from .log_utils import logger from .log_utils import LOGGER
from .run_utils import config from .run_utils import CONFIG
BUCKET_NAME = config["DEFAULT"]["BUCKET_NAME"] BUCKET_NAME = CONFIG["AWS"]["BUCKET_NAME"]
s3 = boto3.client('s3', s3 = boto3.client('s3',
aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"], aws_access_key_id=CONFIG["AWS"]["AWS_ACCESS_KEY"],
aws_secret_access_key=config["DEFAULT"]["AWS_SECRET_KEY"]) aws_secret_access_key=CONFIG["AWS"]["AWS_SECRET_KEY"])
def upload_files(files_to_upload): def upload_files(files_to_upload):
@@ -19,12 +24,12 @@ def upload_files(files_to_upload):
:param files_to_upload: List of files to upload :param files_to_upload: List of files to upload
:return: None :return: None
""" """
for KEY in files_to_upload: for key in files_to_upload:
logger.info("Uploading file " + KEY) LOGGER.info("Uploading file " + key)
try: try:
s3.upload_file(KEY, BUCKET_NAME, KEY) s3.upload_file(key, BUCKET_NAME, key)
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as exception:
print(e.response) print(exception.response)
def download_files(files_to_download): def download_files(files_to_download):
@@ -33,12 +38,12 @@ def download_files(files_to_download):
:param files_to_download: List of files to download :param files_to_download: List of files to download
:return: None :return: None
""" """
for KEY in files_to_download: for key in files_to_download:
logger.info("Downloading file " + KEY) LOGGER.info("Downloading file " + key)
try: try:
s3.download_file(BUCKET_NAME, KEY, KEY) s3.download_file(BUCKET_NAME, key, key)
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as exception:
if e.response['Error']['Code'] == "404": if exception.response['Error']['Code'] == "404":
print("The object does not exist.") print("The object does not exist.")
else: else:
raise raise

View File

@@ -1,13 +1,24 @@
"""
Utility function to format the artefacts created during Reflector run
"""
import json import json
with open("../artefacts/meeting_titles_and_summaries.txt", "r") as f: with open("../artefacts/meeting_titles_and_summaries.txt", "r",
encoding='utf-8') as f:
outputs = f.read() outputs = f.read()
outputs = json.loads(outputs) outputs = json.loads(outputs)
transcript_file = open("../artefacts/meeting_transcript.txt", "a") transcript_file = open("../artefacts/meeting_transcript.txt",
title_desc_file = open("../artefacts/meeting_title_description.txt", "a") "a",
summary_file = open("../artefacts/meeting_summary.txt", "a") encoding='utf-8')
title_desc_file = open("../artefacts/meeting_title_description.txt",
"a",
encoding='utf-8')
summary_file = open("../artefacts/meeting_summary.txt",
"a",
encoding='utf-8')
for item in outputs["topics"]: for item in outputs["topics"]:
transcript_file.write(item["transcript"]) transcript_file.write(item["transcript"])

View File

@@ -1,7 +1,15 @@
"""
Utility file for logging
"""
import loguru import loguru
class SingletonLogger: class SingletonLogger:
"""
Use Singleton design pattern to create a logger object and share it
across the entire project
"""
__instance = None __instance = None
@staticmethod @staticmethod
@@ -15,4 +23,4 @@ class SingletonLogger:
return SingletonLogger.__instance return SingletonLogger.__instance
logger = SingletonLogger.get_logger() LOGGER = SingletonLogger.get_logger()

View File

@@ -1,3 +1,7 @@
"""
Utility file for server side asynchronous task running and config objects
"""
import asyncio import asyncio
import configparser import configparser
import contextlib import contextlib
@@ -7,6 +11,9 @@ from typing import ContextManager, Generic, TypeVar
class ReflectorConfig: class ReflectorConfig:
"""
Create a single config object to share across the project
"""
__config = None __config = None
@staticmethod @staticmethod
@@ -17,7 +24,7 @@ class ReflectorConfig:
return ReflectorConfig.__config return ReflectorConfig.__config
config = ReflectorConfig.get_config() CONFIG = ReflectorConfig.get_config()
def run_in_executor(func, *args, executor=None, **kwargs): def run_in_executor(func, *args, executor=None, **kwargs):

View File

@@ -1,3 +1,7 @@
"""
Utility file for all text processing related functionalities
"""
import nltk import nltk
import torch import torch
from nltk.corpus import stopwords from nltk.corpus import stopwords
@@ -6,8 +10,8 @@ from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
from transformers import BartForConditionalGeneration, BartTokenizer from transformers import BartForConditionalGeneration, BartTokenizer
from log_utils import logger from log_utils import LOGGER
from run_utils import config from run_utils import CONFIG
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
@@ -32,6 +36,12 @@ def compute_similarity(sent1, sent2):
def remove_almost_alike_sentences(sentences, threshold=0.7): def remove_almost_alike_sentences(sentences, threshold=0.7):
"""
Filter sentences that are similar beyond a set threshold
:param sentences:
:param threshold:
:return:
"""
num_sentences = len(sentences) num_sentences = len(sentences)
removed_indices = set() removed_indices = set()
@@ -62,6 +72,11 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
def remove_outright_duplicate_sentences_from_chunk(chunk): def remove_outright_duplicate_sentences_from_chunk(chunk):
"""
Remove repetitive sentences
:param chunk:
:return:
"""
chunk_text = chunk["text"] chunk_text = chunk["text"]
sentences = nltk.sent_tokenize(chunk_text) sentences = nltk.sent_tokenize(chunk_text)
nonduplicate_sentences = list(dict.fromkeys(sentences)) nonduplicate_sentences = list(dict.fromkeys(sentences))
@@ -69,6 +84,12 @@ def remove_outright_duplicate_sentences_from_chunk(chunk):
def remove_whisper_repetitive_hallucination(nonduplicate_sentences): def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
"""
Remove sentences that are repeated as a result of Whisper
hallucinations
:param nonduplicate_sentences:
:return:
"""
chunk_sentences = [] chunk_sentences = []
for sent in nonduplicate_sentences: for sent in nonduplicate_sentences:
@@ -91,6 +112,11 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
def post_process_transcription(whisper_result): def post_process_transcription(whisper_result):
"""
Parent function to perform post-processing on the transcription result
:param whisper_result:
:return:
"""
transcript_text = "" transcript_text = ""
for chunk in whisper_result["chunks"]: for chunk in whisper_result["chunks"]:
nonduplicate_sentences = \ nonduplicate_sentences = \
@@ -121,9 +147,9 @@ def summarize_chunks(chunks, tokenizer, model):
with torch.no_grad(): with torch.no_grad():
summary_ids = \ summary_ids = \
model.generate(input_ids, model.generate(input_ids,
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
length_penalty=2.0, length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
early_stopping=True) early_stopping=True)
summary = tokenizer.decode(summary_ids[0], summary = tokenizer.decode(summary_ids[0],
skip_special_tokens=True) skip_special_tokens=True)
@@ -132,7 +158,7 @@ def summarize_chunks(chunks, tokenizer, model):
def chunk_text(text, def chunk_text(text,
max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): max_chunk_length=int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])):
""" """
Split text into smaller chunks. Split text into smaller chunks.
:param text: Text to be chunked :param text: Text to be chunked
@@ -154,14 +180,22 @@ def chunk_text(text,
def summarize(transcript_text, timestamp, def summarize(transcript_text, timestamp,
real_time=False, real_time=False,
chunk_summarize=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): chunk_summarize=CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
"""
Summarize the given text either as a whole or as chunks as needed
:param transcript_text:
:param timestamp:
:param real_time:
:param chunk_summarize:
:return:
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["DEFAULT"]["SUMMARY_MODEL"] summary_model = CONFIG["SUMMARIZER"]["SUMMARY_MODEL"]
if not summary_model: if not summary_model:
summary_model = "facebook/bart-large-cnn" summary_model = "facebook/bart-large-cnn"
# Summarize the generated transcript using the BART model # Summarize the generated transcript using the BART model
logger.info(f"Loading BART model: {summary_model}") LOGGER.info(f"Loading BART model: {summary_model}")
tokenizer = BartTokenizer.from_pretrained(summary_model) tokenizer = BartTokenizer.from_pretrained(summary_model)
model = BartForConditionalGeneration.from_pretrained(summary_model) model = BartForConditionalGeneration.from_pretrained(summary_model)
model = model.to(device) model = model.to(device)
@@ -171,7 +205,7 @@ def summarize(transcript_text, timestamp,
output_file = "real_time_" + output_file output_file = "real_time_" + output_file
if chunk_summarize != "YES": if chunk_summarize != "YES":
max_length = int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]) max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer. \ inputs = tokenizer. \
batch_encode_plus([transcript_text], truncation=True, batch_encode_plus([transcript_text], truncation=True,
padding='longest', padding='longest',
@@ -180,8 +214,8 @@ def summarize(transcript_text, timestamp,
inputs = inputs.to(device) inputs = inputs.to(device)
with torch.no_grad(): with torch.no_grad():
num_beans = int(config["DEFAULT"]["BEAM_SIZE"]) num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
max_length = int(config["DEFAULT"]["MAX_LENGTH"]) max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
summaries = model.generate(inputs['input_ids'], summaries = model.generate(inputs['input_ids'],
num_beams=num_beans, num_beams=num_beans,
length_penalty=2.0, length_penalty=2.0,
@@ -194,16 +228,16 @@ def summarize(transcript_text, timestamp,
clean_up_tokenization_spaces=False) clean_up_tokenization_spaces=False)
for summary in summaries] for summary in summaries]
summary = " ".join(decoded_summaries) summary = " ".join(decoded_summaries)
with open("./artefacts/" + output_file, 'w') as f: with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file:
f.write(summary.strip() + "\n") file.write(summary.strip() + "\n")
else: else:
logger.info("Breaking transcript into smaller chunks") LOGGER.info("Breaking transcript into smaller chunks")
chunks = chunk_text(transcript_text) chunks = chunk_text(transcript_text)
logger.info(f"Transcript broken into {len(chunks)} " LOGGER.info(f"Transcript broken into {len(chunks)} "
f"chunks of at most 500 words") f"chunks of at most 500 words")
logger.info(f"Writing summary text to: {output_file}") LOGGER.info(f"Writing summary text to: {output_file}")
with open(output_file, 'w') as f: with open(output_file, 'w') as f:
summaries = summarize_chunks(chunks, tokenizer, model) summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries: for summary in summaries:

View File

@@ -1,3 +1,7 @@
"""
Utility file for all visualization related functions
"""
import ast import ast
import collections import collections
import os import os
@@ -81,8 +85,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
else: else:
filename = "./artefacts/transcript_with_timestamp_" + \ filename = "./artefacts/transcript_with_timestamp_" + \
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
with open(filename) as f: with open(filename) as file:
transcription_timestamp_text = f.read() transcription_timestamp_text = file.read()
res = ast.literal_eval(transcription_timestamp_text) res = ast.literal_eval(transcription_timestamp_text)
chunks = res["chunks"] chunks = res["chunks"]