diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index 631233cc..55df052b 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -3,11 +3,11 @@ Reflector GPU backend - transcriber =================================== """ -import tempfile import os -from modal import Image, method, Stub, asgi_app, Secret -from pydantic import BaseModel +import tempfile +from modal import Image, Secret, Stub, asgi_app, method +from pydantic import BaseModel # Whisper WHISPER_MODEL: str = "large-v2" @@ -15,6 +15,9 @@ WHISPER_COMPUTE_TYPE: str = "float16" WHISPER_NUM_WORKERS: int = 1 WHISPER_CACHE_DIR: str = "/cache/whisper" +# Translation Model +TRANSLATION_MODEL = "facebook/m2m100_418M" + stub = Stub(name="reflector-transcriber") @@ -31,6 +34,9 @@ whisper_image = ( "faster-whisper", "requests", "torch", + "transformers", + "sentencepiece", + "protobuf", ) .run_function(download_whisper) .env( @@ -51,17 +57,21 @@ whisper_image = ( ) class Whisper: def __enter__(self): - import torch import faster_whisper + import torch + from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer self.use_gpu = torch.cuda.is_available() - device = "cuda" if self.use_gpu else "cpu" + self.device = "cuda" if self.use_gpu else "cpu" self.model = faster_whisper.WhisperModel( WHISPER_MODEL, - device=device, + device=self.device, compute_type=WHISPER_COMPUTE_TYPE, num_workers=WHISPER_NUM_WORKERS, ) + self.translation_model = M2M100ForConditionalGeneration.from_pretrained(TRANSLATION_MODEL).to(self.device) + self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL) + @method() def warmup(self): @@ -72,28 +82,30 @@ class Whisper: self, audio_data: str, audio_suffix: str, - timestamp: float = 0, - language: str = "en", + source_language: str, + target_language: str, + timestamp: float = 0 ): with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp: fp.write(audio_data) segments, _ = self.model.transcribe( fp.name, - language=language, + language=source_language, beam_size=5, word_timestamps=True, vad_filter=True, vad_parameters={"min_silence_duration_ms": 500}, ) - transcript = "" + multilingual_transcript = {} + transcript_source_lang = "" words = [] if segments: segments = list(segments) for segment in segments: - transcript += segment.text + transcript_source_lang += segment.text for word in segment.words: words.append( { @@ -102,9 +114,24 @@ class Whisper: "end": round(timestamp + word.end, 3), } ) + + multilingual_transcript[source_language] = transcript_source_lang + + if target_language != source_language: + self.translation_tokenizer.src_lang = source_language + forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language) + encoded_transcript = self.translation_tokenizer(transcript_source_lang, return_tensors="pt").to(self.device) + generated_tokens = self.translation_model.generate( + **encoded_transcript, + forced_bos_token_id=forced_bos_token_id + ) + result = self.translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + translation = result[0].strip() + multilingual_transcript[target_language] = translation + return { - "text": transcript, - "words": words, + "text": multilingual_transcript, + "words": words } @@ -122,7 +149,7 @@ class Whisper: ) @asgi_app() def web(): - from fastapi import FastAPI, UploadFile, Form, Depends, HTTPException, status + from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status from fastapi.security import OAuth2PasswordBearer from typing_extensions import Annotated @@ -131,6 +158,7 @@ def web(): app = FastAPI() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + supported_audio_file_types = ["wav", "mp3", "ogg", "flac"] def apikey_auth(apikey: str = Depends(oauth2_scheme)): if apikey != os.environ["REFLECTOR_GPU_APIKEY"]: @@ -140,28 +168,26 @@ def web(): headers={"WWW-Authenticate": "Bearer"}, ) - class TranscriptionRequest(BaseModel): - timestamp: float = 0 - language: str = "en" - class TranscriptResponse(BaseModel): - result: str + result: dict @app.post("/transcribe", dependencies=[Depends(apikey_auth)]) async def transcribe( file: UploadFile, timestamp: Annotated[float, Form()] = 0, - language: Annotated[str, Form()] = "en", - ): + source_language: Annotated[str, Form()] = "en", + target_language: Annotated[str, Form()] = "en" + ) -> TranscriptResponse: audio_data = await file.read() audio_suffix = file.filename.split(".")[-1] - assert audio_suffix in ["wav", "mp3", "ogg", "flac"] + assert audio_suffix in supported_audio_file_types func = transcriberstub.transcribe_segment.spawn( audio_data=audio_data, audio_suffix=audio_suffix, - language=language, - timestamp=timestamp, + source_language=source_language, + target_language=target_language, + timestamp=timestamp ) result = func.get() return result diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 1ed727d6..80b6e582 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -5,19 +5,22 @@ API will be a POST request to TRANSCRIPT_URL: ```form "timestamp": 123.456 -"language": "en" +"source_language": "en" +"target_language": "en" "file":