Merge pull request #177 from Monadical-SAS/french

Live transcription (via translation in 100 languages)
This commit is contained in:
projects-g
2023-08-21 14:17:38 +05:30
committed by GitHub
3 changed files with 200 additions and 35 deletions

View File

@@ -3,11 +3,11 @@ Reflector GPU backend - transcriber
=================================== ===================================
""" """
import tempfile
import os import os
from modal import Image, method, Stub, asgi_app, Secret import tempfile
from pydantic import BaseModel
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
# Whisper # Whisper
WHISPER_MODEL: str = "large-v2" WHISPER_MODEL: str = "large-v2"
@@ -15,6 +15,9 @@ WHISPER_COMPUTE_TYPE: str = "float16"
WHISPER_NUM_WORKERS: int = 1 WHISPER_NUM_WORKERS: int = 1
WHISPER_CACHE_DIR: str = "/cache/whisper" WHISPER_CACHE_DIR: str = "/cache/whisper"
# Translation Model
TRANSLATION_MODEL = "facebook/m2m100_418M"
stub = Stub(name="reflector-transcriber") stub = Stub(name="reflector-transcriber")
@@ -31,6 +34,9 @@ whisper_image = (
"faster-whisper", "faster-whisper",
"requests", "requests",
"torch", "torch",
"transformers",
"sentencepiece",
"protobuf",
) )
.run_function(download_whisper) .run_function(download_whisper)
.env( .env(
@@ -51,17 +57,21 @@ whisper_image = (
) )
class Whisper: class Whisper:
def __enter__(self): def __enter__(self):
import torch
import faster_whisper import faster_whisper
import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
self.use_gpu = torch.cuda.is_available() self.use_gpu = torch.cuda.is_available()
device = "cuda" if self.use_gpu else "cpu" self.device = "cuda" if self.use_gpu else "cpu"
self.model = faster_whisper.WhisperModel( self.model = faster_whisper.WhisperModel(
WHISPER_MODEL, WHISPER_MODEL,
device=device, device=self.device,
compute_type=WHISPER_COMPUTE_TYPE, compute_type=WHISPER_COMPUTE_TYPE,
num_workers=WHISPER_NUM_WORKERS, num_workers=WHISPER_NUM_WORKERS,
) )
self.translation_model = M2M100ForConditionalGeneration.from_pretrained(TRANSLATION_MODEL).to(self.device)
self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL)
@method() @method()
def warmup(self): def warmup(self):
@@ -72,28 +82,30 @@ class Whisper:
self, self,
audio_data: str, audio_data: str,
audio_suffix: str, audio_suffix: str,
timestamp: float = 0, source_language: str,
language: str = "en", target_language: str,
timestamp: float = 0
): ):
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp: with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
fp.write(audio_data) fp.write(audio_data)
segments, _ = self.model.transcribe( segments, _ = self.model.transcribe(
fp.name, fp.name,
language=language, language=source_language,
beam_size=5, beam_size=5,
word_timestamps=True, word_timestamps=True,
vad_filter=True, vad_filter=True,
vad_parameters={"min_silence_duration_ms": 500}, vad_parameters={"min_silence_duration_ms": 500},
) )
transcript = "" multilingual_transcript = {}
transcript_source_lang = ""
words = [] words = []
if segments: if segments:
segments = list(segments) segments = list(segments)
for segment in segments: for segment in segments:
transcript += segment.text transcript_source_lang += segment.text
for word in segment.words: for word in segment.words:
words.append( words.append(
{ {
@@ -102,9 +114,24 @@ class Whisper:
"end": round(timestamp + word.end, 3), "end": round(timestamp + word.end, 3),
} }
) )
multilingual_transcript[source_language] = transcript_source_lang
if target_language != source_language:
self.translation_tokenizer.src_lang = source_language
forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language)
encoded_transcript = self.translation_tokenizer(transcript_source_lang, return_tensors="pt").to(self.device)
generated_tokens = self.translation_model.generate(
**encoded_transcript,
forced_bos_token_id=forced_bos_token_id
)
result = self.translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
translation = result[0].strip()
multilingual_transcript[target_language] = translation
return { return {
"text": transcript, "text": multilingual_transcript,
"words": words, "words": words
} }
@@ -122,7 +149,7 @@ class Whisper:
) )
@asgi_app() @asgi_app()
def web(): def web():
from fastapi import FastAPI, UploadFile, Form, Depends, HTTPException, status from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from typing_extensions import Annotated from typing_extensions import Annotated
@@ -131,6 +158,7 @@ def web():
app = FastAPI() app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
supported_audio_file_types = ["wav", "mp3", "ogg", "flac"]
def apikey_auth(apikey: str = Depends(oauth2_scheme)): def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]: if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
@@ -140,28 +168,26 @@ def web():
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
class TranscriptionRequest(BaseModel):
timestamp: float = 0
language: str = "en"
class TranscriptResponse(BaseModel): class TranscriptResponse(BaseModel):
result: str result: dict
@app.post("/transcribe", dependencies=[Depends(apikey_auth)]) @app.post("/transcribe", dependencies=[Depends(apikey_auth)])
async def transcribe( async def transcribe(
file: UploadFile, file: UploadFile,
timestamp: Annotated[float, Form()] = 0, timestamp: Annotated[float, Form()] = 0,
language: Annotated[str, Form()] = "en", source_language: Annotated[str, Form()] = "en",
): target_language: Annotated[str, Form()] = "en"
) -> TranscriptResponse:
audio_data = await file.read() audio_data = await file.read()
audio_suffix = file.filename.split(".")[-1] audio_suffix = file.filename.split(".")[-1]
assert audio_suffix in ["wav", "mp3", "ogg", "flac"] assert audio_suffix in supported_audio_file_types
func = transcriberstub.transcribe_segment.spawn( func = transcriberstub.transcribe_segment.spawn(
audio_data=audio_data, audio_data=audio_data,
audio_suffix=audio_suffix, audio_suffix=audio_suffix,
language=language, source_language=source_language,
timestamp=timestamp, target_language=target_language,
timestamp=timestamp
) )
result = func.get() result = func.get()
return result return result

View File

@@ -5,19 +5,22 @@ API will be a POST request to TRANSCRIPT_URL:
```form ```form
"timestamp": 123.456 "timestamp": 123.456
"language": "en" "source_language": "en"
"target_language": "en"
"file": <audio file> "file": <audio file>
``` ```
""" """
from time import monotonic
import httpx
from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, Word from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.retry import retry from reflector.utils.retry import retry
from time import monotonic
import httpx
class AudioTranscriptModalProcessor(AudioTranscriptProcessor): class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
@@ -26,9 +29,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
self.transcript_url = settings.TRANSCRIPT_URL + "/transcribe" self.transcript_url = settings.TRANSCRIPT_URL + "/transcribe"
self.warmup_url = settings.TRANSCRIPT_URL + "/warmup" self.warmup_url = settings.TRANSCRIPT_URL + "/warmup"
self.timeout = settings.TRANSCRIPT_TIMEOUT self.timeout = settings.TRANSCRIPT_TIMEOUT
self.headers = { self.headers = {"Authorization": f"Bearer {modal_api_key}"}
"Authorization": f"Bearer {modal_api_key}",
}
async def _warmup(self): async def _warmup(self):
try: try:
@@ -52,11 +53,28 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
files = { files = {
"file": (data.name, data.fd), "file": (data.name, data.fd),
} }
# TODO: Get the source / target language from the UI preferences dynamically
# Update code here once this is possible.
# i.e) extract from context/session objects
source_language = "en"
target_language = "en"
languages = TranslationLanguages()
# Only way to set the target should be the UI element like dropdown.
# Hence, this assert should never fail.
assert languages.is_supported(target_language)
json_payload = {
"source_language": source_language,
"target_language": target_language,
}
response = await retry(client.post)( response = await retry(client.post)(
self.transcript_url, self.transcript_url,
files=files, files=files,
timeout=self.timeout, timeout=self.timeout,
headers=self.headers, headers=self.headers,
json=json_payload,
) )
self.logger.debug( self.logger.debug(
@@ -64,8 +82,14 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
# Sanity check for translation status in the result
if target_language in result["text"]:
text = result["text"][target_language]
else:
text = result["text"][source_language]
transcript = Transcript( transcript = Transcript(
text=result["text"], text=text,
words=[ words=[
Word( Word(
text=word["text"], text=word["text"],

View File

@@ -1,7 +1,8 @@
from pydantic import BaseModel, PrivateAttr
from pathlib import Path
import tempfile
import io import io
import tempfile
from pathlib import Path
from pydantic import BaseModel, PrivateAttr
class AudioFile(BaseModel): class AudioFile(BaseModel):
@@ -104,3 +105,117 @@ class TitleSummary(BaseModel):
class FinalSummary(BaseModel): class FinalSummary(BaseModel):
summary: str summary: str
duration: float duration: float
class TranslationLanguages(BaseModel):
language_to_id_mapping: dict = {
"Afrikaans": "af",
"Albanian": "sq",
"Amharic": "am",
"Arabic": "ar",
"Armenian": "hy",
"Asturian": "ast",
"Azerbaijani": "az",
"Bashkir": "ba",
"Belarusian": "be",
"Bengali": "bn",
"Bosnian": "bs",
"Breton": "br",
"Bulgarian": "bg",
"Burmese": "my",
"Catalan; Valencian": "ca",
"Cebuano": "ceb",
"Central Khmer": "km",
"Chinese": "zh",
"Croatian": "hr",
"Czech": "cs",
"Danish": "da",
"Dutch; Flemish": "nl",
"English": "en",
"Estonian": "et",
"Finnish": "fi",
"French": "fr",
"Fulah": "ff",
"Gaelic; Scottish Gaelic": "gd",
"Galician": "gl",
"Ganda": "lg",
"Georgian": "ka",
"German": "de",
"Greeek": "el",
"Gujarati": "gu",
"Haitian; Haitian Creole": "ht",
"Hausa": "ha",
"Hebrew": "he",
"Hindi": "hi",
"Hungarian": "hu",
"Icelandic": "is",
"Igbo": "ig",
"Iloko": "ilo",
"Indonesian": "id",
"Irish": "ga",
"Italian": "it",
"Japanese": "ja",
"Javanese": "jv",
"Kannada": "kn",
"Kazakh": "kk",
"Korean": "ko",
"Lao": "lo",
"Latvian": "lv",
"Lingala": "ln",
"Lithuanian": "lt",
"Luxembourgish; Letzeburgesch": "lb",
"Macedonian": "mk",
"Malagasy": "mg",
"Malay": "ms",
"Malayalam": "ml",
"Marathi": "mr",
"Mongolian": "mn",
"Nepali": "ne",
"Northern Sotho": "ns",
"Norwegian": "no",
"Occitan": "oc",
"Oriya": "or",
"Panjabi; Punjabi": "pa",
"Persian": "fa",
"Polish": "pl",
"Portuguese": "pt",
"Pushto; Pashto": "ps",
"Romanian; Moldavian; Moldovan": "ro",
"Russian": "ru",
"Serbian": "sr",
"Sindhi": "sd",
"Sinhala; Sinhalese": "si",
"Slovak": "sk",
"Slovenian": "sl",
"Somali": "so",
"Spanish": "es",
"Sundanese": "su",
"Swahili": "sw",
"Swati": "ss",
"Swedish": "sv",
"Tagalog": "tl",
"Tamil": "ta",
"Thai": "th",
"Tswana": "tn",
"Turkish": "tr",
"Ukrainian": "uk",
"Urdu": "ur",
"Uzbek": "uz",
"Vietnamese": "vi",
"Welsh": "cy",
"Western Frisian": "fy",
"Wolof": "wo",
"Xhosa": "xh",
"Yiddish": "yi",
"Yoruba": "yo",
"Zulu": "zu",
}
@property
def supported_languages(self):
return self.language_to_id_mapping.values()
def is_supported(self, lang_id: str) -> bool:
if lang_id in self.supported_languages:
return True
return False