From 6a43297309b58b01c082c2da31bf4ffb5db675cc Mon Sep 17 00:00:00 2001 From: projects-g <63178974+projects-g@users.noreply.github.com> Date: Tue, 26 Sep 2023 19:49:54 +0530 Subject: [PATCH] Translation enhancements (#247) --- server/gpu/modal/reflector_transcriber.py | 203 +++++++++++++----- server/reflector/llm/llm_params.py | 2 +- server/reflector/processors/__init__.py | 1 + .../processors/audio_transcript_modal.py | 23 +- .../reflector/processors/transcript_liner.py | 16 +- .../processors/transcript_translator.py | 88 ++++++++ server/reflector/tools/process.py | 2 + server/reflector/views/rtc_offer.py | 4 +- server/tests/conftest.py | 34 ++- server/tests/test_processors_pipeline.py | 10 +- server/tests/test_transcripts_rtc_ws.py | 46 +--- 11 files changed, 303 insertions(+), 126 deletions(-) create mode 100644 server/reflector/processors/transcript_translator.py diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index b662e05a..2059b10e 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -14,40 +14,52 @@ WHISPER_MODEL: str = "large-v2" WHISPER_COMPUTE_TYPE: str = "float16" WHISPER_NUM_WORKERS: int = 1 -# Translation Model -TRANSLATION_MODEL = "facebook/m2m100_1.2B" +# Seamless M4T +SEAMLESSM4T_MODEL_SIZE: str = "medium" +SEAMLESSM4T_MODEL_CARD_NAME: str = f"seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}" +SEAMLESSM4T_VOCODER_CARD_NAME: str = "vocoder_36langs" -IMAGE_MODEL_DIR = f"/root/transcription_models/{TRANSLATION_MODEL}" +HF_SEAMLESS_M4TEPO: str = f"facebook/seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}" +HF_SEAMLESS_M4T_VOCODEREPO: str = "facebook/seamless-m4t-vocoder" + +SEAMLESS_GITEPO: str = "https://github.com/facebookresearch/seamless_communication.git" +SEAMLESS_MODEL_DIR: str = "m4t" + +WHISPER_MODEL_DIR = "/root/transcription_models" stub = Stub(name="reflector-transcriber") -def download_whisper(cache_dir: str | None = None): +def install_seamless_communication(): + import os + import subprocess + initial_dir = os.getcwd() + subprocess.run(["ssh-keyscan", "-t", "rsa", "github.com", ">>", "~/.ssh/known_hosts"]) + subprocess.run(["rm", "-rf", "seamless_communication"]) + subprocess.run(["git", "clone", SEAMLESS_GITEPO, "." + "/seamless_communication"]) + os.chdir("seamless_communication") + subprocess.run(["pip", "install", "-e", "."]) + os.chdir(initial_dir) + + +def download_whisper(): from faster_whisper.utils import download_model print("Downloading Whisper model") - download_model(WHISPER_MODEL, cache_dir=cache_dir) + download_model(WHISPER_MODEL, cache_dir=WHISPER_MODEL_DIR) print("Whisper model downloaded") -def download_translation_model(cache_dir: str | None = None): +def download_seamlessm4t_model(): from huggingface_hub import snapshot_download - print("Downloading Translation model") - ignore_patterns = ["*.ot"] - snapshot_download( - TRANSLATION_MODEL, - cache_dir=cache_dir, - ignore_patterns=ignore_patterns - ) - print("Translation model downloaded") + print("Downloading Transcriber model & tokenizer") + snapshot_download(HF_SEAMLESS_M4TEPO, cache_dir=SEAMLESS_MODEL_DIR) + print("Transcriber model & tokenizer downloaded") - -def download_models(): - print(f"Downloading models to {IMAGE_MODEL_DIR=}") - download_whisper(cache_dir=IMAGE_MODEL_DIR) - download_translation_model(cache_dir=IMAGE_MODEL_DIR) - print(f"Model downloads complete.") + print("Downloading vocoder weights") + snapshot_download(HF_SEAMLESS_M4T_VOCODEREPO, cache_dir=SEAMLESS_MODEL_DIR) + print("Vocoder weights downloaded") def migrate_cache_llm(): @@ -60,13 +72,61 @@ def migrate_cache_llm(): from transformers.utils.hub import move_cache print("Moving LLM cache") - move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR) + move_cache(cache_dir=WHISPER_MODEL_DIR, new_cache_dir=WHISPER_MODEL_DIR) print("LLM cache moved") -whisper_image = ( +def configure_seamless_m4t(): + import os + + import yaml + + ASSETS_DIR: str = "./seamless_communication/src/seamless_communication/assets/cards" + + with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'r') as file: + model_yaml_data = yaml.load(file, Loader=yaml.FullLoader) + with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'r') as file: + vocoder_yaml_data = yaml.load(file, Loader=yaml.FullLoader) + with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'r') as file: + unity_100_yaml_data = yaml.load(file, Loader=yaml.FullLoader) + with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'r') as file: + unity_200_yaml_data = yaml.load(file, Loader=yaml.FullLoader) + + model_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}/snapshots" + available_model_versions = os.listdir(model_dir) + latest_model_version = sorted(available_model_versions)[-1] + model_name = f"multitask_unity_{SEAMLESSM4T_MODEL_SIZE}.pt" + model_path = os.path.join(os.getcwd(), model_dir, latest_model_version, model_name) + + vocoder_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-vocoder/snapshots" + available_vocoder_versions = os.listdir(vocoder_dir) + latest_vocoder_version = sorted(available_vocoder_versions)[-1] + vocoder_name = "vocoder_36langs.pt" + vocoder_path = os.path.join(os.getcwd(), vocoder_dir, latest_vocoder_version, vocoder_name) + + tokenizer_name = "tokenizer.model" + tokenizer_path = os.path.join(os.getcwd(), model_dir, latest_model_version, tokenizer_name) + + model_yaml_data['checkpoint'] = f"file:/{model_path}" + vocoder_yaml_data['checkpoint'] = f"file:/{vocoder_path}" + unity_100_yaml_data['tokenizer'] = f"file:/{tokenizer_path}" + unity_200_yaml_data['tokenizer'] = f"file:/{tokenizer_path}" + + with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'w') as file: + yaml.dump(model_yaml_data, file) + with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'w') as file: + yaml.dump(vocoder_yaml_data, file) + with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'w') as file: + yaml.dump(unity_100_yaml_data, file) + with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'w') as file: + yaml.dump(unity_200_yaml_data, file) + + +transcriber_image = ( Image.debian_slim(python_version="3.10.8") .apt_install("git") + .apt_install("wget") + .apt_install("libsndfile-dev") .pip_install( "faster-whisper", "requests", @@ -75,8 +135,16 @@ whisper_image = ( "sentencepiece", "protobuf", "huggingface_hub==0.16.4", + "gitpython", + "torchaudio", + "fairseq2", + "pyyaml", + "hf-transfer~=0.1" ) - .run_function(download_models) + .run_function(install_seamless_communication) + .run_function(download_seamlessm4t_model) + .run_function(configure_seamless_m4t) + .run_function(download_whisper) .run_function(migrate_cache_llm) .env( { @@ -90,15 +158,17 @@ whisper_image = ( @stub.cls( - gpu="A10G", - container_idle_timeout=60, - image=whisper_image, + gpu="A100", + timeout=60 * 5, + container_idle_timeout=60 * 5, + concurrency_limit=3, + image=transcriber_image, ) -class Whisper: +class Transcriber: def __enter__(self): import faster_whisper import torch - from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer + from seamless_communication.models.inference.translator import Translator self.use_gpu = torch.cuda.is_available() self.device = "cuda" if self.use_gpu else "cpu" @@ -107,15 +177,13 @@ class Whisper: device=self.device, compute_type=WHISPER_COMPUTE_TYPE, num_workers=WHISPER_NUM_WORKERS, - download_root=IMAGE_MODEL_DIR + download_root=WHISPER_MODEL_DIR ) - self.translation_model = M2M100ForConditionalGeneration.from_pretrained( - TRANSLATION_MODEL, - cache_dir=IMAGE_MODEL_DIR - ).to(self.device) - self.translation_tokenizer = M2M100Tokenizer.from_pretrained( - TRANSLATION_MODEL, - cache_dir=IMAGE_MODEL_DIR + self.translator = Translator( + SEAMLESSM4T_MODEL_CARD_NAME, + SEAMLESSM4T_VOCODER_CARD_NAME, + torch.device(self.device), + dtype=torch.float32 ) @method() @@ -128,7 +196,6 @@ class Whisper: audio_data: str, audio_suffix: str, source_language: str, - target_language: str, timestamp: float = 0 ): with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp: @@ -162,25 +229,43 @@ class Whisper: multilingual_transcript[source_language] = transcript_source_lang - if target_language != source_language: - self.translation_tokenizer.src_lang = source_language - forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language) - encoded_transcript = self.translation_tokenizer(transcript_source_lang, return_tensors="pt").to(self.device) - generated_tokens = self.translation_model.generate( - **encoded_transcript, - forced_bos_token_id=forced_bos_token_id - ) - result = self.translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) - translation = result[0].strip() - multilingual_transcript[target_language] = translation - - return { "text": multilingual_transcript, "words": words } + def get_seamless_lang_code(self, lang_code: str): + """ + The codes for SeamlessM4T is different from regular standards. + For ex, French is "fra" and not "fr". + """ + # TODO: Enhance with complete list of lang codes + seamless_lang_code = { + "en": "eng", + "fr": "fra" + } + return seamless_lang_code.get(lang_code, "eng") + @method() + def translate_text( + self, + text: str, + source_language: str, + target_language: str + ): + translated_text, _, _ = self.translator.predict( + text, + "t2tt", + src_lang=self.get_seamless_lang_code(source_language), + tgt_lang=self.get_seamless_lang_code(target_language), + ngram_filtering=True + ) + return { + "text": { + source_language: text, + target_language: str(translated_text) + } + } # ------------------------------------------------------------------- # Web API # ------------------------------------------------------------------- @@ -199,7 +284,7 @@ def web(): from fastapi.security import OAuth2PasswordBearer from typing_extensions import Annotated - transcriberstub = Whisper() + transcriberstub = Transcriber() app = FastAPI() @@ -221,7 +306,6 @@ def web(): async def transcribe( file: UploadFile, source_language: Annotated[str, Body(...)] = "en", - target_language: Annotated[str, Body(...)] = "en", timestamp: Annotated[float, Body()] = 0.0 ) -> TranscriptResponse: audio_data = await file.read() @@ -232,12 +316,25 @@ def web(): audio_data=audio_data, audio_suffix=audio_suffix, source_language=source_language, - target_language=target_language, timestamp=timestamp ) result = func.get() return result + @app.post("/translate", dependencies=[Depends(apikey_auth)]) + async def translate( + text: str, + source_language: Annotated[str, Body(...)] = "en", + target_language: Annotated[str, Body(...)] = "fr", + ) -> TranscriptResponse: + func = transcriberstub.translate_text.spawn( + text=text, + source_language=source_language, + target_language=target_language, + ) + result = func.get() + return result + @app.post("/warmup", dependencies=[Depends(apikey_auth)]) async def warmup(): return transcriberstub.warmup.spawn().get() diff --git a/server/reflector/llm/llm_params.py b/server/reflector/llm/llm_params.py index e43956cc..3d960a7c 100644 --- a/server/reflector/llm/llm_params.py +++ b/server/reflector/llm/llm_params.py @@ -124,7 +124,7 @@ class TopicParams(LLMTaskParams): For the title field, generate a very detailed and self-explanatory title for the given text. Let the title be as descriptive as possible. For the summary field, summarize the given text in a maximum of - three sentences. + two sentences. """ self._schema = { "type": "object", diff --git a/server/reflector/processors/__init__.py b/server/reflector/processors/__init__.py index 349a41a9..96a3941d 100644 --- a/server/reflector/processors/__init__.py +++ b/server/reflector/processors/__init__.py @@ -13,6 +13,7 @@ from .transcript_final_short_summary import ( # noqa: F401 from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401 from .transcript_liner import TranscriptLinerProcessor # noqa: F401 from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401 +from .transcript_translator import TranscriptTranslatorProcessor # noqa: F401 from .types import ( # noqa: F401 AudioFile, FinalLongSummary, diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 13e0bd44..f3f36e61 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -18,7 +18,7 @@ import httpx from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor -from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word +from reflector.processors.types import AudioFile, Transcript, Word from reflector.settings import settings from reflector.utils.retry import retry @@ -53,21 +53,8 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): files = { "file": (data.name, data.fd), } - - # FIXME this should be a processor after, as each user may want - # different languages source_language = self.get_pref("audio:source_language", "en") - target_language = self.get_pref("audio:target_language", "en") - languages = TranslationLanguages() - - # Only way to set the target should be the UI element like dropdown. - # Hence, this assert should never fail. - assert languages.is_supported(target_language) - json_payload = { - "source_language": source_language, - "target_language": target_language, - } - + json_payload = {"source_language": source_language} response = await retry(client.post)( self.transcript_url, files=files, @@ -81,16 +68,10 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): ) response.raise_for_status() result = response.json() - - # Sanity check for translation status in the result - translation = None - if source_language != target_language and target_language in result["text"]: - translation = result["text"][target_language] text = result["text"][source_language] text = self.filter_profanity(text) transcript = Transcript( text=text, - translation=translation, words=[ Word( text=word["text"], diff --git a/server/reflector/processors/transcript_liner.py b/server/reflector/processors/transcript_liner.py index c7ec2f64..c1aa14a0 100644 --- a/server/reflector/processors/transcript_liner.py +++ b/server/reflector/processors/transcript_liner.py @@ -16,29 +16,35 @@ class TranscriptLinerProcessor(Processor): self.transcript = Transcript(words=[]) self.max_text = max_text + def is_sentence_terminated(self, sentence) -> bool: + sentence_terminators = [".", "?", "!"] + for terminator in sentence_terminators: + if terminator in sentence: + return True + return False + async def _push(self, data: Transcript): # merge both transcript self.transcript.merge(data) # check if a line is complete - if "." not in self.transcript.text: + if not self.is_sentence_terminated(self.transcript.text): # if the transcription text is still not too long, wait for more if len(self.transcript.text) < self.max_text: return # cut to the next . - partial = Transcript(translation=self.transcript.translation, words=[]) + partial = Transcript(words=[]) for word in self.transcript.words[:]: partial.text += word.text partial.words.append(word) - if "." not in word.text: + if not self.is_sentence_terminated(word.text): continue # emit line await self.emit(partial) - # create new transcript - partial = Transcript(translation=self.transcript.translation, words=[]) + partial = Transcript(words=[]) self.transcript = partial diff --git a/server/reflector/processors/transcript_translator.py b/server/reflector/processors/transcript_translator.py new file mode 100644 index 00000000..354b1bd8 --- /dev/null +++ b/server/reflector/processors/transcript_translator.py @@ -0,0 +1,88 @@ +from time import monotonic + +import httpx + +from reflector.processors.base import Processor +from reflector.processors.types import Transcript, TranslationLanguages +from reflector.settings import settings +from reflector.utils.retry import retry + + +class TranscriptTranslatorProcessor(Processor): + """ + Translate the transcript into the target language + """ + + INPUT_TYPE = Transcript + OUTPUT_TYPE = Transcript + TASK = "translate" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.transcript_url = settings.TRANSCRIPT_URL + self.timeout = settings.TRANSCRIPT_TIMEOUT + self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"} + + async def _warmup(self): + try: + async with httpx.AsyncClient() as client: + start = monotonic() + self.logger.debug("Translate modal: warming up...") + response = await client.post( + settings.TRANSCRIPT_URL + "/warmup", + headers=self.headers, + timeout=self.timeout, + ) + response.raise_for_status() + duration = monotonic() - start + self.logger.debug(f"Translate modal: warmup took {duration:.2f}s") + except Exception: + self.logger.exception("Translate modal: warmup failed") + + async def _push(self, data: Transcript): + self.transcript = data + await self.flush() + + async def get_translation(self, text: str) -> str: + self.logger.debug(f"Try to translate {text=}") + # FIXME this should be a processor after, as each user may want + # different languages + source_language = self.get_pref("audio:source_language", "en") + target_language = self.get_pref("audio:target_language", "en") + + languages = TranslationLanguages() + + # Only way to set the target should be the UI element like dropdown. + # Hence, this assert should never fail. + assert languages.is_supported(target_language) + assert target_language != source_language + source_language = self.get_pref("audio:source_language", "en") + target_language = self.get_pref("audio:target_language", "en") + json_payload = { + "text": text, + "source_language": source_language, + "target_language": target_language, + } + translation = None + async with httpx.AsyncClient() as client: + response = await retry(client.post)( + settings.TRANSCRIPT_URL + "/translate", + headers=self.headers, + params=json_payload, + timeout=self.timeout, + ) + response.raise_for_status() + result = response.json()["text"] + + # Sanity check for translation status in the result + if source_language != target_language and target_language in result: + translation = result[target_language] + self.logger.debug(f"Translation response: {text=}, {translation=}") + return translation + + async def _flush(self): + if not self.transcript: + return + translation = await self.get_translation(text=self.transcript.text) + self.transcript.translation = translation + await self.emit(self.transcript) diff --git a/server/reflector/tools/process.py b/server/reflector/tools/process.py index add1b104..37f44096 100644 --- a/server/reflector/tools/process.py +++ b/server/reflector/tools/process.py @@ -14,6 +14,7 @@ from reflector.processors import ( TranscriptFinalTitleProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, + TranscriptTranslatorProcessor, ) from reflector.processors.base import BroadcastProcessor @@ -31,6 +32,7 @@ async def process_audio_file( AudioMergeProcessor(), AudioTranscriptAutoProcessor.as_threaded(), TranscriptLinerProcessor(), + TranscriptTranslatorProcessor.as_threaded(), ] if not only_transcript: processors += [ diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 792ce244..d767153e 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -26,6 +26,7 @@ from reflector.processors import ( TranscriptFinalTitleProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, + TranscriptTranslatorProcessor, ) from reflector.processors.base import BroadcastProcessor from reflector.processors.types import FinalTitle @@ -219,8 +220,9 @@ async def rtc_offer_base( processors += [ AudioChunkerProcessor(), AudioMergeProcessor(), - AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript), + AudioTranscriptAutoProcessor.as_threaded(), TranscriptLinerProcessor(), + TranscriptTranslatorProcessor.as_threaded(callback=on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), BroadcastProcessor( processors=[ diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 27417bcb..ef56929a 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -28,13 +28,43 @@ def dummy_processors(): "reflector.processors.transcript_final_long_summary.TranscriptFinalLongSummaryProcessor.get_long_summary" ) as mock_long_summary, patch( "reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary" - ) as mock_short_summary: + ) as mock_short_summary, patch( + "reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation" + ) as mock_translate: mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"} mock_title.return_value = {"title": "LLM TITLE"} mock_long_summary.return_value = "LLM LONG SUMMARY" mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"} + mock_translate.return_value = "Bonjour le monde" + yield mock_translate, mock_topic, mock_title, mock_long_summary, mock_short_summary # noqa - yield mock_topic, mock_title, mock_long_summary, mock_short_summary + +@pytest.fixture +async def dummy_transcript(): + from reflector.processors.audio_transcript import AudioTranscriptProcessor + from reflector.processors.types import AudioFile, Transcript, Word + + class TestAudioTranscriptProcessor(AudioTranscriptProcessor): + async def _transcript(self, data: AudioFile): + source_language = self.get_pref("audio:source_language", "en") + print("transcripting", source_language) + print("pipeline", self.pipeline) + print("prefs", self.pipeline.prefs) + + return Transcript( + text="Hello world.", + words=[ + Word(start=0.0, end=1.0, text="Hello"), + Word(start=1.0, end=2.0, text=" world."), + ], + ) + + with patch( + "reflector.processors.audio_transcript_auto" + ".AudioTranscriptAutoProcessor.get_instance" + ) as mock_audio: + mock_audio.return_value = TestAudioTranscriptProcessor() + yield @pytest.fixture diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index 996c0908..69c3910d 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -3,7 +3,12 @@ import pytest @pytest.mark.asyncio async def test_basic_process( - event_loop, nltk, dummy_llm, dummy_processors, ensure_casing + event_loop, + nltk, + dummy_transcript, + dummy_llm, + dummy_processors, + ensure_casing, ): # goal is to start the server, and send rtc audio to it # validate the events received @@ -29,7 +34,8 @@ async def test_basic_process( print(marks) # validate the events - assert marks["TranscriptLinerProcessor"] == 5 + assert marks["TranscriptLinerProcessor"] == 4 + assert marks["TranscriptTranslatorProcessor"] == 4 assert marks["TranscriptTopicDetectorProcessor"] == 1 assert marks["TranscriptFinalLongSummaryProcessor"] == 1 assert marks["TranscriptFinalShortSummaryProcessor"] == 1 diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index d691dacb..2485ca6b 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -7,7 +7,6 @@ import asyncio import json import threading from pathlib import Path -from unittest.mock import patch import pytest from httpx import AsyncClient @@ -32,41 +31,6 @@ class ThreadedUvicorn: continue -@pytest.fixture -async def dummy_transcript(): - from reflector.processors.audio_transcript import AudioTranscriptProcessor - from reflector.processors.types import AudioFile, Transcript, Word - - class TestAudioTranscriptProcessor(AudioTranscriptProcessor): - async def _transcript(self, data: AudioFile): - source_language = self.get_pref("audio:source_language", "en") - target_language = self.get_pref("audio:target_language", "en") - print("transcripting", source_language, target_language) - print("pipeline", self.pipeline) - print("prefs", self.pipeline.prefs) - - translation = None - if source_language != target_language: - if target_language == "fr": - translation = "Bonjour le monde" - - return Transcript( - text="Hello world", - translation=translation, - words=[ - Word(start=0.0, end=1.0, text="Hello"), - Word(start=1.0, end=2.0, text="world"), - ], - ) - - with patch( - "reflector.processors.audio_transcript_auto" - ".AudioTranscriptAutoProcessor.get_instance" - ) as mock_audio: - mock_audio.return_value = TestAudioTranscriptProcessor() - yield - - @pytest.mark.asyncio async def test_transcript_rtc_and_websocket( tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing @@ -165,14 +129,14 @@ async def test_transcript_rtc_and_websocket( # check events assert "TRANSCRIPT" in eventnames ev = events[eventnames.index("TRANSCRIPT")] - assert ev["data"]["text"] == "Hello world" - assert ev["data"]["translation"] is None + assert ev["data"]["text"].startswith("Hello world.") + assert ev["data"]["translation"] == "Bonjour le monde" assert "TOPIC" in eventnames ev = events[eventnames.index("TOPIC")] assert ev["data"]["id"] assert ev["data"]["summary"] == "LLM SUMMARY" - assert ev["data"]["transcript"].startswith("Hello world") + assert ev["data"]["transcript"].startswith("Hello world.") assert ev["data"]["timestamp"] == 0.0 assert "FINAL_LONG_SUMMARY" in eventnames @@ -310,14 +274,14 @@ async def test_transcript_rtc_and_websocket_and_fr( # check events assert "TRANSCRIPT" in eventnames ev = events[eventnames.index("TRANSCRIPT")] - assert ev["data"]["text"] == "Hello world" + assert ev["data"]["text"].startswith("Hello world.") assert ev["data"]["translation"] == "Bonjour le monde" assert "TOPIC" in eventnames ev = events[eventnames.index("TOPIC")] assert ev["data"]["id"] assert ev["data"]["summary"] == "LLM SUMMARY" - assert ev["data"]["transcript"].startswith("Hello world") + assert ev["data"]["transcript"].startswith("Hello world.") assert ev["data"]["timestamp"] == 0.0 assert "FINAL_LONG_SUMMARY" in eventnames