From 39656d680c23e39fac0bdf26a7c2f3c76945b878 Mon Sep 17 00:00:00 2001 From: gokul Date: Fri, 9 Jun 2023 11:17:32 +0530 Subject: [PATCH] Whisper JAX pipeline for demo --- .gitignore | 3 + README.md | 45 ++++++++++- config.ini | 7 ++ setup_dependencies.sh | 26 +++++++ whisjax.py | 169 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 248 insertions(+), 2 deletions(-) create mode 100644 config.ini create mode 100755 setup_dependencies.sh create mode 100644 whisjax.py diff --git a/.gitignore b/.gitignore index 68bc17f9..be0a7417 100644 --- a/.gitignore +++ b/.gitignore @@ -158,3 +158,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +*.mp4 +*.txt \ No newline at end of file diff --git a/README.md b/README.md index 3796cb91..f5d00dfd 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,43 @@ -# transcription -Transcription experiments +# Reflector + +This is the code base for the Reflector demo (formerly called agenda-talk-diff) for the leads : Troy Web Consulting panel (A Chat with AWS about AI: Real AI/ML AWS projects and what you should know) on 6/14 at 430PM. + +The target deliverable is a local-first live transcription and visualization tool to compare a discussion's target agenda/objectives to the actual discussion live. + +To setup, + +1) Check values in config.ini file. Specifically add your OPENAI_APIKEY. +2) Run ``` export KMP_DUPLICATE_LIB_OK=True``` in Terminal. [This is taken care of in code, but not reflecting, Will fix this issue later.] +3) Run the script setup_depedencies.sh. + + ``` chmod +x setup_dependecies.sh ``` + + ``` sh setup_dependencies.sh ``` + + + ENV refers to the intended environment for JAX. JAX is available in several variants, [CPU | GPU | Colab TPU | Google Cloud TPU] + + ```ENV``` is : + + cpu -> JAX CPU installation + + cuda11 -> JAX CUDA 11.x version + + cuda12 -> JAX CUDA 12.x version (Core Weave has CUDA 12 version, can check with ```nvidia-smi```) + + sh setup_dependencies.sh cuda12 + +4) Run the Whisper-JAX pipeline. Currently, the repo takes a Youtube video and transcribes/summarizes it. + +``` python3 whisjax.py "https://www.youtube.com/watch?v=ihf0S97oxuQ" --transcript transcript.txt summary.txt ``` + + + +NEXT STEPS: + +1) Run this demo on a local Mac M1 to test flow and observe the performance +2) Create a pipeline using microphone to listen to audio chunks to perform transcription realtime (and also efficiently + summarize it as well) +3) Create a RunPod setup for this feature (mentioned in 1 & 2) and test it end-to-end +4) Perform Speaker Diarization using Whisper-JAX +5) Based on feasibility of above points, explore suitable visualizations for transcription & summarization. diff --git a/config.ini b/config.ini new file mode 100644 index 00000000..ad40ac0b --- /dev/null +++ b/config.ini @@ -0,0 +1,7 @@ +[DEFAULT] +# Set exception rule for OpenMP error to allow duplicate lib initialization +KMP_DUPLICATE_LIB_OK=TRUE +# Export OpenAI API Key +OPENAI_APIKEY=API_KEY +# Export Whisper Model Size +WHISPER_MODEL_SIZE=tiny \ No newline at end of file diff --git a/setup_dependencies.sh b/setup_dependencies.sh new file mode 100755 index 00000000..fd7d9e32 --- /dev/null +++ b/setup_dependencies.sh @@ -0,0 +1,26 @@ +# Upgrade pip +pip install --upgrade pip + +# Default to CPU Installation of JAX +jax_mode="jax[cpu]" + +# Install JAX +if [ "$1" == "cpu" ] +then + jax_mode="jax[cpu]" +elif [ "$1" == "cuda11" ] +then + jax_mode="jax[cuda11_pip]" +elif [ "$1" == "cuda12" ] +then + jax_mode="jax[cuda12_pip]" +fi + +pip install --upgrade "$jax_mode" + +# Install Whisper-JAX base +pip install git+https://github.com/sanchit-gandhi/whisper-jax.git + +# Update to latest version +pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git + diff --git a/whisjax.py b/whisjax.py new file mode 100644 index 00000000..2022aa22 --- /dev/null +++ b/whisjax.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 + +# summarize https://www.youtube.com/watch?v=imzTxoEDH_g --transcript=transcript.txt summary.txt +# summarize https://www.sprocket.org/video/cheesemaking.mp4 summary.txt +# summarize podcast.mp3 summary.txt + +from urllib.parse import urlparse +from pytube import YouTube +from loguru import logger +from whisper_jax import FlaxWhisperPipline +import jax.numpy as jnp +import moviepy.editor +import argparse +import tempfile +import whisper +import openai +import re +import configparser +import os + +config = configparser.ConfigParser() +config.read('config.ini') + +WHISPER_MODEL_SIZE = config['DEFAULT']["WHISPER_MODEL_SIZE"] +OPENAI_APIKEY = config['DEFAULT']["OPENAI_APIKEY"] + +MAX_WORDS_IN_CHUNK = 2500 +MAX_OUTPUT_TOKENS = 1000 + + +def init_argparse() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + usage="%(prog)s [OPTIONS] ", + description="Creates a transcript of a video or audio file, then summarizes it using ChatGPT." + ) + + parser.add_argument("-l", "--language", help="Language that the summary should be written in", type=str, + default="english", choices=['english', 'spanish', 'french', 'german', 'romanian']) + parser.add_argument("-t", "--transcript", help="Save a copy of the intermediary transcript file", type=str) + parser.add_argument("location") + parser.add_argument("output") + + return parser + + +def chunk_text(txt): + sentences = re.split('[.!?]', txt) + + chunks = [] + chunk = "" + size = 0 + + for s in sentences: + # Get the number of words in this sentence. + n = len(re.findall(r'\w+', s)) + + # Skip over empty sentences. + if n == 0: + continue + + # We need to break the text up into chunks so as not to exceed the max + # number of tokens accepted by the ChatGPT model. + if size + n > MAX_WORDS_IN_CHUNK: + chunks.append(chunk) + size = n + chunk = s + else: + chunk = chunk + s + size = size + n + + if chunk: + chunks.append(chunk) + + return chunks + + +def main(): + parser = init_argparse() + args = parser.parse_args() + + # Parse the location string that was given to us, and figure out if it's a + # local file (audio or video), a YouTube URL, or a URL referencing an + # audio or video file. + url = urlparse(args.location) + + media_file = "" + if url.scheme == 'http' or url.scheme == 'https': + # Check if we're being asked to retreive a YouTube URL, which is handled + # diffrently, as we'll use a secondary site to download the video first. + if re.search('youtube.com', url.netloc, re.IGNORECASE): + # Download the lowest resolution YouTube video (since we're just interested in the audio). + # It will be saved to the current directory. + logger.info("Downloading YouTube video at url: " + args.location) + + youtube = YouTube(args.location) + media_file = youtube.streams.filter(progressive=True, file_extension='mp4').order_by( + 'resolution').asc().first().download() + + logger.info("Saved downloaded YouTube video to: " + media_file) + else: + # XXX - Download file using urllib, check if file is audio/video using python-magic + logger.info(f"Downloading file at url: {args.location}") + logger.info(" XXX - This method hasn't been implemented yet.") + elif url.scheme == '': + media_file = url.path + else: + print("Unsupported URL scheme: " + url.scheme) + quit() + + # If the media file we just retrieved is a video, extract its audio stream. + # XXX - We should be checking if we've downloaded an audio file (eg .mp3), + # XXX - in which case we can skip this step. For now we'll assume that + # XXX - everything is an mp4 video. + audio_filename = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name + logger.info(f"Extracting audio to: {audio_filename}") + + video = moviepy.editor.VideoFileClip(media_file) + video.audio.write_audiofile(audio_filename, logger=None) + + logger.info("Finished extracting audio") + + # Convert the audio to text using the OpenAI Whisper model + pipeline = FlaxWhisperPipline("openai/whisper-" + WHISPER_MODEL_SIZE, dtype=jnp.float16, batch_size=16) + whisper_result = pipeline(audio_filename, return_timestamps=True) + logger.info("Finished transcribing file") + + # If we got the transcript parameter on the command line, save the transcript to the specified file. + if args.transcript: + logger.info(f"Saving transcript to: {args.transcript}") + transcript_file = open(args.transcript, "w") + transcript_file.write(whisper_result["text"]) + transcript_file.close() + + # Summarize the generated transcript using OpenAI + openai.api_key = OPENAI_APIKEY + + # Break the text up into smaller chunks for ChatGPT to summarize. + logger.info(f"Breaking transcript up into smaller chunks with MAX_WORDS_IN_CHUNK = {MAX_WORDS_IN_CHUNK}") + chunks = chunk_text(whisper_result['text']) + logger.info(f"Transcript broken up into {len(chunks)} chunks") + + language = args.language + + logger.info(f"Writing summary text in {language} to: {args.output}") + with open(args.output, 'w') as f: + f.write('Summary of: ' + args.location + "\n\n") + + for c in chunks: + response = openai.ChatCompletion.create( + frequency_penalty=0.0, + max_tokens=1000, + model="gpt-3.5-turbo", + presence_penalty=1.0, + temperature=0.2, + messages=[ + {"role": "system", + "content": f"You are an assistant helping to summarize transcipts of an audio or video conversation. The summary should be written in the {language} language."}, + {"role": "user", "content": c} + ], + ) + f.write(response['choices'][0]['message']['content'] + "\n\n") + + logger.info("Summarization completed") + + +if __name__ == "__main__": + # os.environ['KMP_DUPLICATE_LIB_OK'] = "1" + print("Gokul", os.environ['KMP_DUPLICATE_LIB_OK']) + main()