Merge pull request #25 from Monadical-SAS/whisper-jax-gokul

Refactor codebase and clean-up code
This commit is contained in:
projects-g
2023-07-11 18:52:54 +05:30
committed by GitHub
23 changed files with 454 additions and 401 deletions

5
.gitignore vendored
View File

@@ -160,9 +160,6 @@ cython_debug/
#.idea/ #.idea/
*.mp4 *.mp4
summary.txt
transcript.txt
transcript_timestamps.txt
*.html *.html
*.pkl *.pkl
transcript_*.txt transcript_*.txt
@@ -174,4 +171,6 @@ test_samples/
*.mp3 *.mp3
*.m4a *.m4a
.DS_Store/ .DS_Store/
.DS_Store
.vscode/ .vscode/
artefacts/

125
README.md
View File

@@ -1,32 +1,34 @@
# Reflector # Reflector
This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM. This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting
panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM.
The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target agenda/objectives to the actual discussion live.
The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target
agenda/objectives to the actual discussion live.
**S3 bucket:** **S3 bucket:**
Everything you need for S3 is already configured in config.ini. Only edit it if you need to change it deliberately. Everything you need for S3 is already configured in config.ini. Only edit it if you need to change it deliberately.
S3 bucket name is mentioned in config.ini. All transfers will happen between this bucket and the local computer where the S3 bucket name is mentioned in config.ini. All transfers will happen between this bucket and the local computer where
script is run. You need AWS_ACCESS_KEY / AWS_SECRET_KEY to authenticate your calls to S3 (done in config.ini). the
script is run. You need AWS_ACCESS_KEY / AWS_SECRET_KEY to authenticate your calls to S3 (done in config.ini).
For AWS S3 Web UI, For AWS S3 Web UI,
1) Login to AWS management console. 1) Login to AWS management console.
2) Search for S3 in the search bar at the top. 2) Search for S3 in the search bar at the top.
3) Navigate to list the buckets under the current account, if needed and choose your bucket [```reflector-bucket```] 3) Navigate to list the buckets under the current account, if needed and choose your bucket [```reflector-bucket```]
4) You should be able to see items in the bucket. You can upload/download files here directly. 4) You should be able to see items in the bucket. You can upload/download files here directly.
For CLI,
For CLI,
Refer to the FILE UTIL section below. Refer to the FILE UTIL section below.
**FILE UTIL MODULE:** **FILE UTIL MODULE:**
A file_util module has been created to upload/download files with AWS S3 bucket pre-configured using config.ini. A file_util module has been created to upload/download files with AWS S3 bucket pre-configured using config.ini.
Though not needed for the workflow, if you need to upload / download file, separately on your own, apart from the pipeline workflow in the script, you can do so by : Though not needed for the workflow, if you need to upload / download file, separately on your own, apart from the
pipeline workflow in the script, you can do so by :
Upload: Upload:
@@ -39,37 +41,37 @@ Download:
If you want to access the S3 artefacts, from another machine, you can either use the python file_util with the commands If you want to access the S3 artefacts, from another machine, you can either use the python file_util with the commands
mentioned above or simply use the GUI of AWS Management Console. mentioned above or simply use the GUI of AWS Management Console.
To setup,
To setup,
1) Check values in config.ini file. Specifically add your OPENAI_APIKEY if you plan to use OpenAI API requests. 1) Check values in config.ini file. Specifically add your OPENAI_APIKEY if you plan to use OpenAI API requests.
2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.] 2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in
Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.]
NOTE: If you don't have portaudio installed already, run ```brew install portaudio``` NOTE: If you don't have portaudio installed already, run ```brew install portaudio```
3) Run the script setup_depedencies.sh. 3) Run the script setup_depedencies.sh.
``` chmod +x setup_dependencies.sh ``` ``` chmod +x setup_dependencies.sh ```
``` sh setup_dependencies.sh <ENV>``` ``` sh setup_dependencies.sh <ENV>```
ENV refers to the intended environment for JAX. JAX is available in several
ENV refers to the intended environment for JAX. JAX is available in several variants, [CPU | GPU | Colab TPU | Google Cloud TPU] variants, [CPU | GPU | Colab TPU | Google Cloud TPU]
```ENV``` is :
cpu -> JAX CPU installation
cuda11 -> JAX CUDA 11.x version ```ENV``` is :
cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```) cpu -> JAX CPU installation
cuda11 -> JAX CUDA 11.x version
cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```)
```sh setup_dependencies.sh cuda12``` ```sh setup_dependencies.sh cuda12```
4) If not already done, install ffmpeg. ```brew install ffmpeg``` 4) If not already done, install ffmpeg. ```brew install ffmpeg```
For NLTK SSL error, check [here](https://stackoverflow.com/questions/38916452/nltk-download-ssl-certificate-verify-failed) For NLTK SSL error,
check [here](https://stackoverflow.com/questions/38916452/nltk-download-ssl-certificate-verify-failed)
5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it. 5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
@@ -79,83 +81,92 @@ You can even run it on local file or a file in your configured S3 bucket.
``` python3 whisjax.py "startup.mp4"``` ``` python3 whisjax.py "startup.mp4"```
The script will take care of a few cases like youtube file, local file, video file, audio-only file, The script will take care of a few cases like youtube file, local file, video file, audio-only file,
file in S3, etc. If local file is not present, it can automatically take the file from S3. file in S3, etc. If local file is not present, it can automatically take the file from S3.
**OFFLINE WORKFLOW:** **OFFLINE WORKFLOW:**
1) Specify the input source file] from a local, youtube link or upload to S3 if needed and pass it as input to the script.If the source file is in 1) Specify the input source file] from a local, youtube link or upload to S3 if needed and pass it as input to the
script.If the source file is in
```.m4a``` format, it will get converted to ```.mp4``` automatically. ```.m4a``` format, it will get converted to ```.mp4``` automatically.
2) Keep the agenda header topics in a local file named ```agenda-headers.txt```. This needs to be present where the script is run. 2) Keep the agenda header topics in a local file named ```agenda-headers.txt```. This needs to be present where the
script is run.
This version of the pipeline compares covered agenda topics using agenda headers in the following format. This version of the pipeline compares covered agenda topics using agenda headers in the following format.
1) ```agenda_topic : <short description>``` 1) ```agenda_topic : <short description>```
3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the 3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the
topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact called topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact
```df_<timestamp>.pkl``` , you can load the df and choose different topics to plot. You can filter using certain words to search for the called
```df_<timestamp>.pkl``` , you can load the df and choose different topics to plot. You can filter using certain
words to search for the
transcriptions and you can see the top influencers and characteristic in each topic we have chosen to plot in the transcriptions and you can see the top influencers and characteristic in each topic we have chosen to plot in the
interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with, named interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with,
named
```Viz_experiments.ipynb```. ```Viz_experiments.ipynb```.
4) Run the script. The script automatically transcribes, summarizes and creates a scatter plot of words & topics in the form of an interactive 4) Run the script. The script automatically transcribes, summarizes and creates a scatter plot of words & topics in the
HTML file, a sample word cloud and uploads them to the S3 bucket form of an interactive
HTML file, a sample word cloud and uploads them to the S3 bucket
5) Additional artefacts pushed to S3: 5) Additional artefacts pushed to S3:
1) HTML visualization file 1) HTML visualization file
2) pandas df in pickle format for others to collaborate and make their own visualizations 2) pandas df in pickle format for others to collaborate and make their own visualizations
3) Summary, transcript and transcript with timestamps file in text format. 3) Summary, transcript and transcript with timestamps file in text format.
The script also creates 2 types of mappings. The script also creates 2 types of mappings.
1) Timestamp -> The top 2 matched agenda topic 1) Timestamp -> The top 2 matched agenda topic
2) Topic -> All matched timestamps in the transcription 2) Topic -> All matched timestamps in the transcription
Other visualizations can be planned based on available artefacts or new ones can be created. Refer the section ```Viz-experiments```.
Other visualizations can be planned based on available artefacts or new ones can be created. Refer the
section ```Viz-experiments```.
**Visualization experiments:** **Visualization experiments:**
This is a jupyter notebook playground with template instructions on handling the metadata and data artefacts generated from the This is a jupyter notebook playground with template instructions on handling the metadata and data artefacts generated
pipeline. Follow the instructions given and tweak your own logic into it or use it as a playground to experiment libraries and from the
pipeline. Follow the instructions given and tweak your own logic into it or use it as a playground to experiment
libraries and
visualizations on top of the metadata. visualizations on top of the metadata.
**WHISPER-JAX REALTIME TRANSCRIPTION PIPELINE:** **WHISPER-JAX REALTIME TRANSCRIPTION PIPELINE:**
We also support a provision to perform real-time transcripton using whisper-jax pipeline. But, there are We also support a provision to perform real-time transcripton using whisper-jax pipeline. But, there are
a few pre-requisites before you run it on your local machine. The instructions are for a few pre-requisites before you run it on your local machine. The instructions are for
configuring on a MacOS. configuring on a MacOS.
We need to way to route audio from an application opened via the browser, ex. "Whereby" and audio from your local We need to way to route audio from an application opened via the browser, ex. "Whereby" and audio from your local
microphone input which you will be using for speaking. We use [Blackhole](https://github.com/ExistentialAudio/BlackHole). microphone input which you will be using for speaking. We
use [Blackhole](https://github.com/ExistentialAudio/BlackHole).
1) Install Blackhole-2ch (2 ch is enough) by 1 of 2 options listed. 1) Install Blackhole-2ch (2 ch is enough) by 1 of 2 options listed.
2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and 2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and
local microphone input. local microphone input.
Be sure to mirror the settings given ![here](./images/aggregate_input.png) Be sure to mirror the settings given ![here](./images/aggregate_input.png)
3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device) 3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device)
Refer ![here](./images/multi-output.png) Refer ![here](./images/multi-output.png)
4) Set the aggregator input device name created in step 2 in config.ini as ```BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME``` 4) Set the aggregator input device name created in step 2 in config.ini as ```BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME```
5) Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and 5) Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and
Input tabs. Input tabs.
6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen to 6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen
and the output should be fed back to your specified output devices if everything is configured properly. Check this to
before trying out the trial. and the output should be fed back to your specified output devices if everything is configured properly. Check this
before trying out the trial.
**Permissions:** **Permissions:**
You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in
```System Preferences -> Privacy & Security -> Microphone```, ```System Preferences -> Privacy & Security -> Microphone```,
```System Preferences -> Privacy & Security -> Accessibility```, ```System Preferences -> Privacy & Security -> Accessibility```,
```System Preferences -> Privacy & Security -> Input Monitoring```. ```System Preferences -> Privacy & Security -> Input Monitoring```.
From the reflector root folder, From the reflector root folder,
run ```python3 whisjax_realtime.py``` run ```python3 whisjax_realtime.py```
The transcription text should be written to ```real_time_transcription_<timestamp>.txt```. The transcription text should be written to ```real_time_transcription_<timestamp>.txt```.
NEXT STEPS: NEXT STEPS:
1) Create a RunPod setup for this feature (mentioned in 1 & 2) and test it end-to-end 1) Create a RunPod setup for this feature (mentioned in 1 & 2) and test it end-to-end

View File

@@ -1,35 +1,33 @@
import argparse import argparse
import asyncio import asyncio
import logging
import signal import signal
from aiortc.contrib.signaling import (add_signaling_arguments, from aiortc.contrib.signaling import (add_signaling_arguments,
create_signaling) create_signaling)
from stream_client import StreamClient from stream_client import StreamClient
from utils.log_utils import logger
logger = logging.getLogger("pc")
async def main(): async def main():
parser = argparse.ArgumentParser(description="Data channels ping/pong") parser = argparse.ArgumentParser(description="Data channels ping/pong")
parser.add_argument( parser.add_argument(
"--url", type=str, nargs="?", default="http://127.0.0.1:1250/offer" "--url", type=str, nargs="?", default="http://127.0.0.1:1250/offer"
) )
parser.add_argument( parser.add_argument(
"--ping-pong", "--ping-pong",
help="Benchmark data channel with ping pong", help="Benchmark data channel with ping pong",
type=eval, type=eval,
choices=[True, False], choices=[True, False],
default="False", default="False",
) )
parser.add_argument( parser.add_argument(
"--play-from", "--play-from",
type=str, type=str,
default="", default="",
) )
add_signaling_arguments(parser) add_signaling_arguments(parser)
@@ -39,34 +37,33 @@ async def main():
async def shutdown(signal, loop): async def shutdown(signal, loop):
"""Cleanup tasks tied to the service's shutdown.""" """Cleanup tasks tied to the service's shutdown."""
logging.info(f"Received exit signal {signal.name}...") logger.info(f"Received exit signal {signal.name}...")
logging.info("Closing database connections") logger.info("Closing database connections")
logging.info("Nacking outstanding messages") logger.info("Nacking outstanding messages")
tasks = [t for t in asyncio.all_tasks() if t is not tasks = [t for t in asyncio.all_tasks() if t is not
asyncio.current_task()] asyncio.current_task()]
[task.cancel() for task in tasks] [task.cancel() for task in tasks]
logging.info(f"Cancelling {len(tasks)} outstanding tasks") logger.info(f"Cancelling {len(tasks)} outstanding tasks")
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
logging.info(f"Flushing metrics") logger.info(f'{"Flushing metrics"}')
loop.stop() loop.stop()
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT) signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
for s in signals: for s in signals:
loop.add_signal_handler( loop.add_signal_handler(
s, lambda s=s: asyncio.create_task(shutdown(s, loop))) s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
# Init client # Init client
sc = StreamClient( sc = StreamClient(
signaling=signaling, signaling=signaling,
url=args.url, url=args.url,
play_from=args.play_from, play_from=args.play_from,
ping_pong=args.ping_pong ping_pong=args.ping_pong
) )
await sc.start() await sc.start()
print("Stream client started")
async for msg in sc.get_reader(): async for msg in sc.get_reader():
print(msg) print(msg)

View File

@@ -1,22 +0,0 @@
[DEFAULT]
# Set exception rule for OpenMP error to allow duplicate lib initialization
KMP_DUPLICATE_LIB_OK=TRUE
# Export OpenAI API Key
OPENAI_APIKEY=
# Export Whisper Model Size
WHISPER_MODEL_SIZE=tiny
WHISPER_REAL_TIME_MODEL_SIZE=tiny
# AWS config
AWS_ACCESS_KEY=***REMOVED***
AWS_SECRET_KEY=***REMOVED***
BUCKET_NAME='reflector-bucket'
# Summarizer config
SUMMARY_MODEL=facebook/bart-large-cnn
INPUT_ENCODING_MAX_LENGTH=1024
MAX_LENGTH=2048
BEAM_SIZE=6
MAX_CHUNK_LENGTH=1024
SUMMARIZE_USING_CHUNKS=YES
# Audio device
BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME=aggregator
AV_FOUNDATION_DEVICE_ID=2

View File

@@ -1,10 +1,11 @@
import os import os
import subprocess import subprocess
import sys import sys
from loguru import logger from loguru import logger
# Get the input file name from the command line argument # Get the input file name from the command line argument
input_file = sys.argv[1] input_file = sys.argv[1]
# example use: python 0-reflector-local.py input.m4a agenda.txt # example use: python 0-reflector-local.py input.m4a agenda.txt
# Get the agenda file name from the command line argument if provided # Get the agenda file name from the command line argument if provided
@@ -21,7 +22,7 @@ if not os.path.exists(agenda_file):
# Check if the input file is .m4a, if so convert to .mp4 # Check if the input file is .m4a, if so convert to .mp4
if input_file.endswith(".m4a"): if input_file.endswith(".m4a"):
subprocess.run(["ffmpeg", "-i", input_file, f"{input_file}.mp4"]) subprocess.run(["ffmpeg", "-i", input_file, f"{input_file}.mp4"])
input_file = f"{input_file}.mp4" input_file = f"{input_file}.mp4"
# Run the first script to generate the transcript # Run the first script to generate the transcript
subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_file}_transcript.txt"]) subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_file}_transcript.txt"])
@@ -30,4 +31,4 @@ subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_fil
subprocess.run(["python3", "2-agenda-transcript-diff.py", agenda_file, f"{input_file}_transcript.txt"]) subprocess.run(["python3", "2-agenda-transcript-diff.py", agenda_file, f"{input_file}_transcript.txt"])
# Run the third script to summarize the transcript # Run the third script to summarize the transcript
subprocess.run(["python3", "3-transcript-summarizer.py", f"{input_file}_transcript.txt", f"{input_file}_summary.txt"]) subprocess.run(["python3", "3-transcript-summarizer.py", f"{input_file}_transcript.txt", f"{input_file}_summary.txt"])

View File

@@ -1,11 +1,13 @@
import argparse import argparse
import os import os
import moviepy.editor import moviepy.editor
from loguru import logger
import whisper import whisper
from loguru import logger
WHISPER_MODEL_SIZE = "base" WHISPER_MODEL_SIZE = "base"
def init_argparse() -> argparse.ArgumentParser: def init_argparse() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
usage="%(prog)s <LOCATION> <OUTPUT>", usage="%(prog)s <LOCATION> <OUTPUT>",
@@ -15,6 +17,7 @@ def init_argparse() -> argparse.ArgumentParser:
parser.add_argument("output", help="Output file path") parser.add_argument("output", help="Output file path")
return parser return parser
def main(): def main():
import sys import sys
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
@@ -26,10 +29,11 @@ def main():
logger.info(f"Processing file: {media_file}") logger.info(f"Processing file: {media_file}")
# Check if the media file is a valid audio or video file # Check if the media file is a valid audio or video file
if os.path.isfile(media_file) and not media_file.endswith(('.mp3', '.wav', '.ogg', '.flac', '.mp4', '.avi', '.flv')): if os.path.isfile(media_file) and not media_file.endswith(
('.mp3', '.wav', '.ogg', '.flac', '.mp4', '.avi', '.flv')):
logger.error(f"Invalid file format: {media_file}") logger.error(f"Invalid file format: {media_file}")
return return
# If the media file we just retrieved is an audio file then skip extraction step # If the media file we just retrieved is an audio file then skip extraction step
audio_filename = media_file audio_filename = media_file
logger.info(f"Found audio-only file, skipping audio extraction") logger.info(f"Found audio-only file, skipping audio extraction")
@@ -53,5 +57,6 @@ def main():
transcript_file.write(whisper_result["text"]) transcript_file.write(whisper_result["text"])
transcript_file.close() transcript_file.close()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,7 +1,9 @@
import argparse import argparse
import spacy import spacy
from loguru import logger from loguru import logger
# Define the paths for agenda and transcription files # Define the paths for agenda and transcription files
def init_argparse() -> argparse.ArgumentParser: def init_argparse() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@@ -11,6 +13,8 @@ def init_argparse() -> argparse.ArgumentParser:
parser.add_argument("agenda", help="Location of the agenda file") parser.add_argument("agenda", help="Location of the agenda file")
parser.add_argument("transcription", help="Location of the transcription file") parser.add_argument("transcription", help="Location of the transcription file")
return parser return parser
args = init_argparse().parse_args() args = init_argparse().parse_args()
agenda_path = args.agenda agenda_path = args.agenda
transcription_path = args.transcription transcription_path = args.transcription
@@ -19,7 +23,7 @@ transcription_path = args.transcription
spaCy_model = "en_core_web_md" spaCy_model = "en_core_web_md"
nlp = spacy.load(spaCy_model) nlp = spacy.load(spaCy_model)
nlp.add_pipe('sentencizer') nlp.add_pipe('sentencizer')
logger.info("Loaded spaCy model " + spaCy_model ) logger.info("Loaded spaCy model " + spaCy_model)
# Load the agenda # Load the agenda
with open(agenda_path, "r") as f: with open(agenda_path, "r") as f:

View File

@@ -1,11 +1,14 @@
import argparse import argparse
import nltk import nltk
nltk.download('stopwords') nltk.download('stopwords')
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize, sent_tokenize from nltk.tokenize import word_tokenize, sent_tokenize
from heapq import nlargest from heapq import nlargest
from loguru import logger from loguru import logger
# Function to initialize the argument parser # Function to initialize the argument parser
def init_argparse(): def init_argparse():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@@ -17,12 +20,14 @@ def init_argparse():
parser.add_argument("--num_sentences", type=int, default=5, help="Number of sentences to include in the summary") parser.add_argument("--num_sentences", type=int, default=5, help="Number of sentences to include in the summary")
return parser return parser
# Function to read the input transcript file # Function to read the input transcript file
def read_transcript(file_path): def read_transcript(file_path):
with open(file_path, "r") as file: with open(file_path, "r") as file:
transcript = file.read() transcript = file.read()
return transcript return transcript
# Function to preprocess the text by removing stop words and special characters # Function to preprocess the text by removing stop words and special characters
def preprocess_text(text): def preprocess_text(text):
stop_words = set(stopwords.words('english')) stop_words = set(stopwords.words('english'))
@@ -30,6 +35,7 @@ def preprocess_text(text):
words = [w.lower() for w in words if w.isalpha() and w.lower() not in stop_words] words = [w.lower() for w in words if w.isalpha() and w.lower() not in stop_words]
return words return words
# Function to score each sentence based on the frequency of its words and return the top sentences # Function to score each sentence based on the frequency of its words and return the top sentences
def summarize_text(text, num_sentences): def summarize_text(text, num_sentences):
# Tokenize the text into sentences # Tokenize the text into sentences
@@ -61,6 +67,7 @@ def summarize_text(text, num_sentences):
return " ".join(summary) return " ".join(summary)
def main(): def main():
# Initialize the argument parser and parse the arguments # Initialize the argument parser and parse the arguments
parser = init_argparse() parser = init_argparse()
@@ -82,5 +89,6 @@ def main():
logger.info("Summarization completed") logger.info("Summarization completed")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,15 +1,18 @@
import argparse import argparse
import os import os
import tempfile import tempfile
import moviepy.editor import moviepy.editor
import nltk
import whisper
from loguru import logger from loguru import logger
from transformers import BartTokenizer, BartForConditionalGeneration from transformers import BartTokenizer, BartForConditionalGeneration
import whisper
import nltk
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
WHISPER_MODEL_SIZE = "base" WHISPER_MODEL_SIZE = "base"
def init_argparse() -> argparse.ArgumentParser: def init_argparse() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>", usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
@@ -30,6 +33,7 @@ def init_argparse() -> argparse.ArgumentParser:
return parser return parser
# NLTK chunking function # NLTK chunking function
def chunk_text(txt, max_chunk_length=500): def chunk_text(txt, max_chunk_length=500):
"Split text into smaller chunks." "Split text into smaller chunks."
@@ -45,6 +49,7 @@ def chunk_text(txt, max_chunk_length=500):
chunks.append(current_chunk.strip()) chunks.append(current_chunk.strip())
return chunks return chunks
# BART summary function # BART summary function
def summarize_chunks(chunks, tokenizer, model): def summarize_chunks(chunks, tokenizer, model):
summaries = [] summaries = []
@@ -56,6 +61,7 @@ def summarize_chunks(chunks, tokenizer, model):
summaries.append(summary) summaries.append(summary)
return summaries return summaries
def main(): def main():
import sys import sys
sys.setrecursionlimit(10000) sys.setrecursionlimit(10000)
@@ -103,7 +109,7 @@ def main():
chunks = chunk_text(whisper_result['text']) chunks = chunk_text(whisper_result['text'])
logger.info( logger.info(
f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
logger.info(f"Writing summary text in {args.language} to: {args.output}") logger.info(f"Writing summary text in {args.language} to: {args.output}")
with open(args.output, 'w') as f: with open(args.output, 'w') as f:
@@ -114,5 +120,6 @@ def main():
logger.info("Summarization completed") logger.info("Summarization completed")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -26,7 +26,7 @@ networkx==3.1
numba==0.57.0 numba==0.57.0
numpy==1.24.3 numpy==1.24.3
openai==0.27.7 openai==0.27.7
openai-whisper @ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8 openai-whisper@ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8
Pillow==9.5.0 Pillow==9.5.0
proglog==0.1.10 proglog==0.1.10
pytube==15.0.0 pytube==15.0.0
@@ -56,5 +56,4 @@ cached_property==1.5.2
stamina==23.1.0 stamina==23.1.0
httpx==0.24.1 httpx==0.24.1
sortedcontainers==2.4.0 sortedcontainers==2.4.0
openai-whisper @ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8
https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz

View File

@@ -1,15 +1,24 @@
#!/bin/bash #!/bin/bash
# Directory to search for Python files # Directory to search for Python files
directory="." cwd=$(pwd)
last_component="${cwd##*/}"
if [ "$last_component" = "reflector" ]; then
directory="./artefacts"
elif [ "$last_component" = "scripts" ]; then
directory="../artefacts"
fi
# Pattern to match Python files (e.g., "*.py" for all .py files) # Pattern to match Python files (e.g., "*.py" for all .py files)
text_file_pattern="transcript_*.txt" transcript_file_pattern="transcript_*.txt"
summary_file_pattern="summary_*.txt"
pickle_file_pattern="*.pkl" pickle_file_pattern="*.pkl"
html_file_pattern="*.html" html_file_pattern="*.html"
png_file_pattern="wordcloud*.png" png_file_pattern="wordcloud*.png"
find "$directory" -type f -name "$text_file_pattern" -delete find "$directory" -type f -name "$transcript_file_pattern" -delete
find "$directory" -type f -name "$summary_file_pattern" -delete
find "$directory" -type f -name "$pickle_file_pattern" -delete find "$directory" -type f -name "$pickle_file_pattern" -delete
find "$directory" -type f -name "$html_file_pattern" -delete find "$directory" -type f -name "$html_file_pattern" -delete
find "$directory" -type f -name "$png_file_pattern" -delete find "$directory" -type f -name "$png_file_pattern" -delete

View File

@@ -1,9 +1,6 @@
import asyncio import asyncio
import datetime
import io import io
import json import json
import logging
import sys
import uuid import uuid
import wave import wave
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@@ -13,24 +10,21 @@ from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import MediaRelay from aiortc.contrib.media import MediaRelay
from av import AudioFifo from av import AudioFifo
from loguru import logger
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from utils.server_utils import run_in_executor from utils.run_utils import run_in_executor
logger = logging.getLogger(__name__)
transcription = ""
pcs = set() pcs = set()
relay = MediaRelay() relay = MediaRelay()
data_channel = None data_channel = None
total_bytes_handled = 0 pipeline = FlaxWhisperPipline("openai/whisper-tiny",
pipeline = FlaxWhisperPipline("openai/whisper-tiny", dtype=jnp.float16, batch_size=16) dtype=jnp.float16,
batch_size=16)
CHANNELS = 2 CHANNELS = 2
RATE = 48000 RATE = 48000
audio_buffer = AudioFifo() audio_buffer = AudioFifo()
start_time = datetime.datetime.now()
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
@@ -40,30 +34,12 @@ def channel_log(channel, t, message):
def channel_send(channel, message): def channel_send(channel, message):
# channel_log(channel, ">", message) # channel_log(channel, ">", message)
global start_time
if channel: if channel:
channel.send(message) channel.send(message)
print(
"Bytes handled :",
total_bytes_handled,
" Time : ",
datetime.datetime.now() - start_time,
)
def get_transcription(frames): def get_transcription(frames):
print("Transcribing..") print("Transcribing..")
# samples = np.ndarray(
# np.concatenate([f.to_ndarray() for f in frames], axis=None),
# dtype=np.float32,
# )
# whisper_result = pipeline(
# {
# "array": samples,
# "sampling_rate": 48000,
# },
# return_timestamps=True,
# )
out_file = io.BytesIO() out_file = io.BytesIO()
wf = wave.open(out_file, "wb") wf = wave.open(out_file, "wb")
wf.setnchannels(CHANNELS) wf.setnchannels(CHANNELS)
@@ -73,8 +49,6 @@ def get_transcription(frames):
for frame in frames: for frame in frames:
wf.writeframes(b"".join(frame.to_ndarray())) wf.writeframes(b"".join(frame.to_ndarray()))
wf.close() wf.close()
global total_bytes_handled
total_bytes_handled += sys.getsizeof(wf)
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True) whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
with open("test_exec.txt", "a") as f: with open("test_exec.txt", "a") as f:
f.write(whisper_result["text"]) f.write(whisper_result["text"])
@@ -98,19 +72,19 @@ class AudioStreamTrack(MediaStreamTrack):
audio_buffer.write(frame) audio_buffer.write(frame)
if local_frames := audio_buffer.read_many(256 * 960, partial=False): if local_frames := audio_buffer.read_many(256 * 960, partial=False):
whisper_result = run_in_executor( whisper_result = run_in_executor(
get_transcription, local_frames, executor=executor get_transcription, local_frames, executor=executor
) )
whisper_result.add_done_callback( whisper_result.add_done_callback(
lambda f: channel_send(data_channel, str(whisper_result.result())) lambda f: channel_send(data_channel,
if (f.result()) str(whisper_result.result()))
else None if (f.result())
else None
) )
return frame return frame
async def offer(request): async def offer(request):
params = await request.json() params = await request.json()
print("Request received")
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
pc = RTCPeerConnection() pc = RTCPeerConnection()
@@ -120,55 +94,47 @@ async def offer(request):
def log_info(msg, *args): def log_info(msg, *args):
logger.info(pc_id + " " + msg, *args) logger.info(pc_id + " " + msg, *args)
log_info("Created for %s", request.remote) log_info("Created for " + request.remote)
@pc.on("datachannel") @pc.on("datachannel")
def on_datachannel(channel): def on_datachannel(channel):
global data_channel, start_time global data_channel
data_channel = channel data_channel = channel
channel_log(channel, "-", "created by remote party") channel_log(channel, "-", "created by remote party")
start_time = datetime.datetime.now()
@channel.on("message") @channel.on("message")
def on_message(message): def on_message(message):
channel_log(channel, "<", message) channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("ping"): if isinstance(message, str) and message.startswith("ping"):
# reply
channel_send(channel, "pong" + message[4:]) channel_send(channel, "pong" + message[4:])
@pc.on("connectionstatechange") @pc.on("connectionstatechange")
async def on_connectionstatechange(): async def on_connectionstatechange():
log_info("Connection state is %s", pc.connectionState) log_info("Connection state is " + pc.connectionState)
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track):
print("Track %s received" % track.kind) log_info("Track " + track.kind + " received")
log_info("Track %s received", track.kind)
# Trials to listen to the correct track
pc.addTrack(AudioStreamTrack(relay.subscribe(track))) pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# pc.addTrack(AudioStreamTrack(track))
# handle offer
await pc.setRemoteDescription(offer) await pc.setRemoteDescription(offer)
# send answer
answer = await pc.createAnswer() answer = await pc.createAnswer()
await pc.setLocalDescription(answer) await pc.setLocalDescription(answer)
print("Response sent")
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
text=json.dumps( text=json.dumps(
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} {"sdp": pc.localDescription.sdp,
), "type": pc.localDescription.type}
),
) )
async def on_shutdown(app): async def on_shutdown(app):
# close peer connections
coros = [pc.close() for pc in pcs] coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros) await asyncio.gather(*coros)
pcs.clear() pcs.clear()

View File

@@ -1,54 +1,38 @@
import asyncio import asyncio
import configparser
import datetime import datetime
import io import io
import json import json
import logging
import os
import threading import threading
import uuid import uuid
import wave import wave
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
import jax.numpy as jnp import jax.numpy as jnp
from aiohttp import webq import requests
from aiohttp import web
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from aiortc.contrib.media import (MediaRelay) from aiortc.contrib.media import MediaRelay
from av import AudioFifo from av import AudioFifo
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from utils.server_utils import Mutex from utils.log_utils import logger
from utils.run_utils import config, Mutex
ROOT = os.path.dirname(__file__) WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"]
config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
logger = logging.getLogger("pc")
pcs = set() pcs = set()
relay = MediaRelay() relay = MediaRelay()
data_channel = None data_channel = None
sorted_message_queue = SortedDict() sorted_message_queue = SortedDict()
CHANNELS = 2 CHANNELS = 2
RATE = 44100 RATE = 44100
CHUNK_SIZE = 256 CHUNK_SIZE = 256
audio_buffer = AudioFifo()
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
dtype=jnp.float16, dtype=jnp.float16,
batch_size=16) batch_size=16)
transcription = ""
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
total_bytes_handled = 0
executor = ThreadPoolExecutor() executor = ThreadPoolExecutor()
audio_buffer = AudioFifo()
frame_lock = Mutex(audio_buffer) frame_lock = Mutex(audio_buffer)
@@ -81,6 +65,7 @@ def get_transcription():
transcribe = True transcribe = True
if transcribe: if transcribe:
print("Transcribing..")
try: try:
sorted_message_queue[frames[0].time] = None sorted_message_queue[frames[0].time] = None
out_file = io.BytesIO() out_file = io.BytesIO()
@@ -94,10 +79,11 @@ def get_transcription():
wf.close() wf.close()
whisper_result = pipeline(out_file.getvalue()) whisper_result = pipeline(out_file.getvalue())
item = {'text': whisper_result["text"], item = {
'text': whisper_result["text"],
'start_time': str(frames[0].time), 'start_time': str(frames[0].time),
'time': str(datetime.datetime.now()) 'time': str(datetime.datetime.now())
} }
sorted_message_queue[frames[0].time] = str(item) sorted_message_queue[frames[0].time] = str(item)
start_messaging_thread() start_messaging_thread()
except Exception as e: except Exception as e:
@@ -106,7 +92,7 @@ def get_transcription():
class AudioStreamTrack(MediaStreamTrack): class AudioStreamTrack(MediaStreamTrack):
""" """
A video stream track that transforms frames from an another track. An audio stream track to send audio frames.
""" """
kind = "audio" kind = "audio"
@@ -126,15 +112,13 @@ def start_messaging_thread():
message_thread.start() message_thread.start()
def start_transcription_thread(max_threads): def start_transcription_thread(max_threads: int):
t_threads = []
for i in range(max_threads): for i in range(max_threads):
t_thread = threading.Thread(target=get_transcription, args=(i,)) t_thread = threading.Thread(target=get_transcription)
t_threads.append(t_thread)
t_thread.start() t_thread.start()
async def offer(request): async def offer(request: requests.Request):
params = await request.json() params = await request.json()
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
@@ -142,10 +126,10 @@ async def offer(request):
pc_id = "PeerConnection(%s)" % uuid.uuid4() pc_id = "PeerConnection(%s)" % uuid.uuid4()
pcs.add(pc) pcs.add(pc)
def log_info(msg, *args): def log_info(msg: str, *args):
logger.info(pc_id + " " + msg, *args) logger.info(pc_id + " " + msg, *args)
log_info("Created for %s", request.remote) log_info("Created for " + request.remote)
@pc.on("datachannel") @pc.on("datachannel")
def on_datachannel(channel): def on_datachannel(channel):
@@ -155,7 +139,7 @@ async def offer(request):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
@channel.on("message") @channel.on("message")
def on_message(message): def on_message(message: str):
channel_log(channel, "<", message) channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("ping"): if isinstance(message, str) and message.startswith("ping"):
# reply # reply
@@ -163,14 +147,14 @@ async def offer(request):
@pc.on("connectionstatechange") @pc.on("connectionstatechange")
async def on_connectionstatechange(): async def on_connectionstatechange():
log_info("Connection state is %s", pc.connectionState) log_info("Connection state is " + pc.connectionState)
if pc.connectionState == "failed": if pc.connectionState == "failed":
await pc.close() await pc.close()
pcs.discard(pc) pcs.discard(pc)
@pc.on("track") @pc.on("track")
def on_track(track): def on_track(track):
log_info("Track %s received", track.kind) log_info("Track " + track.kind + " received")
pc.addTrack(AudioStreamTrack(relay.subscribe(track))) pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
# handle offer # handle offer
@@ -180,14 +164,15 @@ async def offer(request):
answer = await pc.createAnswer() answer = await pc.createAnswer()
await pc.setLocalDescription(answer) await pc.setLocalDescription(answer)
return web.Response( return web.Response(
content_type="application/json", content_type="application/json",
text=json.dumps( text=json.dumps({
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type} "sdp": pc.localDescription.sdp,
), "type": pc.localDescription.type
}),
) )
async def on_shutdown(app): async def on_shutdown(app: web.Application):
coros = [pc.close() for pc in pcs] coros = [pc.close() for pc in pcs]
await asyncio.gather(*coros) await asyncio.gather(*coros)
pcs.clear() pcs.clear()
@@ -199,5 +184,5 @@ if __name__ == "__main__":
start_transcription_thread(6) start_transcription_thread(6)
app.router.add_post("/offer", offer) app.router.add_post("/offer", offer)
web.run_app( web.run_app(
app, access_log=None, host="127.0.0.1", port=1250 app, access_log=None, host="127.0.0.1", port=1250
) )

View File

@@ -1,7 +1,5 @@
import ast import ast
import asyncio import asyncio
import configparser
import logging
import time import time
import uuid import uuid
@@ -12,14 +10,11 @@ import stamina
from aiortc import (RTCPeerConnection, RTCSessionDescription) from aiortc import (RTCPeerConnection, RTCSessionDescription)
from aiortc.contrib.media import (MediaPlayer, MediaRelay) from aiortc.contrib.media import (MediaPlayer, MediaRelay)
from utils.server_utils import Mutex from utils.log_utils import logger
from utils.run_utils import config, Mutex
logger = logging.getLogger("pc")
file_lock = Mutex(open("test_sm_6.txt", "a")) file_lock = Mutex(open("test_sm_6.txt", "a"))
config = configparser.ConfigParser()
config.read('config.ini')
class StreamClient: class StreamClient:
def __init__( def __init__(
@@ -42,14 +37,15 @@ class StreamClient:
self.pcs = set() self.pcs = set()
self.time_start = None self.time_start = None
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
self.player = MediaPlayer(':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]), self.player = MediaPlayer(
format='avfoundation', options={'channels': '2'}) ':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]),
format='avfoundation',
options={'channels': '2'})
def stop(self): def stop(self):
self.loop.run_until_complete(self.signaling.close()) self.loop.run_until_complete(self.signaling.close())
self.loop.run_until_complete(self.pc.close()) self.loop.run_until_complete(self.pc.close())
# self.loop.close() # self.loop.close()
print("ended")
def create_local_tracks(self, play_from): def create_local_tracks(self, play_from):
if play_from: if play_from:
@@ -58,7 +54,6 @@ class StreamClient:
else: else:
if self.relay is None: if self.relay is None:
self.relay = MediaRelay() self.relay = MediaRelay()
print("Created local track from microphone stream")
return self.relay.subscribe(self.player.audio), None return self.relay.subscribe(self.player.audio), None
def channel_log(self, channel, t, message): def channel_log(self, channel, t, message):
@@ -122,14 +117,15 @@ class StreamClient:
self.channel_log(channel, "<", message) self.channel_log(channel, "<", message)
if isinstance(message, str) and message.startswith("pong"): if isinstance(message, str) and message.startswith("pong"):
elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000 elapsed_ms = (self.current_stamp() - int(message[5:]))\
/ 1000
print(" RTT %.2f ms" % elapsed_ms) print(" RTT %.2f ms" % elapsed_ms)
await pc.setLocalDescription(await pc.createOffer()) await pc.setLocalDescription(await pc.createOffer())
sdp = { sdp = {
"sdp": pc.localDescription.sdp, "sdp": pc.localDescription.sdp,
"type": pc.localDescription.type "type": pc.localDescription.type
} }
@stamina.retry(on=httpx.HTTPError, attempts=5) @stamina.retry(on=httpx.HTTPError, attempts=5)
@@ -142,7 +138,7 @@ class StreamClient:
answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"]) answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
await pc.setRemoteDescription(answer) await pc.setRemoteDescription(answer)
self.reader = self.worker(f"worker", self.queue) self.reader = self.worker(f'{"worker"}', self.queue)
def get_reader(self): def get_reader(self):
return self.reader return self.reader

0
utils/__init__.py Normal file
View File

View File

@@ -1,13 +1,12 @@
import configparser import sys
import boto3 import boto3
import botocore import botocore
from loguru import logger
config = configparser.ConfigParser() from .log_utils import logger
config.read('config.ini') from .run_utils import config
BUCKET_NAME = 'reflector-bucket' BUCKET_NAME = config["DEFAULT"]["BUCKET_NAME"]
s3 = boto3.client('s3', s3 = boto3.client('s3',
aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"], aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"],
@@ -17,8 +16,8 @@ s3 = boto3.client('s3',
def upload_files(files_to_upload): def upload_files(files_to_upload):
""" """
Upload a list of files to the configured S3 bucket Upload a list of files to the configured S3 bucket
:param files_to_upload: :param files_to_upload: List of files to upload
:return: :return: None
""" """
for KEY in files_to_upload: for KEY in files_to_upload:
logger.info("Uploading file " + KEY) logger.info("Uploading file " + KEY)
@@ -31,8 +30,8 @@ def upload_files(files_to_upload):
def download_files(files_to_download): def download_files(files_to_download):
""" """
Download a list of files from the configured S3 bucket Download a list of files from the configured S3 bucket
:param files_to_download: :param files_to_download: List of files to download
:return: :return: None
""" """
for KEY in files_to_download: for KEY in files_to_download:
logger.info("Downloading file " + KEY) logger.info("Downloading file " + KEY)
@@ -46,8 +45,6 @@ def download_files(files_to_download):
if __name__ == "__main__": if __name__ == "__main__":
import sys
if sys.argv[1] == "download": if sys.argv[1] == "download":
download_files([sys.argv[2]]) download_files([sys.argv[2]])
elif sys.argv[1] == "upload": elif sys.argv[1] == "upload":

18
utils/log_utils.py Normal file
View File

@@ -0,0 +1,18 @@
import loguru
class SingletonLogger:
__instance = None
@staticmethod
def get_logger():
"""
Create or return the singleton instance for the SingletonLogger class
:return: SingletonLogger instance
"""
if not SingletonLogger.__instance:
SingletonLogger.__instance = loguru.logger
return SingletonLogger.__instance
logger = SingletonLogger.get_logger()

66
utils/run_utils.py Normal file
View File

@@ -0,0 +1,66 @@
import asyncio
import configparser
import contextlib
from functools import partial
from threading import Lock
from typing import ContextManager, Generic, TypeVar
class ReflectorConfig:
__config = None
@staticmethod
def get_config():
if ReflectorConfig.__config is None:
ReflectorConfig.__config = configparser.ConfigParser()
ReflectorConfig.__config.read('utils/config.ini')
return ReflectorConfig.__config
config = ReflectorConfig.get_config()
def run_in_executor(func, *args, executor=None, **kwargs):
"""
Run the function in an executor, unblocking the main loop
:param func: Function to be run in executor
:param args: function parameters
:param executor: executor instance [Thread | Process]
:param kwargs: Additional parameters
:return: Future of function result upon completion
"""
callback = partial(func, *args, **kwargs)
loop = asyncio.get_event_loop()
return loop.run_in_executor(executor, callback)
# Genetic type template
T = TypeVar("T")
class Mutex(Generic[T]):
"""
Mutex class to implement lock/release of a shared
protected variable
"""
def __init__(self, value: T):
"""
Create an instance of Mutex wrapper for the given resource
:param value: Shared resources to be thread protected
"""
self.__value = value
self.__lock = Lock()
@contextlib.contextmanager
def lock(self) -> ContextManager[T]:
"""
Lock the resource with a mutex to be used within a context block
The lock is automatically released on context exit
:return: Shared resource
"""
self.__lock.acquire()
try:
yield self.__value
finally:
self.__lock.release()

View File

@@ -1,28 +0,0 @@
import asyncio
import contextlib
from functools import partial
from threading import Lock
from typing import ContextManager, Generic, TypeVar
def run_in_executor(func, *args, executor=None, **kwargs):
callback = partial(func, *args, **kwargs)
loop = asyncio.get_event_loop()
return asyncio.get_event_loop().run_in_executor(executor, callback)
T = TypeVar("T")
class Mutex(Generic[T]):
def __init__(self, value: T):
self.__value = value
self.__lock = Lock()
@contextlib.contextmanager
def lock(self) -> ContextManager[T]:
self.__lock.acquire()
try:
yield self.__value
finally:
self.__lock.release()

View File

@@ -1,24 +1,22 @@
import configparser
import nltk import nltk
import torch import torch
from loguru import logger
from nltk.corpus import stopwords from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize from nltk.tokenize import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity
from transformers import BartTokenizer, BartForConditionalGeneration from transformers import BartForConditionalGeneration, BartTokenizer
from utils.log_utils import logger
from utils.run_utils import config
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
config = configparser.ConfigParser()
config.read('config.ini')
def preprocess_sentence(sentence): def preprocess_sentence(sentence):
stop_words = set(stopwords.words('english')) stop_words = set(stopwords.words('english'))
tokens = word_tokenize(sentence.lower()) tokens = word_tokenize(sentence.lower())
tokens = [token for token in tokens if token.isalnum() and token not in stop_words] tokens = [token for token in tokens
if token.isalnum() and token not in stop_words]
return ' '.join(tokens) return ' '.join(tokens)
@@ -52,12 +50,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
sentence1 = preprocess_sentence(sentences[i]) sentence1 = preprocess_sentence(sentences[i])
sentence2 = preprocess_sentence(sentences[j]) sentence2 = preprocess_sentence(sentences[j])
if len(sentence1) != 0 and len(sentence2) != 0: if len(sentence1) != 0 and len(sentence2) != 0:
similarity = compute_similarity(sentence1, sentence2) similarity = compute_similarity(sentence1,
sentence2)
if similarity >= threshold: if similarity >= threshold:
removed_indices.add(max(i, j)) removed_indices.add(max(i, j))
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices] filtered_sentences = [sentences[i] for i in range(num_sentences)
if i not in removed_indices]
return filtered_sentences return filtered_sentences
@@ -77,11 +77,13 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
words = nltk.word_tokenize(sent) words = nltk.word_tokenize(sent)
n_gram_filter = 3 n_gram_filter = 3
for i in range(len(words)): for i in range(len(words)):
if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[ if str(words[i:i + n_gram_filter]) in seen and \
i + 1:i + n_gram_filter + 2]: seen[str(words[i:i + n_gram_filter])] == \
words[i + 1:i + n_gram_filter + 2]:
pass pass
else: else:
seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2] seen[str(words[i:i + n_gram_filter])] = \
words[i + 1:i + n_gram_filter + 2]
temp_result += words[i] temp_result += words[i]
temp_result += " " temp_result += " "
chunk_sentences.append(temp_result) chunk_sentences.append(temp_result)
@@ -91,9 +93,12 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
def post_process_transcription(whisper_result): def post_process_transcription(whisper_result):
transcript_text = "" transcript_text = ""
for chunk in whisper_result["chunks"]: for chunk in whisper_result["chunks"]:
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk) nonduplicate_sentences = \
chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences) remove_outright_duplicate_sentences_from_chunk(chunk)
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences) chunk_sentences = \
remove_whisper_repetitive_hallucination(nonduplicate_sentences)
similarity_matched_sentences = \
remove_almost_alike_sentences(chunk_sentences)
chunk["text"] = " ".join(similarity_matched_sentences) chunk["text"] = " ".join(similarity_matched_sentences)
transcript_text += chunk["text"] transcript_text += chunk["text"]
whisper_result["text"] = transcript_text whisper_result["text"] = transcript_text
@@ -114,18 +119,23 @@ def summarize_chunks(chunks, tokenizer, model):
input_ids = tokenizer.encode(c, return_tensors='pt') input_ids = tokenizer.encode(c, return_tensors='pt')
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
with torch.no_grad(): with torch.no_grad():
summary_ids = model.generate(input_ids, summary_ids = \
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, model.generate(input_ids,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) num_beams=int(config["DEFAULT"]["BEAM_SIZE"]),
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]),
early_stopping=True)
summary = tokenizer.decode(summary_ids[0],
skip_special_tokens=True)
summaries.append(summary) summaries.append(summary)
return summaries return summaries
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])): def chunk_text(text,
max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
""" """
Split text into smaller chunks. Split text into smaller chunks.
:param txt: Text to be chunked :param text: Text to be chunked
:param max_chunk_length: length of chunk :param max_chunk_length: length of chunk
:return: chunked texts :return: chunked texts
""" """
@@ -143,7 +153,8 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])
def summarize(transcript_text, timestamp, def summarize(transcript_text, timestamp,
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]): real_time=False,
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
summary_model = config["DEFAULT"]["SUMMARY_MODEL"] summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
if not summary_model: if not summary_model:
@@ -160,9 +171,11 @@ def summarize(transcript_text, timestamp,
output_filename = "real_time_" + output_filename output_filename = "real_time_" + output_filename
if summarize_using_chunks != "YES": if summarize_using_chunks != "YES":
inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest', inputs = tokenizer.\
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]), batch_encode_plus([transcript_text], truncation=True,
return_tensors='pt') padding='longest',
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
return_tensors='pt')
inputs = inputs.to(device) inputs = inputs.to(device)
with torch.no_grad(): with torch.no_grad():
@@ -170,16 +183,17 @@ def summarize(transcript_text, timestamp,
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0, num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True) max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False)
summary in summaries] for summary in summaries]
summary = " ".join(decoded_summaries) summary = " ".join(decoded_summaries)
with open(output_filename, 'w') as f: with open("./artefacts/" + output_filename, 'w') as f:
f.write(summary.strip() + "\n") f.write(summary.strip() + "\n")
else: else:
logger.info("Breaking transcript into smaller chunks") logger.info("Breaking transcript into smaller chunks")
chunks = chunk_text(transcript_text) chunks = chunk_text(transcript_text)
logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable logger.info(f"Transcript broken into {len(chunks)} "
f"chunks of at most 500 words")
logger.info(f"Writing summary text to: {output_filename}") logger.info(f"Writing summary text to: {output_filename}")
with open(output_filename, 'w') as f: with open(output_filename, 'w') as f:

View File

@@ -1,24 +1,20 @@
import ast import ast
import collections import collections
import configparser
import os import os
import pickle import pickle
from pathlib import Path
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import scattertext as st import scattertext as st
import spacy import spacy
from nltk.corpus import stopwords from nltk.corpus import stopwords
from wordcloud import WordCloud, STOPWORDS from wordcloud import STOPWORDS, WordCloud
config = configparser.ConfigParser()
config.read('config.ini')
en = spacy.load('en_core_web_md') en = spacy.load('en_core_web_md')
spacy_stopwords = en.Defaults.stop_words spacy_stopwords = en.Defaults.stop_words
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords)) STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).\
union(set(spacy_stopwords))
def create_wordcloud(timestamp, real_time=False): def create_wordcloud(timestamp, real_time=False):
@@ -28,7 +24,8 @@ def create_wordcloud(timestamp, real_time=False):
""" """
filename = "transcript" filename = "transcript"
if real_time: if real_time:
filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" filename = "real_time_" + filename + "_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
else: else:
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
@@ -50,11 +47,12 @@ def create_wordcloud(timestamp, real_time=False):
wordcloud_name = "wordcloud" wordcloud_name = "wordcloud"
if real_time: if real_time:
wordcloud_name = "real_time_" + wordcloud_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" wordcloud_name = "real_time_" + wordcloud_name + "_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
else: else:
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png" wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
plt.savefig(wordcloud_name) plt.savefig("./artefacts/" + wordcloud_name)
def create_talk_diff_scatter_viz(timestamp, real_time=False): def create_talk_diff_scatter_viz(timestamp, real_time=False):
@@ -70,7 +68,6 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
agenda_topics = [] agenda_topics = []
agenda = [] agenda = []
# Load the agenda # Load the agenda
path = Path(__file__)
with open(os.path.join(os.getcwd(), "agenda-headers.txt"), "r") as f: with open(os.path.join(os.getcwd(), "agenda-headers.txt"), "r") as f:
for line in f.readlines(): for line in f.readlines():
if line.strip(): if line.strip():
@@ -80,9 +77,11 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
# Load the transcription with timestamp # Load the transcription with timestamp
filename = "" filename = ""
if real_time: if real_time:
filename = "real_time_transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" filename = "./artefacts/real_time_transcript_with_timestamp_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
else: else:
filename = "transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt" filename = "./artefacts/transcript_with_timestamp_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
with open(filename) as f: with open(filename) as f:
transcription_timestamp_text = f.read() transcription_timestamp_text = f.read()
@@ -98,7 +97,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
ts_to_topic_mapping_top_1 = {} ts_to_topic_mapping_top_1 = {}
ts_to_topic_mapping_top_2 = {} ts_to_topic_mapping_top_2 = {}
# Also create a mapping of the different timestamps in which each topic was covered # Also create a mapping of the different timestamps
# in which each topic was covered
topic_to_ts_mapping_top_1 = collections.defaultdict(list) topic_to_ts_mapping_top_1 = collections.defaultdict(list)
topic_to_ts_mapping_top_2 = collections.defaultdict(list) topic_to_ts_mapping_top_2 = collections.defaultdict(list)
@@ -109,7 +109,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
topic_similarities = [] topic_similarities = []
for item in range(len(agenda)): for item in range(len(agenda)):
item_doc = nlp(agenda[item]) item_doc = nlp(agenda[item])
# if not doc_transcription or not all(token.has_vector for token in doc_transcription): # if not doc_transcription or not all
# (token.has_vector for token in doc_transcription):
if not doc_transcription: if not doc_transcription:
continue continue
similarity = doc_transcription.similarity(item_doc) similarity = doc_transcription.similarity(item_doc)
@@ -133,8 +134,10 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
:param record: :param record:
:return: :return:
""" """
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]] record["ts_to_topic_mapping_top_1"] = \
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]] ts_to_topic_mapping_top_1[record["timestamp"]]
record["ts_to_topic_mapping_top_2"] = \
ts_to_topic_mapping_top_2[record["timestamp"]]
return record return record
df = df.apply(create_new_columns, axis=1) df = df.apply(create_new_columns, axis=1)
@@ -155,20 +158,22 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
# Save df, mappings for further experimentation # Save df, mappings for further experimentation
df_name = "df" df_name = "df"
if real_time: if real_time:
df_name = "real_time_" + df_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" df_name = "real_time_" + df_name + "_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
else: else:
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
df.to_pickle(df_name) df.to_pickle("./artefacts/" + df_name)
my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2, my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2,
topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2] topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2]
mappings_name = "mappings" mappings_name = "mappings"
if real_time: if real_time:
mappings_name = "real_time_" + mappings_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" mappings_name = "real_time_" + mappings_name + "_" +\
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
else: else:
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl" mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
pickle.dump(my_mappings, open(mappings_name, "wb")) pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb"))
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") ) # to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
@@ -182,25 +187,28 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True) topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
cat_1 = topic_times[0][0] if len(topic_times) > 1:
cat_1_name = topic_times[0][0] cat_1 = topic_times[0][0]
cat_2_name = topic_times[1][0] cat_1_name = topic_times[0][0]
cat_2_name = topic_times[1][0]
# Scatter plot of topics # Scatter plot of topics
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences)) df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
corpus = st.CorpusFromParsedDocuments( corpus = st.CorpusFromParsedDocuments(
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse' df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000)) ).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
html = st.produce_scattertext_explorer( html = st.produce_scattertext_explorer(
corpus, corpus,
category=cat_1, category=cat_1,
category_name=cat_1_name, category_name=cat_1_name,
not_category_name=cat_2_name, not_category_name=cat_2_name,
minimum_term_frequency=0, pmi_threshold_coefficient=0, minimum_term_frequency=0, pmi_threshold_coefficient=0,
width_in_pixels=1000, width_in_pixels=1000,
transform=st.Scalers.dense_rank transform=st.Scalers.dense_rank
) )
if real_time: if real_time:
open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) open('./artefacts/real_time_scatter_' +
else: timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html) else:
open('./artefacts/scatter_' +
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)

View File

@@ -1,11 +1,10 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# summarize https://www.youtube.com/watch?v=imzTxoEDH_g --transcript=transcript.txt summary.txt # summarize https://www.youtube.com/watch?v=imzTxoEDH_g
# summarize https://www.sprocket.org/video/cheesemaking.mp4 summary.txt # summarize https://www.sprocket.org/video/cheesemaking.mp4 summary.txt
# summarize podcast.mp3 summary.txt # summarize podcast.mp3 summary.txt
import argparse import argparse
import configparser
import os import os
import re import re
import subprocess import subprocess
@@ -15,23 +14,19 @@ from urllib.parse import urlparse
import jax.numpy as jnp import jax.numpy as jnp
import moviepy.editor import moviepy.editor
import moviepy.editor
import nltk import nltk
import yt_dlp as youtube_dl import yt_dlp as youtube_dl
from loguru import logger
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from utils.file_utilities import upload_files, download_files from utils.file_utils import download_files, upload_files
from utils.text_utilities import summarize, post_process_transcription from utils.log_utils import logger
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz from utils.run_utils import config
from utils.text_utilities import post_process_transcription, summarize
from utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud
nltk.download('punkt', quiet=True) nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True) nltk.download('stopwords', quiet=True)
# Configurations can be found in config.ini. Set them properly before executing
config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
NOW = datetime.now() NOW = datetime.now()
@@ -42,12 +37,17 @@ def init_argparse() -> argparse.ArgumentParser:
:return: parser object :return: parser object
""" """
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>", usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
description="Creates a transcript of a video or audio file, then summarizes it using ChatGPT." description="Creates a transcript of a video or audio file, then"
" summarizes it using ChatGPT."
) )
parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str, parser.add_argument("-l", "--language",
default="english", choices=['english', 'spanish', 'french', 'german', 'romanian']) help="Language that the summary should be written in",
type=str,
default="english",
choices=['english', 'spanish', 'french', 'german',
'romanian'])
parser.add_argument("location") parser.add_argument("location")
return parser return parser
@@ -65,22 +65,24 @@ def main():
media_file = "" media_file = ""
if url.scheme == 'http' or url.scheme == 'https': if url.scheme == 'http' or url.scheme == 'https':
# Check if we're being asked to retreive a YouTube URL, which is handled # Check if we're being asked to retreive a YouTube URL, which is
# diffrently, as we'll use a secondary site to download the video first. # handled differently, as we'll use a secondary site to download
# the video first.
if re.search('youtube.com', url.netloc, re.IGNORECASE): if re.search('youtube.com', url.netloc, re.IGNORECASE):
# Download the lowest resolution YouTube video (since we're just interested in the audio). # Download the lowest resolution YouTube video
# (since we're just interested in the audio).
# It will be saved to the current directory. # It will be saved to the current directory.
logger.info("Downloading YouTube video at url: " + args.location) logger.info("Downloading YouTube video at url: " + args.location)
# Create options for the download # Create options for the download
ydl_opts = { ydl_opts = {
'format': 'bestaudio/best', 'format': 'bestaudio/best',
'postprocessors': [{ 'postprocessors': [{
'key': 'FFmpegExtractAudio', 'key': 'FFmpegExtractAudio',
'preferredcodec': 'mp3', 'preferredcodec': 'mp3',
'preferredquality': '192', 'preferredquality': '192',
}], }],
'outtmpl': 'audio', # Specify the output file path and name 'outtmpl': 'audio', # Specify output file path and name
} }
# Download the audio # Download the audio
@@ -90,7 +92,8 @@ def main():
logger.info("Saved downloaded YouTube video to: " + media_file) logger.info("Saved downloaded YouTube video to: " + media_file)
else: else:
# XXX - Download file using urllib, check if file is audio/video using python-magic # XXX - Download file using urllib, check if file is
# audio/video using python-magic
logger.info(f"Downloading file at url: {args.location}") logger.info(f"Downloading file at url: {args.location}")
logger.info(" XXX - This method hasn't been implemented yet.") logger.info(" XXX - This method hasn't been implemented yet.")
elif url.scheme == '': elif url.scheme == '':
@@ -101,7 +104,7 @@ def main():
if media_file.endswith(".m4a"): if media_file.endswith(".m4a"):
subprocess.run(["ffmpeg", "-i", media_file, f"{media_file}.mp4"]) subprocess.run(["ffmpeg", "-i", media_file, f"{media_file}.mp4"])
input_file = f"{media_file}.mp4" media_file = f"{media_file}.mp4"
else: else:
print("Unsupported URL scheme: " + url.scheme) print("Unsupported URL scheme: " + url.scheme)
quit() quit()
@@ -110,19 +113,21 @@ def main():
if not media_file.endswith(".mp3"): if not media_file.endswith(".mp3"):
try: try:
video = moviepy.editor.VideoFileClip(media_file) video = moviepy.editor.VideoFileClip(media_file)
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3",
delete=False).name
video.audio.write_audiofile(audio_filename, logger=None) video.audio.write_audiofile(audio_filename, logger=None)
logger.info(f"Extracting audio to: {audio_filename}") logger.info(f"Extracting audio to: {audio_filename}")
# Handle audio only file # Handle audio only file
except: except Exception:
audio = moviepy.editor.AudioFileClip(media_file) audio = moviepy.editor.AudioFileClip(media_file)
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3",
delete=False).name
audio.write_audiofile(audio_filename, logger=None) audio.write_audiofile(audio_filename, logger=None)
else: else:
audio_filename = media_file audio_filename = media_file
logger.info("Finished extracting audio") logger.info("Finished extracting audio")
logger.info("Transcribing")
# Convert the audio to text using the OpenAI Whisper model # Convert the audio to text using the OpenAI Whisper model
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
dtype=jnp.float16, dtype=jnp.float16,
@@ -136,10 +141,12 @@ def main():
for chunk in whisper_result["chunks"]: for chunk in whisper_result["chunks"]:
transcript_text += chunk["text"] transcript_text += chunk["text"]
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file: with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") +
".txt", "w") as transcript_file:
transcript_file.write(transcript_text) transcript_file.write(transcript_text)
with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", with open("./artefacts/transcript_with_timestamp_" +
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt",
"w") as transcript_file_timestamps: "w") as transcript_file_timestamps:
transcript_file_timestamps.write(str(whisper_result)) transcript_file_timestamps.write(str(whisper_result))
@@ -150,13 +157,14 @@ def main():
create_talk_diff_scatter_viz(NOW) create_talk_diff_scatter_viz(NOW)
# S3 : Push artefacts to S3 bucket # S3 : Push artefacts to S3 bucket
prefix = "./artefacts/"
suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S") suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S")
files_to_upload = ["transcript_" + suffix + ".txt", files_to_upload = [prefix + "transcript_" + suffix + ".txt",
"transcript_with_timestamp_" + suffix + ".txt", prefix + "transcript_with_timestamp_" + suffix + ".txt",
"df_" + suffix + ".pkl", prefix + "df_" + suffix + ".pkl",
"wordcloud_" + suffix + ".png", prefix + "wordcloud_" + suffix + ".png",
"mappings_" + suffix + ".pkl", prefix + "mappings_" + suffix + ".pkl",
"scatter_" + suffix + ".html"] prefix + "scatter_" + suffix + ".html"]
upload_files(files_to_upload) upload_files(files_to_upload)
summarize(transcript_text, NOW, False, False) summarize(transcript_text, NOW, False, False)

View File

@@ -1,23 +1,20 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import configparser
import time import time
import wave import wave
from datetime import datetime from datetime import datetime
import jax.numpy as jnp import jax.numpy as jnp
import pyaudio import pyaudio
from loguru import logger
from pynput import keyboard from pynput import keyboard
from termcolor import colored from termcolor import colored
from whisper_jax import FlaxWhisperPipline from whisper_jax import FlaxWhisperPipline
from utils.file_utilities import upload_files from utils.file_utils import upload_files
from utils.text_utilities import summarize, post_process_transcription from utils.log_utils import logger
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz from utils.run_utils import config
from utils.text_utilities import post_process_transcription, summarize
config = configparser.ConfigParser() from utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
@@ -33,19 +30,21 @@ def main():
p = pyaudio.PyAudio() p = pyaudio.PyAudio()
AUDIO_DEVICE_ID = -1 AUDIO_DEVICE_ID = -1
for i in range(p.get_device_count()): for i in range(p.get_device_count()):
if p.get_device_info_by_index(i)["name"] == config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]: if p.get_device_info_by_index(i)["name"] == \
config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]:
AUDIO_DEVICE_ID = i AUDIO_DEVICE_ID = i
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID) audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
stream = p.open( stream = p.open(
format=FORMAT, format=FORMAT,
channels=CHANNELS, channels=CHANNELS,
rate=RATE, rate=RATE,
input=True, input=True,
frames_per_buffer=FRAMES_PER_BUFFER, frames_per_buffer=FRAMES_PER_BUFFER,
input_device_index=int(audio_devices['index']) input_device_index=int(audio_devices['index'])
) )
pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"], pipeline = FlaxWhisperPipline("openai/whisper-" +
config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"],
dtype=jnp.float16, dtype=jnp.float16,
batch_size=16) batch_size=16)
@@ -72,7 +71,8 @@ def main():
frames = [] frames = []
start_time = time.time() start_time = time.time()
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)): for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False) data = stream.read(FRAMES_PER_BUFFER,
exception_on_overflow=False)
frames.append(data) frames.append(data)
end_time = time.time() end_time = time.time()
@@ -90,7 +90,8 @@ def main():
if end is None: if end is None:
end = start + 15.0 end = start + 15.0
duration = end - start duration = end - start
item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration), item = {'timestamp': (last_transcribed_time,
last_transcribed_time + duration),
'text': whisper_result['text'], 'text': whisper_result['text'],
'stats': (str(end_time - start_time), str(duration)) 'stats': (str(end_time - start_time), str(duration))
} }
@@ -100,15 +101,19 @@ def main():
print(colored("<START>", "yellow")) print(colored("<START>", "yellow"))
print(colored(whisper_result['text'], 'green')) print(colored(whisper_result['text'], 'green'))
print(colored("<END> Recorded duration: " + str(end_time - start_time) + " | Transcribed duration: " + print(colored("<END> Recorded duration: " +
str(end_time - start_time) +
" | Transcribed duration: " +
str(duration), "yellow")) str(duration), "yellow"))
except Exception as e: except Exception as e:
print(e) print(e)
finally: finally:
with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: with open("real_time_transcript_" +
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
f.write(transcription) f.write(transcription)
with open("real_time_transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f: with open("real_time_transcript_with_timestamp_" +
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
transcript_with_timestamp["text"] = transcription transcript_with_timestamp["text"] = transcription
f.write(str(transcript_with_timestamp)) f.write(str(transcript_with_timestamp))