Translation enhancements (#247)

This commit is contained in:
projects-g
2023-09-26 19:49:54 +05:30
committed by GitHub
parent 4dbec9b154
commit 6a43297309
11 changed files with 303 additions and 126 deletions

View File

@@ -7,7 +7,6 @@ import asyncio
import json
import threading
from pathlib import Path
from unittest.mock import patch
import pytest
from httpx import AsyncClient
@@ -32,41 +31,6 @@ class ThreadedUvicorn:
continue
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.types import AudioFile, Transcript, Word
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
print("transcripting", source_language, target_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
translation = None
if source_language != target_language:
if target_language == "fr":
translation = "Bonjour le monde"
return Transcript(
text="Hello world",
translation=translation,
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text="world"),
],
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.get_instance"
) as mock_audio:
mock_audio.return_value = TestAudioTranscriptProcessor()
yield
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
@@ -165,14 +129,14 @@ async def test_transcript_rtc_and_websocket(
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world"
assert ev["data"]["translation"] is None
assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["translation"] == "Bonjour le monde"
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames
@@ -310,14 +274,14 @@ async def test_transcript_rtc_and_websocket_and_fr(
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world"
assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["translation"] == "Bonjour le monde"
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["transcript"].startswith("Hello world.")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_LONG_SUMMARY" in eventnames