Translation enhancements (#247)

This commit is contained in:
projects-g
2023-09-26 19:49:54 +05:30
committed by GitHub
parent 4dbec9b154
commit 6a43297309
11 changed files with 303 additions and 126 deletions

View File

@@ -28,13 +28,43 @@ def dummy_processors():
"reflector.processors.transcript_final_long_summary.TranscriptFinalLongSummaryProcessor.get_long_summary"
) as mock_long_summary, patch(
"reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary"
) as mock_short_summary:
) as mock_short_summary, patch(
"reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation"
) as mock_translate:
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
mock_title.return_value = {"title": "LLM TITLE"}
mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
mock_translate.return_value = "Bonjour le monde"
yield mock_translate, mock_topic, mock_title, mock_long_summary, mock_short_summary # noqa
yield mock_topic, mock_title, mock_long_summary, mock_short_summary
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
source_language = self.get_pref("audio:source_language", "en")
print("transcripting", source_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
return Transcript(
text="Hello world.",
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text=" world."),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.get_instance"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.fixture

View File

@@ -3,7 +3,12 @@ import pytest
@pytest.mark.asyncio
async def test_basic_process(
event_loop, nltk, dummy_llm, dummy_processors, ensure_casing
event_loop,
nltk,
dummy_transcript,
dummy_llm,
dummy_processors,
ensure_casing,
):
# goal is to start the server, and send rtc audio to it
# validate the events received
@@ -29,7 +34,8 @@ async def test_basic_process(
print(marks)
# validate the events
assert marks["TranscriptLinerProcessor"] == 5
assert marks["TranscriptLinerProcessor"] == 4
assert marks["TranscriptTranslatorProcessor"] == 4
assert marks["TranscriptTopicDetectorProcessor"] == 1
assert marks["TranscriptFinalLongSummaryProcessor"] == 1
assert marks["TranscriptFinalShortSummaryProcessor"] == 1

View File

@@ -7,7 +7,6 @@ import asyncio
import json
import threading
from pathlib import Path
from unittest.mock import patch
import pytest
from httpx import AsyncClient
@@ -32,41 +31,6 @@ class ThreadedUvicorn:
continue
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
print("transcripting", source_language, target_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
translation = None
if source_language != target_language:
if target_language == "fr":
translation = "Bonjour le monde"
return Transcript(
text="Hello world",
translation=translation,
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text="world"),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.get_instance"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
@@ -165,14 +129,14 @@ async def test_transcript_rtc_and_websocket(
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world"
assert ev["data"]["translation"] is None
assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["translation"] == "Bonjour le monde"
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
@@ -310,14 +274,14 @@ async def test_transcript_rtc_and_websocket_and_fr(
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world"
assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["translation"] == "Bonjour le monde"
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames