translation update

This commit is contained in:
Gokul Mohanarangan
2023-08-21 11:46:28 +05:30
parent 218bb9c91f
commit 5b0883730f
3 changed files with 202 additions and 33 deletions

View File

@@ -3,19 +3,22 @@ Reflector GPU backend - transcriber
===================================
"""
import tempfile
import os
from modal import Image, method, Stub, asgi_app, Secret
import tempfile
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
# Whisper
WHISPER_MODEL: str = "large-v2"
WHISPER_MODEL: str = "tiny"
WHISPER_COMPUTE_TYPE: str = "float16"
WHISPER_NUM_WORKERS: int = 1
WHISPER_CACHE_DIR: str = "/cache/whisper"
stub = Stub(name="reflector-transcriber")
# Translation Model
TRANSLATION_MODEL = "facebook/m2m100_418M"
stub = Stub(name="reflector-translator")
def download_whisper():
@@ -31,6 +34,9 @@ whisper_image = (
"faster-whisper",
"requests",
"torch",
"transformers",
"sentencepiece",
"protobuf",
)
.run_function(download_whisper)
.env(
@@ -51,17 +57,21 @@ whisper_image = (
)
class Whisper:
def __enter__(self):
import torch
import faster_whisper
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
self.use_gpu = torch.cuda.is_available()
device = "cuda" if self.use_gpu else "cpu"
self.device = "cuda" if self.use_gpu else "cpu"
self.model = faster_whisper.WhisperModel(
WHISPER_MODEL,
device=device,
device=self.device,
compute_type=WHISPER_COMPUTE_TYPE,
num_workers=WHISPER_NUM_WORKERS,
)
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(TRANSLATION_MODEL).to(self.device)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL)
@method()
def warmup(self):
@@ -73,27 +83,29 @@ class Whisper:
audio_data: str,
audio_suffix: str,
timestamp: float = 0,
language: str = "en",
source_language: str = "en",
target_language: str = "fr"
):
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
fp.write(audio_data)
segments, _ = self.model.transcribe(
fp.name,
language=language,
language=source_language,
beam_size=5,
word_timestamps=True,
vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500},
)
transcript = ""
multilingual_transcript = {}
transcript_en = ""
words = []
if segments:
segments = list(segments)
for segment in segments:
transcript += segment.text
transcript_en += segment.text
for word in segment.words:
words.append(
{
@@ -102,9 +114,23 @@ class Whisper:
"end": round(timestamp + word.end, 3),
}
)
multilingual_transcript["en"] = transcript_en
if target_language != "en":
self.translation_tokenizer.src_lang = source_language
forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language)
encoded_transcript = self.translation_tokenizer(transcript_en, return_tensors="pt").to(self.device)
generated_tokens = self.translation_model.generate(
**encoded_transcript,
forced_bos_token_id=forced_bos_token_id
)
result = self.translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
multilingual_transcript[target_language] = result[0].strip()
return {
"text": transcript,
"words": words,
"text": multilingual_transcript,
"words": words
}
@@ -122,7 +148,7 @@ class Whisper:
)
@asgi_app()
def web():
from fastapi import FastAPI, UploadFile, Form, Depends, HTTPException, status
from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status
from fastapi.security import OAuth2PasswordBearer
from typing_extensions import Annotated
@@ -131,6 +157,7 @@ def web():
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
supported_audio_file_types = ["wav", "mp3", "ogg", "flac"]
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
@@ -141,27 +168,28 @@ def web():
)
class TranscriptionRequest(BaseModel):
timestamp: float = 0
language: str = "en"
file: UploadFile
timestamp: Annotated[float, Form()] = 0
source_language: Annotated[str, Form()] = "en"
target_language: Annotated[str, Form()] = "en"
class TranscriptResponse(BaseModel):
result: str
result: dict
@app.post("/transcribe", dependencies=[Depends(apikey_auth)])
async def transcribe(
file: UploadFile,
timestamp: Annotated[float, Form()] = 0,
language: Annotated[str, Form()] = "en",
req
):
audio_data = await file.read()
audio_suffix = file.filename.split(".")[-1]
assert audio_suffix in ["wav", "mp3", "ogg", "flac"]
print(req)
audio_data = await req.file.read()
audio_suffix = req.file.filename.split(".")[-1]
assert audio_suffix in supported_audio_file_types
func = transcriberstub.transcribe_segment.spawn(
audio_data=audio_data,
audio_suffix=audio_suffix,
language=language,
timestamp=timestamp,
source_language="en",
timestamp=req.timestamp
)
result = func.get()
return result

View File

@@ -11,13 +11,15 @@ API will be a POST request to TRANSCRIPT_URL:
"""
from time import monotonic
import httpx
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, Word
from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word
from reflector.settings import settings
from reflector.utils.retry import retry
from time import monotonic
import httpx
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
@@ -28,6 +30,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
self.timeout = settings.TRANSCRIPT_TIMEOUT
self.headers = {
"Authorization": f"Bearer {modal_api_key}",
# "Content-Type": "multipart/form-data"
}
async def _warmup(self):
@@ -52,11 +55,28 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
files = {
"file": (data.name, data.fd),
}
# TODO: Get the source / target language from the UI preferences dynamically
# like context, session objects
source_language = "en"
target_language = "fr"
languages = TranslationLanguages()
# Only way to set the target should be the UI element like dropdown.
# Hence, this assert should never fail.
assert languages.is_supported(target_language)
data = {
"source_language": source_language,
"target_language": target_language,
}
print("TRYING TO TRANSCRIBE")
response = await retry(client.post)(
self.transcript_url,
files=files,
timeout=self.timeout,
headers=self.headers,
# data=data
)
self.logger.debug(
@@ -64,8 +84,14 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
)
response.raise_for_status()
result = response.json()
# Sanity check for translation status in result
if "target_language" in result["text"]:
text = result["text"]["target_language"]
else:
text = result["text"]["en"]
transcript = Transcript(
text=result["text"],
text=text,
words=[
Word(
text=word["text"],

View File

@@ -1,7 +1,8 @@
from pydantic import BaseModel, PrivateAttr
from pathlib import Path
import tempfile
import io
import tempfile
from pathlib import Path
from pydantic import BaseModel, PrivateAttr
class AudioFile(BaseModel):
@@ -104,3 +105,117 @@ class TitleSummary(BaseModel):
class FinalSummary(BaseModel):
summary: str
duration: float
class TranslationLanguages(BaseModel):
language_to_id_mapping: dict = {
"Afrikaans": "af",
"Albanian": "sq",
"Amharic": "am",
"Arabic": "ar",
"Armenian": "hy",
"Asturian": "ast",
"Azerbaijani": "az",
"Bashkir": "ba",
"Belarusian": "be",
"Bengali": "bn",
"Bosnian": "bs",
"Breton": "br",
"Bulgarian": "bg",
"Burmese": "my",
"Catalan; Valencian": "ca",
"Cebuano": "ceb",
"Central Khmer": "km",
"Chinese": "zh",
"Croatian": "hr",
"Czech": "cs",
"Danish": "da",
"Dutch; Flemish": "nl",
"English": "en",
"Estonian": "et",
"Finnish": "fi",
"French": "fr",
"Fulah": "ff",
"Gaelic; Scottish Gaelic": "gd",
"Galician": "gl",
"Ganda": "lg",
"Georgian": "ka",
"German": "de",
"Greeek": "el",
"Gujarati": "gu",
"Haitian; Haitian Creole": "ht",
"Hausa": "ha",
"Hebrew": "he",
"Hindi": "hi",
"Hungarian": "hu",
"Icelandic": "is",
"Igbo": "ig",
"Iloko": "ilo",
"Indonesian": "id",
"Irish": "ga",
"Italian": "it",
"Japanese": "ja",
"Javanese": "jv",
"Kannada": "kn",
"Kazakh": "kk",
"Korean": "ko",
"Lao": "lo",
"Latvian": "lv",
"Lingala": "ln",
"Lithuanian": "lt",
"Luxembourgish; Letzeburgesch": "lb",
"Macedonian": "mk",
"Malagasy": "mg",
"Malay": "ms",
"Malayalam": "ml",
"Marathi": "mr",
"Mongolian": "mn",
"Nepali": "ne",
"Northern Sotho": "ns",
"Norwegian": "no",
"Occitan": "oc",
"Oriya": "or",
"Panjabi; Punjabi": "pa",
"Persian": "fa",
"Polish": "pl",
"Portuguese": "pt",
"Pushto; Pashto": "ps",
"Romanian; Moldavian; Moldovan": "ro",
"Russian": "ru",
"Serbian": "sr",
"Sindhi": "sd",
"Sinhala; Sinhalese": "si",
"Slovak": "sk",
"Slovenian": "sl",
"Somali": "so",
"Spanish": "es",
"Sundanese": "su",
"Swahili": "sw",
"Swati": "ss",
"Swedish": "sv",
"Tagalog": "tl",
"Tamil": "ta",
"Thai": "th",
"Tswana": "tn",
"Turkish": "tr",
"Ukrainian": "uk",
"Urdu": "ur",
"Uzbek": "uz",
"Vietnamese": "vi",
"Welsh": "cy",
"Western Frisian": "fy",
"Wolof": "wo",
"Xhosa": "xh",
"Yiddish": "yi",
"Yoruba": "yo",
"Zulu": "zu",
}
@property
def supported_languages(self):
return self.language_to_id_mapping.values()
def is_supported(self, lang_id: str) -> bool:
if lang_id in self.supported_languages:
return True
return False