mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Merge pull request #25 from Monadical-SAS/whisper-jax-gokul
Refactor codebase and clean-up code
This commit is contained in:
5
.gitignore
vendored
5
.gitignore
vendored
@@ -160,9 +160,6 @@ cython_debug/
|
|||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
*.mp4
|
*.mp4
|
||||||
summary.txt
|
|
||||||
transcript.txt
|
|
||||||
transcript_timestamps.txt
|
|
||||||
*.html
|
*.html
|
||||||
*.pkl
|
*.pkl
|
||||||
transcript_*.txt
|
transcript_*.txt
|
||||||
@@ -174,4 +171,6 @@ test_samples/
|
|||||||
*.mp3
|
*.mp3
|
||||||
*.m4a
|
*.m4a
|
||||||
.DS_Store/
|
.DS_Store/
|
||||||
|
.DS_Store
|
||||||
.vscode/
|
.vscode/
|
||||||
|
artefacts/
|
||||||
|
|||||||
125
README.md
125
README.md
@@ -1,32 +1,34 @@
|
|||||||
# Reflector
|
# Reflector
|
||||||
|
|
||||||
This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM.
|
This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting
|
||||||
|
panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM.
|
||||||
The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target agenda/objectives to the actual discussion live.
|
|
||||||
|
|
||||||
|
The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target
|
||||||
|
agenda/objectives to the actual discussion live.
|
||||||
|
|
||||||
**S3 bucket:**
|
**S3 bucket:**
|
||||||
|
|
||||||
Everything you need for S3 is already configured in config.ini. Only edit it if you need to change it deliberately.
|
Everything you need for S3 is already configured in config.ini. Only edit it if you need to change it deliberately.
|
||||||
|
|
||||||
S3 bucket name is mentioned in config.ini. All transfers will happen between this bucket and the local computer where the
|
S3 bucket name is mentioned in config.ini. All transfers will happen between this bucket and the local computer where
|
||||||
script is run. You need AWS_ACCESS_KEY / AWS_SECRET_KEY to authenticate your calls to S3 (done in config.ini).
|
the
|
||||||
|
script is run. You need AWS_ACCESS_KEY / AWS_SECRET_KEY to authenticate your calls to S3 (done in config.ini).
|
||||||
|
|
||||||
For AWS S3 Web UI,
|
For AWS S3 Web UI,
|
||||||
|
|
||||||
1) Login to AWS management console.
|
1) Login to AWS management console.
|
||||||
2) Search for S3 in the search bar at the top.
|
2) Search for S3 in the search bar at the top.
|
||||||
3) Navigate to list the buckets under the current account, if needed and choose your bucket [```reflector-bucket```]
|
3) Navigate to list the buckets under the current account, if needed and choose your bucket [```reflector-bucket```]
|
||||||
4) You should be able to see items in the bucket. You can upload/download files here directly.
|
4) You should be able to see items in the bucket. You can upload/download files here directly.
|
||||||
|
|
||||||
|
For CLI,
|
||||||
For CLI,
|
|
||||||
Refer to the FILE UTIL section below.
|
Refer to the FILE UTIL section below.
|
||||||
|
|
||||||
|
|
||||||
**FILE UTIL MODULE:**
|
**FILE UTIL MODULE:**
|
||||||
|
|
||||||
A file_util module has been created to upload/download files with AWS S3 bucket pre-configured using config.ini.
|
A file_util module has been created to upload/download files with AWS S3 bucket pre-configured using config.ini.
|
||||||
Though not needed for the workflow, if you need to upload / download file, separately on your own, apart from the pipeline workflow in the script, you can do so by :
|
Though not needed for the workflow, if you need to upload / download file, separately on your own, apart from the
|
||||||
|
pipeline workflow in the script, you can do so by :
|
||||||
|
|
||||||
Upload:
|
Upload:
|
||||||
|
|
||||||
@@ -39,37 +41,37 @@ Download:
|
|||||||
If you want to access the S3 artefacts, from another machine, you can either use the python file_util with the commands
|
If you want to access the S3 artefacts, from another machine, you can either use the python file_util with the commands
|
||||||
mentioned above or simply use the GUI of AWS Management Console.
|
mentioned above or simply use the GUI of AWS Management Console.
|
||||||
|
|
||||||
|
To setup,
|
||||||
To setup,
|
|
||||||
|
|
||||||
1) Check values in config.ini file. Specifically add your OPENAI_APIKEY if you plan to use OpenAI API requests.
|
1) Check values in config.ini file. Specifically add your OPENAI_APIKEY if you plan to use OpenAI API requests.
|
||||||
2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.]
|
2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in
|
||||||
|
Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.]
|
||||||
|
|
||||||
NOTE: If you don't have portaudio installed already, run ```brew install portaudio```
|
NOTE: If you don't have portaudio installed already, run ```brew install portaudio```
|
||||||
|
|
||||||
3) Run the script setup_depedencies.sh.
|
3) Run the script setup_depedencies.sh.
|
||||||
|
|
||||||
``` chmod +x setup_dependencies.sh ```
|
``` chmod +x setup_dependencies.sh ```
|
||||||
|
|
||||||
``` sh setup_dependencies.sh <ENV>```
|
``` sh setup_dependencies.sh <ENV>```
|
||||||
|
|
||||||
|
ENV refers to the intended environment for JAX. JAX is available in several
|
||||||
ENV refers to the intended environment for JAX. JAX is available in several variants, [CPU | GPU | Colab TPU | Google Cloud TPU]
|
variants, [CPU | GPU | Colab TPU | Google Cloud TPU]
|
||||||
|
|
||||||
```ENV``` is :
|
|
||||||
|
|
||||||
cpu -> JAX CPU installation
|
|
||||||
|
|
||||||
cuda11 -> JAX CUDA 11.x version
|
```ENV``` is :
|
||||||
|
|
||||||
cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```)
|
cpu -> JAX CPU installation
|
||||||
|
|
||||||
|
cuda11 -> JAX CUDA 11.x version
|
||||||
|
|
||||||
|
cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```)
|
||||||
|
|
||||||
```sh setup_dependencies.sh cuda12```
|
```sh setup_dependencies.sh cuda12```
|
||||||
|
|
||||||
4) If not already done, install ffmpeg. ```brew install ffmpeg```
|
4) If not already done, install ffmpeg. ```brew install ffmpeg```
|
||||||
|
|
||||||
For NLTK SSL error, check [here](https://stackoverflow.com/questions/38916452/nltk-download-ssl-certificate-verify-failed)
|
For NLTK SSL error,
|
||||||
|
check [here](https://stackoverflow.com/questions/38916452/nltk-download-ssl-certificate-verify-failed)
|
||||||
|
|
||||||
5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
|
5) Run the Whisper-JAX pipeline. Currently, the repo can take a Youtube video and transcribes/summarizes it.
|
||||||
|
|
||||||
@@ -79,83 +81,92 @@ You can even run it on local file or a file in your configured S3 bucket.
|
|||||||
|
|
||||||
``` python3 whisjax.py "startup.mp4"```
|
``` python3 whisjax.py "startup.mp4"```
|
||||||
|
|
||||||
The script will take care of a few cases like youtube file, local file, video file, audio-only file,
|
The script will take care of a few cases like youtube file, local file, video file, audio-only file,
|
||||||
file in S3, etc. If local file is not present, it can automatically take the file from S3.
|
file in S3, etc. If local file is not present, it can automatically take the file from S3.
|
||||||
|
|
||||||
**OFFLINE WORKFLOW:**
|
**OFFLINE WORKFLOW:**
|
||||||
|
|
||||||
1) Specify the input source file] from a local, youtube link or upload to S3 if needed and pass it as input to the script.If the source file is in
|
1) Specify the input source file] from a local, youtube link or upload to S3 if needed and pass it as input to the
|
||||||
|
script.If the source file is in
|
||||||
```.m4a``` format, it will get converted to ```.mp4``` automatically.
|
```.m4a``` format, it will get converted to ```.mp4``` automatically.
|
||||||
2) Keep the agenda header topics in a local file named ```agenda-headers.txt```. This needs to be present where the script is run.
|
2) Keep the agenda header topics in a local file named ```agenda-headers.txt```. This needs to be present where the
|
||||||
|
script is run.
|
||||||
This version of the pipeline compares covered agenda topics using agenda headers in the following format.
|
This version of the pipeline compares covered agenda topics using agenda headers in the following format.
|
||||||
1) ```agenda_topic : <short description>```
|
1) ```agenda_topic : <short description>```
|
||||||
3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the
|
3) Check all the values in ```config.ini```. You need to predefine 2 categories for which you need to scatter plot the
|
||||||
topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact called
|
topic modelling visualization in the config file. This is the default visualization. But, from the dataframe artefact
|
||||||
```df_<timestamp>.pkl``` , you can load the df and choose different topics to plot. You can filter using certain words to search for the
|
called
|
||||||
|
```df_<timestamp>.pkl``` , you can load the df and choose different topics to plot. You can filter using certain
|
||||||
|
words to search for the
|
||||||
transcriptions and you can see the top influencers and characteristic in each topic we have chosen to plot in the
|
transcriptions and you can see the top influencers and characteristic in each topic we have chosen to plot in the
|
||||||
interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with, named
|
interactive HTML document. I have added a new jupyter notebook that gives the base template to play around with,
|
||||||
|
named
|
||||||
```Viz_experiments.ipynb```.
|
```Viz_experiments.ipynb```.
|
||||||
4) Run the script. The script automatically transcribes, summarizes and creates a scatter plot of words & topics in the form of an interactive
|
4) Run the script. The script automatically transcribes, summarizes and creates a scatter plot of words & topics in the
|
||||||
HTML file, a sample word cloud and uploads them to the S3 bucket
|
form of an interactive
|
||||||
|
HTML file, a sample word cloud and uploads them to the S3 bucket
|
||||||
5) Additional artefacts pushed to S3:
|
5) Additional artefacts pushed to S3:
|
||||||
1) HTML visualization file
|
1) HTML visualization file
|
||||||
2) pandas df in pickle format for others to collaborate and make their own visualizations
|
2) pandas df in pickle format for others to collaborate and make their own visualizations
|
||||||
3) Summary, transcript and transcript with timestamps file in text format.
|
3) Summary, transcript and transcript with timestamps file in text format.
|
||||||
|
|
||||||
The script also creates 2 types of mappings.
|
The script also creates 2 types of mappings.
|
||||||
1) Timestamp -> The top 2 matched agenda topic
|
1) Timestamp -> The top 2 matched agenda topic
|
||||||
2) Topic -> All matched timestamps in the transcription
|
2) Topic -> All matched timestamps in the transcription
|
||||||
|
|
||||||
Other visualizations can be planned based on available artefacts or new ones can be created. Refer the section ```Viz-experiments```.
|
|
||||||
|
|
||||||
|
Other visualizations can be planned based on available artefacts or new ones can be created. Refer the
|
||||||
|
section ```Viz-experiments```.
|
||||||
|
|
||||||
**Visualization experiments:**
|
**Visualization experiments:**
|
||||||
|
|
||||||
This is a jupyter notebook playground with template instructions on handling the metadata and data artefacts generated from the
|
This is a jupyter notebook playground with template instructions on handling the metadata and data artefacts generated
|
||||||
pipeline. Follow the instructions given and tweak your own logic into it or use it as a playground to experiment libraries and
|
from the
|
||||||
|
pipeline. Follow the instructions given and tweak your own logic into it or use it as a playground to experiment
|
||||||
|
libraries and
|
||||||
visualizations on top of the metadata.
|
visualizations on top of the metadata.
|
||||||
|
|
||||||
**WHISPER-JAX REALTIME TRANSCRIPTION PIPELINE:**
|
**WHISPER-JAX REALTIME TRANSCRIPTION PIPELINE:**
|
||||||
|
|
||||||
We also support a provision to perform real-time transcripton using whisper-jax pipeline. But, there are
|
We also support a provision to perform real-time transcripton using whisper-jax pipeline. But, there are
|
||||||
a few pre-requisites before you run it on your local machine. The instructions are for
|
a few pre-requisites before you run it on your local machine. The instructions are for
|
||||||
configuring on a MacOS.
|
configuring on a MacOS.
|
||||||
|
|
||||||
We need to way to route audio from an application opened via the browser, ex. "Whereby" and audio from your local
|
We need to way to route audio from an application opened via the browser, ex. "Whereby" and audio from your local
|
||||||
microphone input which you will be using for speaking. We use [Blackhole](https://github.com/ExistentialAudio/BlackHole).
|
microphone input which you will be using for speaking. We
|
||||||
|
use [Blackhole](https://github.com/ExistentialAudio/BlackHole).
|
||||||
|
|
||||||
1) Install Blackhole-2ch (2 ch is enough) by 1 of 2 options listed.
|
1) Install Blackhole-2ch (2 ch is enough) by 1 of 2 options listed.
|
||||||
2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and
|
2) Setup [Aggregate device](https://github.com/ExistentialAudio/BlackHole/wiki/Aggregate-Device) to route web audio and
|
||||||
local microphone input.
|
local microphone input.
|
||||||
|
|
||||||
Be sure to mirror the settings given 
|
Be sure to mirror the settings given 
|
||||||
3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device)
|
3) Setup [Multi-Output device](https://github.com/ExistentialAudio/BlackHole/wiki/Multi-Output-Device)
|
||||||
|
|
||||||
Refer 
|
Refer 
|
||||||
|
|
||||||
4) Set the aggregator input device name created in step 2 in config.ini as ```BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME```
|
4) Set the aggregator input device name created in step 2 in config.ini as ```BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME```
|
||||||
|
|
||||||
5) Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and
|
5) Then goto ``` System Preferences -> Sound ``` and choose the devices created from the Output and
|
||||||
Input tabs.
|
Input tabs.
|
||||||
|
|
||||||
6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen to
|
6) The input from your local microphone, the browser run meeting should be aggregated into one virtual stream to listen
|
||||||
and the output should be fed back to your specified output devices if everything is configured properly. Check this
|
to
|
||||||
before trying out the trial.
|
and the output should be fed back to your specified output devices if everything is configured properly. Check this
|
||||||
|
before trying out the trial.
|
||||||
|
|
||||||
**Permissions:**
|
**Permissions:**
|
||||||
|
|
||||||
You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in
|
You may have to add permission for "Terminal"/Code Editors [Pycharm/VSCode, etc.] microphone access to record audio in
|
||||||
```System Preferences -> Privacy & Security -> Microphone```,
|
```System Preferences -> Privacy & Security -> Microphone```,
|
||||||
```System Preferences -> Privacy & Security -> Accessibility```,
|
```System Preferences -> Privacy & Security -> Accessibility```,
|
||||||
```System Preferences -> Privacy & Security -> Input Monitoring```.
|
```System Preferences -> Privacy & Security -> Input Monitoring```.
|
||||||
|
|
||||||
From the reflector root folder,
|
From the reflector root folder,
|
||||||
|
|
||||||
run ```python3 whisjax_realtime.py```
|
run ```python3 whisjax_realtime.py```
|
||||||
|
|
||||||
The transcription text should be written to ```real_time_transcription_<timestamp>.txt```.
|
The transcription text should be written to ```real_time_transcription_<timestamp>.txt```.
|
||||||
|
|
||||||
|
|
||||||
NEXT STEPS:
|
NEXT STEPS:
|
||||||
|
|
||||||
1) Create a RunPod setup for this feature (mentioned in 1 & 2) and test it end-to-end
|
1) Create a RunPod setup for this feature (mentioned in 1 & 2) and test it end-to-end
|
||||||
|
|||||||
43
client.py
43
client.py
@@ -1,35 +1,33 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
|
||||||
import signal
|
import signal
|
||||||
|
|
||||||
from aiortc.contrib.signaling import (add_signaling_arguments,
|
from aiortc.contrib.signaling import (add_signaling_arguments,
|
||||||
create_signaling)
|
create_signaling)
|
||||||
|
|
||||||
from stream_client import StreamClient
|
from stream_client import StreamClient
|
||||||
|
from utils.log_utils import logger
|
||||||
logger = logging.getLogger("pc")
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
parser = argparse.ArgumentParser(description="Data channels ping/pong")
|
parser = argparse.ArgumentParser(description="Data channels ping/pong")
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--url", type=str, nargs="?", default="http://127.0.0.1:1250/offer"
|
"--url", type=str, nargs="?", default="http://127.0.0.1:1250/offer"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ping-pong",
|
"--ping-pong",
|
||||||
help="Benchmark data channel with ping pong",
|
help="Benchmark data channel with ping pong",
|
||||||
type=eval,
|
type=eval,
|
||||||
choices=[True, False],
|
choices=[True, False],
|
||||||
default="False",
|
default="False",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--play-from",
|
"--play-from",
|
||||||
type=str,
|
type=str,
|
||||||
default="",
|
default="",
|
||||||
)
|
)
|
||||||
add_signaling_arguments(parser)
|
add_signaling_arguments(parser)
|
||||||
|
|
||||||
@@ -39,34 +37,33 @@ async def main():
|
|||||||
|
|
||||||
async def shutdown(signal, loop):
|
async def shutdown(signal, loop):
|
||||||
"""Cleanup tasks tied to the service's shutdown."""
|
"""Cleanup tasks tied to the service's shutdown."""
|
||||||
logging.info(f"Received exit signal {signal.name}...")
|
logger.info(f"Received exit signal {signal.name}...")
|
||||||
logging.info("Closing database connections")
|
logger.info("Closing database connections")
|
||||||
logging.info("Nacking outstanding messages")
|
logger.info("Nacking outstanding messages")
|
||||||
tasks = [t for t in asyncio.all_tasks() if t is not
|
tasks = [t for t in asyncio.all_tasks() if t is not
|
||||||
asyncio.current_task()]
|
asyncio.current_task()]
|
||||||
|
|
||||||
[task.cancel() for task in tasks]
|
[task.cancel() for task in tasks]
|
||||||
|
|
||||||
logging.info(f"Cancelling {len(tasks)} outstanding tasks")
|
logger.info(f"Cancelling {len(tasks)} outstanding tasks")
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
logging.info(f"Flushing metrics")
|
logger.info(f'{"Flushing metrics"}')
|
||||||
loop.stop()
|
loop.stop()
|
||||||
|
|
||||||
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
signals = (signal.SIGHUP, signal.SIGTERM, signal.SIGINT)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
for s in signals:
|
for s in signals:
|
||||||
loop.add_signal_handler(
|
loop.add_signal_handler(
|
||||||
s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
|
s, lambda s=s: asyncio.create_task(shutdown(s, loop)))
|
||||||
|
|
||||||
# Init client
|
# Init client
|
||||||
sc = StreamClient(
|
sc = StreamClient(
|
||||||
signaling=signaling,
|
signaling=signaling,
|
||||||
url=args.url,
|
url=args.url,
|
||||||
play_from=args.play_from,
|
play_from=args.play_from,
|
||||||
ping_pong=args.ping_pong
|
ping_pong=args.ping_pong
|
||||||
)
|
)
|
||||||
await sc.start()
|
await sc.start()
|
||||||
print("Stream client started")
|
|
||||||
async for msg in sc.get_reader():
|
async for msg in sc.get_reader():
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
|
|||||||
22
config.ini
22
config.ini
@@ -1,22 +0,0 @@
|
|||||||
[DEFAULT]
|
|
||||||
# Set exception rule for OpenMP error to allow duplicate lib initialization
|
|
||||||
KMP_DUPLICATE_LIB_OK=TRUE
|
|
||||||
# Export OpenAI API Key
|
|
||||||
OPENAI_APIKEY=
|
|
||||||
# Export Whisper Model Size
|
|
||||||
WHISPER_MODEL_SIZE=tiny
|
|
||||||
WHISPER_REAL_TIME_MODEL_SIZE=tiny
|
|
||||||
# AWS config
|
|
||||||
AWS_ACCESS_KEY=***REMOVED***
|
|
||||||
AWS_SECRET_KEY=***REMOVED***
|
|
||||||
BUCKET_NAME='reflector-bucket'
|
|
||||||
# Summarizer config
|
|
||||||
SUMMARY_MODEL=facebook/bart-large-cnn
|
|
||||||
INPUT_ENCODING_MAX_LENGTH=1024
|
|
||||||
MAX_LENGTH=2048
|
|
||||||
BEAM_SIZE=6
|
|
||||||
MAX_CHUNK_LENGTH=1024
|
|
||||||
SUMMARIZE_USING_CHUNKS=YES
|
|
||||||
# Audio device
|
|
||||||
BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME=aggregator
|
|
||||||
AV_FOUNDATION_DEVICE_ID=2
|
|
||||||
@@ -1,10 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
# Get the input file name from the command line argument
|
# Get the input file name from the command line argument
|
||||||
input_file = sys.argv[1]
|
input_file = sys.argv[1]
|
||||||
# example use: python 0-reflector-local.py input.m4a agenda.txt
|
# example use: python 0-reflector-local.py input.m4a agenda.txt
|
||||||
|
|
||||||
# Get the agenda file name from the command line argument if provided
|
# Get the agenda file name from the command line argument if provided
|
||||||
@@ -21,7 +22,7 @@ if not os.path.exists(agenda_file):
|
|||||||
# Check if the input file is .m4a, if so convert to .mp4
|
# Check if the input file is .m4a, if so convert to .mp4
|
||||||
if input_file.endswith(".m4a"):
|
if input_file.endswith(".m4a"):
|
||||||
subprocess.run(["ffmpeg", "-i", input_file, f"{input_file}.mp4"])
|
subprocess.run(["ffmpeg", "-i", input_file, f"{input_file}.mp4"])
|
||||||
input_file = f"{input_file}.mp4"
|
input_file = f"{input_file}.mp4"
|
||||||
|
|
||||||
# Run the first script to generate the transcript
|
# Run the first script to generate the transcript
|
||||||
subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_file}_transcript.txt"])
|
subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_file}_transcript.txt"])
|
||||||
@@ -30,4 +31,4 @@ subprocess.run(["python3", "1-transcript-generator.py", input_file, f"{input_fil
|
|||||||
subprocess.run(["python3", "2-agenda-transcript-diff.py", agenda_file, f"{input_file}_transcript.txt"])
|
subprocess.run(["python3", "2-agenda-transcript-diff.py", agenda_file, f"{input_file}_transcript.txt"])
|
||||||
|
|
||||||
# Run the third script to summarize the transcript
|
# Run the third script to summarize the transcript
|
||||||
subprocess.run(["python3", "3-transcript-summarizer.py", f"{input_file}_transcript.txt", f"{input_file}_summary.txt"])
|
subprocess.run(["python3", "3-transcript-summarizer.py", f"{input_file}_transcript.txt", f"{input_file}_summary.txt"])
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import moviepy.editor
|
import moviepy.editor
|
||||||
from loguru import logger
|
|
||||||
import whisper
|
import whisper
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
WHISPER_MODEL_SIZE = "base"
|
WHISPER_MODEL_SIZE = "base"
|
||||||
|
|
||||||
|
|
||||||
def init_argparse() -> argparse.ArgumentParser:
|
def init_argparse() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
usage="%(prog)s <LOCATION> <OUTPUT>",
|
usage="%(prog)s <LOCATION> <OUTPUT>",
|
||||||
@@ -15,6 +17,7 @@ def init_argparse() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("output", help="Output file path")
|
parser.add_argument("output", help="Output file path")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import sys
|
import sys
|
||||||
sys.setrecursionlimit(10000)
|
sys.setrecursionlimit(10000)
|
||||||
@@ -26,10 +29,11 @@ def main():
|
|||||||
logger.info(f"Processing file: {media_file}")
|
logger.info(f"Processing file: {media_file}")
|
||||||
|
|
||||||
# Check if the media file is a valid audio or video file
|
# Check if the media file is a valid audio or video file
|
||||||
if os.path.isfile(media_file) and not media_file.endswith(('.mp3', '.wav', '.ogg', '.flac', '.mp4', '.avi', '.flv')):
|
if os.path.isfile(media_file) and not media_file.endswith(
|
||||||
|
('.mp3', '.wav', '.ogg', '.flac', '.mp4', '.avi', '.flv')):
|
||||||
logger.error(f"Invalid file format: {media_file}")
|
logger.error(f"Invalid file format: {media_file}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# If the media file we just retrieved is an audio file then skip extraction step
|
# If the media file we just retrieved is an audio file then skip extraction step
|
||||||
audio_filename = media_file
|
audio_filename = media_file
|
||||||
logger.info(f"Found audio-only file, skipping audio extraction")
|
logger.info(f"Found audio-only file, skipping audio extraction")
|
||||||
@@ -53,5 +57,6 @@ def main():
|
|||||||
transcript_file.write(whisper_result["text"])
|
transcript_file.write(whisper_result["text"])
|
||||||
transcript_file.close()
|
transcript_file.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
# Define the paths for agenda and transcription files
|
# Define the paths for agenda and transcription files
|
||||||
def init_argparse() -> argparse.ArgumentParser:
|
def init_argparse() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -11,6 +13,8 @@ def init_argparse() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("agenda", help="Location of the agenda file")
|
parser.add_argument("agenda", help="Location of the agenda file")
|
||||||
parser.add_argument("transcription", help="Location of the transcription file")
|
parser.add_argument("transcription", help="Location of the transcription file")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
args = init_argparse().parse_args()
|
args = init_argparse().parse_args()
|
||||||
agenda_path = args.agenda
|
agenda_path = args.agenda
|
||||||
transcription_path = args.transcription
|
transcription_path = args.transcription
|
||||||
@@ -19,7 +23,7 @@ transcription_path = args.transcription
|
|||||||
spaCy_model = "en_core_web_md"
|
spaCy_model = "en_core_web_md"
|
||||||
nlp = spacy.load(spaCy_model)
|
nlp = spacy.load(spaCy_model)
|
||||||
nlp.add_pipe('sentencizer')
|
nlp.add_pipe('sentencizer')
|
||||||
logger.info("Loaded spaCy model " + spaCy_model )
|
logger.info("Loaded spaCy model " + spaCy_model)
|
||||||
|
|
||||||
# Load the agenda
|
# Load the agenda
|
||||||
with open(agenda_path, "r") as f:
|
with open(agenda_path, "r") as f:
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
nltk.download('stopwords')
|
nltk.download('stopwords')
|
||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
from nltk.tokenize import word_tokenize, sent_tokenize
|
from nltk.tokenize import word_tokenize, sent_tokenize
|
||||||
from heapq import nlargest
|
from heapq import nlargest
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
||||||
# Function to initialize the argument parser
|
# Function to initialize the argument parser
|
||||||
def init_argparse():
|
def init_argparse():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@@ -17,12 +20,14 @@ def init_argparse():
|
|||||||
parser.add_argument("--num_sentences", type=int, default=5, help="Number of sentences to include in the summary")
|
parser.add_argument("--num_sentences", type=int, default=5, help="Number of sentences to include in the summary")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
# Function to read the input transcript file
|
# Function to read the input transcript file
|
||||||
def read_transcript(file_path):
|
def read_transcript(file_path):
|
||||||
with open(file_path, "r") as file:
|
with open(file_path, "r") as file:
|
||||||
transcript = file.read()
|
transcript = file.read()
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
|
|
||||||
# Function to preprocess the text by removing stop words and special characters
|
# Function to preprocess the text by removing stop words and special characters
|
||||||
def preprocess_text(text):
|
def preprocess_text(text):
|
||||||
stop_words = set(stopwords.words('english'))
|
stop_words = set(stopwords.words('english'))
|
||||||
@@ -30,6 +35,7 @@ def preprocess_text(text):
|
|||||||
words = [w.lower() for w in words if w.isalpha() and w.lower() not in stop_words]
|
words = [w.lower() for w in words if w.isalpha() and w.lower() not in stop_words]
|
||||||
return words
|
return words
|
||||||
|
|
||||||
|
|
||||||
# Function to score each sentence based on the frequency of its words and return the top sentences
|
# Function to score each sentence based on the frequency of its words and return the top sentences
|
||||||
def summarize_text(text, num_sentences):
|
def summarize_text(text, num_sentences):
|
||||||
# Tokenize the text into sentences
|
# Tokenize the text into sentences
|
||||||
@@ -61,6 +67,7 @@ def summarize_text(text, num_sentences):
|
|||||||
|
|
||||||
return " ".join(summary)
|
return " ".join(summary)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# Initialize the argument parser and parse the arguments
|
# Initialize the argument parser and parse the arguments
|
||||||
parser = init_argparse()
|
parser = init_argparse()
|
||||||
@@ -82,5 +89,6 @@ def main():
|
|||||||
|
|
||||||
logger.info("Summarization completed")
|
logger.info("Summarization completed")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
|
|
||||||
import moviepy.editor
|
import moviepy.editor
|
||||||
|
import nltk
|
||||||
|
import whisper
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import BartTokenizer, BartForConditionalGeneration
|
from transformers import BartTokenizer, BartForConditionalGeneration
|
||||||
import whisper
|
|
||||||
import nltk
|
|
||||||
nltk.download('punkt', quiet=True)
|
nltk.download('punkt', quiet=True)
|
||||||
|
|
||||||
WHISPER_MODEL_SIZE = "base"
|
WHISPER_MODEL_SIZE = "base"
|
||||||
|
|
||||||
|
|
||||||
def init_argparse() -> argparse.ArgumentParser:
|
def init_argparse() -> argparse.ArgumentParser:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
|
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
|
||||||
@@ -30,6 +33,7 @@ def init_argparse() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
# NLTK chunking function
|
# NLTK chunking function
|
||||||
def chunk_text(txt, max_chunk_length=500):
|
def chunk_text(txt, max_chunk_length=500):
|
||||||
"Split text into smaller chunks."
|
"Split text into smaller chunks."
|
||||||
@@ -45,6 +49,7 @@ def chunk_text(txt, max_chunk_length=500):
|
|||||||
chunks.append(current_chunk.strip())
|
chunks.append(current_chunk.strip())
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
# BART summary function
|
# BART summary function
|
||||||
def summarize_chunks(chunks, tokenizer, model):
|
def summarize_chunks(chunks, tokenizer, model):
|
||||||
summaries = []
|
summaries = []
|
||||||
@@ -56,6 +61,7 @@ def summarize_chunks(chunks, tokenizer, model):
|
|||||||
summaries.append(summary)
|
summaries.append(summary)
|
||||||
return summaries
|
return summaries
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
import sys
|
import sys
|
||||||
sys.setrecursionlimit(10000)
|
sys.setrecursionlimit(10000)
|
||||||
@@ -103,7 +109,7 @@ def main():
|
|||||||
chunks = chunk_text(whisper_result['text'])
|
chunks = chunk_text(whisper_result['text'])
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
|
f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
|
||||||
|
|
||||||
logger.info(f"Writing summary text in {args.language} to: {args.output}")
|
logger.info(f"Writing summary text in {args.language} to: {args.output}")
|
||||||
with open(args.output, 'w') as f:
|
with open(args.output, 'w') as f:
|
||||||
@@ -114,5 +120,6 @@ def main():
|
|||||||
|
|
||||||
logger.info("Summarization completed")
|
logger.info("Summarization completed")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ networkx==3.1
|
|||||||
numba==0.57.0
|
numba==0.57.0
|
||||||
numpy==1.24.3
|
numpy==1.24.3
|
||||||
openai==0.27.7
|
openai==0.27.7
|
||||||
openai-whisper @ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8
|
openai-whisper@ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8
|
||||||
Pillow==9.5.0
|
Pillow==9.5.0
|
||||||
proglog==0.1.10
|
proglog==0.1.10
|
||||||
pytube==15.0.0
|
pytube==15.0.0
|
||||||
@@ -56,5 +56,4 @@ cached_property==1.5.2
|
|||||||
stamina==23.1.0
|
stamina==23.1.0
|
||||||
httpx==0.24.1
|
httpx==0.24.1
|
||||||
sortedcontainers==2.4.0
|
sortedcontainers==2.4.0
|
||||||
openai-whisper @ git+https://github.com/openai/whisper.git@248b6cb124225dd263bb9bd32d060b6517e067f8
|
|
||||||
https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz
|
https://github.com/yt-dlp/yt-dlp/archive/master.tar.gz
|
||||||
|
|||||||
@@ -1,15 +1,24 @@
|
|||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
# Directory to search for Python files
|
# Directory to search for Python files
|
||||||
directory="."
|
cwd=$(pwd)
|
||||||
|
last_component="${cwd##*/}"
|
||||||
|
|
||||||
|
if [ "$last_component" = "reflector" ]; then
|
||||||
|
directory="./artefacts"
|
||||||
|
elif [ "$last_component" = "scripts" ]; then
|
||||||
|
directory="../artefacts"
|
||||||
|
fi
|
||||||
|
|
||||||
# Pattern to match Python files (e.g., "*.py" for all .py files)
|
# Pattern to match Python files (e.g., "*.py" for all .py files)
|
||||||
text_file_pattern="transcript_*.txt"
|
transcript_file_pattern="transcript_*.txt"
|
||||||
|
summary_file_pattern="summary_*.txt"
|
||||||
pickle_file_pattern="*.pkl"
|
pickle_file_pattern="*.pkl"
|
||||||
html_file_pattern="*.html"
|
html_file_pattern="*.html"
|
||||||
png_file_pattern="wordcloud*.png"
|
png_file_pattern="wordcloud*.png"
|
||||||
|
|
||||||
find "$directory" -type f -name "$text_file_pattern" -delete
|
find "$directory" -type f -name "$transcript_file_pattern" -delete
|
||||||
|
find "$directory" -type f -name "$summary_file_pattern" -delete
|
||||||
find "$directory" -type f -name "$pickle_file_pattern" -delete
|
find "$directory" -type f -name "$pickle_file_pattern" -delete
|
||||||
find "$directory" -type f -name "$html_file_pattern" -delete
|
find "$directory" -type f -name "$html_file_pattern" -delete
|
||||||
find "$directory" -type f -name "$png_file_pattern" -delete
|
find "$directory" -type f -name "$png_file_pattern" -delete
|
||||||
|
|||||||
@@ -1,9 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import uuid
|
import uuid
|
||||||
import wave
|
import wave
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
@@ -13,24 +10,21 @@ from aiohttp import web
|
|||||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||||
from aiortc.contrib.media import MediaRelay
|
from aiortc.contrib.media import MediaRelay
|
||||||
from av import AudioFifo
|
from av import AudioFifo
|
||||||
|
from loguru import logger
|
||||||
from whisper_jax import FlaxWhisperPipline
|
from whisper_jax import FlaxWhisperPipline
|
||||||
|
|
||||||
from utils.server_utils import run_in_executor
|
from utils.run_utils import run_in_executor
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
transcription = ""
|
|
||||||
|
|
||||||
pcs = set()
|
pcs = set()
|
||||||
relay = MediaRelay()
|
relay = MediaRelay()
|
||||||
data_channel = None
|
data_channel = None
|
||||||
total_bytes_handled = 0
|
pipeline = FlaxWhisperPipline("openai/whisper-tiny",
|
||||||
pipeline = FlaxWhisperPipline("openai/whisper-tiny", dtype=jnp.float16, batch_size=16)
|
dtype=jnp.float16,
|
||||||
|
batch_size=16)
|
||||||
|
|
||||||
CHANNELS = 2
|
CHANNELS = 2
|
||||||
RATE = 48000
|
RATE = 48000
|
||||||
audio_buffer = AudioFifo()
|
audio_buffer = AudioFifo()
|
||||||
start_time = datetime.datetime.now()
|
|
||||||
executor = ThreadPoolExecutor()
|
executor = ThreadPoolExecutor()
|
||||||
|
|
||||||
|
|
||||||
@@ -40,30 +34,12 @@ def channel_log(channel, t, message):
|
|||||||
|
|
||||||
def channel_send(channel, message):
|
def channel_send(channel, message):
|
||||||
# channel_log(channel, ">", message)
|
# channel_log(channel, ">", message)
|
||||||
global start_time
|
|
||||||
if channel:
|
if channel:
|
||||||
channel.send(message)
|
channel.send(message)
|
||||||
print(
|
|
||||||
"Bytes handled :",
|
|
||||||
total_bytes_handled,
|
|
||||||
" Time : ",
|
|
||||||
datetime.datetime.now() - start_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_transcription(frames):
|
def get_transcription(frames):
|
||||||
print("Transcribing..")
|
print("Transcribing..")
|
||||||
# samples = np.ndarray(
|
|
||||||
# np.concatenate([f.to_ndarray() for f in frames], axis=None),
|
|
||||||
# dtype=np.float32,
|
|
||||||
# )
|
|
||||||
# whisper_result = pipeline(
|
|
||||||
# {
|
|
||||||
# "array": samples,
|
|
||||||
# "sampling_rate": 48000,
|
|
||||||
# },
|
|
||||||
# return_timestamps=True,
|
|
||||||
# )
|
|
||||||
out_file = io.BytesIO()
|
out_file = io.BytesIO()
|
||||||
wf = wave.open(out_file, "wb")
|
wf = wave.open(out_file, "wb")
|
||||||
wf.setnchannels(CHANNELS)
|
wf.setnchannels(CHANNELS)
|
||||||
@@ -73,8 +49,6 @@ def get_transcription(frames):
|
|||||||
for frame in frames:
|
for frame in frames:
|
||||||
wf.writeframes(b"".join(frame.to_ndarray()))
|
wf.writeframes(b"".join(frame.to_ndarray()))
|
||||||
wf.close()
|
wf.close()
|
||||||
global total_bytes_handled
|
|
||||||
total_bytes_handled += sys.getsizeof(wf)
|
|
||||||
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
|
whisper_result = pipeline(out_file.getvalue(), return_timestamps=True)
|
||||||
with open("test_exec.txt", "a") as f:
|
with open("test_exec.txt", "a") as f:
|
||||||
f.write(whisper_result["text"])
|
f.write(whisper_result["text"])
|
||||||
@@ -98,19 +72,19 @@ class AudioStreamTrack(MediaStreamTrack):
|
|||||||
audio_buffer.write(frame)
|
audio_buffer.write(frame)
|
||||||
if local_frames := audio_buffer.read_many(256 * 960, partial=False):
|
if local_frames := audio_buffer.read_many(256 * 960, partial=False):
|
||||||
whisper_result = run_in_executor(
|
whisper_result = run_in_executor(
|
||||||
get_transcription, local_frames, executor=executor
|
get_transcription, local_frames, executor=executor
|
||||||
)
|
)
|
||||||
whisper_result.add_done_callback(
|
whisper_result.add_done_callback(
|
||||||
lambda f: channel_send(data_channel, str(whisper_result.result()))
|
lambda f: channel_send(data_channel,
|
||||||
if (f.result())
|
str(whisper_result.result()))
|
||||||
else None
|
if (f.result())
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
return frame
|
return frame
|
||||||
|
|
||||||
|
|
||||||
async def offer(request):
|
async def offer(request):
|
||||||
params = await request.json()
|
params = await request.json()
|
||||||
print("Request received")
|
|
||||||
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||||
|
|
||||||
pc = RTCPeerConnection()
|
pc = RTCPeerConnection()
|
||||||
@@ -120,55 +94,47 @@ async def offer(request):
|
|||||||
def log_info(msg, *args):
|
def log_info(msg, *args):
|
||||||
logger.info(pc_id + " " + msg, *args)
|
logger.info(pc_id + " " + msg, *args)
|
||||||
|
|
||||||
log_info("Created for %s", request.remote)
|
log_info("Created for " + request.remote)
|
||||||
|
|
||||||
@pc.on("datachannel")
|
@pc.on("datachannel")
|
||||||
def on_datachannel(channel):
|
def on_datachannel(channel):
|
||||||
global data_channel, start_time
|
global data_channel
|
||||||
data_channel = channel
|
data_channel = channel
|
||||||
channel_log(channel, "-", "created by remote party")
|
channel_log(channel, "-", "created by remote party")
|
||||||
start_time = datetime.datetime.now()
|
|
||||||
|
|
||||||
@channel.on("message")
|
@channel.on("message")
|
||||||
def on_message(message):
|
def on_message(message):
|
||||||
channel_log(channel, "<", message)
|
channel_log(channel, "<", message)
|
||||||
|
|
||||||
if isinstance(message, str) and message.startswith("ping"):
|
if isinstance(message, str) and message.startswith("ping"):
|
||||||
# reply
|
|
||||||
channel_send(channel, "pong" + message[4:])
|
channel_send(channel, "pong" + message[4:])
|
||||||
|
|
||||||
@pc.on("connectionstatechange")
|
@pc.on("connectionstatechange")
|
||||||
async def on_connectionstatechange():
|
async def on_connectionstatechange():
|
||||||
log_info("Connection state is %s", pc.connectionState)
|
log_info("Connection state is " + pc.connectionState)
|
||||||
if pc.connectionState == "failed":
|
if pc.connectionState == "failed":
|
||||||
await pc.close()
|
await pc.close()
|
||||||
pcs.discard(pc)
|
pcs.discard(pc)
|
||||||
|
|
||||||
@pc.on("track")
|
@pc.on("track")
|
||||||
def on_track(track):
|
def on_track(track):
|
||||||
print("Track %s received" % track.kind)
|
log_info("Track " + track.kind + " received")
|
||||||
log_info("Track %s received", track.kind)
|
|
||||||
# Trials to listen to the correct track
|
|
||||||
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
|
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
|
||||||
# pc.addTrack(AudioStreamTrack(track))
|
|
||||||
|
|
||||||
# handle offer
|
|
||||||
await pc.setRemoteDescription(offer)
|
await pc.setRemoteDescription(offer)
|
||||||
|
|
||||||
# send answer
|
|
||||||
answer = await pc.createAnswer()
|
answer = await pc.createAnswer()
|
||||||
await pc.setLocalDescription(answer)
|
await pc.setLocalDescription(answer)
|
||||||
print("Response sent")
|
|
||||||
return web.Response(
|
return web.Response(
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
text=json.dumps(
|
text=json.dumps(
|
||||||
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
{"sdp": pc.localDescription.sdp,
|
||||||
),
|
"type": pc.localDescription.type}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def on_shutdown(app):
|
async def on_shutdown(app):
|
||||||
# close peer connections
|
|
||||||
coros = [pc.close() for pc in pcs]
|
coros = [pc.close() for pc in pcs]
|
||||||
await asyncio.gather(*coros)
|
await asyncio.gather(*coros)
|
||||||
pcs.clear()
|
pcs.clear()
|
||||||
|
|||||||
@@ -1,54 +1,38 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import configparser
|
|
||||||
import datetime
|
import datetime
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
import wave
|
import wave
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
from aiohttp import webq
|
import requests
|
||||||
|
from aiohttp import web
|
||||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||||
from aiortc.contrib.media import (MediaRelay)
|
from aiortc.contrib.media import MediaRelay
|
||||||
from av import AudioFifo
|
from av import AudioFifo
|
||||||
from sortedcontainers import SortedDict
|
from sortedcontainers import SortedDict
|
||||||
from whisper_jax import FlaxWhisperPipline
|
from whisper_jax import FlaxWhisperPipline
|
||||||
|
|
||||||
from utils.server_utils import Mutex
|
from utils.log_utils import logger
|
||||||
|
from utils.run_utils import config, Mutex
|
||||||
|
|
||||||
ROOT = os.path.dirname(__file__)
|
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_REAL_TIME_MODEL_SIZE"]
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
|
||||||
config.read('config.ini')
|
|
||||||
|
|
||||||
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
|
||||||
|
|
||||||
logger = logging.getLogger("pc")
|
|
||||||
pcs = set()
|
pcs = set()
|
||||||
relay = MediaRelay()
|
relay = MediaRelay()
|
||||||
data_channel = None
|
data_channel = None
|
||||||
sorted_message_queue = SortedDict()
|
sorted_message_queue = SortedDict()
|
||||||
|
|
||||||
CHANNELS = 2
|
CHANNELS = 2
|
||||||
RATE = 44100
|
RATE = 44100
|
||||||
CHUNK_SIZE = 256
|
CHUNK_SIZE = 256
|
||||||
|
|
||||||
audio_buffer = AudioFifo()
|
|
||||||
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
|
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
|
||||||
dtype=jnp.float16,
|
dtype=jnp.float16,
|
||||||
batch_size=16)
|
batch_size=16)
|
||||||
|
|
||||||
transcription = ""
|
|
||||||
start_time = datetime.datetime.now()
|
start_time = datetime.datetime.now()
|
||||||
total_bytes_handled = 0
|
|
||||||
|
|
||||||
executor = ThreadPoolExecutor()
|
executor = ThreadPoolExecutor()
|
||||||
|
audio_buffer = AudioFifo()
|
||||||
frame_lock = Mutex(audio_buffer)
|
frame_lock = Mutex(audio_buffer)
|
||||||
|
|
||||||
|
|
||||||
@@ -81,6 +65,7 @@ def get_transcription():
|
|||||||
transcribe = True
|
transcribe = True
|
||||||
|
|
||||||
if transcribe:
|
if transcribe:
|
||||||
|
print("Transcribing..")
|
||||||
try:
|
try:
|
||||||
sorted_message_queue[frames[0].time] = None
|
sorted_message_queue[frames[0].time] = None
|
||||||
out_file = io.BytesIO()
|
out_file = io.BytesIO()
|
||||||
@@ -94,10 +79,11 @@ def get_transcription():
|
|||||||
wf.close()
|
wf.close()
|
||||||
|
|
||||||
whisper_result = pipeline(out_file.getvalue())
|
whisper_result = pipeline(out_file.getvalue())
|
||||||
item = {'text': whisper_result["text"],
|
item = {
|
||||||
|
'text': whisper_result["text"],
|
||||||
'start_time': str(frames[0].time),
|
'start_time': str(frames[0].time),
|
||||||
'time': str(datetime.datetime.now())
|
'time': str(datetime.datetime.now())
|
||||||
}
|
}
|
||||||
sorted_message_queue[frames[0].time] = str(item)
|
sorted_message_queue[frames[0].time] = str(item)
|
||||||
start_messaging_thread()
|
start_messaging_thread()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -106,7 +92,7 @@ def get_transcription():
|
|||||||
|
|
||||||
class AudioStreamTrack(MediaStreamTrack):
|
class AudioStreamTrack(MediaStreamTrack):
|
||||||
"""
|
"""
|
||||||
A video stream track that transforms frames from an another track.
|
An audio stream track to send audio frames.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
kind = "audio"
|
kind = "audio"
|
||||||
@@ -126,15 +112,13 @@ def start_messaging_thread():
|
|||||||
message_thread.start()
|
message_thread.start()
|
||||||
|
|
||||||
|
|
||||||
def start_transcription_thread(max_threads):
|
def start_transcription_thread(max_threads: int):
|
||||||
t_threads = []
|
|
||||||
for i in range(max_threads):
|
for i in range(max_threads):
|
||||||
t_thread = threading.Thread(target=get_transcription, args=(i,))
|
t_thread = threading.Thread(target=get_transcription)
|
||||||
t_threads.append(t_thread)
|
|
||||||
t_thread.start()
|
t_thread.start()
|
||||||
|
|
||||||
|
|
||||||
async def offer(request):
|
async def offer(request: requests.Request):
|
||||||
params = await request.json()
|
params = await request.json()
|
||||||
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
offer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||||
|
|
||||||
@@ -142,10 +126,10 @@ async def offer(request):
|
|||||||
pc_id = "PeerConnection(%s)" % uuid.uuid4()
|
pc_id = "PeerConnection(%s)" % uuid.uuid4()
|
||||||
pcs.add(pc)
|
pcs.add(pc)
|
||||||
|
|
||||||
def log_info(msg, *args):
|
def log_info(msg: str, *args):
|
||||||
logger.info(pc_id + " " + msg, *args)
|
logger.info(pc_id + " " + msg, *args)
|
||||||
|
|
||||||
log_info("Created for %s", request.remote)
|
log_info("Created for " + request.remote)
|
||||||
|
|
||||||
@pc.on("datachannel")
|
@pc.on("datachannel")
|
||||||
def on_datachannel(channel):
|
def on_datachannel(channel):
|
||||||
@@ -155,7 +139,7 @@ async def offer(request):
|
|||||||
start_time = datetime.datetime.now()
|
start_time = datetime.datetime.now()
|
||||||
|
|
||||||
@channel.on("message")
|
@channel.on("message")
|
||||||
def on_message(message):
|
def on_message(message: str):
|
||||||
channel_log(channel, "<", message)
|
channel_log(channel, "<", message)
|
||||||
if isinstance(message, str) and message.startswith("ping"):
|
if isinstance(message, str) and message.startswith("ping"):
|
||||||
# reply
|
# reply
|
||||||
@@ -163,14 +147,14 @@ async def offer(request):
|
|||||||
|
|
||||||
@pc.on("connectionstatechange")
|
@pc.on("connectionstatechange")
|
||||||
async def on_connectionstatechange():
|
async def on_connectionstatechange():
|
||||||
log_info("Connection state is %s", pc.connectionState)
|
log_info("Connection state is " + pc.connectionState)
|
||||||
if pc.connectionState == "failed":
|
if pc.connectionState == "failed":
|
||||||
await pc.close()
|
await pc.close()
|
||||||
pcs.discard(pc)
|
pcs.discard(pc)
|
||||||
|
|
||||||
@pc.on("track")
|
@pc.on("track")
|
||||||
def on_track(track):
|
def on_track(track):
|
||||||
log_info("Track %s received", track.kind)
|
log_info("Track " + track.kind + " received")
|
||||||
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
|
pc.addTrack(AudioStreamTrack(relay.subscribe(track)))
|
||||||
|
|
||||||
# handle offer
|
# handle offer
|
||||||
@@ -180,14 +164,15 @@ async def offer(request):
|
|||||||
answer = await pc.createAnswer()
|
answer = await pc.createAnswer()
|
||||||
await pc.setLocalDescription(answer)
|
await pc.setLocalDescription(answer)
|
||||||
return web.Response(
|
return web.Response(
|
||||||
content_type="application/json",
|
content_type="application/json",
|
||||||
text=json.dumps(
|
text=json.dumps({
|
||||||
{"sdp": pc.localDescription.sdp, "type": pc.localDescription.type}
|
"sdp": pc.localDescription.sdp,
|
||||||
),
|
"type": pc.localDescription.type
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def on_shutdown(app):
|
async def on_shutdown(app: web.Application):
|
||||||
coros = [pc.close() for pc in pcs]
|
coros = [pc.close() for pc in pcs]
|
||||||
await asyncio.gather(*coros)
|
await asyncio.gather(*coros)
|
||||||
pcs.clear()
|
pcs.clear()
|
||||||
@@ -199,5 +184,5 @@ if __name__ == "__main__":
|
|||||||
start_transcription_thread(6)
|
start_transcription_thread(6)
|
||||||
app.router.add_post("/offer", offer)
|
app.router.add_post("/offer", offer)
|
||||||
web.run_app(
|
web.run_app(
|
||||||
app, access_log=None, host="127.0.0.1", port=1250
|
app, access_log=None, host="127.0.0.1", port=1250
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import ast
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
import configparser
|
|
||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@@ -12,14 +10,11 @@ import stamina
|
|||||||
from aiortc import (RTCPeerConnection, RTCSessionDescription)
|
from aiortc import (RTCPeerConnection, RTCSessionDescription)
|
||||||
from aiortc.contrib.media import (MediaPlayer, MediaRelay)
|
from aiortc.contrib.media import (MediaPlayer, MediaRelay)
|
||||||
|
|
||||||
from utils.server_utils import Mutex
|
from utils.log_utils import logger
|
||||||
|
from utils.run_utils import config, Mutex
|
||||||
|
|
||||||
logger = logging.getLogger("pc")
|
|
||||||
file_lock = Mutex(open("test_sm_6.txt", "a"))
|
file_lock = Mutex(open("test_sm_6.txt", "a"))
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
|
||||||
config.read('config.ini')
|
|
||||||
|
|
||||||
|
|
||||||
class StreamClient:
|
class StreamClient:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -42,14 +37,15 @@ class StreamClient:
|
|||||||
self.pcs = set()
|
self.pcs = set()
|
||||||
self.time_start = None
|
self.time_start = None
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
self.player = MediaPlayer(':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]),
|
self.player = MediaPlayer(
|
||||||
format='avfoundation', options={'channels': '2'})
|
':' + str(config['DEFAULT']["AV_FOUNDATION_DEVICE_ID"]),
|
||||||
|
format='avfoundation',
|
||||||
|
options={'channels': '2'})
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self.loop.run_until_complete(self.signaling.close())
|
self.loop.run_until_complete(self.signaling.close())
|
||||||
self.loop.run_until_complete(self.pc.close())
|
self.loop.run_until_complete(self.pc.close())
|
||||||
# self.loop.close()
|
# self.loop.close()
|
||||||
print("ended")
|
|
||||||
|
|
||||||
def create_local_tracks(self, play_from):
|
def create_local_tracks(self, play_from):
|
||||||
if play_from:
|
if play_from:
|
||||||
@@ -58,7 +54,6 @@ class StreamClient:
|
|||||||
else:
|
else:
|
||||||
if self.relay is None:
|
if self.relay is None:
|
||||||
self.relay = MediaRelay()
|
self.relay = MediaRelay()
|
||||||
print("Created local track from microphone stream")
|
|
||||||
return self.relay.subscribe(self.player.audio), None
|
return self.relay.subscribe(self.player.audio), None
|
||||||
|
|
||||||
def channel_log(self, channel, t, message):
|
def channel_log(self, channel, t, message):
|
||||||
@@ -122,14 +117,15 @@ class StreamClient:
|
|||||||
self.channel_log(channel, "<", message)
|
self.channel_log(channel, "<", message)
|
||||||
|
|
||||||
if isinstance(message, str) and message.startswith("pong"):
|
if isinstance(message, str) and message.startswith("pong"):
|
||||||
elapsed_ms = (self.current_stamp() - int(message[5:])) / 1000
|
elapsed_ms = (self.current_stamp() - int(message[5:]))\
|
||||||
|
/ 1000
|
||||||
print(" RTT %.2f ms" % elapsed_ms)
|
print(" RTT %.2f ms" % elapsed_ms)
|
||||||
|
|
||||||
await pc.setLocalDescription(await pc.createOffer())
|
await pc.setLocalDescription(await pc.createOffer())
|
||||||
|
|
||||||
sdp = {
|
sdp = {
|
||||||
"sdp": pc.localDescription.sdp,
|
"sdp": pc.localDescription.sdp,
|
||||||
"type": pc.localDescription.type
|
"type": pc.localDescription.type
|
||||||
}
|
}
|
||||||
|
|
||||||
@stamina.retry(on=httpx.HTTPError, attempts=5)
|
@stamina.retry(on=httpx.HTTPError, attempts=5)
|
||||||
@@ -142,7 +138,7 @@ class StreamClient:
|
|||||||
answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
answer = RTCSessionDescription(sdp=params["sdp"], type=params["type"])
|
||||||
await pc.setRemoteDescription(answer)
|
await pc.setRemoteDescription(answer)
|
||||||
|
|
||||||
self.reader = self.worker(f"worker", self.queue)
|
self.reader = self.worker(f'{"worker"}', self.queue)
|
||||||
|
|
||||||
def get_reader(self):
|
def get_reader(self):
|
||||||
return self.reader
|
return self.reader
|
||||||
|
|||||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
@@ -1,13 +1,12 @@
|
|||||||
import configparser
|
import sys
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import botocore
|
import botocore
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
from .log_utils import logger
|
||||||
config.read('config.ini')
|
from .run_utils import config
|
||||||
|
|
||||||
BUCKET_NAME = 'reflector-bucket'
|
BUCKET_NAME = config["DEFAULT"]["BUCKET_NAME"]
|
||||||
|
|
||||||
s3 = boto3.client('s3',
|
s3 = boto3.client('s3',
|
||||||
aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"],
|
aws_access_key_id=config["DEFAULT"]["AWS_ACCESS_KEY"],
|
||||||
@@ -17,8 +16,8 @@ s3 = boto3.client('s3',
|
|||||||
def upload_files(files_to_upload):
|
def upload_files(files_to_upload):
|
||||||
"""
|
"""
|
||||||
Upload a list of files to the configured S3 bucket
|
Upload a list of files to the configured S3 bucket
|
||||||
:param files_to_upload:
|
:param files_to_upload: List of files to upload
|
||||||
:return:
|
:return: None
|
||||||
"""
|
"""
|
||||||
for KEY in files_to_upload:
|
for KEY in files_to_upload:
|
||||||
logger.info("Uploading file " + KEY)
|
logger.info("Uploading file " + KEY)
|
||||||
@@ -31,8 +30,8 @@ def upload_files(files_to_upload):
|
|||||||
def download_files(files_to_download):
|
def download_files(files_to_download):
|
||||||
"""
|
"""
|
||||||
Download a list of files from the configured S3 bucket
|
Download a list of files from the configured S3 bucket
|
||||||
:param files_to_download:
|
:param files_to_download: List of files to download
|
||||||
:return:
|
:return: None
|
||||||
"""
|
"""
|
||||||
for KEY in files_to_download:
|
for KEY in files_to_download:
|
||||||
logger.info("Downloading file " + KEY)
|
logger.info("Downloading file " + KEY)
|
||||||
@@ -46,8 +45,6 @@ def download_files(files_to_download):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import sys
|
|
||||||
|
|
||||||
if sys.argv[1] == "download":
|
if sys.argv[1] == "download":
|
||||||
download_files([sys.argv[2]])
|
download_files([sys.argv[2]])
|
||||||
elif sys.argv[1] == "upload":
|
elif sys.argv[1] == "upload":
|
||||||
18
utils/log_utils.py
Normal file
18
utils/log_utils.py
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
import loguru
|
||||||
|
|
||||||
|
|
||||||
|
class SingletonLogger:
|
||||||
|
__instance = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_logger():
|
||||||
|
"""
|
||||||
|
Create or return the singleton instance for the SingletonLogger class
|
||||||
|
:return: SingletonLogger instance
|
||||||
|
"""
|
||||||
|
if not SingletonLogger.__instance:
|
||||||
|
SingletonLogger.__instance = loguru.logger
|
||||||
|
return SingletonLogger.__instance
|
||||||
|
|
||||||
|
|
||||||
|
logger = SingletonLogger.get_logger()
|
||||||
66
utils/run_utils.py
Normal file
66
utils/run_utils.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import asyncio
|
||||||
|
import configparser
|
||||||
|
import contextlib
|
||||||
|
from functools import partial
|
||||||
|
from threading import Lock
|
||||||
|
from typing import ContextManager, Generic, TypeVar
|
||||||
|
|
||||||
|
|
||||||
|
class ReflectorConfig:
|
||||||
|
__config = None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_config():
|
||||||
|
if ReflectorConfig.__config is None:
|
||||||
|
ReflectorConfig.__config = configparser.ConfigParser()
|
||||||
|
ReflectorConfig.__config.read('utils/config.ini')
|
||||||
|
return ReflectorConfig.__config
|
||||||
|
|
||||||
|
|
||||||
|
config = ReflectorConfig.get_config()
|
||||||
|
|
||||||
|
|
||||||
|
def run_in_executor(func, *args, executor=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Run the function in an executor, unblocking the main loop
|
||||||
|
:param func: Function to be run in executor
|
||||||
|
:param args: function parameters
|
||||||
|
:param executor: executor instance [Thread | Process]
|
||||||
|
:param kwargs: Additional parameters
|
||||||
|
:return: Future of function result upon completion
|
||||||
|
"""
|
||||||
|
callback = partial(func, *args, **kwargs)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_in_executor(executor, callback)
|
||||||
|
|
||||||
|
|
||||||
|
# Genetic type template
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Mutex(Generic[T]):
|
||||||
|
"""
|
||||||
|
Mutex class to implement lock/release of a shared
|
||||||
|
protected variable
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value: T):
|
||||||
|
"""
|
||||||
|
Create an instance of Mutex wrapper for the given resource
|
||||||
|
:param value: Shared resources to be thread protected
|
||||||
|
"""
|
||||||
|
self.__value = value
|
||||||
|
self.__lock = Lock()
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def lock(self) -> ContextManager[T]:
|
||||||
|
"""
|
||||||
|
Lock the resource with a mutex to be used within a context block
|
||||||
|
The lock is automatically released on context exit
|
||||||
|
:return: Shared resource
|
||||||
|
"""
|
||||||
|
self.__lock.acquire()
|
||||||
|
try:
|
||||||
|
yield self.__value
|
||||||
|
finally:
|
||||||
|
self.__lock.release()
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import contextlib
|
|
||||||
from functools import partial
|
|
||||||
from threading import Lock
|
|
||||||
from typing import ContextManager, Generic, TypeVar
|
|
||||||
|
|
||||||
|
|
||||||
def run_in_executor(func, *args, executor=None, **kwargs):
|
|
||||||
callback = partial(func, *args, **kwargs)
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
return asyncio.get_event_loop().run_in_executor(executor, callback)
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class Mutex(Generic[T]):
|
|
||||||
def __init__(self, value: T):
|
|
||||||
self.__value = value
|
|
||||||
self.__lock = Lock()
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def lock(self) -> ContextManager[T]:
|
|
||||||
self.__lock.acquire()
|
|
||||||
try:
|
|
||||||
yield self.__value
|
|
||||||
finally:
|
|
||||||
self.__lock.release()
|
|
||||||
@@ -1,24 +1,22 @@
|
|||||||
import configparser
|
|
||||||
|
|
||||||
import nltk
|
import nltk
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
|
||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
from nltk.tokenize import word_tokenize
|
from nltk.tokenize import word_tokenize
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
from sklearn.metrics.pairwise import cosine_similarity
|
from sklearn.metrics.pairwise import cosine_similarity
|
||||||
from transformers import BartTokenizer, BartForConditionalGeneration
|
from transformers import BartForConditionalGeneration, BartTokenizer
|
||||||
|
|
||||||
|
from utils.log_utils import logger
|
||||||
|
from utils.run_utils import config
|
||||||
|
|
||||||
nltk.download('punkt', quiet=True)
|
nltk.download('punkt', quiet=True)
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
|
||||||
config.read('config.ini')
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_sentence(sentence):
|
def preprocess_sentence(sentence):
|
||||||
stop_words = set(stopwords.words('english'))
|
stop_words = set(stopwords.words('english'))
|
||||||
tokens = word_tokenize(sentence.lower())
|
tokens = word_tokenize(sentence.lower())
|
||||||
tokens = [token for token in tokens if token.isalnum() and token not in stop_words]
|
tokens = [token for token in tokens
|
||||||
|
if token.isalnum() and token not in stop_words]
|
||||||
return ' '.join(tokens)
|
return ' '.join(tokens)
|
||||||
|
|
||||||
|
|
||||||
@@ -52,12 +50,14 @@ def remove_almost_alike_sentences(sentences, threshold=0.7):
|
|||||||
sentence1 = preprocess_sentence(sentences[i])
|
sentence1 = preprocess_sentence(sentences[i])
|
||||||
sentence2 = preprocess_sentence(sentences[j])
|
sentence2 = preprocess_sentence(sentences[j])
|
||||||
if len(sentence1) != 0 and len(sentence2) != 0:
|
if len(sentence1) != 0 and len(sentence2) != 0:
|
||||||
similarity = compute_similarity(sentence1, sentence2)
|
similarity = compute_similarity(sentence1,
|
||||||
|
sentence2)
|
||||||
|
|
||||||
if similarity >= threshold:
|
if similarity >= threshold:
|
||||||
removed_indices.add(max(i, j))
|
removed_indices.add(max(i, j))
|
||||||
|
|
||||||
filtered_sentences = [sentences[i] for i in range(num_sentences) if i not in removed_indices]
|
filtered_sentences = [sentences[i] for i in range(num_sentences)
|
||||||
|
if i not in removed_indices]
|
||||||
return filtered_sentences
|
return filtered_sentences
|
||||||
|
|
||||||
|
|
||||||
@@ -77,11 +77,13 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
|||||||
words = nltk.word_tokenize(sent)
|
words = nltk.word_tokenize(sent)
|
||||||
n_gram_filter = 3
|
n_gram_filter = 3
|
||||||
for i in range(len(words)):
|
for i in range(len(words)):
|
||||||
if str(words[i:i + n_gram_filter]) in seen and seen[str(words[i:i + n_gram_filter])] == words[
|
if str(words[i:i + n_gram_filter]) in seen and \
|
||||||
i + 1:i + n_gram_filter + 2]:
|
seen[str(words[i:i + n_gram_filter])] == \
|
||||||
|
words[i + 1:i + n_gram_filter + 2]:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
seen[str(words[i:i + n_gram_filter])] = words[i + 1:i + n_gram_filter + 2]
|
seen[str(words[i:i + n_gram_filter])] = \
|
||||||
|
words[i + 1:i + n_gram_filter + 2]
|
||||||
temp_result += words[i]
|
temp_result += words[i]
|
||||||
temp_result += " "
|
temp_result += " "
|
||||||
chunk_sentences.append(temp_result)
|
chunk_sentences.append(temp_result)
|
||||||
@@ -91,9 +93,12 @@ def remove_whisper_repetitive_hallucination(nonduplicate_sentences):
|
|||||||
def post_process_transcription(whisper_result):
|
def post_process_transcription(whisper_result):
|
||||||
transcript_text = ""
|
transcript_text = ""
|
||||||
for chunk in whisper_result["chunks"]:
|
for chunk in whisper_result["chunks"]:
|
||||||
nonduplicate_sentences = remove_outright_duplicate_sentences_from_chunk(chunk)
|
nonduplicate_sentences = \
|
||||||
chunk_sentences = remove_whisper_repetitive_hallucination(nonduplicate_sentences)
|
remove_outright_duplicate_sentences_from_chunk(chunk)
|
||||||
similarity_matched_sentences = remove_almost_alike_sentences(chunk_sentences)
|
chunk_sentences = \
|
||||||
|
remove_whisper_repetitive_hallucination(nonduplicate_sentences)
|
||||||
|
similarity_matched_sentences = \
|
||||||
|
remove_almost_alike_sentences(chunk_sentences)
|
||||||
chunk["text"] = " ".join(similarity_matched_sentences)
|
chunk["text"] = " ".join(similarity_matched_sentences)
|
||||||
transcript_text += chunk["text"]
|
transcript_text += chunk["text"]
|
||||||
whisper_result["text"] = transcript_text
|
whisper_result["text"] = transcript_text
|
||||||
@@ -114,18 +119,23 @@ def summarize_chunks(chunks, tokenizer, model):
|
|||||||
input_ids = tokenizer.encode(c, return_tensors='pt')
|
input_ids = tokenizer.encode(c, return_tensors='pt')
|
||||||
input_ids = input_ids.to(device)
|
input_ids = input_ids.to(device)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
summary_ids = model.generate(input_ids,
|
summary_ids = \
|
||||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
model.generate(input_ids,
|
||||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]),
|
||||||
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
|
length_penalty=2.0,
|
||||||
|
max_length=int(config["DEFAULT"]["MAX_LENGTH"]),
|
||||||
|
early_stopping=True)
|
||||||
|
summary = tokenizer.decode(summary_ids[0],
|
||||||
|
skip_special_tokens=True)
|
||||||
summaries.append(summary)
|
summaries.append(summary)
|
||||||
return summaries
|
return summaries
|
||||||
|
|
||||||
|
|
||||||
def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
def chunk_text(text,
|
||||||
|
max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])):
|
||||||
"""
|
"""
|
||||||
Split text into smaller chunks.
|
Split text into smaller chunks.
|
||||||
:param txt: Text to be chunked
|
:param text: Text to be chunked
|
||||||
:param max_chunk_length: length of chunk
|
:param max_chunk_length: length of chunk
|
||||||
:return: chunked texts
|
:return: chunked texts
|
||||||
"""
|
"""
|
||||||
@@ -143,7 +153,8 @@ def chunk_text(text, max_chunk_length=int(config["DEFAULT"]["MAX_CHUNK_LENGTH"])
|
|||||||
|
|
||||||
|
|
||||||
def summarize(transcript_text, timestamp,
|
def summarize(transcript_text, timestamp,
|
||||||
real_time=False, summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
real_time=False,
|
||||||
|
summarize_using_chunks=config["DEFAULT"]["SUMMARIZE_USING_CHUNKS"]):
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
|
summary_model = config["DEFAULT"]["SUMMARY_MODEL"]
|
||||||
if not summary_model:
|
if not summary_model:
|
||||||
@@ -160,9 +171,11 @@ def summarize(transcript_text, timestamp,
|
|||||||
output_filename = "real_time_" + output_filename
|
output_filename = "real_time_" + output_filename
|
||||||
|
|
||||||
if summarize_using_chunks != "YES":
|
if summarize_using_chunks != "YES":
|
||||||
inputs = tokenizer.batch_encode_plus([transcript_text], truncation=True, padding='longest',
|
inputs = tokenizer.\
|
||||||
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
|
batch_encode_plus([transcript_text], truncation=True,
|
||||||
return_tensors='pt')
|
padding='longest',
|
||||||
|
max_length=int(config["DEFAULT"]["INPUT_ENCODING_MAX_LENGTH"]),
|
||||||
|
return_tensors='pt')
|
||||||
inputs = inputs.to(device)
|
inputs = inputs.to(device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@@ -170,16 +183,17 @@ def summarize(transcript_text, timestamp,
|
|||||||
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
num_beams=int(config["DEFAULT"]["BEAM_SIZE"]), length_penalty=2.0,
|
||||||
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
max_length=int(config["DEFAULT"]["MAX_LENGTH"]), early_stopping=True)
|
||||||
|
|
||||||
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False) for
|
decoded_summaries = [tokenizer.decode(summary, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
summary in summaries]
|
for summary in summaries]
|
||||||
summary = " ".join(decoded_summaries)
|
summary = " ".join(decoded_summaries)
|
||||||
with open(output_filename, 'w') as f:
|
with open("./artefacts/" + output_filename, 'w') as f:
|
||||||
f.write(summary.strip() + "\n")
|
f.write(summary.strip() + "\n")
|
||||||
else:
|
else:
|
||||||
logger.info("Breaking transcript into smaller chunks")
|
logger.info("Breaking transcript into smaller chunks")
|
||||||
chunks = chunk_text(transcript_text)
|
chunks = chunk_text(transcript_text)
|
||||||
|
|
||||||
logger.info(f"Transcript broken into {len(chunks)} chunks of at most 500 words") # TODO fix variable
|
logger.info(f"Transcript broken into {len(chunks)} "
|
||||||
|
f"chunks of at most 500 words")
|
||||||
|
|
||||||
logger.info(f"Writing summary text to: {output_filename}")
|
logger.info(f"Writing summary text to: {output_filename}")
|
||||||
with open(output_filename, 'w') as f:
|
with open(output_filename, 'w') as f:
|
||||||
|
|||||||
@@ -1,24 +1,20 @@
|
|||||||
import ast
|
import ast
|
||||||
import collections
|
import collections
|
||||||
import configparser
|
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import scattertext as st
|
import scattertext as st
|
||||||
import spacy
|
import spacy
|
||||||
from nltk.corpus import stopwords
|
from nltk.corpus import stopwords
|
||||||
from wordcloud import WordCloud, STOPWORDS
|
from wordcloud import STOPWORDS, WordCloud
|
||||||
|
|
||||||
config = configparser.ConfigParser()
|
|
||||||
config.read('config.ini')
|
|
||||||
|
|
||||||
en = spacy.load('en_core_web_md')
|
en = spacy.load('en_core_web_md')
|
||||||
spacy_stopwords = en.Defaults.stop_words
|
spacy_stopwords = en.Defaults.stop_words
|
||||||
|
|
||||||
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).union(set(spacy_stopwords))
|
STOPWORDS = set(STOPWORDS).union(set(stopwords.words("english"))).\
|
||||||
|
union(set(spacy_stopwords))
|
||||||
|
|
||||||
|
|
||||||
def create_wordcloud(timestamp, real_time=False):
|
def create_wordcloud(timestamp, real_time=False):
|
||||||
@@ -28,7 +24,8 @@ def create_wordcloud(timestamp, real_time=False):
|
|||||||
"""
|
"""
|
||||||
filename = "transcript"
|
filename = "transcript"
|
||||||
if real_time:
|
if real_time:
|
||||||
filename = "real_time_" + filename + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
filename = "real_time_" + filename + "_" +\
|
||||||
|
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||||
else:
|
else:
|
||||||
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
filename += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||||
|
|
||||||
@@ -50,11 +47,12 @@ def create_wordcloud(timestamp, real_time=False):
|
|||||||
|
|
||||||
wordcloud_name = "wordcloud"
|
wordcloud_name = "wordcloud"
|
||||||
if real_time:
|
if real_time:
|
||||||
wordcloud_name = "real_time_" + wordcloud_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
wordcloud_name = "real_time_" + wordcloud_name + "_" +\
|
||||||
|
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
||||||
else:
|
else:
|
||||||
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
wordcloud_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".png"
|
||||||
|
|
||||||
plt.savefig(wordcloud_name)
|
plt.savefig("./artefacts/" + wordcloud_name)
|
||||||
|
|
||||||
|
|
||||||
def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
||||||
@@ -70,7 +68,6 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
|||||||
agenda_topics = []
|
agenda_topics = []
|
||||||
agenda = []
|
agenda = []
|
||||||
# Load the agenda
|
# Load the agenda
|
||||||
path = Path(__file__)
|
|
||||||
with open(os.path.join(os.getcwd(), "agenda-headers.txt"), "r") as f:
|
with open(os.path.join(os.getcwd(), "agenda-headers.txt"), "r") as f:
|
||||||
for line in f.readlines():
|
for line in f.readlines():
|
||||||
if line.strip():
|
if line.strip():
|
||||||
@@ -80,9 +77,11 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
|||||||
# Load the transcription with timestamp
|
# Load the transcription with timestamp
|
||||||
filename = ""
|
filename = ""
|
||||||
if real_time:
|
if real_time:
|
||||||
filename = "real_time_transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
filename = "./artefacts/real_time_transcript_with_timestamp_" +\
|
||||||
|
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||||
else:
|
else:
|
||||||
filename = "transcript_with_timestamp_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
filename = "./artefacts/transcript_with_timestamp_" +\
|
||||||
|
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".txt"
|
||||||
with open(filename) as f:
|
with open(filename) as f:
|
||||||
transcription_timestamp_text = f.read()
|
transcription_timestamp_text = f.read()
|
||||||
|
|
||||||
@@ -98,7 +97,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
|||||||
ts_to_topic_mapping_top_1 = {}
|
ts_to_topic_mapping_top_1 = {}
|
||||||
ts_to_topic_mapping_top_2 = {}
|
ts_to_topic_mapping_top_2 = {}
|
||||||
|
|
||||||
# Also create a mapping of the different timestamps in which each topic was covered
|
# Also create a mapping of the different timestamps
|
||||||
|
# in which each topic was covered
|
||||||
topic_to_ts_mapping_top_1 = collections.defaultdict(list)
|
topic_to_ts_mapping_top_1 = collections.defaultdict(list)
|
||||||
topic_to_ts_mapping_top_2 = collections.defaultdict(list)
|
topic_to_ts_mapping_top_2 = collections.defaultdict(list)
|
||||||
|
|
||||||
@@ -109,7 +109,8 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
|||||||
topic_similarities = []
|
topic_similarities = []
|
||||||
for item in range(len(agenda)):
|
for item in range(len(agenda)):
|
||||||
item_doc = nlp(agenda[item])
|
item_doc = nlp(agenda[item])
|
||||||
# if not doc_transcription or not all(token.has_vector for token in doc_transcription):
|
# if not doc_transcription or not all
|
||||||
|
# (token.has_vector for token in doc_transcription):
|
||||||
if not doc_transcription:
|
if not doc_transcription:
|
||||||
continue
|
continue
|
||||||
similarity = doc_transcription.similarity(item_doc)
|
similarity = doc_transcription.similarity(item_doc)
|
||||||
@@ -133,8 +134,10 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
|||||||
:param record:
|
:param record:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
record["ts_to_topic_mapping_top_1"] = ts_to_topic_mapping_top_1[record["timestamp"]]
|
record["ts_to_topic_mapping_top_1"] = \
|
||||||
record["ts_to_topic_mapping_top_2"] = ts_to_topic_mapping_top_2[record["timestamp"]]
|
ts_to_topic_mapping_top_1[record["timestamp"]]
|
||||||
|
record["ts_to_topic_mapping_top_2"] = \
|
||||||
|
ts_to_topic_mapping_top_2[record["timestamp"]]
|
||||||
return record
|
return record
|
||||||
|
|
||||||
df = df.apply(create_new_columns, axis=1)
|
df = df.apply(create_new_columns, axis=1)
|
||||||
@@ -155,20 +158,22 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
|||||||
# Save df, mappings for further experimentation
|
# Save df, mappings for further experimentation
|
||||||
df_name = "df"
|
df_name = "df"
|
||||||
if real_time:
|
if real_time:
|
||||||
df_name = "real_time_" + df_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
df_name = "real_time_" + df_name + "_" +\
|
||||||
|
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||||
else:
|
else:
|
||||||
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
df_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||||
df.to_pickle(df_name)
|
df.to_pickle("./artefacts/" + df_name)
|
||||||
|
|
||||||
my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2,
|
my_mappings = [ts_to_topic_mapping_top_1, ts_to_topic_mapping_top_2,
|
||||||
topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2]
|
topic_to_ts_mapping_top_1, topic_to_ts_mapping_top_2]
|
||||||
|
|
||||||
mappings_name = "mappings"
|
mappings_name = "mappings"
|
||||||
if real_time:
|
if real_time:
|
||||||
mappings_name = "real_time_" + mappings_name + "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
mappings_name = "real_time_" + mappings_name + "_" +\
|
||||||
|
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||||
else:
|
else:
|
||||||
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
mappings_name += "_" + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + ".pkl"
|
||||||
pickle.dump(my_mappings, open(mappings_name, "wb"))
|
pickle.dump(my_mappings, open("./artefacts/" + mappings_name, "wb"))
|
||||||
|
|
||||||
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
|
# to load, my_mappings = pickle.load( open ("mappings.pkl", "rb") )
|
||||||
|
|
||||||
@@ -182,25 +187,28 @@ def create_talk_diff_scatter_viz(timestamp, real_time=False):
|
|||||||
|
|
||||||
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
|
topic_times = sorted(topic_times.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
cat_1 = topic_times[0][0]
|
if len(topic_times) > 1:
|
||||||
cat_1_name = topic_times[0][0]
|
cat_1 = topic_times[0][0]
|
||||||
cat_2_name = topic_times[1][0]
|
cat_1_name = topic_times[0][0]
|
||||||
|
cat_2_name = topic_times[1][0]
|
||||||
|
|
||||||
# Scatter plot of topics
|
# Scatter plot of topics
|
||||||
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
|
df = df.assign(parse=lambda df: df.text.apply(st.whitespace_nlp_with_sentences))
|
||||||
corpus = st.CorpusFromParsedDocuments(
|
corpus = st.CorpusFromParsedDocuments(
|
||||||
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
|
df, category_col='ts_to_topic_mapping_top_1', parsed_col='parse'
|
||||||
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
|
).build().get_unigram_corpus().compact(st.AssociationCompactor(2000))
|
||||||
html = st.produce_scattertext_explorer(
|
html = st.produce_scattertext_explorer(
|
||||||
corpus,
|
corpus,
|
||||||
category=cat_1,
|
category=cat_1,
|
||||||
category_name=cat_1_name,
|
category_name=cat_1_name,
|
||||||
not_category_name=cat_2_name,
|
not_category_name=cat_2_name,
|
||||||
minimum_term_frequency=0, pmi_threshold_coefficient=0,
|
minimum_term_frequency=0, pmi_threshold_coefficient=0,
|
||||||
width_in_pixels=1000,
|
width_in_pixels=1000,
|
||||||
transform=st.Scalers.dense_rank
|
transform=st.Scalers.dense_rank
|
||||||
)
|
)
|
||||||
if real_time:
|
if real_time:
|
||||||
open('./artefacts/real_time_scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
open('./artefacts/real_time_scatter_' +
|
||||||
else:
|
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||||
open('./artefacts/scatter_' + timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
else:
|
||||||
|
open('./artefacts/scatter_' +
|
||||||
|
timestamp.strftime("%m-%d-%Y_%H:%M:%S") + '.html', 'w').write(html)
|
||||||
|
|||||||
86
whisjax.py
86
whisjax.py
@@ -1,11 +1,10 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
# summarize https://www.youtube.com/watch?v=imzTxoEDH_g --transcript=transcript.txt summary.txt
|
# summarize https://www.youtube.com/watch?v=imzTxoEDH_g
|
||||||
# summarize https://www.sprocket.org/video/cheesemaking.mp4 summary.txt
|
# summarize https://www.sprocket.org/video/cheesemaking.mp4 summary.txt
|
||||||
# summarize podcast.mp3 summary.txt
|
# summarize podcast.mp3 summary.txt
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import configparser
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -15,23 +14,19 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import moviepy.editor
|
import moviepy.editor
|
||||||
import moviepy.editor
|
|
||||||
import nltk
|
import nltk
|
||||||
import yt_dlp as youtube_dl
|
import yt_dlp as youtube_dl
|
||||||
from loguru import logger
|
|
||||||
from whisper_jax import FlaxWhisperPipline
|
from whisper_jax import FlaxWhisperPipline
|
||||||
|
|
||||||
from utils.file_utilities import upload_files, download_files
|
from utils.file_utils import download_files, upload_files
|
||||||
from utils.text_utilities import summarize, post_process_transcription
|
from utils.log_utils import logger
|
||||||
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
from utils.run_utils import config
|
||||||
|
from utils.text_utilities import post_process_transcription, summarize
|
||||||
|
from utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud
|
||||||
|
|
||||||
nltk.download('punkt', quiet=True)
|
nltk.download('punkt', quiet=True)
|
||||||
nltk.download('stopwords', quiet=True)
|
nltk.download('stopwords', quiet=True)
|
||||||
|
|
||||||
# Configurations can be found in config.ini. Set them properly before executing
|
|
||||||
config = configparser.ConfigParser()
|
|
||||||
config.read('config.ini')
|
|
||||||
|
|
||||||
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
||||||
NOW = datetime.now()
|
NOW = datetime.now()
|
||||||
|
|
||||||
@@ -42,12 +37,17 @@ def init_argparse() -> argparse.ArgumentParser:
|
|||||||
:return: parser object
|
:return: parser object
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
|
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
|
||||||
description="Creates a transcript of a video or audio file, then summarizes it using ChatGPT."
|
description="Creates a transcript of a video or audio file, then"
|
||||||
|
" summarizes it using ChatGPT."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str,
|
parser.add_argument("-l", "--language",
|
||||||
default="english", choices=['english', 'spanish', 'french', 'german', 'romanian'])
|
help="Language that the summary should be written in",
|
||||||
|
type=str,
|
||||||
|
default="english",
|
||||||
|
choices=['english', 'spanish', 'french', 'german',
|
||||||
|
'romanian'])
|
||||||
parser.add_argument("location")
|
parser.add_argument("location")
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@@ -65,22 +65,24 @@ def main():
|
|||||||
|
|
||||||
media_file = ""
|
media_file = ""
|
||||||
if url.scheme == 'http' or url.scheme == 'https':
|
if url.scheme == 'http' or url.scheme == 'https':
|
||||||
# Check if we're being asked to retreive a YouTube URL, which is handled
|
# Check if we're being asked to retreive a YouTube URL, which is
|
||||||
# diffrently, as we'll use a secondary site to download the video first.
|
# handled differently, as we'll use a secondary site to download
|
||||||
|
# the video first.
|
||||||
if re.search('youtube.com', url.netloc, re.IGNORECASE):
|
if re.search('youtube.com', url.netloc, re.IGNORECASE):
|
||||||
# Download the lowest resolution YouTube video (since we're just interested in the audio).
|
# Download the lowest resolution YouTube video
|
||||||
|
# (since we're just interested in the audio).
|
||||||
# It will be saved to the current directory.
|
# It will be saved to the current directory.
|
||||||
logger.info("Downloading YouTube video at url: " + args.location)
|
logger.info("Downloading YouTube video at url: " + args.location)
|
||||||
|
|
||||||
# Create options for the download
|
# Create options for the download
|
||||||
ydl_opts = {
|
ydl_opts = {
|
||||||
'format': 'bestaudio/best',
|
'format': 'bestaudio/best',
|
||||||
'postprocessors': [{
|
'postprocessors': [{
|
||||||
'key': 'FFmpegExtractAudio',
|
'key': 'FFmpegExtractAudio',
|
||||||
'preferredcodec': 'mp3',
|
'preferredcodec': 'mp3',
|
||||||
'preferredquality': '192',
|
'preferredquality': '192',
|
||||||
}],
|
}],
|
||||||
'outtmpl': 'audio', # Specify the output file path and name
|
'outtmpl': 'audio', # Specify output file path and name
|
||||||
}
|
}
|
||||||
|
|
||||||
# Download the audio
|
# Download the audio
|
||||||
@@ -90,7 +92,8 @@ def main():
|
|||||||
|
|
||||||
logger.info("Saved downloaded YouTube video to: " + media_file)
|
logger.info("Saved downloaded YouTube video to: " + media_file)
|
||||||
else:
|
else:
|
||||||
# XXX - Download file using urllib, check if file is audio/video using python-magic
|
# XXX - Download file using urllib, check if file is
|
||||||
|
# audio/video using python-magic
|
||||||
logger.info(f"Downloading file at url: {args.location}")
|
logger.info(f"Downloading file at url: {args.location}")
|
||||||
logger.info(" XXX - This method hasn't been implemented yet.")
|
logger.info(" XXX - This method hasn't been implemented yet.")
|
||||||
elif url.scheme == '':
|
elif url.scheme == '':
|
||||||
@@ -101,7 +104,7 @@ def main():
|
|||||||
|
|
||||||
if media_file.endswith(".m4a"):
|
if media_file.endswith(".m4a"):
|
||||||
subprocess.run(["ffmpeg", "-i", media_file, f"{media_file}.mp4"])
|
subprocess.run(["ffmpeg", "-i", media_file, f"{media_file}.mp4"])
|
||||||
input_file = f"{media_file}.mp4"
|
media_file = f"{media_file}.mp4"
|
||||||
else:
|
else:
|
||||||
print("Unsupported URL scheme: " + url.scheme)
|
print("Unsupported URL scheme: " + url.scheme)
|
||||||
quit()
|
quit()
|
||||||
@@ -110,19 +113,21 @@ def main():
|
|||||||
if not media_file.endswith(".mp3"):
|
if not media_file.endswith(".mp3"):
|
||||||
try:
|
try:
|
||||||
video = moviepy.editor.VideoFileClip(media_file)
|
video = moviepy.editor.VideoFileClip(media_file)
|
||||||
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
|
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3",
|
||||||
|
delete=False).name
|
||||||
video.audio.write_audiofile(audio_filename, logger=None)
|
video.audio.write_audiofile(audio_filename, logger=None)
|
||||||
logger.info(f"Extracting audio to: {audio_filename}")
|
logger.info(f"Extracting audio to: {audio_filename}")
|
||||||
# Handle audio only file
|
# Handle audio only file
|
||||||
except:
|
except Exception:
|
||||||
audio = moviepy.editor.AudioFileClip(media_file)
|
audio = moviepy.editor.AudioFileClip(media_file)
|
||||||
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
|
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3",
|
||||||
|
delete=False).name
|
||||||
audio.write_audiofile(audio_filename, logger=None)
|
audio.write_audiofile(audio_filename, logger=None)
|
||||||
else:
|
else:
|
||||||
audio_filename = media_file
|
audio_filename = media_file
|
||||||
|
|
||||||
logger.info("Finished extracting audio")
|
logger.info("Finished extracting audio")
|
||||||
|
logger.info("Transcribing")
|
||||||
# Convert the audio to text using the OpenAI Whisper model
|
# Convert the audio to text using the OpenAI Whisper model
|
||||||
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
|
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE,
|
||||||
dtype=jnp.float16,
|
dtype=jnp.float16,
|
||||||
@@ -136,10 +141,12 @@ def main():
|
|||||||
for chunk in whisper_result["chunks"]:
|
for chunk in whisper_result["chunks"]:
|
||||||
transcript_text += chunk["text"]
|
transcript_text += chunk["text"]
|
||||||
|
|
||||||
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as transcript_file:
|
with open("./artefacts/transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") +
|
||||||
|
".txt", "w") as transcript_file:
|
||||||
transcript_file.write(transcript_text)
|
transcript_file.write(transcript_text)
|
||||||
|
|
||||||
with open("./artefacts/transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt",
|
with open("./artefacts/transcript_with_timestamp_" +
|
||||||
|
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt",
|
||||||
"w") as transcript_file_timestamps:
|
"w") as transcript_file_timestamps:
|
||||||
transcript_file_timestamps.write(str(whisper_result))
|
transcript_file_timestamps.write(str(whisper_result))
|
||||||
|
|
||||||
@@ -150,13 +157,14 @@ def main():
|
|||||||
create_talk_diff_scatter_viz(NOW)
|
create_talk_diff_scatter_viz(NOW)
|
||||||
|
|
||||||
# S3 : Push artefacts to S3 bucket
|
# S3 : Push artefacts to S3 bucket
|
||||||
|
prefix = "./artefacts/"
|
||||||
suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S")
|
suffix = NOW.strftime("%m-%d-%Y_%H:%M:%S")
|
||||||
files_to_upload = ["transcript_" + suffix + ".txt",
|
files_to_upload = [prefix + "transcript_" + suffix + ".txt",
|
||||||
"transcript_with_timestamp_" + suffix + ".txt",
|
prefix + "transcript_with_timestamp_" + suffix + ".txt",
|
||||||
"df_" + suffix + ".pkl",
|
prefix + "df_" + suffix + ".pkl",
|
||||||
"wordcloud_" + suffix + ".png",
|
prefix + "wordcloud_" + suffix + ".png",
|
||||||
"mappings_" + suffix + ".pkl",
|
prefix + "mappings_" + suffix + ".pkl",
|
||||||
"scatter_" + suffix + ".html"]
|
prefix + "scatter_" + suffix + ".html"]
|
||||||
upload_files(files_to_upload)
|
upload_files(files_to_upload)
|
||||||
|
|
||||||
summarize(transcript_text, NOW, False, False)
|
summarize(transcript_text, NOW, False, False)
|
||||||
|
|||||||
@@ -1,23 +1,20 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
import configparser
|
|
||||||
import time
|
import time
|
||||||
import wave
|
import wave
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import jax.numpy as jnp
|
import jax.numpy as jnp
|
||||||
import pyaudio
|
import pyaudio
|
||||||
from loguru import logger
|
|
||||||
from pynput import keyboard
|
from pynput import keyboard
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from whisper_jax import FlaxWhisperPipline
|
from whisper_jax import FlaxWhisperPipline
|
||||||
|
|
||||||
from utils.file_utilities import upload_files
|
from utils.file_utils import upload_files
|
||||||
from utils.text_utilities import summarize, post_process_transcription
|
from utils.log_utils import logger
|
||||||
from utils.viz_utilities import create_wordcloud, create_talk_diff_scatter_viz
|
from utils.run_utils import config
|
||||||
|
from utils.text_utilities import post_process_transcription, summarize
|
||||||
config = configparser.ConfigParser()
|
from utils.viz_utilities import create_talk_diff_scatter_viz, create_wordcloud
|
||||||
config.read('config.ini')
|
|
||||||
|
|
||||||
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
||||||
|
|
||||||
@@ -33,19 +30,21 @@ def main():
|
|||||||
p = pyaudio.PyAudio()
|
p = pyaudio.PyAudio()
|
||||||
AUDIO_DEVICE_ID = -1
|
AUDIO_DEVICE_ID = -1
|
||||||
for i in range(p.get_device_count()):
|
for i in range(p.get_device_count()):
|
||||||
if p.get_device_info_by_index(i)["name"] == config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]:
|
if p.get_device_info_by_index(i)["name"] == \
|
||||||
|
config["DEFAULT"]["BLACKHOLE_INPUT_AGGREGATOR_DEVICE_NAME"]:
|
||||||
AUDIO_DEVICE_ID = i
|
AUDIO_DEVICE_ID = i
|
||||||
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
|
audio_devices = p.get_device_info_by_index(AUDIO_DEVICE_ID)
|
||||||
stream = p.open(
|
stream = p.open(
|
||||||
format=FORMAT,
|
format=FORMAT,
|
||||||
channels=CHANNELS,
|
channels=CHANNELS,
|
||||||
rate=RATE,
|
rate=RATE,
|
||||||
input=True,
|
input=True,
|
||||||
frames_per_buffer=FRAMES_PER_BUFFER,
|
frames_per_buffer=FRAMES_PER_BUFFER,
|
||||||
input_device_index=int(audio_devices['index'])
|
input_device_index=int(audio_devices['index'])
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = FlaxWhisperPipline("openai/whisper-" + config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"],
|
pipeline = FlaxWhisperPipline("openai/whisper-" +
|
||||||
|
config["DEFAULT"]["WHISPER_REAL_TIME_MODEL_SIZE"],
|
||||||
dtype=jnp.float16,
|
dtype=jnp.float16,
|
||||||
batch_size=16)
|
batch_size=16)
|
||||||
|
|
||||||
@@ -72,7 +71,8 @@ def main():
|
|||||||
frames = []
|
frames = []
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
|
for i in range(0, int(RATE / FRAMES_PER_BUFFER * RECORD_SECONDS)):
|
||||||
data = stream.read(FRAMES_PER_BUFFER, exception_on_overflow=False)
|
data = stream.read(FRAMES_PER_BUFFER,
|
||||||
|
exception_on_overflow=False)
|
||||||
frames.append(data)
|
frames.append(data)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
@@ -90,7 +90,8 @@ def main():
|
|||||||
if end is None:
|
if end is None:
|
||||||
end = start + 15.0
|
end = start + 15.0
|
||||||
duration = end - start
|
duration = end - start
|
||||||
item = {'timestamp': (last_transcribed_time, last_transcribed_time + duration),
|
item = {'timestamp': (last_transcribed_time,
|
||||||
|
last_transcribed_time + duration),
|
||||||
'text': whisper_result['text'],
|
'text': whisper_result['text'],
|
||||||
'stats': (str(end_time - start_time), str(duration))
|
'stats': (str(end_time - start_time), str(duration))
|
||||||
}
|
}
|
||||||
@@ -100,15 +101,19 @@ def main():
|
|||||||
|
|
||||||
print(colored("<START>", "yellow"))
|
print(colored("<START>", "yellow"))
|
||||||
print(colored(whisper_result['text'], 'green'))
|
print(colored(whisper_result['text'], 'green'))
|
||||||
print(colored("<END> Recorded duration: " + str(end_time - start_time) + " | Transcribed duration: " +
|
print(colored("<END> Recorded duration: " +
|
||||||
|
str(end_time - start_time) +
|
||||||
|
" | Transcribed duration: " +
|
||||||
str(duration), "yellow"))
|
str(duration), "yellow"))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
finally:
|
finally:
|
||||||
with open("real_time_transcript_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
|
with open("real_time_transcript_" +
|
||||||
|
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
|
||||||
f.write(transcription)
|
f.write(transcription)
|
||||||
with open("real_time_transcript_with_timestamp_" + NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
|
with open("real_time_transcript_with_timestamp_" +
|
||||||
|
NOW.strftime("%m-%d-%Y_%H:%M:%S") + ".txt", "w") as f:
|
||||||
transcript_with_timestamp["text"] = transcription
|
transcript_with_timestamp["text"] = transcription
|
||||||
f.write(str(transcript_with_timestamp))
|
f.write(str(transcript_with_timestamp))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user