translation update

This commit is contained in:
Gokul Mohanarangan
2023-08-21 11:46:28 +05:30
parent 218bb9c91f
commit 5b0883730f
3 changed files with 202 additions and 33 deletions

View File

@@ -3,19 +3,22 @@ Reflector GPU backend - transcriber
===================================
"""
import tempfile
import os
from modal import Image, method, Stub, asgi_app, Secret
import tempfile
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
# Whisper
WHISPER_MODEL: str = "large-v2"
WHISPER_MODEL: str = "tiny"
WHISPER_COMPUTE_TYPE: str = "float16"
WHISPER_NUM_WORKERS: int = 1
WHISPER_CACHE_DIR: str = "/cache/whisper"
stub = Stub(name="reflector-transcriber")
# Translation Model
TRANSLATION_MODEL = "facebook/m2m100_418M"
stub = Stub(name="reflector-translator")
def download_whisper():
@@ -31,6 +34,9 @@ whisper_image = (
"faster-whisper",
"requests",
"torch",
"transformers",
"sentencepiece",
"protobuf",
)
.run_function(download_whisper)
.env(
@@ -51,17 +57,21 @@ whisper_image = (
)
class Whisper:
def __enter__(self):
import torch
import faster_whisper
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
self.use_gpu = torch.cuda.is_available()
device = "cuda" if self.use_gpu else "cpu"
self.device = "cuda" if self.use_gpu else "cpu"
self.model = faster_whisper.WhisperModel(
WHISPER_MODEL,
device=device,
device=self.device,
compute_type=WHISPER_COMPUTE_TYPE,
num_workers=WHISPER_NUM_WORKERS,
)
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(TRANSLATION_MODEL).to(self.device)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL)
@method()
def warmup(self):
@@ -73,27 +83,29 @@ class Whisper:
audio_data: str,
audio_suffix: str,
timestamp: float = 0,
language: str = "en",
source_language: str = "en",
target_language: str = "fr"
):
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
fp.write(audio_data)
segments, _ = self.model.transcribe(
fp.name,
language=language,
language=source_language,
beam_size=5,
word_timestamps=True,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
transcript = ""
multilingual_transcript = {}
transcript_en = ""
words = []
if segments:
segments = list(segments)
for segment in segments:
transcript += segment.text
transcript_en += segment.text
for word in segment.words:
words.append(
{
@@ -102,9 +114,23 @@ class Whisper:
"end": round(timestamp + word.end, 3),
}
)
multilingual_transcript["en"] = transcript_en
if target_language != "en":
self.translation_tokenizer.src_lang = source_language
forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language)
encoded_transcript = self.translation_tokenizer(transcript_en, return_tensors="pt").to(self.device)
generated_tokens = self.translation_model.generate(
**encoded_transcript,
forced_bos_token_id=forced_bos_token_id
)
result = self.translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
multilingual_transcript[target_language] = result[0].strip()
return {
"text": transcript,
"words": words,
"text": multilingual_transcript,
"words": words
}
@@ -122,7 +148,7 @@ class Whisper:
)
@asgi_app()
def web():
from fastapi import FastAPI, UploadFile, Form, Depends, HTTPException, status
from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status
from fastapi.security import OAuth2PasswordBearer
from typing_extensions import Annotated
@@ -131,6 +157,7 @@ def web():
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
supported_audio_file_types = ["wav", "mp3", "ogg", "flac"]
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
@@ -141,27 +168,28 @@ def web():
)
class TranscriptionRequest(BaseModel):
timestamp: float = 0
language: str = "en"
file: UploadFile
timestamp: Annotated[float, Form()] = 0
source_language: Annotated[str, Form()] = "en"
target_language: Annotated[str, Form()] = "en"
class TranscriptResponse(BaseModel):
result: str
result: dict
@app.post("/transcribe", dependencies=[Depends(apikey_auth)])
async def transcribe(
file: UploadFile,
timestamp: Annotated[float, Form()] = 0,
language: Annotated[str, Form()] = "en",
req
):
audio_data = await file.read()
audio_suffix = file.filename.split(".")[-1]
assert audio_suffix in ["wav", "mp3", "ogg", "flac"]
print(req)
audio_data = await req.file.read()
audio_suffix = req.file.filename.split(".")[-1]
assert audio_suffix in supported_audio_file_types
func = transcriberstub.transcribe_segment.spawn(
audio_data=audio_data,
audio_suffix=audio_suffix,
language=language,
timestamp=timestamp,
source_language="en",
timestamp=req.timestamp
)
result = func.get()
return result