Merge pull request #25 from Monadical-SAS/whisper-jax-gokul

Refactor codebase and clean-up code
This commit is contained in:
projects-g
2023-07-11 18:52:54 +05:30
committed by GitHub
23 changed files with 454 additions and 401 deletions

5
.gitignore vendored
View File

@@ -160,9 +160,6 @@ cython_debug/
#.idea/
*.mp4
summary.txt
transcript.txt
transcript_timestamps.txt
*.html
*.pkl
transcript_*.txt
@@ -174,4 +171,6 @@ test_samples/
*.mp3
*.m4a
.DS_Store/
.DS_Store
.vscode/
artefacts/

125
README.md
View File

@@ -1,32 +1,34 @@
# Reflector
This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM.
The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target agenda/objectives to the actual discussion live.
This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting
panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM.
The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target
agenda/objectives to the actual discussion live.
**S3 bucket:**
Everything you need for S3 is already configured in config.ini. Only edit it if you need to change it deliberately.
S3 bucket name is mentioned in config.ini. All transfers will happen between this bucket and the local computer where the
script is run. You need AWS_ACCESS_KEY / AWS_SECRET_KEY to authenticate your calls to S3 (done in config.ini).
S3 bucket name is mentioned in config.ini. All transfers will happen between this bucket and the local computer where
the
script is run. You need AWS_ACCESS_KEY / AWS_SECRET_KEY to authenticate your calls to S3 (done in config.ini).
For AWS S3 Web UI,
1) Login to AWS management console.
2) Search for S3 in the search bar at the top.
3) Navigate to list the buckets under the current account, if needed and choose your bucket [```reflector-bucket```]
4) You should be able to see items in the bucket. You can upload/download files here directly.
For CLI,
For CLI,
Refer to the FILE UTIL section below.
**FILE UTIL MODULE:**
A file_util module has been created to upload/download files with AWS S3 bucket pre-configured using config.ini.
Though not needed for the workflow, if you need to upload / download file, separately on your own, apart from the pipeline workflow in the script, you can do so by :
A file_util module has been created to upload/download files with AWS S3 bucket pre-configured using config.ini.
Though not needed for the workflow, if you need to upload / download file, separately on your own, apart from the
pipeline workflow in the script, you can do so by :
Upload:
@@ -39,37 +41,37 @@ Download:
If you want to access the S3 artefacts, from another machine, you can either use the python file_util with the commands
mentioned above or simply use the GUI of AWS Management Console.
To setup,
To setup,
1) Check values in config.ini file. Specifically add your OPENAI_APIKEY if you plan to use OpenAI API requests.
2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.]
2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in
Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.]
NOTE: If you don't have portaudio installed already, run ```brew install portaudio```
3) Run the script setup_depedencies.sh.
``` chmod +x setup_dependencies.sh ```
``` chmod +x setup_dependencies.sh ```
``` sh setup_dependencies.sh <ENV>```
``` sh setup_dependencies.sh <ENV>```
ENV refers to the intended environment for JAX. JAX is available in several variants, [CPU | GPU | Colab TPU | Google Cloud TPU]
```ENV``` is :
cpu -> JAX CPU installation
ENV refers to the intended environment for JAX. JAX is available in several
variants, [CPU | GPU | Colab TPU | Google Cloud TPU]
cuda11 -> JAX CUDA 11.x version
```ENV``` is :
cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```)
cpu -> JAX CPU installation
cuda11 -> JAX CUDA 11.x version
cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```)
```sh setup_dependencies.sh cuda12```
4) If not already done, install ffmpeg. ```brew install ffmpeg```
For NLTK SSL error, check [here](https://stackoverflow.com/questions/38916452/nltk-download-ssl-certificate-verify-failed)
For NLTK SSL error,
check [here](https://stackoverflow.com/questions/38916452/nltk-download-ssl-certificate-verify-failed)
5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
@@ -79,83 +81,92 @@ You can even run it on local file or a file in your configured S3 bucket.
``` python3 whisjax.py "startup.mp4"```
The script will take care of a few cases like youtube file, local file, video file, audio-only file,
The script will take care of a few cases like youtube file, local file, video file, audio-only file,
file in S3, etc. If local file is not present, it can automatically take the file from S3.
**OFFLINE WORKFLOW:**
1) Specify the input source file] from a local, youtube link or upload to S3 if needed and pass it as input to the script.If the source file is in
1) Specify the input source file] from a local, youtube link or upload to S3 if needed and pass it as input to the
script.If the source file is in
```.m4a``` format, it will get converted to ```.mp4``` automatically.
2) Keep the agenda header topics in a local file named ```agenda-headers.txt```. This needs to be present where the script is run.
2) Keep the agenda header topics in a local file named ```agenda-headers.txt```. This needs to be present where the
script is run.
This version of the pipeline compares covered agenda topics using agenda headers in the following format.
1) ```agenda_topic : <short description>```
3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the
topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact called
```df_<timestamp>.pkl``` , you can load the df and choose different topics to plot. You can filter using certain words to search for the
1) ```agenda_topic : <short description>```
3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the
topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact
called
```df_<timestamp>.pkl``` , you can load the df and choose different topics to plot. You can filter using certain
words to search for the
transcriptions and you can see the top influencers and characteristic in each topic we have chosen to plot in the
interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with, named
interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with,
named
```Viz_experiments.ipynb```.
4) Run the script. The script automatically transcribes, summarizes and creates a scatter plot of words & topics in the form of an interactive
HTML file, a sample word cloud and uploads them to the S3 bucket
4) Run the script. The script automatically transcribes, summarizes and creates a scatter plot of words & topics in the
form of an interactive
HTML file, a sample word cloud and uploads them to the S3 bucket
5) Additional artefacts pushed to S3:
1) HTML visualization file
2) pandas df in pickle format for others to collaborate and make their own visualizations
3) Summary, transcript and transcript with timestamps file in text format.
1) HTML visualization file
2) pandas df in pickle format for others to collaborate and make their own visualizations
3) Summary, transcript and transcript with timestamps file in text format.
The script also creates 2 types of mappings.
1) Timestamp -> The top 2 matched agenda topic
2) Topic -> All matched timestamps in the transcription
Other visualizations can be planned based on available artefacts or new ones can be created. Refer the section ```Viz-experiments```.
The script also creates 2 types of mappings.
1) Timestamp -> The top 2 matched agenda topic
2) Topic -> All matched timestamps in the transcription
Other visualizations can be planned based on available artefacts or new ones can be created. Refer the
section ```Viz-experiments```.
**Visualization experiments:**
This is a jupyter notebook playground with template instructions on handling the metadata and data artefacts generated from the
pipeline. Follow the instructions given and tweak your own logic into it or use it as a playground to experiment libraries and
This is a jupyter notebook playground with template instructions on handling the metadata and data artefacts generated
from the
pipeline. Follow the instructions given and tweak your own logic into it or use it as a playground to experiment
libraries and
visualizations on top of the metadata.
**WHISPER-JAX REALTIME TRANSCRIPTION PIPELINE:**
We also support a provision to perform real-time transcripton using whisper-jax pipeline. But, there are
a few pre-requisites before you run it on your local machine. The instructions are for
We also support a provision to perform real-time transcripton using whisper-jax pipeline. But, there are
a few pre-requisites before you run it on your local machine. The instructions are for
configuring on a MacOS.
We need to way to route audio from an application opened via the browser, ex. "Whereby" and audio from your local
microphone input which you will be using for speaking. We use [Blackhole](https://github.com/ExistentialAudio/BlackHole).
microphone input which you will be using for speaking. We
use [Blackhole](https://github.com/ExistentialAudio/BlackHole).
1) Install Blackhole-2ch (2 ch is enough) by 1 of 2 options listed.
2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and
local microphone input.
Be sure to mirror the settings given ![here](./images/aggregate_input.png)
Be sure to mirror the settings given ![here](./images/aggregate_input.png)
3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device)
Refer ![here](./images/multi-output.png)
4) Set the aggregator input device name created in step 2 in config.ini as ```BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME```
5) Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and
Input tabs.
Input tabs.
6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen to
and the output should be fed back to your specified output devices if everything is configured properly. Check this
before trying out the trial.
6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen
to
and the output should be fed back to your specified output devices if everything is configured properly. Check this
before trying out the trial.
**Permissions:**
You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in
You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in
```System Preferences -> Privacy & Security -> Microphone```,
```System Preferences -> Privacy & Security -> Accessibility```,
```System Preferences -> Privacy & Security -> Input Monitoring```.
From the reflector root folder,
From the reflector root folder,
run ```python3 whisjax_realtime.py```
The transcription text should be written to ```real_time_transcription_<timestamp>.txt```.
NEXT STEPS:
1) Create a RunPod setup for this feature (mentioned in 1 & 2) and test it end-to-end

View File

@@ -1,35 +1,33 @@
import argparse
import asyncio
import logging
import signal
from aiortc.contrib.signaling import (add_signaling_arguments,
create_signaling)
from stream_client import StreamClient
logger = logging.getLogger("pc")
from utils.log_utils import logger
async def main():
parser = argparse.ArgumentParser(description="Data channels ping/pong")
parser.add_argument(
"--url", type=str, nargs="?", default="http://127.0.0.1:1250/offer"
"--url", type=str, nargs="?", default="http://127.0.0.1:1250/offer"
)
parser.add_argument(
"--ping-pong",
help="Benchmark data channel with ping pong",
type=eval,
choices=[True, False],
default="False",
"--ping-pong",
help="Benchmark data channel with ping pong",
type=eval,
choices=[True, False],
default="False",
)
parser.add_argument(
"--play-from",
type=str,
default="",
"--play-from",
type=str,
default="",
)
add_signaling_arguments(parser)
@@ -39,34 +37,33 @@ async def main():
async def shutdown(signal, loop):
"""Cleanup tasks tied to the service's shutdown."""
logging.info(f"Received exit signal {signal.name}...")
logging.info("Closing database connections")
logging.info("Nacking outstanding messages")
logger.info(f"Received exit signal {signal.name}...")
logger.info("Closing database connections")
logger.info("Nacking outstanding messages")
tasks = [t for t in asyncio.all_tasks() if t is not
asyncio.current_task()]
[task.cancel() for task in tasks]
logging.info(f"Cancelling {len(tasks)} outstanding tasks")
logger.info(f"Cancelling {len(tasks)} outstanding tasks")
await asyncio.gather(*tasks, return_exceptions=True)
logging.info(f"Flushing metrics")
logger.info(f'{"Flushing metrics"}')
loop.stop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
loop = asyncio.get_event_loop()
for s in signals:
loop.add_signal_handler(
s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
# Init client
sc = StreamClient(
signaling=signaling,
url=args.url,
play_from=args.play_from,
ping_pong=args.ping_pong
signaling=signaling,
url=args.url,
play_from=args.play_from,
ping_pong=args.ping_pong
)
await sc.start()
print("Stream client started")
async for msg in sc.get_reader():
print(msg)

View File

@@ -1,22 +0,0 @@
[DEFAULT]
# Set exception rule for OpenMP error to allow duplicate lib initialization
KMP_DUPLICATE_LIB_OK=TRUE
# Export OpenAI API Key
OPENAI_APIKEY=
# Export Whisper Model Size
WHISPER_MODEL_SIZE=tiny
WHISPER_REAL_TIME_MODEL_SIZE=tiny
# AWS config
AWS_ACCESS_KEY=***REMOVED***
AWS_SECRET_KEY=***REMOVED***
BUCKET_NAME='reflector-bucket'
# Summarizer config
SUMMARY_MODEL=facebook/bart-large-cnn
INPUT_ENCODING_MAX_LENGTH=1024
MAX_LENGTH=2048
BEAM_SIZE=6
MAX_CHUNK_LENGTH=1024
SUMMARIZE_USING_CHUNKS=YES
# Audio device
BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME=aggregator
AV_FOUNDATION_DEVICE_ID=2

View File

@@ -1,10 +1,11 @@
import os
import subprocess
import sys
from loguru import logger
# Get the input file name from the command line argument
input_file = sys.argv[1]
input_file = sys.argv[1]
# example use: python 0-reflector-local.py input.m4a agenda.txt
# Get the agenda file name from the command line argument if provided
@@ -21,7 +22,7 @@ if not os.path.exists(agenda_file):
# Check if the input file is .m4a, if so convert to .mp4
if input_file.endswith(".m4a"):
subprocess.run(["ffmpeg", "-i", input_file, f"{input_file}.mp4"])
input_file = f"{input_file}.mp4"
input_file = f"{input_file}.mp4"
# Run the first script to generate the transcript
subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_file}_transcript.txt"])
@@ -30,4 +31,4 @@ subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_fil
subprocess.run(["python3", "2-agenda-transcript-diff.py", agenda_file, f"{input_file}_transcript.txt"])
# Run the third script to summarize the transcript
subprocess.run(["python3", "3-transcript-summarizer.py", f"{input_file}_transcript.txt", f"{input_file}_summary.txt"])
subprocess.run(["python3", "3-transcript-summarizer.py", f"{input_file}_transcript.txt", f"{input_file}_summary.txt"])

View File

@@ -1,11 +1,13 @@
import argparse
import os
import moviepy.editor
from loguru import logger
import whisper
from loguru import logger
WHISPER_MODEL_SIZE = "base"
def init_argparse() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
usage="%(prog)s <LOCATION> <OUTPUT>",
@@ -15,6 +17,7 @@ def init_argparse() -> argparse.ArgumentParser:
parser.add_argument("output", help="Output file path")
return parser
def main():
import sys
sys.setrecursionlimit(10000)
@@ -26,10 +29,11 @@ def main():
logger.info(f"Processing file: {media_file}")
# Check if the media file is a valid audio or video file
if os.path.isfile(media_file) and not media_file.endswith(('.mp3', '.wav', '.ogg', '.flac', '.mp4', '.avi', '.flv')):
if os.path.isfile(media_file) and not media_file.endswith(
('.mp3', '.wav', '.ogg', '.flac', '.mp4', '.avi', '.flv')):
logger.error(f"Invalid file format: {media_file}")
return
# If the media file we just retrieved is an audio file then skip extraction step
audio_filename = media_file
logger.info(f"Found audio-only file, skipping audio extraction")
@@ -53,5 +57,6 @@ def main():
transcript_file.write(whisper_result["text"])
transcript_file.close()
if __name__ == "__main__":
main()

View File

@@ -1,7 +1,9 @@
import argparse
import spacy
from loguru import logger
# Define the paths for agenda and transcription files
def init_argparse() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
@@ -11,6 +13,8 @@ def init_argparse() -> argparse.ArgumentParser:
parser.add_argument("agenda", help="Location of the agenda file")
parser.add_argument("transcription", help="Location of the transcription file")
return parser
args = init_argparse().parse_args()
agenda_path = args.agenda
transcription_path = args.transcription
@@ -19,7 +23,7 @@ transcription_path = args.transcription
spaCy_model = "en_core_web_md"
nlp = spacy.load(spaCy_model)
nlp.add_pipe('sentencizer')
logger.info("Loaded spaCy model " + spaCy_model )
logger.info("Loaded spaCy model " + spaCy_model)
# Load the agenda
with open(agenda_path, "r") as f:

View File

@@ -1,11 +1,14 @@
import argparse
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize
from heapq import nlargest
from loguru import logger
# Function to initialize the argument parser
def init_argparse():
parser = argparse.ArgumentParser(
@@ -17,12 +20,14 @@ def init_argparse():
parser.add_argument("--num_sentences", type=int, default=5, help="Number of sentences to include in the summary")
return parser
# Function to read the input transcript file
def read_transcript(file_path):
with open(file_path, "r") as file:
transcript = file.read()
return transcript
# Function to preprocess the text by removing stop words and special characters
def preprocess_text(text):
stop_words = set(stopwords.words('english'))
@@ -30,6 +35,7 @@ def preprocess_text(text):
words = [w.lower() for w in words if w.isalpha() and w.lower() not in stop_words]
return words
# Function to score each sentence based on the frequency of its words and return the top sentences
def summarize_text(text, num_sentences):
# Tokenize the text into sentences
@@ -61,6 +67,7 @@ def summarize_text(text, num_sentences):
return " ".join(summary)
def main():
# Initialize the argument parser and parse the arguments
parser = init_argparse()
@@ -82,5 +89,6 @@ def main():
logger.info("Summarization completed")
if __name__ == "__main__":
main()

View File

@@ -1,15 +1,18 @@
import argparse
import os
import tempfile
import moviepy.editor
import nltk
import whisper
from loguru import logger
from transformers import BartTokenizer, BartForConditionalGeneration
import whisper
import nltk
nltk.download('punkt', quiet=True)
WHISPER_MODEL_SIZE = "base"
def init_argparse() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
@@ -30,6 +33,7 @@ def init_argparse() -> argparse.ArgumentParser:
return parser
# NLTK chunking function
def chunk_text(txt, max_chunk_length=500):
"Split text into smaller chunks."
@@ -45,6 +49,7 @@ def chunk_text(txt, max_chunk_length=500):
chunks.append(current_chunk.strip())
return chunks
# BART summary function
def summarize_chunks(chunks, tokenizer, model):
summaries = []
@@ -56,6 +61,7 @@ def summarize_chunks(chunks, tokenizer, model):
summaries.append(summary)
return summaries
def main():
import sys
sys.setrecursionlimit(10000)
@@ -103,7 +109,7 @@ def main():
chunks = chunk_text(whisper_result['text'])
logger.info(
f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
logger.info(f"Writing summary text in {args.language} to: {args.output}")
with open(args.output, 'w') as f:
@@ -114,5 +120,6 @@ def main():
logger.info("Summarization completed")
if __name__ == "__main__":
main()

View File

@@ -26,7 +26,7 @@ networkx==3.1
numba==0.57.0
numpy==1.24.3
openai==0.27.7
openai-whisper @ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8
openai-whisper@ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8
Pillow==9.5.0
proglog==0.1.10
pytube==15.0.0
@@ -56,5 +56,4 @@ cached_property==1.5.2
stamina==23.1.0
httpx==0.24.1
sortedcontainers==2.4.0
openai-whisper @ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8
https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz

View File

@@ -1,15 +1,24 @@
#!/bin/bash
# Directory to search for Python files
directory="."
cwd=$(pwd)
last_component="${cwd##*/}"
if [ "$last_component" = "reflector" ]; then
directory="./artefacts"
elif [ "$last_component" = "scripts" ]; then
directory="../artefacts"
fi
# Pattern to match Python files (e.g., "*.py" for all .py files)
text_file_pattern="transcript_*.txt"
transcript_file_pattern="transcript_*.txt"
summary_file_pattern="summary_*.txt"
pickle_file_pattern="*.pkl"
html_file_pattern="*.html"
png_file_pattern="wordcloud*.png"
find "$directory" -type f -name "$text_file_pattern" -delete
find "$directory" -type f -name "$transcript_file_pattern" -delete
find "$directory" -type f -name "$summary_file_pattern" -delete
find "$directory" -type f -name "$pickle_file_pattern" -delete
find "$directory" -type f -name "$html_file_pattern" -delete
find "$directory" -type f -name "$png_file_pattern" -delete

View File

@@ -1,9 +1,6 @@
import asyncio
import datetime
import io
import json
import logging
import sys
import uuid
import wave
from concurrent.futures import ThreadPoolExecutor
@@ -13,24 +10,21 @@ from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay
from av import AudioFifo
from loguru import logger
from whisper_jax import FlaxWhisperPipline
from utils.server_utils import run_in_executor
logger = logging.getLogger(__name__)
transcription = ""
from utils.run_utils import run_in_executor
pcs = set()
relay = MediaRelay()
data_channel = None
total_bytes_handled = 0
pipeline = FlaxWhisperPipline("openai/whisper-tiny", dtype=jnp.float16, batch_size=16)
pipeline = FlaxWhisperPipline("openai/whisper-tiny",
dtype=jnp.float16,
batch_size=16)
CHANNELS = 2
RATE = 48000
audio_buffer = AudioFifo()
start_time = datetime.datetime.now()
executor = ThreadPoolExecutor()
@@ -40,30 +34,12 @@ def channel_log(channel, t, message):
def channel_send(channel, message):
# channel_log(channel, ">", message)
global start_time
if channel:
channel.send(message)
print(
"Bytes handled :",
total_bytes_handled,
" Time : ",
datetime.datetime.now() - start_time,
)
def get_transcription(frames):
print("Transcribing..")
# samples = np.ndarray(
# np.concatenate([f.to_ndarray() for f in frames], axis=None),
# dtype=np.float32,
# )
# whisper_result = pipeline(
# {
# "array": samples,
# "sampling_rate": 48000,
# },
# return_timestamps=True,
# )
out_file = io.BytesIO()
wf = wave.open(out_file, "wb")
wf.setnchannels(CHANNELS)
@@ -73,8 +49,6 @@ def get_transcription(frames):
for frame in frames:
wf.writeframes(b"".join(frame.to_ndarray()))
wf.close()
global total_bytes_handled
total_bytes_handled += sys.getsizeof(wf)
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
with open("test_exec.txt", "a") as f:
f.write(whisper_result["text"])
@@ -98,19 +72,19 @@ class AudioStreamTrack(MediaStreamTrack):
audio_buffer.write(frame)
if local_frames := audio_buffer.read_many(256 * 960, partial=False):
whisper_result = run_in_executor(
get_transcription, local_frames, executor=executor
get_transcription, local_frames, executor=executor
)
whisper_result.add_done_callback(
lambda f: channel_send(data_channel, str(whisper_result.result()))
if (f.result())
else None
lambda f: channel_send(data_channel,
str(whisper_result.result()))
if (f.result())
else None
)
return frame
async def offer(request):
params = await request.json()
print("Request received")
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection()
@@ -120,55 +94,47 @@ async def offer(request):
def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args)
log_info("Created for %s", request.remote)
log_info("Created for " + request.remote)
@pc.on("datachannel")
def on_datachannel(channel):
global data_channel, start_time
global data_channel
data_channel = channel
channel_log(channel, "-", "created by remote party")
start_time = datetime.datetime.now()
@channel.on("message")
def on_message(message):
channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("ping"):
# reply
channel_send(channel, "pong" + message[4:])
@pc.on("connectionstatechange")
async def on_connectionstatechange():
log_info("Connection state is %s", pc.connectionState)
log_info("Connection state is " + pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
@pc.on("track")
def on_track(track):
print("Track %s received" % track.kind)
log_info("Track %s received", track.kind)
# Trials to listen to the correct track
log_info("Track " + track.kind + " received")
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# pc.addTrack(AudioStreamTrack(track))
# handle offer
await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
print("Response sent")
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type}
),
)
async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()

View File

@@ -1,54 +1,38 @@
import asyncio
import configparser
import datetime
import io
import json
import logging
import os
import threading
import uuid
import wave
from concurrent.futures import ThreadPoolExecutor
import jax.numpy as jnp
from aiohttp import webq
import requests
from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import (MediaRelay)
from aiortc.contrib.media import MediaRelay
from av import AudioFifo
from sortedcontainers import SortedDict
from whisper_jax import FlaxWhisperPipline
from utils.server_utils import Mutex
from utils.log_utils import logger
from utils.run_utils import config, Mutex
ROOT = os.path.dirname(__file__)
config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
logger = logging.getLogger("pc")
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"]
pcs = set()
relay = MediaRelay()
data_channel = None
sorted_message_queue = SortedDict()
CHANNELS = 2
RATE = 44100
CHUNK_SIZE = 256
audio_buffer = AudioFifo()
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
dtype=jnp.float16,
batch_size=16)
transcription = ""
start_time = datetime.datetime.now()
total_bytes_handled = 0
executor = ThreadPoolExecutor()
audio_buffer = AudioFifo()
frame_lock = Mutex(audio_buffer)
@@ -81,6 +65,7 @@ def get_transcription():
transcribe = True
if transcribe:
print("Transcribing..")
try:
sorted_message_queue[frames[0].time] = None
out_file = io.BytesIO()
@@ -94,10 +79,11 @@ def get_transcription():
wf.close()
whisper_result = pipeline(out_file.getvalue())
item = {'text': whisper_result["text"],
item = {
'text': whisper_result["text"],
'start_time': str(frames[0].time),
'time': str(datetime.datetime.now())
}
}
sorted_message_queue[frames[0].time] = str(item)
start_messaging_thread()
except Exception as e:
@@ -106,7 +92,7 @@ def get_transcription():
class AudioStreamTrack(MediaStreamTrack):
"""
A video stream track that transforms frames from an another track.
An audio stream track to send audio frames.
"""
kind = "audio"
@@ -126,15 +112,13 @@ def start_messaging_thread():
message_thread.start()
def start_transcription_thread(max_threads):
t_threads = []
def start_transcription_thread(max_threads: int):
for i in range(max_threads):
t_thread = threading.Thread(target=get_transcription, args=(i,))
t_threads.append(t_thread)
t_thread = threading.Thread(target=get_transcription)
t_thread.start()
async def offer(request):
async def offer(request: requests.Request):
params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
@@ -142,10 +126,10 @@ async def offer(request):
pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc)
def log_info(msg, *args):
def log_info(msg: str, *args):
logger.info(pc_id + " " + msg, *args)
log_info("Created for %s", request.remote)
log_info("Created for " + request.remote)
@pc.on("datachannel")
def on_datachannel(channel):
@@ -155,7 +139,7 @@ async def offer(request):
start_time = datetime.datetime.now()
@channel.on("message")
def on_message(message):
def on_message(message: str):
channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("ping"):
# reply
@@ -163,14 +147,14 @@ async def offer(request):
@pc.on("connectionstatechange")
async def on_connectionstatechange():
log_info("Connection state is %s", pc.connectionState)
log_info("Connection state is " + pc.connectionState)
if pc.connectionState == "failed":
await pc.close()
pcs.discard(pc)
@pc.on("track")
def on_track(track):
log_info("Track %s received", track.kind)
log_info("Track " + track.kind + " received")
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# handle offer
@@ -180,14 +164,15 @@ async def offer(request):
answer = await pc.createAnswer()
await pc.setLocalDescription(answer)
return web.Response(
content_type="application/json",
text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
),
content_type="application/json",
text=json.dumps({
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type
}),
)
async def on_shutdown(app):
async def on_shutdown(app: web.Application):
coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros)
pcs.clear()
@@ -199,5 +184,5 @@ if __name__ == "__main__":
start_transcription_thread(6)
app.router.add_post("/offer", offer)
web.run_app(
app, access_log=None, host="127.0.0.1", port=1250
app, access_log=None, host="127.0.0.1", port=1250
)

View File

@@ -1,7 +1,5 @@
import ast
import asyncio
import configparser
import logging
import time
import uuid
@@ -12,14 +10,11 @@ import stamina
from aiortc import (RTCPeerConnection, RTCSessionDescription)
from aiortc.contrib.media import (MediaPlayer, MediaRelay)
from utils.server_utils import Mutex
from utils.log_utils import logger
from utils.run_utils import config, Mutex
logger = logging.getLogger("pc")
file_lock = Mutex(open("test_sm_6.txt", "a"))
config = configparser.ConfigParser()
config.read('config.ini')
class StreamClient:
def __init__(
@@ -42,14 +37,15 @@ class StreamClient:
self.pcs = set()
self.time_start = None
self.queue = asyncio.Queue()
self.player = MediaPlayer(':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]),
format='avfoundation', options={'channels': '2'})
self.player = MediaPlayer(
':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]),
format='avfoundation',
options={'channels': '2'})
def stop(self):
self.loop.run_until_complete(self.signaling.close())
self.loop.run_until_complete(self.pc.close())
# self.loop.close()
print("ended")
def create_local_tracks(self, play_from):
if play_from:
@@ -58,7 +54,6 @@ class StreamClient:
else:
if self.relay is None:
self.relay = MediaRelay()
print("Created local track from microphone stream")
return self.relay.subscribe(self.player.audio), None
def channel_log(self, channel, t, message):
@@ -122,14 +117,15 @@ class StreamClient:
self.channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("pong"):
elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000
elapsed_ms = (self.current_stamp() - int(message[5:]))\
/ 1000
print(" RTT %.2f ms" % elapsed_ms)
await pc.setLocalDescription(await pc.createOffer())
sdp = {
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type
"sdp": pc.localDescription.sdp,
"type": pc.localDescription.type
}
@stamina.retry(on=httpx.HTTPError, attempts=5)
@@ -142,7 +138,7 @@ class StreamClient:
answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
await pc.setRemoteDescription(answer)
self.reader = self.worker(f"worker", self.queue)
self.reader = self.worker(f'{"worker"}', self.queue)
def get_reader(self):
return self.reader

0
utils/__init__.py Normal file
View File

View File

@@ -1,13 +1,12 @@
import configparser
import sys
import boto3
import botocore
from loguru import logger
config = configparser.ConfigParser()
config.read('config.ini')
from .log_utils import logger
from .run_utils import config
BUCKET_NAME = 'reflector-bucket'
BUCKET_NAME = config["DEFAULT"]["BUCKET_NAME"]
s3 = boto3.client('s3',
aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"],
@@ -17,8 +16,8 @@ s3 = boto3.client('s3',
def upload_files(files_to_upload):
"""
Upload a list of files to the configured S3 bucket
:param files_to_upload:
:return:
:param files_to_upload: List of files to upload
:return: None
"""
for KEY in files_to_upload:
logger.info("Uploading file " + KEY)
@@ -31,8 +30,8 @@ def upload_files(files_to_upload):
def download_files(files_to_download):
"""
Download a list of files from the configured S3 bucket
:param files_to_download:
:return:
:param files_to_download: List of files to download
:return: None
"""
for KEY in files_to_download:
logger.info("Downloading file " + KEY)
@@ -46,8 +45,6 @@ def download_files(files_to_download):
if __name__ == "__main__":
import sys
if sys.argv[1] == "download":
download_files([sys.argv[2]])
elif sys.argv[1] == "upload":

18
utils/log_utils.py Normal file
View File

@@ -0,0 +1,18 @@
import loguru
class SingletonLogger:
__instance = None
@staticmethod
def get_logger():
"""
Create or return the singleton instance for the SingletonLogger class
:return: SingletonLogger instance
"""
if not SingletonLogger.__instance:
SingletonLogger.__instance = loguru.logger
return SingletonLogger.__instance
logger = SingletonLogger.get_logger()

66
utils/run_utils.py Normal file
View File

@@ -0,0 +1,66 @@
import asyncio
import configparser
import contextlib
from functools import partial
from threading import Lock
from typing import ContextManager, Generic, TypeVar
class ReflectorConfig:
__config = None
@staticmethod
def get_config():
if ReflectorConfig.__config is None:
ReflectorConfig.__config = configparser.ConfigParser()
ReflectorConfig.__config.read('utils/config.ini')
return ReflectorConfig.__config
config = ReflectorConfig.get_config()
def run_in_executor(func, *args, executor=None, **kwargs):
"""
Run the function in an executor, unblocking the main loop
:param func: Function to be run in executor
:param args: function parameters
:param executor: executor instance [Thread | Process]
:param kwargs: Additional parameters
:return: Future of function result upon completion
"""
callback = partial(func, *args, **kwargs)
loop = asyncio.get_event_loop()
return loop.run_in_executor(executor, callback)
# Genetic type template
T = TypeVar("T")
class Mutex(Generic[T]):
"""
Mutex class to implement lock/release of a shared
protected variable
"""
def __init__(self, value: T):
"""
Create an instance of Mutex wrapper for the given resource
:param value: Shared resources to be thread protected
"""
self.__value = value
self.__lock = Lock()
@contextlib.contextmanager
def lock(self) -> ContextManager[T]:
"""
Lock the resource with a mutex to be used within a context block
The lock is automatically released on context exit
:return: Shared resource
"""
self.__lock.acquire()
try:
yield self.__value
finally:
self.__lock.release()

View File

@@ -1,28 +0,0 @@
import asyncio
import contextlib
from functools import partial
from threading import Lock
from typing import ContextManager, Generic, TypeVar
def run_in_executor(func, *args, executor=None, **kwargs):
callback = partial(func, *args, **kwargs)
loop = asyncio.get_event_loop()
return asyncio.get_event_loop().run_in_executor(executor, callback)
T = TypeVar("T")
class Mutex(Generic[T]):
def __init__(self, value: T):
self.__value = value
self.__lock = Lock()
@contextlib.contextmanager
def lock(self) -> ContextManager[T]:
self.__lock.acquire()
try:
yield self.__value
finally:
self.__lock.release()

View File

@@ -1,24 +1,22 @@
import configparser
import nltk
import torch
from loguru import logger
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import BartForConditionalGeneration, BartTokenizer
from utils.log_utils import logger
from utils.run_utils import config
nltk.download('punkt', quiet=True)
config = configparser.ConfigParser()
config.read('config.ini')
def preprocess_sentence(sentence):
stop_words = set(stopwords.words('english'))
tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
tokens = [token for token in tokens
if token.isalnum() and token not in stop_words]
return ' '.join(tokens)
@@ -52,12 +50,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
sentence1 = preprocess_sentence(sentences[i])
sentence2 = preprocess_sentence(sentences[j])
if len(sentence1) != 0 and len(sentence2) != 0:
similarity = compute_similarity(sentence1, sentence2)
similarity = compute_similarity(sentence1,
sentence2)
if similarity >= threshold:
removed_indices.add(max(i, j))
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
filtered_sentences = [sentences[i] for i in range(num_sentences)
if i not in removed_indices]
return filtered_sentences
@@ -77,11 +77,13 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
words = nltk.word_tokenize(sent)
n_gram_filter = 3
for i in range(len(words)):
if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[
i + 1:i + n_gram_filter + 2]:
if str(words[i:i + n_gram_filter]) in seen and \
seen[str(words[i:i + n_gram_filter])] == \
words[i + 1:i + n_gram_filter + 2]:
pass
else:
seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2]
seen[str(words[i:i + n_gram_filter])] = \
words[i + 1:i + n_gram_filter + 2]
temp_result += words[i]
temp_result += " "
chunk_sentences.append(temp_result)
@@ -91,9 +93,12 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
def post_process_transcription(whisper_result):
transcript_text = ""
for chunk in whisper_result["chunks"]:
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences)
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
nonduplicate_sentences = \
remove_outright_duplicate_sentences_from_chunk(chunk)
chunk_sentences = \
remove_whisper_repetitive_hallucination(nonduplicate_sentences)
similarity_matched_sentences = \
remove_almost_alike_sentences(chunk_sentences)
chunk["text"] = " ".join(similarity_matched_sentences)
transcript_text += chunk["text"]
whisper_result["text"] = transcript_text
@@ -114,18 +119,23 @@ def summarize_chunks(chunks, tokenizer, model):
input_ids = tokenizer.encode(c, return_tensors='pt')
input_ids = input_ids.to(device)
with torch.no_grad():
summary_ids = model.generate(input_ids,
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
summary_ids = \
model.generate(input_ids,
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]),
length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]),
early_stopping=True)
summary = tokenizer.decode(summary_ids[0],
skip_special_tokens=True)
summaries.append(summary)
return summaries
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
def chunk_text(text,
max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
"""
Split text into smaller chunks.
:param txt: Text to be chunked
:param text: Text to be chunked
:param max_chunk_length: length of chunk
:return: chunked texts
"""
@@ -143,7 +153,8 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])
def summarize(transcript_text, timestamp,
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
real_time=False,
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
if not summary_model:
@@ -160,9 +171,11 @@ def summarize(transcript_text, timestamp,
output_filename = "real_time_" + output_filename
if summarize_using_chunks != "YES":
inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest',
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
return_tensors='pt')
inputs = tokenizer.\
batch_encode_plus([transcript_text], truncation=True,
padding='longest',
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
return_tensors='pt')
inputs = inputs.to(device)
with torch.no_grad():
@@ -170,16 +183,17 @@ def summarize(transcript_text, timestamp,
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for
summary in summaries]
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False)
for summary in summaries]
summary = " ".join(decoded_summaries)
with open(output_filename, 'w') as f:
with open("./artefacts/" + output_filename, 'w') as f:
f.write(summary.strip() + "\n")
else:
logger.info("Breaking transcript into smaller chunks")
chunks = chunk_text(transcript_text)
logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
logger.info(f"Transcript broken into {len(chunks)} "
f"chunks of at most 500 words")
logger.info(f"Writing summary text to: {output_filename}")
with open(output_filename, 'w') as f:

View File

@@ -1,24 +1,20 @@
import ast
import collections
import configparser
import os
import pickle
from pathlib import Path
import matplotlib.pyplot as plt
import pandas as pd
import scattertext as st
import spacy
from nltk.corpus import stopwords
from wordcloud import WordCloud, STOPWORDS
config = configparser.ConfigParser()
config.read('config.ini')
from wordcloud import STOPWORDS, WordCloud
en = spacy.load('en_core_web_md')
spacy_stopwords = en.Defaults.stop_words
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).\
union(set(spacy_stopwords))
def create_wordcloud(timestamp, real_time=False):
@@ -28,7 +24,8 @@ def create_wordcloud(timestamp, real_time=False):
"""
filename = "transcript"
if real_time:
filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
filename = "real_time_" + filename + "_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
else:
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
@@ -50,11 +47,12 @@ def create_wordcloud(timestamp, real_time=False):
wordcloud_name = "wordcloud"
if real_time:
wordcloud_name = "real_time_" + wordcloud_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
wordcloud_name = "real_time_" + wordcloud_name + "_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
else:
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
plt.savefig(wordcloud_name)
plt.savefig("./artefacts/" + wordcloud_name)
def create_talk_diff_scatter_viz(timestamp, real_time=False):
@@ -70,7 +68,6 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
agenda_topics = []
agenda = []
# Load the agenda
path = Path(__file__)
with open(os.path.join(os.getcwd(), "agenda-headers.txt"), "r") as f:
for line in f.readlines():
if line.strip():
@@ -80,9 +77,11 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
# Load the transcription with timestamp
filename = ""
if real_time:
filename = "real_time_transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
filename = "./artefacts/real_time_transcript_with_timestamp_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
else:
filename = "transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
filename = "./artefacts/transcript_with_timestamp_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
with open(filename) as f:
transcription_timestamp_text = f.read()
@@ -98,7 +97,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
ts_to_topic_mapping_top_1 = {}
ts_to_topic_mapping_top_2 = {}
# Also create a mapping of the different timestamps in which each topic was covered
# Also create a mapping of the different timestamps
# in which each topic was covered
topic_to_ts_mapping_top_1 = collections.defaultdict(list)
topic_to_ts_mapping_top_2 = collections.defaultdict(list)
@@ -109,7 +109,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
topic_similarities = []
for item in range(len(agenda)):
item_doc = nlp(agenda[item])
# if not doc_transcription or not all(token.has_vector for token in doc_transcription):
# if not doc_transcription or not all
# (token.has_vector for token in doc_transcription):
if not doc_transcription:
continue
similarity = doc_transcription.similarity(item_doc)
@@ -133,8 +134,10 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
:param record:
:return:
"""
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]]
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]]
record["ts_to_topic_mapping_top_1"] = \
ts_to_topic_mapping_top_1[record["timestamp"]]
record["ts_to_topic_mapping_top_2"] = \
ts_to_topic_mapping_top_2[record["timestamp"]]
return record
df = df.apply(create_new_columns, axis=1)
@@ -155,20 +158,22 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
# Save df, mappings for further experimentation
df_name = "df"
if real_time:
df_name = "real_time_" + df_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
df_name = "real_time_" + df_name + "_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
else:
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
df.to_pickle(df_name)
df.to_pickle("./artefacts/" + df_name)
my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2,
topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2]
mappings_name = "mappings"
if real_time:
mappings_name = "real_time_" + mappings_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
mappings_name = "real_time_" + mappings_name + "_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
else:
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
pickle.dump(my_mappings, open(mappings_name, "wb"))
pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb"))
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
@@ -182,25 +187,28 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
cat_1 = topic_times[0][0]
cat_1_name = topic_times[0][0]
cat_2_name = topic_times[1][0]
if len(topic_times) > 1:
cat_1 = topic_times[0][0]
cat_1_name = topic_times[0][0]
cat_2_name = topic_times[1][0]
# Scatter plot of topics
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
corpus = st.CorpusFromParsedDocuments(
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
html = st.produce_scattertext_explorer(
corpus,
category=cat_1,
category_name=cat_1_name,
not_category_name=cat_2_name,
minimum_term_frequency=0, pmi_threshold_coefficient=0,
width_in_pixels=1000,
transform=st.Scalers.dense_rank
)
if real_time:
open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
else:
open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
# Scatter plot of topics
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
corpus = st.CorpusFromParsedDocuments(
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
html = st.produce_scattertext_explorer(
corpus,
category=cat_1,
category_name=cat_1_name,
not_category_name=cat_2_name,
minimum_term_frequency=0, pmi_threshold_coefficient=0,
width_in_pixels=1000,
transform=st.Scalers.dense_rank
)
if real_time:
open('./artefacts/real_time_scatter_' +
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
else:
open('./artefacts/scatter_' +
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)

View File

@@ -1,11 +1,10 @@
#!/usr/bin/env python3
# summarize https://www.youtube.com/watch?v=imzTxoEDH_g --transcript=transcript.txt summary.txt
# summarize https://www.youtube.com/watch?v=imzTxoEDH_g
# summarize https://www.sprocket.org/video/cheesemaking.mp4 summary.txt
# summarize podcast.mp3 summary.txt
import argparse
import configparser
import os
import re
import subprocess
@@ -15,23 +14,19 @@ from urllib.parse import urlparse
import jax.numpy as jnp
import moviepy.editor
import moviepy.editor
import nltk
import yt_dlp as youtube_dl
from loguru import logger
from whisper_jax import FlaxWhisperPipline
from utils.file_utilities import upload_files, download_files
from utils.text_utilities import summarize, post_process_transcription
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
from utils.file_utils import download_files, upload_files
from utils.log_utils import logger
from utils.run_utils import config
from utils.text_utilities import post_process_transcription, summarize
from utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
# Configurations can be found in config.ini. Set them properly before executing
config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
NOW = datetime.now()
@@ -42,12 +37,17 @@ def init_argparse() -> argparse.ArgumentParser:
:return: parser object
"""
parser = argparse.ArgumentParser(
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
description="Creates a transcript of a video or audio file, then summarizes it using ChatGPT."
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
description="Creates a transcript of a video or audio file, then"
" summarizes it using ChatGPT."
)
parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str,
default="english", choices=['english', 'spanish', 'french', 'german', 'romanian'])
parser.add_argument("-l", "--language",
help="Language that the summary should be written in",
type=str,
default="english",
choices=['english', 'spanish', 'french', 'german',
'romanian'])
parser.add_argument("location")
return parser
@@ -65,22 +65,24 @@ def main():
media_file = ""
if url.scheme == 'http' or url.scheme == 'https':
# Check if we're being asked to retreive a YouTube URL, which is handled
# diffrently, as we'll use a secondary site to download the video first.
# Check if we're being asked to retreive a YouTube URL, which is
# handled differently, as we'll use a secondary site to download
# the video first.
if re.search('youtube.com', url.netloc, re.IGNORECASE):
# Download the lowest resolution YouTube video (since we're just interested in the audio).
# Download the lowest resolution YouTube video
# (since we're just interested in the audio).
# It will be saved to the current directory.
logger.info("Downloading YouTube video at url: " + args.location)
# Create options for the download
ydl_opts = {
'format': 'bestaudio/best',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '192',
}],
'outtmpl': 'audio', # Specify the output file path and name
'format': 'bestaudio/best',
'postprocessors': [{
'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3',
'preferredquality': '192',
}],
'outtmpl': 'audio', # Specify output file path and name
}
# Download the audio
@@ -90,7 +92,8 @@ def main():
logger.info("Saved downloaded YouTube video to: " + media_file)
else:
# XXX - Download file using urllib, check if file is audio/video using python-magic
# XXX - Download file using urllib, check if file is
# audio/video using python-magic
logger.info(f"Downloading file at url: {args.location}")
logger.info(" XXX - This method hasn't been implemented yet.")
elif url.scheme == '':
@@ -101,7 +104,7 @@ def main():
if media_file.endswith(".m4a"):
subprocess.run(["ffmpeg", "-i", media_file, f"{media_file}.mp4"])
input_file = f"{media_file}.mp4"
media_file = f"{media_file}.mp4"
else:
print("Unsupported URL scheme: " + url.scheme)
quit()
@@ -110,19 +113,21 @@ def main():
if not media_file.endswith(".mp3"):
try:
video = moviepy.editor.VideoFileClip(media_file)
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3",
delete=False).name
video.audio.write_audiofile(audio_filename, logger=None)
logger.info(f"Extracting audio to: {audio_filename}")
# Handle audio only file
except:
except Exception:
audio = moviepy.editor.AudioFileClip(media_file)
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3",
delete=False).name
audio.write_audiofile(audio_filename, logger=None)
else:
audio_filename = media_file
logger.info("Finished extracting audio")
logger.info("Transcribing")
# Convert the audio to text using the OpenAI Whisper model
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
dtype=jnp.float16,
@@ -136,10 +141,12 @@ def main():
for chunk in whisper_result["chunks"]:
transcript_text += chunk["text"]
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") +
".txt", "w") as transcript_file:
transcript_file.write(transcript_text)
with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt",
with open("./artefacts/transcript_with_timestamp_" +
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt",
"w") as transcript_file_timestamps:
transcript_file_timestamps.write(str(whisper_result))
@@ -150,13 +157,14 @@ def main():
create_talk_diff_scatter_viz(NOW)
# S3 : Push artefacts to S3 bucket
prefix = "./artefacts/"
suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S")
files_to_upload = ["transcript_" + suffix + ".txt",
"transcript_with_timestamp_" + suffix + ".txt",
"df_" + suffix + ".pkl",
"wordcloud_" + suffix + ".png",
"mappings_" + suffix + ".pkl",
"scatter_" + suffix + ".html"]
files_to_upload = [prefix + "transcript_" + suffix + ".txt",
prefix + "transcript_with_timestamp_" + suffix + ".txt",
prefix + "df_" + suffix + ".pkl",
prefix + "wordcloud_" + suffix + ".png",
prefix + "mappings_" + suffix + ".pkl",
prefix + "scatter_" + suffix + ".html"]
upload_files(files_to_upload)
summarize(transcript_text, NOW, False, False)

View File

@@ -1,23 +1,20 @@
#!/usr/bin/env python3
import configparser
import time
import wave
from datetime import datetime
import jax.numpy as jnp
import pyaudio
from loguru import logger
from pynput import keyboard
from termcolor import colored
from whisper_jax import FlaxWhisperPipline
from utils.file_utilities import upload_files
from utils.text_utilities import summarize, post_process_transcription
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
config = configparser.ConfigParser()
config.read('config.ini')
from utils.file_utils import upload_files
from utils.log_utils import logger
from utils.run_utils import config
from utils.text_utilities import post_process_transcription, summarize
from utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
@@ -33,19 +30,21 @@ def main():
p = pyaudio.PyAudio()
AUDIO_DEVICE_ID = -1
for i in range(p.get_device_count()):
if p.get_device_info_by_index(i)["name"] == config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]:
if p.get_device_info_by_index(i)["name"] == \
config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]:
AUDIO_DEVICE_ID = i
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
stream = p.open(
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=FRAMES_PER_BUFFER,
input_device_index=int(audio_devices['index'])
format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=FRAMES_PER_BUFFER,
input_device_index=int(audio_devices['index'])
)
pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"],
pipeline = FlaxWhisperPipline("openai/whisper-" +
config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"],
dtype=jnp.float16,
batch_size=16)
@@ -72,7 +71,8 @@ def main():
frames = []
start_time = time.time()
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
data = stream.read(FRAMES_PER_BUFFER,
exception_on_overflow=False)
frames.append(data)
end_time = time.time()
@@ -90,7 +90,8 @@ def main():
if end is None:
end = start + 15.0
duration = end - start
item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration),
item = {'timestamp': (last_transcribed_time,
last_transcribed_time + duration),
'text': whisper_result['text'],
'stats': (str(end_time - start_time), str(duration))
}
@@ -100,15 +101,19 @@ def main():
print(colored("<START>", "yellow"))
print(colored(whisper_result['text'], 'green'))
print(colored("<END> Recorded duration: " + str(end_time - start_time) + " | Transcribed duration: " +
print(colored("<END> Recorded duration: " +
str(end_time - start_time) +
" | Transcribed duration: " +
str(duration), "yellow"))
except Exception as e:
print(e)
finally:
with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
with open("real_time_transcript_" +
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
f.write(transcription)
with open("real_time_transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
with open("real_time_transcript_with_timestamp_" +
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
transcript_with_timestamp["text"] = transcription
f.write(str(transcript_with_timestamp))