mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Merge pull request #1 from Monadical-SAS/whisper-jax-gokul
Whisper JAX pipeline for demo
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -158,3 +158,6 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
*.mp4
|
||||
*.txt
|
||||
45
README.md
45
README.md
@@ -1,2 +1,43 @@
|
||||
# transcription
|
||||
Transcription experiments
|
||||
# Reflector
|
||||
|
||||
This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM.
|
||||
|
||||
The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target agenda/objectives to the actual discussion live.
|
||||
|
||||
To setup,
|
||||
|
||||
1) Check values in config.ini file. Specifically add your OPENAI_APIKEY.
|
||||
2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.]
|
||||
3) Run the script setup_depedencies.sh.
|
||||
|
||||
``` chmod +x setup_dependecies.sh ```
|
||||
|
||||
``` sh setup_dependencies.sh <ENV>```
|
||||
|
||||
|
||||
ENV refers to the intended environment for JAX. JAX is available in several variants, [CPU | GPU | Colab TPU | Google Cloud TPU]
|
||||
|
||||
```ENV``` is :
|
||||
|
||||
cpu -> JAX CPU installation
|
||||
|
||||
cuda11 -> JAX CUDA 11.x version
|
||||
|
||||
cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```)
|
||||
|
||||
sh setup_dependencies.sh cuda12
|
||||
|
||||
4) Run the Whisper-JAX pipeline. Currently, the repo takes a Youtube video and transcribes/summarizes it.
|
||||
|
||||
``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ" --transcript transcript.txt summary.txt ```
|
||||
|
||||
|
||||
|
||||
NEXT STEPS:
|
||||
|
||||
1) Run this demo on a local Mac M1 to test flow and observe the performance
|
||||
2) Create a pipeline using microphone to listen to audio chunks to perform transcription realtime (and also efficiently
|
||||
summarize it as well)
|
||||
3) Create a RunPod setup for this feature (mentioned in 1 & 2) and test it end-to-end
|
||||
4) Perform Speaker Diarization using Whisper-JAX
|
||||
5) Based on feasibility of above points, explore suitable visualizations for transcription & summarization.
|
||||
|
||||
7
config.ini
Normal file
7
config.ini
Normal file
@@ -0,0 +1,7 @@
|
||||
[DEFAULT]
|
||||
# Set exception rule for OpenMP error to allow duplicate lib initialization
|
||||
KMP_DUPLICATE_LIB_OK=TRUE
|
||||
# Export OpenAI API Key
|
||||
OPENAI_APIKEY=API_KEY
|
||||
# Export Whisper Model Size
|
||||
WHISPER_MODEL_SIZE=tiny
|
||||
26
setup_dependencies.sh
Executable file
26
setup_dependencies.sh
Executable file
@@ -0,0 +1,26 @@
|
||||
# Upgrade pip
|
||||
pip install --upgrade pip
|
||||
|
||||
# Default to CPU Installation of JAX
|
||||
jax_mode="jax[cpu]"
|
||||
|
||||
# Install JAX
|
||||
if [ "$1" == "cpu" ]
|
||||
then
|
||||
jax_mode="jax[cpu]"
|
||||
elif [ "$1" == "cuda11" ]
|
||||
then
|
||||
jax_mode="jax[cuda11_pip]"
|
||||
elif [ "$1" == "cuda12" ]
|
||||
then
|
||||
jax_mode="jax[cuda12_pip]"
|
||||
fi
|
||||
|
||||
pip install --upgrade "$jax_mode"
|
||||
|
||||
# Install Whisper-JAX base
|
||||
pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
|
||||
|
||||
# Update to latest version
|
||||
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
|
||||
|
||||
169
whisjax.py
Normal file
169
whisjax.py
Normal file
@@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# summarize https://www.youtube.com/watch?v=imzTxoEDH_g --transcript=transcript.txt summary.txt
|
||||
# summarize https://www.sprocket.org/video/cheesemaking.mp4 summary.txt
|
||||
# summarize podcast.mp3 summary.txt
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from pytube import YouTube
|
||||
from loguru import logger
|
||||
from whisper_jax import FlaxWhisperPipline
|
||||
import jax.numpy as jnp
|
||||
import moviepy.editor
|
||||
import argparse
|
||||
import tempfile
|
||||
import whisper
|
||||
import openai
|
||||
import re
|
||||
import configparser
|
||||
import os
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
config.read('config.ini')
|
||||
|
||||
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
|
||||
OPENAI_APIKEY = config['DEFAULT']["OPENAI_APIKEY"]
|
||||
|
||||
MAX_WORDS_IN_CHUNK = 2500
|
||||
MAX_OUTPUT_TOKENS = 1000
|
||||
|
||||
|
||||
def init_argparse() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
|
||||
description="Creates a transcript of a video or audio file, then summarizes it using ChatGPT."
|
||||
)
|
||||
|
||||
parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str,
|
||||
default="english", choices=['english', 'spanish', 'french', 'german', 'romanian'])
|
||||
parser.add_argument("-t", "--transcript", help="Save a copy of the intermediary transcript file", type=str)
|
||||
parser.add_argument("location")
|
||||
parser.add_argument("output")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def chunk_text(txt):
|
||||
sentences = re.split('[.!?]', txt)
|
||||
|
||||
chunks = []
|
||||
chunk = ""
|
||||
size = 0
|
||||
|
||||
for s in sentences:
|
||||
# Get the number of words in this sentence.
|
||||
n = len(re.findall(r'\w+', s))
|
||||
|
||||
# Skip over empty sentences.
|
||||
if n == 0:
|
||||
continue
|
||||
|
||||
# We need to break the text up into chunks so as not to exceed the max
|
||||
# number of tokens accepted by the ChatGPT model.
|
||||
if size + n > MAX_WORDS_IN_CHUNK:
|
||||
chunks.append(chunk)
|
||||
size = n
|
||||
chunk = s
|
||||
else:
|
||||
chunk = chunk + s
|
||||
size = size + n
|
||||
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
|
||||
def main():
|
||||
parser = init_argparse()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse the location string that was given to us, and figure out if it's a
|
||||
# local file (audio or video), a YouTube URL, or a URL referencing an
|
||||
# audio or video file.
|
||||
url = urlparse(args.location)
|
||||
|
||||
media_file = ""
|
||||
if url.scheme == 'http' or url.scheme == 'https':
|
||||
# Check if we're being asked to retreive a YouTube URL, which is handled
|
||||
# diffrently, as we'll use a secondary site to download the video first.
|
||||
if re.search('youtube.com', url.netloc, re.IGNORECASE):
|
||||
# Download the lowest resolution YouTube video (since we're just interested in the audio).
|
||||
# It will be saved to the current directory.
|
||||
logger.info("Downloading YouTube video at url: " + args.location)
|
||||
|
||||
youtube = YouTube(args.location)
|
||||
media_file = youtube.streams.filter(progressive=True, file_extension='mp4').order_by(
|
||||
'resolution').asc().first().download()
|
||||
|
||||
logger.info("Saved downloaded YouTube video to: " + media_file)
|
||||
else:
|
||||
# XXX - Download file using urllib, check if file is audio/video using python-magic
|
||||
logger.info(f"Downloading file at url: {args.location}")
|
||||
logger.info(" XXX - This method hasn't been implemented yet.")
|
||||
elif url.scheme == '':
|
||||
media_file = url.path
|
||||
else:
|
||||
print("Unsupported URL scheme: " + url.scheme)
|
||||
quit()
|
||||
|
||||
# If the media file we just retrieved is a video, extract its audio stream.
|
||||
# XXX - We should be checking if we've downloaded an audio file (eg .mp3),
|
||||
# XXX - in which case we can skip this step. For now we'll assume that
|
||||
# XXX - everything is an mp4 video.
|
||||
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
|
||||
logger.info(f"Extracting audio to: {audio_filename}")
|
||||
|
||||
video = moviepy.editor.VideoFileClip(media_file)
|
||||
video.audio.write_audiofile(audio_filename, logger=None)
|
||||
|
||||
logger.info("Finished extracting audio")
|
||||
|
||||
# Convert the audio to text using the OpenAI Whisper model
|
||||
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, dtype=jnp.float16, batch_size=16)
|
||||
whisper_result = pipeline(audio_filename, return_timestamps=True)
|
||||
logger.info("Finished transcribing file")
|
||||
|
||||
# If we got the transcript parameter on the command line, save the transcript to the specified file.
|
||||
if args.transcript:
|
||||
logger.info(f"Saving transcript to: {args.transcript}")
|
||||
transcript_file = open(args.transcript, "w")
|
||||
transcript_file.write(whisper_result["text"])
|
||||
transcript_file.close()
|
||||
|
||||
# Summarize the generated transcript using OpenAI
|
||||
openai.api_key = OPENAI_APIKEY
|
||||
|
||||
# Break the text up into smaller chunks for ChatGPT to summarize.
|
||||
logger.info(f"Breaking transcript up into smaller chunks with MAX_WORDS_IN_CHUNK = {MAX_WORDS_IN_CHUNK}")
|
||||
chunks = chunk_text(whisper_result['text'])
|
||||
logger.info(f"Transcript broken up into {len(chunks)} chunks")
|
||||
|
||||
language = args.language
|
||||
|
||||
logger.info(f"Writing summary text in {language} to: {args.output}")
|
||||
with open(args.output, 'w') as f:
|
||||
f.write('Summary of: ' + args.location + "\n\n")
|
||||
|
||||
for c in chunks:
|
||||
response = openai.ChatCompletion.create(
|
||||
frequency_penalty=0.0,
|
||||
max_tokens=1000,
|
||||
model="gpt-3.5-turbo",
|
||||
presence_penalty=1.0,
|
||||
temperature=0.2,
|
||||
messages=[
|
||||
{"role": "system",
|
||||
"content": f"You are an assistant helping to summarize transcipts of an audio or video conversation. The summary should be written in the {language} language."},
|
||||
{"role": "user", "content": c}
|
||||
],
|
||||
)
|
||||
f.write(response['choices'][0]['message']['content'] + "\n\n")
|
||||
|
||||
logger.info("Summarization completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# os.environ['KMP_DUPLICATE_LIB_OK'] = "1"
|
||||
print("Gokul", os.environ['KMP_DUPLICATE_LIB_OK'])
|
||||
main()
|
||||
Reference in New Issue
Block a user