Translation enhancements (#247)

This commit is contained in:
projects-g
2023-09-26 19:49:54 +05:30
committed by GitHub
parent 4dbec9b154
commit 6a43297309
11 changed files with 303 additions and 126 deletions

View File

@@ -14,40 +14,52 @@ WHISPER_MODEL: str = "large-v2"
WHISPER_COMPUTE_TYPE: str = "float16" WHISPER_COMPUTE_TYPE: str = "float16"
WHISPER_NUM_WORKERS: int = 1 WHISPER_NUM_WORKERS: int = 1
# Translation Model # Seamless M4T
TRANSLATION_MODEL = "facebook/m2m100_1.2B" SEAMLESSM4T_MODEL_SIZE: str = "medium"
SEAMLESSM4T_MODEL_CARD_NAME: str = f"seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}"
SEAMLESSM4T_VOCODER_CARD_NAME: str = "vocoder_36langs"
IMAGE_MODEL_DIR = f"/root/transcription_models/{TRANSLATION_MODEL}" HF_SEAMLESS_M4TEPO: str = f"facebook/seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}"
HF_SEAMLESS_M4T_VOCODEREPO: str = "facebook/seamless-m4t-vocoder"
SEAMLESS_GITEPO: str = "https://github.com/facebookresearch/seamless_communication.git"
SEAMLESS_MODEL_DIR: str = "m4t"
WHISPER_MODEL_DIR = "/root/transcription_models"
stub = Stub(name="reflector-transcriber") stub = Stub(name="reflector-transcriber")
def download_whisper(cache_dir: str | None = None): def install_seamless_communication():
import os
import subprocess
initial_dir = os.getcwd()
subprocess.run(["ssh-keyscan", "-t", "rsa", "github.com", ">>", "~/.ssh/known_hosts"])
subprocess.run(["rm", "-rf", "seamless_communication"])
subprocess.run(["git", "clone", SEAMLESS_GITEPO, "." + "/seamless_communication"])
os.chdir("seamless_communication")
subprocess.run(["pip", "install", "-e", "."])
os.chdir(initial_dir)
def download_whisper():
from faster_whisper.utils import download_model from faster_whisper.utils import download_model
print("Downloading Whisper model") print("Downloading Whisper model")
download_model(WHISPER_MODEL, cache_dir=cache_dir) download_model(WHISPER_MODEL, cache_dir=WHISPER_MODEL_DIR)
print("Whisper model downloaded") print("Whisper model downloaded")
def download_translation_model(cache_dir: str | None = None): def download_seamlessm4t_model():
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
print("Downloading Translation model") print("Downloading Transcriber model & tokenizer")
ignore_patterns = ["*.ot"] snapshot_download(HF_SEAMLESS_M4TEPO, cache_dir=SEAMLESS_MODEL_DIR)
snapshot_download( print("Transcriber model & tokenizer downloaded")
TRANSLATION_MODEL,
cache_dir=cache_dir,
ignore_patterns=ignore_patterns
)
print("Translation model downloaded")
print("Downloading vocoder weights")
def download_models(): snapshot_download(HF_SEAMLESS_M4T_VOCODEREPO, cache_dir=SEAMLESS_MODEL_DIR)
print(f"Downloading models to {IMAGE_MODEL_DIR=}") print("Vocoder weights downloaded")
download_whisper(cache_dir=IMAGE_MODEL_DIR)
download_translation_model(cache_dir=IMAGE_MODEL_DIR)
print(f"Model downloads complete.")
def migrate_cache_llm(): def migrate_cache_llm():
@@ -60,13 +72,61 @@ def migrate_cache_llm():
from transformers.utils.hub import move_cache from transformers.utils.hub import move_cache
print("Moving LLM cache") print("Moving LLM cache")
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR) move_cache(cache_dir=WHISPER_MODEL_DIR, new_cache_dir=WHISPER_MODEL_DIR)
print("LLM cache moved") print("LLM cache moved")
whisper_image = ( def configure_seamless_m4t():
import os
import yaml
ASSETS_DIR: str = "./seamless_communication/src/seamless_communication/assets/cards"
with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'r') as file:
model_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'r') as file:
vocoder_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'r') as file:
unity_100_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'r') as file:
unity_200_yaml_data = yaml.load(file, Loader=yaml.FullLoader)
model_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}/snapshots"
available_model_versions = os.listdir(model_dir)
latest_model_version = sorted(available_model_versions)[-1]
model_name = f"multitask_unity_{SEAMLESSM4T_MODEL_SIZE}.pt"
model_path = os.path.join(os.getcwd(), model_dir, latest_model_version, model_name)
vocoder_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-vocoder/snapshots"
available_vocoder_versions = os.listdir(vocoder_dir)
latest_vocoder_version = sorted(available_vocoder_versions)[-1]
vocoder_name = "vocoder_36langs.pt"
vocoder_path = os.path.join(os.getcwd(), vocoder_dir, latest_vocoder_version, vocoder_name)
tokenizer_name = "tokenizer.model"
tokenizer_path = os.path.join(os.getcwd(), model_dir, latest_model_version, tokenizer_name)
model_yaml_data['checkpoint'] = f"file:/{model_path}"
vocoder_yaml_data['checkpoint'] = f"file:/{vocoder_path}"
unity_100_yaml_data['tokenizer'] = f"file:/{tokenizer_path}"
unity_200_yaml_data['tokenizer'] = f"file:/{tokenizer_path}"
with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'w') as file:
yaml.dump(model_yaml_data, file)
with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'w') as file:
yaml.dump(vocoder_yaml_data, file)
with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'w') as file:
yaml.dump(unity_100_yaml_data, file)
with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'w') as file:
yaml.dump(unity_200_yaml_data, file)
transcriber_image = (
Image.debian_slim(python_version="3.10.8") Image.debian_slim(python_version="3.10.8")
.apt_install("git") .apt_install("git")
.apt_install("wget")
.apt_install("libsndfile-dev")
.pip_install( .pip_install(
"faster-whisper", "faster-whisper",
"requests", "requests",
@@ -75,8 +135,16 @@ whisper_image = (
"sentencepiece", "sentencepiece",
"protobuf", "protobuf",
"huggingface_hub==0.16.4", "huggingface_hub==0.16.4",
"gitpython",
"torchaudio",
"fairseq2",
"pyyaml",
"hf-transfer~=0.1"
) )
.run_function(download_models) .run_function(install_seamless_communication)
.run_function(download_seamlessm4t_model)
.run_function(configure_seamless_m4t)
.run_function(download_whisper)
.run_function(migrate_cache_llm) .run_function(migrate_cache_llm)
.env( .env(
{ {
@@ -90,15 +158,17 @@ whisper_image = (
@stub.cls( @stub.cls(
gpu="A10G", gpu="A100",
container_idle_timeout=60, timeout=60 * 5,
image=whisper_image, container_idle_timeout=60 * 5,
concurrency_limit=3,
image=transcriber_image,
) )
class Whisper: class Transcriber:
def __enter__(self): def __enter__(self):
import faster_whisper import faster_whisper
import torch import torch
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer from seamless_communication.models.inference.translator import Translator
self.use_gpu = torch.cuda.is_available() self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu" self.device = "cuda" if self.use_gpu else "cpu"
@@ -107,15 +177,13 @@ class Whisper:
device=self.device, device=self.device,
compute_type=WHISPER_COMPUTE_TYPE, compute_type=WHISPER_COMPUTE_TYPE,
num_workers=WHISPER_NUM_WORKERS, num_workers=WHISPER_NUM_WORKERS,
download_root=IMAGE_MODEL_DIR download_root=WHISPER_MODEL_DIR
) )
self.translation_model = M2M100ForConditionalGeneration.from_pretrained( self.translator = Translator(
TRANSLATION_MODEL, SEAMLESSM4T_MODEL_CARD_NAME,
cache_dir=IMAGE_MODEL_DIR SEAMLESSM4T_VOCODER_CARD_NAME,
).to(self.device) torch.device(self.device),
self.translation_tokenizer = M2M100Tokenizer.from_pretrained( dtype=torch.float32
TRANSLATION_MODEL,
cache_dir=IMAGE_MODEL_DIR
) )
@method() @method()
@@ -128,7 +196,6 @@ class Whisper:
audio_data: str, audio_data: str,
audio_suffix: str, audio_suffix: str,
source_language: str, source_language: str,
target_language: str,
timestamp: float = 0 timestamp: float = 0
): ):
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp: with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
@@ -162,25 +229,43 @@ class Whisper:
multilingual_transcript[source_language] = transcript_source_lang multilingual_transcript[source_language] = transcript_source_lang
if target_language != source_language:
self.translation_tokenizer.src_lang = source_language
forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language)
encoded_transcript = self.translation_tokenizer(transcript_source_lang, return_tensors="pt").to(self.device)
generated_tokens = self.translation_model.generate(
**encoded_transcript,
forced_bos_token_id=forced_bos_token_id
)
result = self.translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
translation = result[0].strip()
multilingual_transcript[target_language] = translation
return { return {
"text": multilingual_transcript, "text": multilingual_transcript,
"words": words "words": words
} }
def get_seamless_lang_code(self, lang_code: str):
"""
The codes for SeamlessM4T is different from regular standards.
For ex, French is "fra" and not "fr".
"""
# TODO: Enhance with complete list of lang codes
seamless_lang_code = {
"en": "eng",
"fr": "fra"
}
return seamless_lang_code.get(lang_code, "eng")
@method()
def translate_text(
self,
text: str,
source_language: str,
target_language: str
):
translated_text, _, _ = self.translator.predict(
text,
"t2tt",
src_lang=self.get_seamless_lang_code(source_language),
tgt_lang=self.get_seamless_lang_code(target_language),
ngram_filtering=True
)
return {
"text": {
source_language: text,
target_language: str(translated_text)
}
}
# ------------------------------------------------------------------- # -------------------------------------------------------------------
# Web API # Web API
# ------------------------------------------------------------------- # -------------------------------------------------------------------
@@ -199,7 +284,7 @@ def web():
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from typing_extensions import Annotated from typing_extensions import Annotated
transcriberstub = Whisper() transcriberstub = Transcriber()
app = FastAPI() app = FastAPI()
@@ -221,7 +306,6 @@ def web():
async def transcribe( async def transcribe(
file: UploadFile, file: UploadFile,
source_language: Annotated[str, Body(...)] = "en", source_language: Annotated[str, Body(...)] = "en",
target_language: Annotated[str, Body(...)] = "en",
timestamp: Annotated[float, Body()] = 0.0 timestamp: Annotated[float, Body()] = 0.0
) -> TranscriptResponse: ) -> TranscriptResponse:
audio_data = await file.read() audio_data = await file.read()
@@ -232,12 +316,25 @@ def web():
audio_data=audio_data, audio_data=audio_data,
audio_suffix=audio_suffix, audio_suffix=audio_suffix,
source_language=source_language, source_language=source_language,
target_language=target_language,
timestamp=timestamp timestamp=timestamp
) )
result = func.get() result = func.get()
return result return result
@app.post("/translate", dependencies=[Depends(apikey_auth)])
async def translate(
text: str,
source_language: Annotated[str, Body(...)] = "en",
target_language: Annotated[str, Body(...)] = "fr",
) -> TranscriptResponse:
func = transcriberstub.translate_text.spawn(
text=text,
source_language=source_language,
target_language=target_language,
)
result = func.get()
return result
@app.post("/warmup", dependencies=[Depends(apikey_auth)]) @app.post("/warmup", dependencies=[Depends(apikey_auth)])
async def warmup(): async def warmup():
return transcriberstub.warmup.spawn().get() return transcriberstub.warmup.spawn().get()

View File

@@ -124,7 +124,7 @@ class TopicParams(LLMTaskParams):
For the title field, generate a very detailed and self-explanatory For the title field, generate a very detailed and self-explanatory
title for the given text. Let the title be as descriptive as possible. title for the given text. Let the title be as descriptive as possible.
For the summary field, summarize the given text in a maximum of For the summary field, summarize the given text in a maximum of
three sentences. two sentences.
""" """
self._schema = { self._schema = {
"type": "object", "type": "object",

View File

@@ -13,6 +13,7 @@ from .transcript_final_short_summary import ( # noqa: F401
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401 from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
from .transcript_liner import TranscriptLinerProcessor # noqa: F401 from .transcript_liner import TranscriptLinerProcessor # noqa: F401
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401 from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
from .transcript_translator import TranscriptTranslatorProcessor # noqa: F401
from .types import ( # noqa: F401 from .types import ( # noqa: F401
AudioFile, AudioFile,
FinalLongSummary, FinalLongSummary,

View File

@@ -18,7 +18,7 @@ import httpx
from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word from reflector.processors.types import AudioFile, Transcript, Word
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.retry import retry from reflector.utils.retry import retry
@@ -53,21 +53,8 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
files = { files = {
"file": (data.name, data.fd), "file": (data.name, data.fd),
} }
# FIXME this should be a processor after, as each user may want
# different languages
source_language = self.get_pref("audio:source_language", "en") source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en") json_payload = {"source_language": source_language}
languages = TranslationLanguages()
# Only way to set the target should be the UI element like dropdown.
# Hence, this assert should never fail.
assert languages.is_supported(target_language)
json_payload = {
"source_language": source_language,
"target_language": target_language,
}
response = await retry(client.post)( response = await retry(client.post)(
self.transcript_url, self.transcript_url,
files=files, files=files,
@@ -81,16 +68,10 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
) )
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()
# Sanity check for translation status in the result
translation = None
if source_language != target_language and target_language in result["text"]:
translation = result["text"][target_language]
text = result["text"][source_language] text = result["text"][source_language]
text = self.filter_profanity(text) text = self.filter_profanity(text)
transcript = Transcript( transcript = Transcript(
text=text, text=text,
translation=translation,
words=[ words=[
Word( Word(
text=word["text"], text=word["text"],

View File

@@ -16,29 +16,35 @@ class TranscriptLinerProcessor(Processor):
self.transcript = Transcript(words=[]) self.transcript = Transcript(words=[])
self.max_text = max_text self.max_text = max_text
def is_sentence_terminated(self, sentence) -> bool:
sentence_terminators = [".", "?", "!"]
for terminator in sentence_terminators:
if terminator in sentence:
return True
return False
async def _push(self, data: Transcript): async def _push(self, data: Transcript):
# merge both transcript # merge both transcript
self.transcript.merge(data) self.transcript.merge(data)
# check if a line is complete # check if a line is complete
if "." not in self.transcript.text: if not self.is_sentence_terminated(self.transcript.text):
# if the transcription text is still not too long, wait for more # if the transcription text is still not too long, wait for more
if len(self.transcript.text) < self.max_text: if len(self.transcript.text) < self.max_text:
return return
# cut to the next . # cut to the next .
partial = Transcript(translation=self.transcript.translation, words=[]) partial = Transcript(words=[])
for word in self.transcript.words[:]: for word in self.transcript.words[:]:
partial.text += word.text partial.text += word.text
partial.words.append(word) partial.words.append(word)
if "." not in word.text: if not self.is_sentence_terminated(word.text):
continue continue
# emit line # emit line
await self.emit(partial) await self.emit(partial)
# create new transcript # create new transcript
partial = Transcript(translation=self.transcript.translation, words=[]) partial = Transcript(words=[])
self.transcript = partial self.transcript = partial

View File

@@ -0,0 +1,88 @@
from time import monotonic
import httpx
from reflector.processors.base import Processor
from reflector.processors.types import Transcript, TranslationLanguages
from reflector.settings import settings
from reflector.utils.retry import retry
class TranscriptTranslatorProcessor(Processor):
"""
Translate the transcript into the target language
"""
INPUT_TYPE = Transcript
OUTPUT_TYPE = Transcript
TASK = "translate"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.transcript_url = settings.TRANSCRIPT_URL
self.timeout = settings.TRANSCRIPT_TIMEOUT
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
async def _warmup(self):
try:
async with httpx.AsyncClient() as client:
start = monotonic()
self.logger.debug("Translate modal: warming up...")
response = await client.post(
settings.TRANSCRIPT_URL + "/warmup",
headers=self.headers,
timeout=self.timeout,
)
response.raise_for_status()
duration = monotonic() - start
self.logger.debug(f"Translate modal: warmup took {duration:.2f}s")
except Exception:
self.logger.exception("Translate modal: warmup failed")
async def _push(self, data: Transcript):
self.transcript = data
await self.flush()
async def get_translation(self, text: str) -> str:
self.logger.debug(f"Try to translate {text=}")
# FIXME this should be a processor after, as each user may want
# different languages
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
languages = TranslationLanguages()
# Only way to set the target should be the UI element like dropdown.
# Hence, this assert should never fail.
assert languages.is_supported(target_language)
assert target_language != source_language
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
json_payload = {
"text": text,
"source_language": source_language,
"target_language": target_language,
}
translation = None
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
settings.TRANSCRIPT_URL + "/translate",
headers=self.headers,
params=json_payload,
timeout=self.timeout,
)
response.raise_for_status()
result = response.json()["text"]
# Sanity check for translation status in the result
if source_language != target_language and target_language in result:
translation = result[target_language]
self.logger.debug(f"Translation response: {text=}, {translation=}")
return translation
async def _flush(self):
if not self.transcript:
return
translation = await self.get_translation(text=self.transcript.text)
self.transcript.translation = translation
await self.emit(self.transcript)

View File

@@ -14,6 +14,7 @@ from reflector.processors import (
TranscriptFinalTitleProcessor, TranscriptFinalTitleProcessor,
TranscriptLinerProcessor, TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor, TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
) )
from reflector.processors.base import BroadcastProcessor from reflector.processors.base import BroadcastProcessor
@@ -31,6 +32,7 @@ async def process_audio_file(
AudioMergeProcessor(), AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(), AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(), TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(),
] ]
if not only_transcript: if not only_transcript:
processors += [ processors += [

View File

@@ -26,6 +26,7 @@ from reflector.processors import (
TranscriptFinalTitleProcessor, TranscriptFinalTitleProcessor,
TranscriptLinerProcessor, TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor, TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
) )
from reflector.processors.base import BroadcastProcessor from reflector.processors.base import BroadcastProcessor
from reflector.processors.types import FinalTitle from reflector.processors.types import FinalTitle
@@ -219,8 +220,9 @@ async def rtc_offer_base(
processors += [ processors += [
AudioChunkerProcessor(), AudioChunkerProcessor(),
AudioMergeProcessor(), AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript), AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(), TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
BroadcastProcessor( BroadcastProcessor(
processors=[ processors=[

View File

@@ -28,13 +28,43 @@ def dummy_processors():
"reflector.processors.transcript_final_long_summary.TranscriptFinalLongSummaryProcessor.get_long_summary" "reflector.processors.transcript_final_long_summary.TranscriptFinalLongSummaryProcessor.get_long_summary"
) as mock_long_summary, patch( ) as mock_long_summary, patch(
"reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary" "reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary"
) as mock_short_summary: ) as mock_short_summary, patch(
"reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation"
) as mock_translate:
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"} mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
mock_title.return_value = {"title": "LLM TITLE"} mock_title.return_value = {"title": "LLM TITLE"}
mock_long_summary.return_value = "LLM LONG SUMMARY" mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"} mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
mock_translate.return_value = "Bonjour le monde"
yield mock_translate, mock_topic, mock_title, mock_long_summary, mock_short_summary # noqa
yield mock_topic, mock_title, mock_long_summary, mock_short_summary
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
source_language = self.get_pref("audio:source_language", "en")
print("transcripting", source_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
return Transcript(
text="Hello world.",
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text=" world."),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.get_instance"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.fixture @pytest.fixture

View File

@@ -3,7 +3,12 @@ import pytest
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_basic_process( async def test_basic_process(
event_loop, nltk, dummy_llm, dummy_processors, ensure_casing event_loop,
nltk,
dummy_transcript,
dummy_llm,
dummy_processors,
ensure_casing,
): ):
# goal is to start the server, and send rtc audio to it # goal is to start the server, and send rtc audio to it
# validate the events received # validate the events received
@@ -29,7 +34,8 @@ async def test_basic_process(
print(marks) print(marks)
# validate the events # validate the events
assert marks["TranscriptLinerProcessor"] == 5 assert marks["TranscriptLinerProcessor"] == 4
assert marks["TranscriptTranslatorProcessor"] == 4
assert marks["TranscriptTopicDetectorProcessor"] == 1 assert marks["TranscriptTopicDetectorProcessor"] == 1
assert marks["TranscriptFinalLongSummaryProcessor"] == 1 assert marks["TranscriptFinalLongSummaryProcessor"] == 1
assert marks["TranscriptFinalShortSummaryProcessor"] == 1 assert marks["TranscriptFinalShortSummaryProcessor"] == 1

View File

@@ -7,7 +7,6 @@ import asyncio
import json import json
import threading import threading
from pathlib import Path from pathlib import Path
from unittest.mock import patch
import pytest import pytest
from httpx import AsyncClient from httpx import AsyncClient
@@ -32,41 +31,6 @@ class ThreadedUvicorn:
continue continue
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
print("transcripting", source_language, target_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
translation = None
if source_language != target_language:
if target_language == "fr":
translation = "Bonjour le monde"
return Transcript(
text="Hello world",
translation=translation,
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text="world"),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.get_instance"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_transcript_rtc_and_websocket( async def test_transcript_rtc_and_websocket(
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
@@ -165,14 +129,14 @@ async def test_transcript_rtc_and_websocket(
# check events # check events
assert "TRANSCRIPT" in eventnames assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")] ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world" assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["translation"] is None assert ev["data"]["translation"] == "Bonjour le monde"
assert "TOPIC" in eventnames assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")] ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"] assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY" assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world") assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0 assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames assert "FINAL_LONG_SUMMARY" in eventnames
@@ -310,14 +274,14 @@ async def test_transcript_rtc_and_websocket_and_fr(
# check events # check events
assert "TRANSCRIPT" in eventnames assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")] ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world" assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["translation"] == "Bonjour le monde" assert ev["data"]["translation"] == "Bonjour le monde"
assert "TOPIC" in eventnames assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")] ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"] assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY" assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world") assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0 assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames assert "FINAL_LONG_SUMMARY" in eventnames