Whisper JAX pipeline for demo

This commit is contained in:
gokul
2023-06-09 11:17:32 +05:30
parent 8336296ab4
commit 39656d680c
5 changed files with 248 additions and 2 deletions

3
.gitignore vendored
View File

@@ -158,3 +158,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
*.mp4
*.txt

View File

@@ -1,2 +1,43 @@
# transcription
Transcription experiments
# Reflector
This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM.
The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target agenda/objectives to the actual discussion live.
To setup,
1) Check values in config.ini file. Specifically add your OPENAI_APIKEY.
2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.]
3) Run the script setup_depedencies.sh.
``` chmod +x setup_dependecies.sh ```
``` sh setup_dependencies.sh <ENV>```
ENV refers to the intended environment for JAX. JAX is available in several variants, [CPU | GPU | Colab TPU | Google Cloud TPU]
```ENV``` is :
cpu -> JAX CPU installation
cuda11 -> JAX CUDA 11.x version
cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```)
sh setup_dependencies.sh cuda12
4) Run the Whisper-JAX pipeline. Currently, the repo takes a Youtube video and transcribes/summarizes it.
``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ" --transcript transcript.txt summary.txt ```
NEXT STEPS:
1) Run this demo on a local Mac M1 to test flow and observe the performance
2) Create a pipeline using microphone to listen to audio chunks to perform transcription realtime (and also efficiently
summarize it as well)
3) Create a RunPod setup for this feature (mentioned in 1 & 2) and test it end-to-end
4) Perform Speaker Diarization using Whisper-JAX
5) Based on feasibility of above points, explore suitable visualizations for transcription & summarization.

7
config.ini Normal file
View File

@@ -0,0 +1,7 @@
[DEFAULT]
# Set exception rule for OpenMP error to allow duplicate lib initialization
KMP_DUPLICATE_LIB_OK=TRUE
# Export OpenAI API Key
OPENAI_APIKEY=API_KEY
# Export Whisper Model Size
WHISPER_MODEL_SIZE=tiny

26
setup_dependencies.sh Executable file
View File

@@ -0,0 +1,26 @@
# Upgrade pip
pip install --upgrade pip
# Default to CPU Installation of JAX
jax_mode="jax[cpu]"
# Install JAX
if [ "$1" == "cpu" ]
then
jax_mode="jax[cpu]"
elif [ "$1" == "cuda11" ]
then
jax_mode="jax[cuda11_pip]"
elif [ "$1" == "cuda12" ]
then
jax_mode="jax[cuda12_pip]"
fi
pip install --upgrade "$jax_mode"
# Install Whisper-JAX base
pip install git+https://github.com/sanchit-gandhi/whisper-jax.git
# Update to latest version
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git

169
whisjax.py Normal file
View File

@@ -0,0 +1,169 @@
#!/usr/bin/env python3
# summarize https://www.youtube.com/watch?v=imzTxoEDH_g --transcript=transcript.txt summary.txt
# summarize https://www.sprocket.org/video/cheesemaking.mp4 summary.txt
# summarize podcast.mp3 summary.txt
from urllib.parse import urlparse
from pytube import YouTube
from loguru import logger
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
import moviepy.editor
import argparse
import tempfile
import whisper
import openai
import re
import configparser
import os
config = configparser.ConfigParser()
config.read('config.ini')
WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"]
OPENAI_APIKEY = config['DEFAULT']["OPENAI_APIKEY"]
MAX_WORDS_IN_CHUNK = 2500
MAX_OUTPUT_TOKENS = 1000
def init_argparse() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
usage="%(prog)s [OPTIONS] <LOCATION> <OUTPUT>",
description="Creates a transcript of a video or audio file, then summarizes it using ChatGPT."
)
parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str,
default="english", choices=['english', 'spanish', 'french', 'german', 'romanian'])
parser.add_argument("-t", "--transcript", help="Save a copy of the intermediary transcript file", type=str)
parser.add_argument("location")
parser.add_argument("output")
return parser
def chunk_text(txt):
sentences = re.split('[.!?]', txt)
chunks = []
chunk = ""
size = 0
for s in sentences:
# Get the number of words in this sentence.
n = len(re.findall(r'\w+', s))
# Skip over empty sentences.
if n == 0:
continue
# We need to break the text up into chunks so as not to exceed the max
# number of tokens accepted by the ChatGPT model.
if size + n > MAX_WORDS_IN_CHUNK:
chunks.append(chunk)
size = n
chunk = s
else:
chunk = chunk + s
size = size + n
if chunk:
chunks.append(chunk)
return chunks
def main():
parser = init_argparse()
args = parser.parse_args()
# Parse the location string that was given to us, and figure out if it's a
# local file (audio or video), a YouTube URL, or a URL referencing an
# audio or video file.
url = urlparse(args.location)
media_file = ""
if url.scheme == 'http' or url.scheme == 'https':
# Check if we're being asked to retreive a YouTube URL, which is handled
# diffrently, as we'll use a secondary site to download the video first.
if re.search('youtube.com', url.netloc, re.IGNORECASE):
# Download the lowest resolution YouTube video (since we're just interested in the audio).
# It will be saved to the current directory.
logger.info("Downloading YouTube video at url: " + args.location)
youtube = YouTube(args.location)
media_file = youtube.streams.filter(progressive=True, file_extension='mp4').order_by(
'resolution').asc().first().download()
logger.info("Saved downloaded YouTube video to: " + media_file)
else:
# XXX - Download file using urllib, check if file is audio/video using python-magic
logger.info(f"Downloading file at url: {args.location}")
logger.info(" XXX - This method hasn't been implemented yet.")
elif url.scheme == '':
media_file = url.path
else:
print("Unsupported URL scheme: " + url.scheme)
quit()
# If the media file we just retrieved is a video, extract its audio stream.
# XXX - We should be checking if we've downloaded an audio file (eg .mp3),
# XXX - in which case we can skip this step. For now we'll assume that
# XXX - everything is an mp4 video.
audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name
logger.info(f"Extracting audio to: {audio_filename}")
video = moviepy.editor.VideoFileClip(media_file)
video.audio.write_audiofile(audio_filename, logger=None)
logger.info("Finished extracting audio")
# Convert the audio to text using the OpenAI Whisper model
pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, dtype=jnp.float16, batch_size=16)
whisper_result = pipeline(audio_filename, return_timestamps=True)
logger.info("Finished transcribing file")
# If we got the transcript parameter on the command line, save the transcript to the specified file.
if args.transcript:
logger.info(f"Saving transcript to: {args.transcript}")
transcript_file = open(args.transcript, "w")
transcript_file.write(whisper_result["text"])
transcript_file.close()
# Summarize the generated transcript using OpenAI
openai.api_key = OPENAI_APIKEY
# Break the text up into smaller chunks for ChatGPT to summarize.
logger.info(f"Breaking transcript up into smaller chunks with MAX_WORDS_IN_CHUNK = {MAX_WORDS_IN_CHUNK}")
chunks = chunk_text(whisper_result['text'])
logger.info(f"Transcript broken up into {len(chunks)} chunks")
language = args.language
logger.info(f"Writing summary text in {language} to: {args.output}")
with open(args.output, 'w') as f:
f.write('Summary of: ' + args.location + "\n\n")
for c in chunks:
response = openai.ChatCompletion.create(
frequency_penalty=0.0,
max_tokens=1000,
model="gpt-3.5-turbo",
presence_penalty=1.0,
temperature=0.2,
messages=[
{"role": "system",
"content": f"You are an assistant helping to summarize transcipts of an audio or video conversation. The summary should be written in the {language} language."},
{"role": "user", "content": c}
],
)
f.write(response['choices'][0]['message']['content'] + "\n\n")
logger.info("Summarization completed")
if __name__ == "__main__":
# os.environ['KMP_DUPLICATE_LIB_OK'] = "1"
print("Gokul", os.environ['KMP_DUPLICATE_LIB_OK'])
main()