flake8 / pylint updates

This commit is contained in:
Gokul Mohanarangan
2023-07-26 11:28:14 +05:30
parent c970fc89dd
commit e512b4dca5
15 changed files with 279 additions and 146 deletions

View File

@@ -5,11 +5,16 @@ import signal
from aiortc.contrib.signaling import (add_signaling_arguments, from aiortc.contrib.signaling import (add_signaling_arguments,
create_signaling) create_signaling)
from utils.log_utils import logger from utils.log_utils import LOGGER
from stream_client import StreamClient from stream_client import StreamClient
from typing import NoReturn
async def main() -> NoReturn:
async def main(): """
Reflector's entry point to the python client for WebRTC streaming if not
using the browser based UI-application
:return:
"""
parser = argparse.ArgumentParser(description="Data channels ping/pong") parser = argparse.ArgumentParser(description="Data channels ping/pong")
parser.add_argument( parser.add_argument(
@@ -37,17 +42,17 @@ async def main():
async def shutdown(signal, loop): async def shutdown(signal, loop):
"""Cleanup tasks tied to the service's shutdown.""" """Cleanup tasks tied to the service's shutdown."""
logger.info(f"Received exit signal {signal.name}...") LOGGER.info(f"Received exit signal {signal.name}...")
logger.info("Closing database connections") LOGGER.info("Closing database connections")
logger.info("Nacking outstanding messages") LOGGER.info("Nacking outstanding messages")
tasks = [t for t in asyncio.all_tasks() if t is not tasks = [t for t in asyncio.all_tasks() if t is not
asyncio.current_task()] asyncio.current_task()]
[task.cancel() for task in tasks] [task.cancel() for task in tasks]
logger.info(f"Cancelling {len(tasks)} outstanding tasks") LOGGER.info(f"Cancelling {len(tasks)} outstanding tasks")
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
logger.info(f'{"Flushing metrics"}') LOGGER.info(f'{"Flushing metrics"}')
loop.stop() loop.stop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)

View File

@@ -1,3 +1,8 @@
"""
Collection of data classes for streamlining and rigidly structuring
the input and output parameters of functions
"""
import datetime import datetime
from dataclasses import dataclass from dataclasses import dataclass
from typing import List from typing import List
@@ -7,6 +12,10 @@ import av
@dataclass @dataclass
class TitleSummaryInput: class TitleSummaryInput:
"""
Data class for the input to generate title and summaries.
The outcome will be used to send query to the LLM for processing.
"""
input_text = str input_text = str
transcribed_time = float transcribed_time = float
prompt = str prompt = str
@@ -15,7 +24,8 @@ class TitleSummaryInput:
def __init__(self, transcribed_time, input_text=""): def __init__(self, transcribed_time, input_text=""):
self.input_text = input_text self.input_text = input_text
self.transcribed_time = transcribed_time self.transcribed_time = transcribed_time
self.prompt = f""" self.prompt = \
f"""
### Human: ### Human:
Create a JSON object as response.The JSON object must have 2 fields: Create a JSON object as response.The JSON object must have 2 fields:
i) title and ii) summary.For the title field,generate a short title i) title and ii) summary.For the title field,generate a short title
@@ -32,6 +42,10 @@ class TitleSummaryInput:
@dataclass @dataclass
class IncrementalResult: class IncrementalResult:
"""
Data class for the result of generating one title and summaries.
Defines how a single "topic" looks like.
"""
title = str title = str
description = str description = str
transcript = str transcript = str
@@ -44,6 +58,10 @@ class IncrementalResult:
@dataclass @dataclass
class TitleSummaryOutput: class TitleSummaryOutput:
"""
Data class for the result of all generated titles and summaries.
The result will be sent back to the client
"""
cmd = str cmd = str
topics = List[IncrementalResult] topics = List[IncrementalResult]
@@ -59,6 +77,10 @@ class TitleSummaryOutput:
@dataclass @dataclass
class ParseLLMResult: class ParseLLMResult:
"""
Data class to parse the result returned by the LLM while generating title
and summaries. The result will be sent back to the client.
"""
description = str description = str
transcript = str transcript = str
timestamp = str timestamp = str
@@ -66,7 +88,8 @@ class ParseLLMResult:
def __init__(self, param: TitleSummaryInput, output: dict): def __init__(self, param: TitleSummaryInput, output: dict):
self.transcript = param.input_text self.transcript = param.input_text
self.description = output.pop("summary") self.description = output.pop("summary")
self.timestamp = str(datetime.timedelta(seconds=round(param.transcribed_time))) self.timestamp = \
str(datetime.timedelta(seconds=round(param.transcribed_time)))
def get_result(self): def get_result(self):
return { return {
@@ -78,6 +101,10 @@ class ParseLLMResult:
@dataclass @dataclass
class TranscriptionInput: class TranscriptionInput:
"""
Data class to define the input to the transcription function
AudioFrames -> input
"""
frames = List[av.audio.frame.AudioFrame] frames = List[av.audio.frame.AudioFrame]
def __init__(self, frames): def __init__(self, frames):
@@ -86,6 +113,10 @@ class TranscriptionInput:
@dataclass @dataclass
class TranscriptionOutput: class TranscriptionOutput:
"""
Dataclass to define the result of the transcription function.
The result will be sent back to the client
"""
cmd = str cmd = str
result_text = str result_text = str
@@ -102,6 +133,10 @@ class TranscriptionOutput:
@dataclass @dataclass
class FinalSummaryResult: class FinalSummaryResult:
"""
Dataclass to define the result of the final summary function.
The result will be sent back to the client.
"""
cmd = str cmd = str
final_summary = str final_summary = str
duration = str duration = str
@@ -117,3 +152,13 @@ class FinalSummaryResult:
"duration": self.duration, "duration": self.duration,
"summary": self.final_summary "summary": self.final_summary
} }
class BlackListedMessages:
"""
Class to hold the blacklisted messages. These messages should be filtered
out and not sent back to the client as part of the transcription.
"""
messages = [" Thank you.", " See you next time!",
" Thank you for watching!", " Bye!",
" And that's what I'm talking about."]

View File

@@ -6,7 +6,7 @@ import os
import uuid import uuid
import wave import wave
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any, NoReturn from typing import Union, NoReturn
import aiohttp_cors import aiohttp_cors
import av import av
@@ -15,13 +15,13 @@ from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay from aiortc.contrib.media import MediaRelay
from faster_whisper import WhisperModel from faster_whisper import WhisperModel
from loguru import logger
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from reflector_dataclasses import FinalSummaryResult, ParseLLMResult,\ from reflector_dataclasses import FinalSummaryResult, ParseLLMResult,\
TitleSummaryInput, TitleSummaryOutput, TranscriptionInput,\ TitleSummaryInput, TitleSummaryOutput, TranscriptionInput,\
TranscriptionOutput TranscriptionOutput, BlackListedMessages
from utils.run_utils import config, run_in_executor from utils.run_utils import CONFIG, run_in_executor
from utils.log_utils import LOGGER
pcs = set() pcs = set()
relay = MediaRelay() relay = MediaRelay()
@@ -36,24 +36,24 @@ audio_buffer = av.AudioFifo()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
transcription_text = "" transcription_text = ""
last_transcribed_time = 0.0 last_transcribed_time = 0.0
LLM_MACHINE_IP = config["LLM"]["LLM_MACHINE_IP"] LLM_MACHINE_IP = CONFIG["LLM"]["LLM_MACHINE_IP"]
LLM_MACHINE_PORT = config["LLM"]["LLM_MACHINE_PORT"] LLM_MACHINE_PORT = CONFIG["LLM"]["LLM_MACHINE_PORT"]
LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate" LLM_URL = f"http://{LLM_MACHINE_IP}:{LLM_MACHINE_PORT}/api/v1/generate"
incremental_responses = [] incremental_responses = []
sorted_transcripts = SortedDict() sorted_transcripts = SortedDict()
def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Any[None, ParseLLMResult]: def parse_llm_output(param: TitleSummaryInput, response: requests.Response) -> Union[None, ParseLLMResult]:
try: try:
output = json.loads(response.json()["results"][0]["text"]) output = json.loads(response.json()["results"][0]["text"])
return ParseLLMResult(param, output) return ParseLLMResult(param, output)
except Exception as e: except Exception as e:
logger.info("Exception" + str(e)) LOGGER.info("Exception" + str(e))
return None return None
def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOutput]: def get_title_and_summary(param: TitleSummaryInput) -> Union[None, TitleSummaryOutput]:
logger.info("Generating title and summary") LOGGER.info("Generating title and summary")
# TODO : Handle unexpected output formats from the model # TODO : Handle unexpected output formats from the model
try: try:
@@ -66,12 +66,12 @@ def get_title_and_summary(param: TitleSummaryInput) -> Any[None, TitleSummaryOut
incremental_responses.append(result) incremental_responses.append(result)
return TitleSummaryOutput(incremental_responses) return TitleSummaryOutput(incremental_responses)
except Exception as e: except Exception as e:
logger.info("Exception" + str(e)) LOGGER.info("Exception" + str(e))
return None return None
def channel_log(channel, t: str, message: str) -> NoReturn: def channel_log(channel, t: str, message: str) -> NoReturn:
logger.info("channel(%s) %s %s" % (channel.label, t, message)) LOGGER.info("channel(%s) %s %s" % (channel.label, t, message))
def channel_send(channel, message: str) -> NoReturn: def channel_send(channel, message: str) -> NoReturn:
@@ -79,7 +79,7 @@ def channel_send(channel, message: str) -> NoReturn:
channel.send(message) channel.send(message)
def channel_send_increment(channel, param: Any[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn: def channel_send_increment(channel, param: Union[FinalSummaryResult, TitleSummaryOutput]) -> NoReturn:
if channel and param: if channel and param:
message = param.get_result() message = param.get_result()
channel.send(json.dumps(message)) channel.send(json.dumps(message))
@@ -89,11 +89,11 @@ def channel_send_transcript(channel) -> NoReturn:
# channel_log(channel, ">", message) # channel_log(channel, ">", message)
if channel: if channel:
try: try:
least_time = sorted_transcripts.keys()[0] least_time = next(iter(sorted_transcripts))
message = sorted_transcripts[least_time].get_result() message = sorted_transcripts[least_time].get_result()
if message: if message:
del sorted_transcripts[least_time] del sorted_transcripts[least_time]
if message["text"] not in blacklisted_messages: if message["text"] not in BlackListedMessages.messages:
channel.send(json.dumps(message)) channel.send(json.dumps(message))
# Due to exceptions if one of the earlier batches can't return # Due to exceptions if one of the earlier batches can't return
# a transcript, we don't want to be stuck waiting for the result # a transcript, we don't want to be stuck waiting for the result
@@ -101,22 +101,21 @@ def channel_send_transcript(channel) -> NoReturn:
else: else:
if len(sorted_transcripts) >= 3: if len(sorted_transcripts) >= 3:
del sorted_transcripts[least_time] del sorted_transcripts[least_time]
except Exception as e: except Exception as exception:
logger.info("Exception", str(e)) LOGGER.info("Exception", str(exception))
pass
def get_transcription(input_frames: TranscriptionInput) -> Any[None, TranscriptionOutput]: def get_transcription(input_frames: TranscriptionInput) -> Union[None, TranscriptionOutput]:
logger.info("Transcribing..") LOGGER.info("Transcribing..")
sorted_transcripts[input_frames[0].time] = None sorted_transcripts[input_frames.frames[0].time] = None
# TODO: Find cleaner way, watch "no transcription" issue below # TODO: Find cleaner way, watch "no transcription" issue below
# Passing IO objects instead of temporary files throws an error # Passing IO objects instead of temporary files throws an error
# Passing ndarrays (typecasted with float) does not give any # Passing ndarray (type casted with float) does not give any
# transcription. Refer issue, # transcription. Refer issue,
# https://github.com/guillaumekln/faster-whisper/issues/369 # https://github.com/guillaumekln/faster-whisper/issues/369
audiofilename = "test" + str(datetime.datetime.now()) audio_file = "test" + str(datetime.datetime.now())
wf = wave.open(audiofilename, "wb") wf = wave.open(audio_file, "wb")
wf.setnchannels(CHANNELS) wf.setnchannels(CHANNELS)
wf.setframerate(RATE) wf.setframerate(RATE)
wf.setsampwidth(2) wf.setsampwidth(2)
@@ -129,12 +128,12 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti
try: try:
segments, _ = \ segments, _ = \
model.transcribe(audiofilename, model.transcribe(audio_file,
language="en", language="en",
beam_size=5, beam_size=5,
vad_filter=True, vad_filter=True,
vad_parameters=dict(min_silence_duration_ms=500)) vad_parameters={"min_silence_duration_ms": 500})
os.remove(audiofilename) os.remove(audio_file)
segments = list(segments) segments = list(segments)
result_text = "" result_text = ""
duration = 0.0 duration = 0.0
@@ -152,9 +151,8 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti
last_transcribed_time += duration last_transcribed_time += duration
transcription_text += result_text transcription_text += result_text
except Exception as e: except Exception as exception:
logger.info("Exception" + str(e)) LOGGER.info("Exception" + str(exception))
pass
result = TranscriptionOutput(result_text) result = TranscriptionOutput(result_text)
sorted_transcripts[input_frames.frames[0].time] = result sorted_transcripts[input_frames.frames[0].time] = result
@@ -162,6 +160,11 @@ def get_transcription(input_frames: TranscriptionInput) -> Any[None, Transcripti
def get_final_summary_response() -> FinalSummaryResult: def get_final_summary_response() -> FinalSummaryResult:
"""
Collate the incremental summaries generated so far and return as the final
summary
:return:
"""
final_summary = "" final_summary = ""
# Collate inc summaries # Collate inc summaries
@@ -170,8 +173,9 @@ def get_final_summary_response() -> FinalSummaryResult:
response = FinalSummaryResult(final_summary, last_transcribed_time) response = FinalSummaryResult(final_summary, last_transcribed_time)
with open("./artefacts/meeting_titles_and_summaries.txt", "a") as f: with open("./artefacts/meeting_titles_and_summaries.txt", "a",
f.write(json.dumps(incremental_responses)) encoding="utf-8") as file:
file.write(json.dumps(incremental_responses))
return response return response
@@ -222,6 +226,11 @@ class AudioStreamTrack(MediaStreamTrack):
async def offer(request: requests.Request) -> web.Response: async def offer(request: requests.Request) -> web.Response:
"""
Establish the WebRTC connection with the client
:param request:
:return:
"""
params = await request.json() params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
@@ -230,7 +239,7 @@ async def offer(request: requests.Request) -> web.Response:
pcs.add(pc) pcs.add(pc)
def log_info(msg, *args) -> NoReturn: def log_info(msg, *args) -> NoReturn:
logger.info(pc_id + " " + msg, *args) LOGGER.info(pc_id + " " + msg, *args)
log_info("Created for " + request.remote) log_info("Created for " + request.remote)
@@ -272,15 +281,17 @@ async def offer(request: requests.Request) -> web.Response:
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
text=json.dumps( text=json.dumps(
{"sdp": pc.localDescription.sdp, {
"type": pc.localDescription.type} "sdp": pc.localDescription.sdp,
"type": pc.localDescription.type
}
), ),
) )
async def on_shutdown(app) -> NoReturn: async def on_shutdown(application: web.Application) -> NoReturn:
coros = [pc.close() for pc in pcs] coroutines = [pc.close() for pc in pcs]
await asyncio.gather(*coros) await asyncio.gather(*coroutines)
pcs.clear() pcs.clear()

View File

@@ -9,8 +9,8 @@ import stamina
from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc import (RTCPeerConnection, RTCSessionDescription)
from aiortc.contrib.media import (MediaPlayer, MediaRelay) from aiortc.contrib.media import (MediaPlayer, MediaRelay)
from utils.log_utils import logger from utils.log_utils import LOGGER
from utils.run_utils import config from utils.run_utils import CONFIG
class StreamClient: class StreamClient:
@@ -35,7 +35,7 @@ class StreamClient:
self.time_start = None self.time_start = None
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.player = MediaPlayer( self.player = MediaPlayer(
':' + str(config['AUDIO']["AV_FOUNDATION_DEVICE_ID"]), ':' + str(CONFIG['AUDIO']["AV_FOUNDATION_DEVICE_ID"]),
format='avfoundation', format='avfoundation',
options={'channels': '2'}) options={'channels': '2'})
@@ -74,7 +74,7 @@ class StreamClient:
self.pcs.add(pc) self.pcs.add(pc)
def log_info(msg, *args): def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args) LOGGER.info(pc_id + " " + msg, *args)
@pc.on("connectionstatechange") @pc.on("connectionstatechange")
async def on_connectionstatechange(): async def on_connectionstatechange():

View File

@@ -93,6 +93,6 @@ def generate_finetuning_dataset(video_ids):
video_ids = ["yTnSEZIwnkU"] video_ids = ["yTnSEZIwnkU"]
dataset = generate_finetuning_dataset(video_ids) dataset = generate_finetuning_dataset(video_ids)
with open("finetuning_dataset.jsonl", "w") as f: with open("finetuning_dataset.jsonl", "w", encoding="utf-8") as file:
for example in dataset: for example in dataset:
f.write(json.dumps(example) + "\n") file.write(json.dumps(example) + "\n")

View File

@@ -16,10 +16,10 @@ from av import AudioFifo
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from reflector.utils.log_utils import logger from reflector.utils.log_utils import LOGGER
from reflector.utils.run_utils import config, Mutex from reflector.utils.run_utils import CONFIG, Mutex
WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_REAL_TIME_MODEL_SIZE"] WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_REAL_TIME_MODEL_SIZE"]
pcs = set() pcs = set()
relay = MediaRelay() relay = MediaRelay()
data_channel = None data_channel = None
@@ -127,7 +127,7 @@ async def offer(request: requests.Request):
pcs.add(pc) pcs.add(pc)
def log_info(msg: str, *args): def log_info(msg: str, *args):
logger.info(pc_id + " " + msg, *args) LOGGER.info(pc_id + " " + msg, *args)
log_info("Created for " + request.remote) log_info("Created for " + request.remote)

View File

@@ -3,14 +3,14 @@ import sys
# Observe the incremental summaries by performing summaries in chunks # Observe the incremental summaries by performing summaries in chunks
with open("transcript.txt") as f: with open("transcript.txt", "r", encoding="utf-8") as file:
transcription = f.read() transcription = file.read()
def split_text_file(filename, token_count): def split_text_file(filename, token_count):
nlp = spacy.load('en_core_web_md') nlp = spacy.load('en_core_web_md')
with open(filename, 'r') as file: with open(filename, 'r', encoding="utf-8") as file:
text = file.read() text = file.read()
doc = nlp(text) doc = nlp(text)
@@ -36,9 +36,9 @@ chunks = split_text_file("transcript.txt", MAX_CHUNK_LENGTH)
print("Number of chunks", len(chunks)) print("Number of chunks", len(chunks))
# Write chunks to file to refer to input vs output, separated by blank lines # Write chunks to file to refer to input vs output, separated by blank lines
with open("chunks" + str(MAX_CHUNK_LENGTH) + ".txt", "a") as f: with open("chunks" + str(MAX_CHUNK_LENGTH) + ".txt", "a", encoding="utf-8") as file:
for c in chunks: for c in chunks:
f.write(c + "\n\n") file.write(c + "\n\n")
# If we want to run only a certain model, type the option while running # If we want to run only a certain model, type the option while running
# ex. python incsum.py 1 => will run approach 1 # ex. python incsum.py 1 => will run approach 1
@@ -78,9 +78,9 @@ if index == "1" or index is None:
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summaries.append(summary) summaries.append(summary)
with open("bart-summaries.txt", "a") as f: with open("bart-summaries.txt", "a", encoding="utf-8") as file:
for summary in summaries: for summary in summaries:
f.write(summary + "\n\n") file.write(summary + "\n\n")
# Approach 2 # Approach 2
if index == "2" or index is None: if index == "2" or index is None:
@@ -114,8 +114,8 @@ if index == "2" or index is None:
summary_ids = output[0, input_length:] summary_ids = output[0, input_length:]
summary = tokenizer.decode(summary_ids, skip_special_tokens=True) summary = tokenizer.decode(summary_ids, skip_special_tokens=True)
summaries.append(summary) summaries.append(summary)
with open("gptneo1.3B-summaries.txt", "a") as f: with open("gptneo1.3B-summaries.txt", "a", encoding="utf-8") as file:
f.write(summary + "\n\n") file.write(summary + "\n\n")
# Approach 3 # Approach 3
if index == "3" or index is None: if index == "3" or index is None:
@@ -152,6 +152,6 @@ if index == "3" or index is None:
skip_special_tokens=True) skip_special_tokens=True)
summaries.append(summary) summaries.append(summary)
with open("mpt-7b-summaries.txt", "a") as f: with open("mpt-7b-summaries.txt", "a", encoding="utf-8") as file:
for summary in summaries: for summary in summaries:
f.write(summary + "\n\n") file.write(summary + "\n\n")

View File

@@ -19,15 +19,15 @@ import yt_dlp as youtube_dl
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from ...utils.file_utils import download_files, upload_files from ...utils.file_utils import download_files, upload_files
from ...utils.log_utils import logger from ...utils.log_utils import LOGGER
from ...utils.run_utils import config from ...utils.run_utils import CONFIG
from ...utils.text_utils import post_process_transcription, summarize from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True) nltk.download('stopwords', quiet=True)
WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_MODEL_SIZE"]
NOW = datetime.now() NOW = datetime.now()
if not os.path.exists('../../artefacts'): if not os.path.exists('../../artefacts'):
@@ -75,7 +75,7 @@ def main():
# Download the lowest resolution YouTube video # Download the lowest resolution YouTube video
# (since we're just interested in the audio). # (since we're just interested in the audio).
# It will be saved to the current directory. # It will be saved to the current directory.
logger.info("Downloading YouTube video at url: " + args.location) LOGGER.info("Downloading YouTube video at url: " + args.location)
# Create options for the download # Create options for the download
ydl_opts = { ydl_opts = {
@@ -93,12 +93,12 @@ def main():
ydl.download([args.location]) ydl.download([args.location])
media_file = "../artefacts/audio.mp3" media_file = "../artefacts/audio.mp3"
logger.info("Saved downloaded YouTube video to: " + media_file) LOGGER.info("Saved downloaded YouTube video to: " + media_file)
else: else:
# XXX - Download file using urllib, check if file is # XXX - Download file using urllib, check if file is
# audio/video using python-magic # audio/video using python-magic
logger.info(f"Downloading file at url: {args.location}") LOGGER.info(f"Downloading file at url: {args.location}")
logger.info(" XXX - This method hasn't been implemented yet.") LOGGER.info(" XXX - This method hasn't been implemented yet.")
elif url.scheme == '': elif url.scheme == '':
media_file = url.path media_file = url.path
# If file is not present locally, take it from S3 bucket # If file is not present locally, take it from S3 bucket
@@ -119,7 +119,7 @@ def main():
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3",
delete=False).name delete=False).name
video.audio.write_audiofile(audio_filename, logger=None) video.audio.write_audiofile(audio_filename, logger=None)
logger.info(f"Extracting audio to: {audio_filename}") LOGGER.info(f"Extracting audio to: {audio_filename}")
# Handle audio only file # Handle audio only file
except Exception: except Exception:
audio = moviepy.editor.AudioFileClip(media_file) audio = moviepy.editor.AudioFileClip(media_file)
@@ -129,14 +129,14 @@ def main():
else: else:
audio_filename = media_file audio_filename = media_file
logger.info("Finished extracting audio") LOGGER.info("Finished extracting audio")
logger.info("Transcribing") LOGGER.info("Transcribing")
# Convert the audio to text using the OpenAI Whisper model # Convert the audio to text using the OpenAI Whisper model
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
dtype=jnp.float16, dtype=jnp.float16,
batch_size=16) batch_size=16)
whisper_result = pipeline(audio_filename, return_timestamps=True) whisper_result = pipeline(audio_filename, return_timestamps=True)
logger.info("Finished transcribing file") LOGGER.info("Finished transcribing file")
whisper_result = post_process_transcription(whisper_result) whisper_result = post_process_transcription(whisper_result)
@@ -153,10 +153,10 @@ def main():
"w") as transcript_file_timestamps: "w") as transcript_file_timestamps:
transcript_file_timestamps.write(str(whisper_result)) transcript_file_timestamps.write(str(whisper_result))
logger.info("Creating word cloud") LOGGER.info("Creating word cloud")
create_wordcloud(NOW) create_wordcloud(NOW)
logger.info("Performing talk-diff and talk-diff visualization") LOGGER.info("Performing talk-diff and talk-diff visualization")
create_talk_diff_scatter_viz(NOW) create_talk_diff_scatter_viz(NOW)
# S3 : Push artefacts to S3 bucket # S3 : Push artefacts to S3 bucket
@@ -172,7 +172,7 @@ def main():
summarize(transcript_text, NOW, False, False) summarize(transcript_text, NOW, False, False)
logger.info("Summarization completed") LOGGER.info("Summarization completed")
# Summarization takes a lot of time, so do this separately at the end # Summarization takes a lot of time, so do this separately at the end
files_to_upload = [prefix + "summary_" + suffix + ".txt"] files_to_upload = [prefix + "summary_" + suffix + ".txt"]

View File

@@ -11,12 +11,12 @@ from termcolor import colored
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from ...utils.file_utils import upload_files from ...utils.file_utils import upload_files
from ...utils.log_utils import logger from ...utils.log_utils import LOGGER
from ...utils.run_utils import config from ...utils.run_utils import CONFIG
from ...utils.text_utils import post_process_transcription, summarize from ...utils.text_utils import post_process_transcription, summarize
from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud from ...utils.viz_utils import create_talk_diff_scatter_viz, create_wordcloud
WHISPER_MODEL_SIZE = config['WHISPER']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = CONFIG['WHISPER']["WHISPER_MODEL_SIZE"]
FRAMES_PER_BUFFER = 8000 FRAMES_PER_BUFFER = 8000
FORMAT = pyaudio.paInt16 FORMAT = pyaudio.paInt16
@@ -31,7 +31,7 @@ def main():
AUDIO_DEVICE_ID = -1 AUDIO_DEVICE_ID = -1
for i in range(p.get_device_count()): for i in range(p.get_device_count()):
if p.get_device_info_by_index(i)["name"] == \ if p.get_device_info_by_index(i)["name"] == \
config["AUDIO"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: CONFIG["AUDIO"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]:
AUDIO_DEVICE_ID = i AUDIO_DEVICE_ID = i
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
stream = p.open( stream = p.open(
@@ -44,7 +44,7 @@ def main():
) )
pipeline = FlaxWhisperPipline("openai/whisper-" + pipeline = FlaxWhisperPipline("openai/whisper-" +
config["WHISPER"]["WHISPER_REAL_TIME_MODEL_SIZE"], CONFIG["WHISPER"]["WHISPER_REAL_TIME_MODEL_SIZE"],
dtype=jnp.float16, dtype=jnp.float16,
batch_size=16) batch_size=16)
@@ -106,23 +106,26 @@ def main():
" | Transcribed duration: " + " | Transcribed duration: " +
str(duration), "yellow")) str(duration), "yellow"))
except Exception as e: except Exception as exception:
print(e) print(str(exception))
finally: finally:
with open("real_time_transcript_" + with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S")
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: + ".txt", "w", encoding="utf-8") as file:
f.write(transcription) file.write(transcription)
with open("real_time_transcript_with_timestamp_" + with open("real_time_transcript_with_timestamp_" +
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w",
encoding="utf-8") as file:
transcript_with_timestamp["text"] = transcription transcript_with_timestamp["text"] = transcription
f.write(str(transcript_with_timestamp)) file.write(str(transcript_with_timestamp))
transcript_with_timestamp = post_process_transcription(transcript_with_timestamp) transcript_with_timestamp = \
post_process_transcription(transcript_with_timestamp)
logger.info("Creating word cloud") LOGGER.info("Creating word cloud")
create_wordcloud(NOW, True) create_wordcloud(NOW, True)
logger.info("Performing talk-diff and talk-diff visualization") LOGGER.info("Performing talk-diff and talk-diff visualization")
create_talk_diff_scatter_viz(NOW, True) create_talk_diff_scatter_viz(NOW, True)
# S3 : Push artefacts to S3 bucket # S3 : Push artefacts to S3 bucket
@@ -137,7 +140,7 @@ def main():
summarize(transcript_with_timestamp["text"], NOW, True, True) summarize(transcript_with_timestamp["text"], NOW, True, True)
logger.info("Summarization completed") LOGGER.info("Summarization completed")
# Summarization takes a lot of time, so do this separately at the end # Summarization takes a lot of time, so do this separately at the end
files_to_upload = ["real_time_summary_" + suffix + ".txt"] files_to_upload = ["real_time_summary_" + suffix + ".txt"]

View File

@@ -1,16 +1,21 @@
"""
Utility file for file handling related functions, including file downloads and
uploads to cloud storage
"""
import sys import sys
import boto3 import boto3
import botocore import botocore
from .log_utils import logger from .log_utils import LOGGER
from .run_utils import config from .run_utils import CONFIG
BUCKET_NAME = config["AWS"]["BUCKET_NAME"] BUCKET_NAME = CONFIG["AWS"]["BUCKET_NAME"]
s3 = boto3.client('s3', s3 = boto3.client('s3',
aws_access_key_id=config["AWS"]["AWS_ACCESS_KEY"], aws_access_key_id=CONFIG["AWS"]["AWS_ACCESS_KEY"],
aws_secret_access_key=config["AWS"]["AWS_SECRET_KEY"]) aws_secret_access_key=CONFIG["AWS"]["AWS_SECRET_KEY"])
def upload_files(files_to_upload): def upload_files(files_to_upload):
@@ -19,12 +24,12 @@ def upload_files(files_to_upload):
:param files_to_upload: List of files to upload :param files_to_upload: List of files to upload
:return: None :return: None
""" """
for KEY in files_to_upload: for key in files_to_upload:
logger.info("Uploading file " + KEY) LOGGER.info("Uploading file " + key)
try: try:
s3.upload_file(KEY, BUCKET_NAME, KEY) s3.upload_file(key, BUCKET_NAME, key)
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as exception:
print(e.response) print(exception.response)
def download_files(files_to_download): def download_files(files_to_download):
@@ -33,12 +38,12 @@ def download_files(files_to_download):
:param files_to_download: List of files to download :param files_to_download: List of files to download
:return: None :return: None
""" """
for KEY in files_to_download: for key in files_to_download:
logger.info("Downloading file " + KEY) LOGGER.info("Downloading file " + key)
try: try:
s3.download_file(BUCKET_NAME, KEY, KEY) s3.download_file(BUCKET_NAME, key, key)
except botocore.exceptions.ClientError as e: except botocore.exceptions.ClientError as exception:
if e.response['Error']['Code'] == "404": if exception.response['Error']['Code'] == "404":
print("The object does not exist.") print("The object does not exist.")
else: else:
raise raise

View File

@@ -1,13 +1,24 @@
"""
Utility function to format the artefacts created during Reflector run
"""
import json import json
with open("../artefacts/meeting_titles_and_summaries.txt", "r") as f: with open("../artefacts/meeting_titles_and_summaries.txt", "r",
encoding='utf-8') as f:
outputs = f.read() outputs = f.read()
outputs = json.loads(outputs) outputs = json.loads(outputs)
transcript_file = open("../artefacts/meeting_transcript.txt", "a") transcript_file = open("../artefacts/meeting_transcript.txt",
title_desc_file = open("../artefacts/meeting_title_description.txt", "a") "a",
summary_file = open("../artefacts/meeting_summary.txt", "a") encoding='utf-8')
title_desc_file = open("../artefacts/meeting_title_description.txt",
"a",
encoding='utf-8')
summary_file = open("../artefacts/meeting_summary.txt",
"a",
encoding='utf-8')
for item in outputs["topics"]: for item in outputs["topics"]:
transcript_file.write(item["transcript"]) transcript_file.write(item["transcript"])

View File

@@ -1,7 +1,15 @@
"""
Utility file for logging
"""
import loguru import loguru
class SingletonLogger: class SingletonLogger:
"""
Use Singleton design pattern to create a logger object and share it
across the entire project
"""
__instance = None __instance = None
@staticmethod @staticmethod
@@ -15,4 +23,4 @@ class SingletonLogger:
return SingletonLogger.__instance return SingletonLogger.__instance
logger = SingletonLogger.get_logger() LOGGER = SingletonLogger.get_logger()

View File

@@ -1,3 +1,7 @@
"""
Utility file for server side asynchronous task running and config objects
"""
import asyncio import asyncio
import configparser import configparser
import contextlib import contextlib
@@ -7,6 +11,9 @@ from typing import ContextManager, Generic, TypeVar
class ReflectorConfig: class ReflectorConfig:
"""
Create a single config object to share across the project
"""
__config = None __config = None
@staticmethod @staticmethod
@@ -17,7 +24,7 @@ class ReflectorConfig:
return ReflectorConfig.__config return ReflectorConfig.__config
config = ReflectorConfig.get_config() CONFIG = ReflectorConfig.get_config()
def run_in_executor(func, *args, executor=None, **kwargs): def run_in_executor(func, *args, executor=None, **kwargs):

View File

@@ -1,3 +1,7 @@
"""
Utility file for all text processing related functionalities
"""
import nltk import nltk
import torch import torch
from nltk.corpus import stopwords from nltk.corpus import stopwords
@@ -6,8 +10,8 @@ from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
from transformers import BartForConditionalGeneration, BartTokenizer from transformers import BartForConditionalGeneration, BartTokenizer
from log_utils import logger from log_utils import LOGGER
from run_utils import config from run_utils import CONFIG
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
@@ -32,6 +36,12 @@ def compute_similarity(sent1, sent2):
def remove_almost_alike_sentences(sentences, threshold=0.7): def remove_almost_alike_sentences(sentences, threshold=0.7):
"""
Filter sentences that are similar beyond a set threshold
:param sentences:
:param threshold:
:return:
"""
num_sentences = len(sentences) num_sentences = len(sentences)
removed_indices = set() removed_indices = set()
@@ -62,6 +72,11 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
def remove_outright_duplicate_sentences_from_chunk(chunk): def remove_outright_duplicate_sentences_from_chunk(chunk):
"""
Remove repetitive sentences
:param chunk:
:return:
"""
chunk_text = chunk["text"] chunk_text = chunk["text"]
sentences = nltk.sent_tokenize(chunk_text) sentences = nltk.sent_tokenize(chunk_text)
nonduplicate_sentences = list(dict.fromkeys(sentences)) nonduplicate_sentences = list(dict.fromkeys(sentences))
@@ -69,6 +84,12 @@ def remove_outright_duplicate_sentences_from_chunk(chunk):
def remove_whisper_repetitive_hallucination(nonduplicate_sentences): def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
"""
Remove sentences that are repeated as a result of Whisper
hallucinations
:param nonduplicate_sentences:
:return:
"""
chunk_sentences = [] chunk_sentences = []
for sent in nonduplicate_sentences: for sent in nonduplicate_sentences:
@@ -91,6 +112,11 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
def post_process_transcription(whisper_result): def post_process_transcription(whisper_result):
"""
Parent function to perform post-processing on the transcription result
:param whisper_result:
:return:
"""
transcript_text = "" transcript_text = ""
for chunk in whisper_result["chunks"]: for chunk in whisper_result["chunks"]:
nonduplicate_sentences = \ nonduplicate_sentences = \
@@ -121,9 +147,9 @@ def summarize_chunks(chunks, tokenizer, model):
with torch.no_grad(): with torch.no_grad():
summary_ids = \ summary_ids = \
model.generate(input_ids, model.generate(input_ids,
num_beams=int(config["SUMMARIZER"]["BEAM_SIZE"]), num_beams=int(CONFIG["SUMMARIZER"]["BEAM_SIZE"]),
length_penalty=2.0, length_penalty=2.0,
max_length=int(config["SUMMARIZER"]["MAX_LENGTH"]), max_length=int(CONFIG["SUMMARIZER"]["MAX_LENGTH"]),
early_stopping=True) early_stopping=True)
summary = tokenizer.decode(summary_ids[0], summary = tokenizer.decode(summary_ids[0],
skip_special_tokens=True) skip_special_tokens=True)
@@ -132,7 +158,7 @@ def summarize_chunks(chunks, tokenizer, model):
def chunk_text(text, def chunk_text(text,
max_chunk_length=int(config["SUMMARIZER"]["MAX_CHUNK_LENGTH"])): max_chunk_length=int(CONFIG["SUMMARIZER"]["MAX_CHUNK_LENGTH"])):
""" """
Split text into smaller chunks. Split text into smaller chunks.
:param text: Text to be chunked :param text: Text to be chunked
@@ -154,14 +180,22 @@ def chunk_text(text,
def summarize(transcript_text, timestamp, def summarize(transcript_text, timestamp,
real_time=False, real_time=False,
chunk_summarize=config["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]): chunk_summarize=CONFIG["SUMMARIZER"]["SUMMARIZE_USING_CHUNKS"]):
"""
Summarize the given text either as a whole or as chunks as needed
:param transcript_text:
:param timestamp:
:param real_time:
:param chunk_summarize:
:return:
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["SUMMARIZER"]["SUMMARY_MODEL"] summary_model = CONFIG["SUMMARIZER"]["SUMMARY_MODEL"]
if not summary_model: if not summary_model:
summary_model = "facebook/bart-large-cnn" summary_model = "facebook/bart-large-cnn"
# Summarize the generated transcript using the BART model # Summarize the generated transcript using the BART model
logger.info(f"Loading BART model: {summary_model}") LOGGER.info(f"Loading BART model: {summary_model}")
tokenizer = BartTokenizer.from_pretrained(summary_model) tokenizer = BartTokenizer.from_pretrained(summary_model)
model = BartForConditionalGeneration.from_pretrained(summary_model) model = BartForConditionalGeneration.from_pretrained(summary_model)
model = model.to(device) model = model.to(device)
@@ -171,7 +205,7 @@ def summarize(transcript_text, timestamp,
output_file = "real_time_" + output_file output_file = "real_time_" + output_file
if chunk_summarize != "YES": if chunk_summarize != "YES":
max_length = int(config["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"]) max_length = int(CONFIG["SUMMARIZER"]["INPUT_ENCODING_MAX_LENGTH"])
inputs = tokenizer. \ inputs = tokenizer. \
batch_encode_plus([transcript_text], truncation=True, batch_encode_plus([transcript_text], truncation=True,
padding='longest', padding='longest',
@@ -180,8 +214,8 @@ def summarize(transcript_text, timestamp,
inputs = inputs.to(device) inputs = inputs.to(device)
with torch.no_grad(): with torch.no_grad():
num_beans = int(config["SUMMARIZER"]["BEAM_SIZE"]) num_beans = int(CONFIG["SUMMARIZER"]["BEAM_SIZE"])
max_length = int(config["SUMMARIZER"]["MAX_LENGTH"]) max_length = int(CONFIG["SUMMARIZER"]["MAX_LENGTH"])
summaries = model.generate(inputs['input_ids'], summaries = model.generate(inputs['input_ids'],
num_beams=num_beans, num_beams=num_beans,
length_penalty=2.0, length_penalty=2.0,
@@ -194,16 +228,16 @@ def summarize(transcript_text, timestamp,
clean_up_tokenization_spaces=False) clean_up_tokenization_spaces=False)
for summary in summaries] for summary in summaries]
summary = " ".join(decoded_summaries) summary = " ".join(decoded_summaries)
with open("./artefacts/" + output_file, 'w') as f: with open("./artefacts/" + output_file, 'w', encoding="utf-8") as file:
f.write(summary.strip() + "\n") file.write(summary.strip() + "\n")
else: else:
logger.info("Breaking transcript into smaller chunks") LOGGER.info("Breaking transcript into smaller chunks")
chunks = chunk_text(transcript_text) chunks = chunk_text(transcript_text)
logger.info(f"Transcript broken into {len(chunks)} " LOGGER.info(f"Transcript broken into {len(chunks)} "
f"chunks of at most 500 words") f"chunks of at most 500 words")
logger.info(f"Writing summary text to: {output_file}") LOGGER.info(f"Writing summary text to: {output_file}")
with open(output_file, 'w') as f: with open(output_file, 'w') as f:
summaries = summarize_chunks(chunks, tokenizer, model) summaries = summarize_chunks(chunks, tokenizer, model)
for summary in summaries: for summary in summaries:

View File

@@ -1,3 +1,7 @@
"""
Utility file for all visualization related functions
"""
import ast import ast
import collections import collections
import os import os
@@ -81,8 +85,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
else: else:
filename = "./artefacts/transcript_with_timestamp_" + \ filename = "./artefacts/transcript_with_timestamp_" + \
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
with open(filename) as f: with open(filename) as file:
transcription_timestamp_text = f.read() transcription_timestamp_text = file.read()
res = ast.literal_eval(transcription_timestamp_text) res = ast.literal_eval(transcription_timestamp_text)
chunks = res["chunks"] chunks = res["chunks"]