From 5b0883730f05382bb561b18e7d9bee4deaba19b4 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Mon, 21 Aug 2023 11:46:28 +0530 Subject: [PATCH] translation update --- server/gpu/modal/reflector_transcriber.py | 80 ++++++++---- .../processors/audio_transcript_modal.py | 34 ++++- server/reflector/processors/types.py | 121 +++++++++++++++++- 3 files changed, 202 insertions(+), 33 deletions(-) diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index 631233cc..f2db9225 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -3,19 +3,22 @@ Reflector GPU backend - transcriber =================================== """ -import tempfile import os -from modal import Image, method, Stub, asgi_app, Secret +import tempfile + +from modal import Image, Secret, Stub, asgi_app, method from pydantic import BaseModel - # Whisper -WHISPER_MODEL: str = "large-v2" +WHISPER_MODEL: str = "tiny" WHISPER_COMPUTE_TYPE: str = "float16" WHISPER_NUM_WORKERS: int = 1 WHISPER_CACHE_DIR: str = "/cache/whisper" -stub = Stub(name="reflector-transcriber") +# Translation Model +TRANSLATION_MODEL = "facebook/m2m100_418M" + +stub = Stub(name="reflector-translator") def download_whisper(): @@ -31,6 +34,9 @@ whisper_image = ( "faster-whisper", "requests", "torch", + "transformers", + "sentencepiece", + "protobuf", ) .run_function(download_whisper) .env( @@ -51,17 +57,21 @@ whisper_image = ( ) class Whisper: def __enter__(self): - import torch import faster_whisper + import torch + from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer self.use_gpu = torch.cuda.is_available() - device = "cuda" if self.use_gpu else "cpu" + self.device = "cuda" if self.use_gpu else "cpu" self.model = faster_whisper.WhisperModel( WHISPER_MODEL, - device=device, + device=self.device, compute_type=WHISPER_COMPUTE_TYPE, num_workers=WHISPER_NUM_WORKERS, ) + self.translation_model = M2M100ForConditionalGeneration.from_pretrained(TRANSLATION_MODEL).to(self.device) + self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL) + @method() def warmup(self): @@ -73,27 +83,29 @@ class Whisper: audio_data: str, audio_suffix: str, timestamp: float = 0, - language: str = "en", + source_language: str = "en", + target_language: str = "fr" ): with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp: fp.write(audio_data) segments, _ = self.model.transcribe( fp.name, - language=language, + language=source_language, beam_size=5, word_timestamps=True, vad_filter=True, vad_parameters={"min_silence_duration_ms": 500}, ) - transcript = "" + multilingual_transcript = {} + transcript_en = "" words = [] if segments: segments = list(segments) for segment in segments: - transcript += segment.text + transcript_en += segment.text for word in segment.words: words.append( { @@ -102,9 +114,23 @@ class Whisper: "end": round(timestamp + word.end, 3), } ) + + multilingual_transcript["en"] = transcript_en + + if target_language != "en": + self.translation_tokenizer.src_lang = source_language + forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language) + encoded_transcript = self.translation_tokenizer(transcript_en, return_tensors="pt").to(self.device) + generated_tokens = self.translation_model.generate( + **encoded_transcript, + forced_bos_token_id=forced_bos_token_id + ) + result = self.translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + multilingual_transcript[target_language] = result[0].strip() + return { - "text": transcript, - "words": words, + "text": multilingual_transcript, + "words": words } @@ -122,7 +148,7 @@ class Whisper: ) @asgi_app() def web(): - from fastapi import FastAPI, UploadFile, Form, Depends, HTTPException, status + from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status from fastapi.security import OAuth2PasswordBearer from typing_extensions import Annotated @@ -131,6 +157,7 @@ def web(): app = FastAPI() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + supported_audio_file_types = ["wav", "mp3", "ogg", "flac"] def apikey_auth(apikey: str = Depends(oauth2_scheme)): if apikey != os.environ["REFLECTOR_GPU_APIKEY"]: @@ -141,27 +168,28 @@ def web(): ) class TranscriptionRequest(BaseModel): - timestamp: float = 0 - language: str = "en" + file: UploadFile + timestamp: Annotated[float, Form()] = 0 + source_language: Annotated[str, Form()] = "en" + target_language: Annotated[str, Form()] = "en" class TranscriptResponse(BaseModel): - result: str + result: dict @app.post("/transcribe", dependencies=[Depends(apikey_auth)]) async def transcribe( - file: UploadFile, - timestamp: Annotated[float, Form()] = 0, - language: Annotated[str, Form()] = "en", + req ): - audio_data = await file.read() - audio_suffix = file.filename.split(".")[-1] - assert audio_suffix in ["wav", "mp3", "ogg", "flac"] + print(req) + audio_data = await req.file.read() + audio_suffix = req.file.filename.split(".")[-1] + assert audio_suffix in supported_audio_file_types func = transcriberstub.transcribe_segment.spawn( audio_data=audio_data, audio_suffix=audio_suffix, - language=language, - timestamp=timestamp, + source_language="en", + timestamp=req.timestamp ) result = func.get() return result diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 1ed727d6..5d7a6b85 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -11,13 +11,15 @@ API will be a POST request to TRANSCRIPT_URL: """ +from time import monotonic + +import httpx + from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor -from reflector.processors.types import AudioFile, Transcript, Word +from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word from reflector.settings import settings from reflector.utils.retry import retry -from time import monotonic -import httpx class AudioTranscriptModalProcessor(AudioTranscriptProcessor): @@ -28,6 +30,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): self.timeout = settings.TRANSCRIPT_TIMEOUT self.headers = { "Authorization": f"Bearer {modal_api_key}", + # "Content-Type": "multipart/form-data" } async def _warmup(self): @@ -52,11 +55,28 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): files = { "file": (data.name, data.fd), } + # TODO: Get the source / target language from the UI preferences dynamically + # like context, session objects + source_language = "en" + target_language = "fr" + languages = TranslationLanguages() + + # Only way to set the target should be the UI element like dropdown. + # Hence, this assert should never fail. + assert languages.is_supported(target_language) + data = { + "source_language": source_language, + "target_language": target_language, + } + + print("TRYING TO TRANSCRIBE") + response = await retry(client.post)( self.transcript_url, files=files, timeout=self.timeout, headers=self.headers, + # data=data ) self.logger.debug( @@ -64,8 +84,14 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): ) response.raise_for_status() result = response.json() + + # Sanity check for translation status in result + if "target_language" in result["text"]: + text = result["text"]["target_language"] + else: + text = result["text"]["en"] transcript = Transcript( - text=result["text"], + text=text, words=[ Word( text=word["text"], diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 0c7c48d4..1e5c84f2 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -1,7 +1,8 @@ -from pydantic import BaseModel, PrivateAttr -from pathlib import Path -import tempfile import io +import tempfile +from pathlib import Path + +from pydantic import BaseModel, PrivateAttr class AudioFile(BaseModel): @@ -104,3 +105,117 @@ class TitleSummary(BaseModel): class FinalSummary(BaseModel): summary: str duration: float + + +class TranslationLanguages(BaseModel): + language_to_id_mapping: dict = { + "Afrikaans": "af", + "Albanian": "sq", + "Amharic": "am", + "Arabic": "ar", + "Armenian": "hy", + "Asturian": "ast", + "Azerbaijani": "az", + "Bashkir": "ba", + "Belarusian": "be", + "Bengali": "bn", + "Bosnian": "bs", + "Breton": "br", + "Bulgarian": "bg", + "Burmese": "my", + "Catalan; Valencian": "ca", + "Cebuano": "ceb", + "Central Khmer": "km", + "Chinese": "zh", + "Croatian": "hr", + "Czech": "cs", + "Danish": "da", + "Dutch; Flemish": "nl", + "English": "en", + "Estonian": "et", + "Finnish": "fi", + "French": "fr", + "Fulah": "ff", + "Gaelic; Scottish Gaelic": "gd", + "Galician": "gl", + "Ganda": "lg", + "Georgian": "ka", + "German": "de", + "Greeek": "el", + "Gujarati": "gu", + "Haitian; Haitian Creole": "ht", + "Hausa": "ha", + "Hebrew": "he", + "Hindi": "hi", + "Hungarian": "hu", + "Icelandic": "is", + "Igbo": "ig", + "Iloko": "ilo", + "Indonesian": "id", + "Irish": "ga", + "Italian": "it", + "Japanese": "ja", + "Javanese": "jv", + "Kannada": "kn", + "Kazakh": "kk", + "Korean": "ko", + "Lao": "lo", + "Latvian": "lv", + "Lingala": "ln", + "Lithuanian": "lt", + "Luxembourgish; Letzeburgesch": "lb", + "Macedonian": "mk", + "Malagasy": "mg", + "Malay": "ms", + "Malayalam": "ml", + "Marathi": "mr", + "Mongolian": "mn", + "Nepali": "ne", + "Northern Sotho": "ns", + "Norwegian": "no", + "Occitan": "oc", + "Oriya": "or", + "Panjabi; Punjabi": "pa", + "Persian": "fa", + "Polish": "pl", + "Portuguese": "pt", + "Pushto; Pashto": "ps", + "Romanian; Moldavian; Moldovan": "ro", + "Russian": "ru", + "Serbian": "sr", + "Sindhi": "sd", + "Sinhala; Sinhalese": "si", + "Slovak": "sk", + "Slovenian": "sl", + "Somali": "so", + "Spanish": "es", + "Sundanese": "su", + "Swahili": "sw", + "Swati": "ss", + "Swedish": "sv", + "Tagalog": "tl", + "Tamil": "ta", + "Thai": "th", + "Tswana": "tn", + "Turkish": "tr", + "Ukrainian": "uk", + "Urdu": "ur", + "Uzbek": "uz", + "Vietnamese": "vi", + "Welsh": "cy", + "Western Frisian": "fy", + "Wolof": "wo", + "Xhosa": "xh", + "Yiddish": "yi", + "Yoruba": "yo", + "Zulu": "zu", + } + + @property + def supported_languages(self): + return self.language_to_id_mapping.values() + + def is_supported(self, lang_id: str) -> bool: + if lang_id in self.supported_languages: + return True + return False