code style updates

This commit is contained in:
Gokul Mohanarangan
2023-07-26 09:59:25 +05:30
parent b892fc0562
commit c970fc89dd
8 changed files with 54 additions and 56 deletions

View File

@@ -31,7 +31,7 @@ class TitleSummaryInput:
@dataclass @dataclass
class IncrementalResponse: class IncrementalResult:
title = str title = str
description = str description = str
transcript = str transcript = str
@@ -45,12 +45,12 @@ class IncrementalResponse:
@dataclass @dataclass
class TitleSummaryOutput: class TitleSummaryOutput:
cmd = str cmd = str
topics = List[IncrementalResponse] topics = List[IncrementalResult]
def __init__(self, inc_responses): def __init__(self, inc_responses):
self.topics = inc_responses self.topics = inc_responses
def get_response(self): def get_result(self):
return { return {
"cmd": self.cmd, "cmd": self.cmd,
"topics": self.topics "topics": self.topics
@@ -93,7 +93,7 @@ class TranscriptionOutput:
self.cmd = "SHOW_TRANSCRIPTION" self.cmd = "SHOW_TRANSCRIPTION"
self.result_text = result_text self.result_text = result_text
def get_response(self): def get_result(self):
return { return {
"cmd": self.cmd, "cmd": self.cmd,
"text": self.result_text "text": self.result_text
@@ -101,7 +101,7 @@ class TranscriptionOutput:
@dataclass @dataclass
class FinalSummaryResponse: class FinalSummaryResult:
cmd = str cmd = str
final_summary = str final_summary = str
duration = str duration = str
@@ -111,7 +111,7 @@ class FinalSummaryResponse:
self.final_summary = final_summary self.final_summary = final_summary
self.cmd = "" self.cmd = ""
def get_response(self): def get_result(self):
return { return {
"cmd": self.cmd, "cmd": self.cmd,
"duration": self.duration, "duration": self.duration,

View File

@@ -6,20 +6,21 @@ import os
import uuid import uuid
import wave import wave
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any, NoReturn
import aiohttp_cors import aiohttp_cors
import av
import requests import requests
from aiohttp import web from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay from aiortc.contrib.media import MediaRelay
from av import AudioFifo
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from loguru import logger from loguru import logger
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from reflector_dataclasses import FinalSummaryResponse, ParseLLMResult, TitleSummaryInput, TitleSummaryOutput, \ from reflector_dataclasses import FinalSummaryResult, ParseLLMResult,\
TranscriptionInput, TranscriptionOutput TitleSummaryInput, TitleSummaryOutput, TranscriptionInput,\
TranscriptionOutput
from utils.run_utils import config, run_in_executor from utils.run_utils import config, run_in_executor
pcs = set() pcs = set()
@@ -31,25 +32,21 @@ model = WhisperModel("tiny", device="cpu",
CHANNELS = 2 CHANNELS = 2
RATE = 48000 RATE = 48000
audio_buffer = AudioFifo() audio_buffer = av.AudioFifo()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
transcription_text = "" transcription_text = ""
last_transcribed_time = 0.0 last_transcribed_time = 0.0
LLM_MACHINE_IP = config["DEFAULT"]["LLM_MACHINE_IP"] LLM_MACHINE_IP = config["LLM"]["LLM_MACHINE_IP"]
LLM_MACHINE_PORT = config["DEFAULT"]["LLM_MACHINE_PORT"] LLM_MACHINE_PORT = config["LLM"]["LLM_MACHINE_PORT"]
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
incremental_responses = [] incremental_responses = []
sorted_transcripts = SortedDict() sorted_transcripts = SortedDict()
blacklisted_messages = [" Thank you.", " See you next time!",
" Thank you for watching!", " Bye!",
" And that's what I'm talking about."]
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Any[None, ParseLLMResult]: def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Any[None, ParseLLMResult]:
try: try:
output = json.loads(response.json()["results"][0]["text"]) output = json.loads(response.json()["results"][0]["text"])
return ParseLLMResult(param, output).get_result() return ParseLLMResult(param, output)
except Exception as e: except Exception as e:
logger.info("Exception" + str(e)) logger.info("Exception" + str(e))
return None return None
@@ -65,33 +62,35 @@ def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOut
json=param.data) json=param.data)
output = parse_llm_output(param, response) output = parse_llm_output(param, response)
if output: if output:
incremental_responses.append(output) result = output.get_result()
return TitleSummaryOutput(incremental_responses).get_response() incremental_responses.append(result)
return TitleSummaryOutput(incremental_responses)
except Exception as e: except Exception as e:
logger.info("Exception" + str(e)) logger.info("Exception" + str(e))
return None return None
def channel_log(channel, t, message): def channel_log(channel, t: str, message: str) -> NoReturn:
logger.info("channel(%s) %s %s" % (channel.label, t, message)) logger.info("channel(%s) %s %s" % (channel.label, t, message))
def channel_send(channel, message): def channel_send(channel, message: str) -> NoReturn:
if channel: if channel:
channel.send(message) channel.send(message)
def channel_send_increment(channel, message): def channel_send_increment(channel, param: Any[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn:
if channel and message: if channel and param:
message = param.get_result()
channel.send(json.dumps(message)) channel.send(json.dumps(message))
def channel_send_transcript(channel): def channel_send_transcript(channel) -> NoReturn:
# channel_log(channel, ">", message) # channel_log(channel, ">", message)
if channel: if channel:
try: try:
least_time = sorted_transcripts.keys()[0] least_time = sorted_transcripts.keys()[0]
message = sorted_transcripts[least_time] message = sorted_transcripts[least_time].get_result()
if message: if message:
del sorted_transcripts[least_time] del sorted_transcripts[least_time]
if message["text"] not in blacklisted_messages: if message["text"] not in blacklisted_messages:
@@ -157,19 +156,19 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti
logger.info("Exception" + str(e)) logger.info("Exception" + str(e))
pass pass
result = TranscriptionOutput(result_text).get_response() result = TranscriptionOutput(result_text)
sorted_transcripts[input_frames.frames[0].time] = result sorted_transcripts[input_frames.frames[0].time] = result
return result return result
def get_final_summary_response() -> Any[None, FinalSummaryResponse]: def get_final_summary_response() -> FinalSummaryResult:
final_summary = "" final_summary = ""
# Collate inc summaries # Collate inc summaries
for topic in incremental_responses: for topic in incremental_responses:
final_summary += topic["description"] final_summary += topic["description"]
response = FinalSummaryResponse(final_summary, last_transcribed_time).get_response() response = FinalSummaryResult(final_summary, last_transcribed_time)
with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f: with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f:
f.write(json.dumps(incremental_responses)) f.write(json.dumps(incremental_responses))
@@ -188,7 +187,7 @@ class AudioStreamTrack(MediaStreamTrack):
super().__init__() super().__init__()
self.track = track self.track = track
async def recv(self): async def recv(self) -> av.audio.frame.AudioFrame:
global transcription_text global transcription_text
frame = await self.track.recv() frame = await self.track.recv()
audio_buffer.write(frame) audio_buffer.write(frame)
@@ -222,7 +221,7 @@ class AudioStreamTrack(MediaStreamTrack):
return frame return frame
async def offer(request): async def offer(request: requests.Request) -> web.Response:
params = await request.json() params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
@@ -230,40 +229,39 @@ async def offer(request):
pc_id = "PeerConnection(%s)" % uuid.uuid4() pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc) pcs.add(pc)
def log_info(msg, *args): def log_info(msg, *args) -> NoReturn:
logger.info(pc_id + " " + msg, *args) logger.info(pc_id + " " + msg, *args)
log_info("Created for " + request.remote) log_info("Created for " + request.remote)
@pc.on("datachannel") @pc.on("datachannel")
def on_datachannel(channel): def on_datachannel(channel) -> NoReturn:
global data_channel global data_channel
data_channel = channel data_channel = channel
channel_log(channel, "-", "created by remote party") channel_log(channel, "-", "created by remote party")
@channel.on("message") @channel.on("message")
def on_message(message): def on_message(message: str) -> NoReturn:
channel_log(channel, "<", message) channel_log(channel, "<", message)
if json.loads(message)["cmd"] == "STOP": if json.loads(message)["cmd"] == "STOP":
# Place holder final summary # Placeholder final summary
response = get_final_summary_response() response = get_final_summary_response()
channel_send_increment(data_channel, response) channel_send_increment(data_channel, response)
# To-do Add code to stop connection from server side here # To-do Add code to stop connection from server side here
# But have to handshake with client once # But have to handshake with client once
# pc.close()
if isinstance(message, str) and message.startswith("ping"): if isinstance(message, str) and message.startswith("ping"):
channel_send(channel, "pong" + message[4:]) channel_send(channel, "pong" + message[4:])
@pc.on("connectionstatechange") @pc.on("connectionstatechange")
async def on_connectionstatechange(): async def on_connectionstatechange() -> NoReturn:
log_info("Connection state is " + pc.connectionState) log_info("Connection state is " + pc.connectionState)
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track) -> NoReturn:
log_info("Track " + track.kind + " received") log_info("Track " + track.kind + " received")
pc.addTrack(AudioStreamTrack(relay.subscribe(track))) pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
@@ -280,7 +278,7 @@ async def offer(request):
) )
async def on_shutdown(app): async def on_shutdown(app) -> NoReturn:
coros = [pc.close() for pc in pcs] coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros) await asyncio.gather(*coros)
pcs.clear() pcs.clear()

View File

@@ -35,7 +35,7 @@ class StreamClient:
self.time_start = None self.time_start = None
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.player = MediaPlayer( self.player = MediaPlayer(
':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]), ':' + str(config['AUDIO']["AV_FOUNDATION_DEVICE_ID"]),
format='avfoundation', format='avfoundation',
options={'channels': '2'}) options={'channels': '2'})

View File

@@ -19,7 +19,7 @@ from whisper_jax import FlaxWhisperPipline
from reflector.utils.log_utils import logger from reflector.utils.log_utils import logger
from reflector.utils.run_utils import config, Mutex from reflector.utils.run_utils import config, Mutex
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"] WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_REAL_TIME_MODEL_SIZE"]
pcs = set() pcs = set()
relay = MediaRelay() relay = MediaRelay()
data_channel = None data_channel = None

View File

@@ -27,7 +27,7 @@ from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True) nltk.download('stopwords', quiet=True)
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_MODEL_SIZE"]
NOW = datetime.now() NOW = datetime.now()
if not os.path.exists('../../artefacts'): if not os.path.exists('../../artefacts'):

View File

@@ -16,7 +16,7 @@ from ...utils.run_utils import config
from ...utils.text_utils import post_process_transcription, summarize from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_MODEL_SIZE"]
FRAMES_PER_BUFFER = 8000 FRAMES_PER_BUFFER = 8000
FORMAT = pyaudio.paInt16 FORMAT = pyaudio.paInt16
@@ -31,7 +31,7 @@ def main():
AUDIO_DEVICE_ID = -1 AUDIO_DEVICE_ID = -1
for i in range(p.get_device_count()): for i in range(p.get_device_count()):
if p.get_device_info_by_index(i)["name"] == \ if p.get_device_info_by_index(i)["name"] == \
config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: config["AUDIO"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]:
AUDIO_DEVICE_ID = i AUDIO_DEVICE_ID = i
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
stream = p.open( stream = p.open(
@@ -44,7 +44,7 @@ def main():
) )
pipeline = FlaxWhisperPipline("openai/whisper-" + pipeline = FlaxWhisperPipline("openai/whisper-" +
config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], config["WHISPER"]["WHISPER_REAL_TIME_MODEL_SIZE"],
dtype=jnp.float16, dtype=jnp.float16,
batch_size=16) batch_size=16)

View File

@@ -6,11 +6,11 @@ import botocore
from .log_utils import logger from .log_utils import logger
from .run_utils import config from .run_utils import config
BUCKET_NAME = config["DEFAULT"]["BUCKET_NAME"] BUCKET_NAME = config["AWS"]["BUCKET_NAME"]
s3 = boto3.client('s3', s3 = boto3.client('s3',
aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"], aws_access_key_id=config["AWS"]["AWS_ACCESS_KEY"],
aws_secret_access_key=config["DEFAULT"]["AWS_SECRET_KEY"]) aws_secret_access_key=config["AWS"]["AWS_SECRET_KEY"])
def upload_files(files_to_upload): def upload_files(files_to_upload):

View File

@@ -121,9 +121,9 @@ def summarize_chunks(chunks, tokenizer, model):
with torch.no_grad(): with torch.no_grad():
summary_ids = \ summary_ids = \
model.generate(input_ids, model.generate(input_ids,
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), num_beams=int(config["SUMMARIZER"]["BEAM_SIZE"]),
length_penalty=2.0, length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), max_length=int(config["SUMMARIZER"]["MAX_LENGTH"]),
early_stopping=True) early_stopping=True)
summary = tokenizer.decode(summary_ids[0], summary = tokenizer.decode(summary_ids[0],
skip_special_tokens=True) skip_special_tokens=True)
@@ -132,7 +132,7 @@ def summarize_chunks(chunks, tokenizer, model):
def chunk_text(text, def chunk_text(text,
max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): max_chunk_length=int(config["SUMMARIZER"]["MAX_CHUNK_LENGTH"])):
""" """
Split text into smaller chunks. Split text into smaller chunks.
:param text: Text to be chunked :param text: Text to be chunked
@@ -154,9 +154,9 @@ def chunk_text(text,
def summarize(transcript_text, timestamp, def summarize(transcript_text, timestamp,
real_time=False, real_time=False,
chunk_summarize=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): chunk_summarize=config["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["DEFAULT"]["SUMMARY_MODEL"] summary_model = config["SUMMARIZER"]["SUMMARY_MODEL"]
if not summary_model: if not summary_model:
summary_model = "facebook/bart-large-cnn" summary_model = "facebook/bart-large-cnn"
@@ -171,7 +171,7 @@ def summarize(transcript_text, timestamp,
output_file = "real_time_" + output_file output_file = "real_time_" + output_file
if chunk_summarize != "YES": if chunk_summarize != "YES":
max_length = int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]) max_length = int(config["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer. \ inputs = tokenizer. \
batch_encode_plus([transcript_text], truncation=True, batch_encode_plus([transcript_text], truncation=True,
padding='longest', padding='longest',
@@ -180,8 +180,8 @@ def summarize(transcript_text, timestamp,
inputs = inputs.to(device) inputs = inputs.to(device)
with torch.no_grad(): with torch.no_grad():
num_beans = int(config["DEFAULT"]["BEAM_SIZE"]) num_beans = int(config["SUMMARIZER"]["BEAM_SIZE"])
max_length = int(config["DEFAULT"]["MAX_LENGTH"]) max_length = int(config["SUMMARIZER"]["MAX_LENGTH"])
summaries = model.generate(inputs['input_ids'], summaries = model.generate(inputs['input_ids'],
num_beams=num_beans, num_beams=num_beans,
length_penalty=2.0, length_penalty=2.0,