Issues 44, 46, 47

This commit is contained in:
Gokul Mohanarangan
2023-07-27 11:54:24 +05:30
parent 499edd665b
commit 60ea3ac137
6 changed files with 141 additions and 40 deletions

View File

@@ -36,7 +36,7 @@ class TitleSummaryInput:
### Assistant:
"""
self.data = {"data": self.prompt}
self.data = {"prompt": self.prompt}
self.headers = {"Content-Type": "application/json"}
@@ -49,11 +49,13 @@ class IncrementalResult:
title = str
description = str
transcript = str
timestamp = str
def __init__(self, title, desc, transcript):
def __init__(self, title, desc, transcript, timestamp):
self.title = title
self.description = desc
self.transcript = transcript
self.timestamp = timestamp
@dataclass
@@ -67,8 +69,13 @@ class TitleSummaryOutput:
def __init__(self, inc_responses):
self.topics = inc_responses
self.cmd = "UPDATE_TOPICS"
def get_result(self):
def get_result(self) -> dict:
"""
Return the result dict for displaying the transcription
:return:
"""
return {
"cmd": self.cmd,
"topics": self.topics
@@ -81,18 +88,25 @@ class ParseLLMResult:
Data class to parse the result returned by the LLM while generating title
and summaries. The result will be sent back to the client.
"""
title = str
description = str
transcript = str
timestamp = str
def __init__(self, param: TitleSummaryInput, output: dict):
self.title = output["title"]
self.transcript = param.input_text
self.description = output.pop("summary")
self.timestamp = \
str(datetime.timedelta(seconds=round(param.transcribed_time)))
def get_result(self):
def get_result(self) -> dict:
"""
Return the result dict after parsing the response from LLM
:return:
"""
return {
"title": self.title,
"description": self.description,
"transcript": self.transcript,
"timestamp": self.timestamp
@@ -124,7 +138,11 @@ class TranscriptionOutput:
self.cmd = "SHOW_TRANSCRIPTION"
self.result_text = result_text
def get_result(self):
def get_result(self) -> dict:
"""
Return the result dict for displaying the transcription
:return:
"""
return {
"cmd": self.cmd,
"text": self.result_text
@@ -144,9 +162,13 @@ class FinalSummaryResult:
def __init__(self, final_summary, time):
self.duration = str(datetime.timedelta(seconds=round(time)))
self.final_summary = final_summary
self.cmd = ""
self.cmd = "DISPLAY_FINAL_SUMMARY"
def get_result(self):
def get_result(self) -> dict:
"""
Return the result dict for displaying the final summary
:return:
"""
return {
"cmd": self.cmd,
"duration": self.duration,

View File

@@ -6,7 +6,7 @@ import os
import uuid
import wave
from concurrent.futures import ThreadPoolExecutor
from typing import Union, NoReturn
from typing import NoReturn, Union
import aiohttp_cors
import av
@@ -17,33 +17,50 @@ from aiortc.contrib.media import MediaRelay
from faster_whisper import WhisperModel
from sortedcontainers import SortedDict
from reflector_dataclasses import FinalSummaryResult, ParseLLMResult,\
TitleSummaryInput, TitleSummaryOutput, TranscriptionInput,\
TranscriptionOutput, BlackListedMessages
from utils.run_utils import CONFIG, run_in_executor
from reflector_dataclasses import BlackListedMessages, FinalSummaryResult, ParseLLMResult, TitleSummaryInput, \
TitleSummaryOutput, TranscriptionInput, TranscriptionOutput
from utils.log_utils import LOGGER
from utils.run_utils import CONFIG, run_in_executor
# WebRTC components
pcs = set()
relay = MediaRelay()
data_channel = None
audio_buffer = av.AudioFifo()
executor = ThreadPoolExecutor()
# Transcription model
model = WhisperModel("tiny", device="cpu",
compute_type="float32",
num_workers=12)
CHANNELS = 2
RATE = 48000
audio_buffer = av.AudioFifo()
executor = ThreadPoolExecutor()
# Audio configurations
CHANNELS = int(CONFIG["AUDIO"]["CHANNELS"])
RATE = int(CONFIG["AUDIO"]["SAMPLING_RATE"])
# Global vars
transcription_text = ""
last_transcribed_time = 0.0
# LLM
LLM_MACHINE_IP = CONFIG["LLM"]["LLM_MACHINE_IP"]
LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"]
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
# Topic and summary responses
incremental_responses = []
# To synchronize the thread results before returning to the client
sorted_transcripts = SortedDict()
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]:
"""
Function to parse the LLM response
:param param:
:param response:
:return:
"""
try:
output = json.loads(response.json()["results"][0]["text"])
return ParseLLMResult(param, output)
@@ -53,6 +70,12 @@ def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> U
def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]:
"""
From the input provided (transcript), query the LLM to generate
topics and summaries
:param param:
:return:
"""
LOGGER.info("Generating title and summary")
# TODO : Handle unexpected output formats from the model
@@ -71,21 +94,45 @@ def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryO
def channel_log(channel, t: str, message: str) -> NoReturn:
"""
Add logs
:param channel:
:param t:
:param message:
:return:
"""
LOGGER.info("channel(%s) %s %s" % (channel.label, t, message))
def channel_send(channel, message: str) -> NoReturn:
"""
Send text messages via the data channel
:param channel:
:param message:
:return:
"""
if channel:
channel.send(message)
def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn:
"""
Send the incremental topics and summaries via the data channel
:param channel:
:param param:
:return:
"""
if channel and param:
message = param.get_result()
channel.send(json.dumps(message))
def channel_send_transcript(channel) -> NoReturn:
"""
Send the transcription result via the data channel
:param channel:
:return:
"""
# channel_log(channel, ">", message)
if channel:
try:
@@ -106,6 +153,12 @@ def channel_send_transcript(channel) -> NoReturn:
def get_transcription(input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]:
"""
From the collected audio frames create transcription by inferring from
the chosen transcription model
:param input_frames:
:return:
"""
LOGGER.info("Transcribing..")
sorted_transcripts[input_frames.frames[0].time] = None
@@ -290,6 +343,12 @@ async def offer(request: requests.Request) -> web.Response:
async def on_shutdown(application: web.Application) -> NoReturn:
"""
On shutdown, the coroutines that shutdown client connections are
executed
:param application:
:return:
"""
coroutines = [pc.close() for pc in pcs]
await asyncio.gather(*coroutines)
pcs.clear()

View File

@@ -4,6 +4,7 @@ uploads to cloud storage
"""
import sys
from typing import List, NoReturn
import boto3
import botocore
@@ -18,7 +19,7 @@ s3 = boto3.client('s3',
aws_secret_access_key=CONFIG["AWS"]["AWS_SECRET_KEY"])
def upload_files(files_to_upload):
def upload_files(files_to_upload: List[str]) -> NoReturn:
"""
Upload a list of files to the configured S3 bucket
:param files_to_upload: List of files to upload
@@ -32,7 +33,7 @@ def upload_files(files_to_upload):
print(exception.response)
def download_files(files_to_download):
def download_files(files_to_download: List[str]) -> NoReturn:
"""
Download a list of files from the configured S3 bucket
:param files_to_download: List of files to download

View File

@@ -18,6 +18,10 @@ class ReflectorConfig:
@staticmethod
def get_config():
"""
Load the configurations from the local config.ini file
:return:
"""
if ReflectorConfig.__config is None:
ReflectorConfig.__config = configparser.ConfigParser()
ReflectorConfig.__config.read('utils/config.ini')

View File

@@ -1,6 +1,8 @@
"""
Utility file for all text processing related functionalities
"""
import datetime
from typing import List
import nltk
import torch
@@ -16,7 +18,12 @@ from run_utils import CONFIG
nltk.download('punkt', quiet=True)
def preprocess_sentence(sentence):
def preprocess_sentence(sentence: str) -> str:
"""
Filter out undesirable tokens from thr sentence
:param sentence:
:return:
"""
stop_words = set(stopwords.words('english'))
tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens
@@ -24,7 +31,7 @@ def preprocess_sentence(sentence):
return ' '.join(tokens)
def compute_similarity(sent1, sent2):
def compute_similarity(sent1: str, sent2: str) -> float:
"""
Compute the similarity
"""
@@ -35,7 +42,7 @@ def compute_similarity(sent1, sent2):
return 0.0
def remove_almost_alike_sentences(sentences, threshold=0.7):
def remove_almost_alike_sentences(sentences: List[str], threshold=0.7) -> List[str]:
"""
Filter sentences that are similar beyond a set threshold
:param sentences:
@@ -71,7 +78,7 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
return filtered_sentences
def remove_outright_duplicate_sentences_from_chunk(chunk):
def remove_outright_duplicate_sentences_from_chunk(chunk: str) -> List[str]:
"""
Remove repetitive sentences
:param chunk:
@@ -83,7 +90,7 @@ def remove_outright_duplicate_sentences_from_chunk(chunk):
return nonduplicate_sentences
def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
def remove_whisper_repetitive_hallucination(nonduplicate_sentences: List[str]) -> List[str]:
"""
Remove sentences that are repeated as a result of Whisper
hallucinations
@@ -111,7 +118,7 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
return chunk_sentences
def post_process_transcription(whisper_result):
def post_process_transcription(whisper_result: dict) -> dict:
"""
Parent function to perform post-processing on the transcription result
:param whisper_result:
@@ -131,7 +138,7 @@ def post_process_transcription(whisper_result):
return whisper_result
def summarize_chunks(chunks, tokenizer, model):
def summarize_chunks(chunks: List[str], tokenizer, model) -> List[str]:
"""
Summarize each chunk using a summarizer model
:param chunks:
@@ -157,8 +164,8 @@ def summarize_chunks(chunks, tokenizer, model):
return summaries
def chunk_text(text,
max_chunk_length=int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])):
def chunk_text(text: str,
max_chunk_length: int = int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])) -> List[str]:
"""
Split text into smaller chunks.
:param text: Text to be chunked
@@ -178,9 +185,9 @@ def chunk_text(text,
return chunks
def summarize(transcript_text, timestamp,
real_time=False,
chunk_summarize=CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
def summarize(transcript_text: str, timestamp: datetime.datetime.timestamp,
real_time: bool = False,
chunk_summarize: str = CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
"""
Summarize the given text either as a whole or as chunks as needed
:param transcript_text:

View File

@@ -4,8 +4,10 @@ Utility file for all visualization related functions
import ast
import collections
import datetime
import os
import pickle
from typing import NoReturn
import matplotlib.pyplot as plt
import pandas as pd
@@ -21,7 +23,8 @@ STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))). \
union(set(spacy_stopwords))
def create_wordcloud(timestamp, real_time=False):
def create_wordcloud(timestamp: datetime.datetime.timestamp,
real_time: bool = False) -> NoReturn:
"""
Create a basic word cloud visualization of transcribed text
:return: None. The wordcloud image is saved locally
@@ -52,14 +55,15 @@ def create_wordcloud(timestamp, real_time=False):
wordcloud = "wordcloud"
if real_time:
wordcloud = "real_time_" + wordcloud + "_" + \
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
else:
wordcloud += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
plt.savefig("./artefacts/" + wordcloud)
def create_talk_diff_scatter_viz(timestamp, real_time=False):
def create_talk_diff_scatter_viz(timestamp: datetime.datetime.timestamp,
real_time: bool = False) -> NoReturn:
"""
Perform agenda vs transcription diff to see covered topics.
Create a scatter plot of words in topics.
@@ -124,14 +128,16 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
covered_items[agenda[topic_similarities[i][0]]] = True
# top1 match
if i == 0:
ts_to_topic_mapping_top_1[c["timestamp"]] = agenda_topics[topic_similarities[i][0]]
ts_to_topic_mapping_top_1[c["timestamp"]] = \
agenda_topics[topic_similarities[i][0]]
topic_to_ts_mapping_top_1[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
# top2 match
else:
ts_to_topic_mapping_top_2[c["timestamp"]] = agenda_topics[topic_similarities[i][0]]
ts_to_topic_mapping_top_2[c["timestamp"]] = \
agenda_topics[topic_similarities[i][0]]
topic_to_ts_mapping_top_2[agenda_topics[topic_similarities[i][0]]].append(c["timestamp"])
def create_new_columns(record):
def create_new_columns(record: dict) -> dict:
"""
Accumulate the mapping information into the df
:param record:
@@ -210,8 +216,10 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
transform=st.Scalers.dense_rank
)
if real_time:
open('./artefacts/real_time_scatter_' +
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
with open('./artefacts/real_time_scatter_' +
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w') as file:
file.write(html)
else:
open('./artefacts/scatter_' +
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
with open('./artefacts/scatter_' +
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w') as file:
file.write(html)