mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: pass source and target language from api to pipeline
This commit is contained in:
@@ -39,8 +39,20 @@ async def dummy_transcript():
|
||||
|
||||
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
|
||||
async def _transcript(self, data: AudioFile):
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
target_language = self.get_pref("audio:target_language", "en")
|
||||
print("transcripting", source_language, target_language)
|
||||
print("pipeline", self.pipeline)
|
||||
print("prefs", self.pipeline.prefs)
|
||||
|
||||
translation = None
|
||||
if source_language != target_language:
|
||||
if target_language == "fr":
|
||||
translation = "Bonjour le monde"
|
||||
|
||||
return Transcript(
|
||||
text="Hello world",
|
||||
translation=translation,
|
||||
words=[
|
||||
Word(start=0.0, end=1.0, text="Hello"),
|
||||
Word(start=1.0, end=2.0, text="world"),
|
||||
@@ -165,6 +177,147 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
|
||||
assert "TRANSCRIPT" in eventnames
|
||||
ev = events[eventnames.index("TRANSCRIPT")]
|
||||
assert ev["data"]["text"] == "Hello world"
|
||||
assert ev["data"]["translation"] is None
|
||||
|
||||
assert "TOPIC" in eventnames
|
||||
ev = events[eventnames.index("TOPIC")]
|
||||
assert ev["data"]["id"]
|
||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||
assert ev["data"]["transcript"].startswith("Hello world")
|
||||
assert ev["data"]["timestamp"] == 0.0
|
||||
|
||||
assert "FINAL_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SUMMARY")]
|
||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
assert statuses == ["recording", "processing", "ended"]
|
||||
|
||||
# ensure the last event received is ended
|
||||
assert events[-1]["event"] == "STATUS"
|
||||
assert events[-1]["data"]["value"] == "ended"
|
||||
|
||||
# check that transcript status in model is updated
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ended"
|
||||
|
||||
# check that audio is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/wav"
|
||||
|
||||
# check that audio/mp3 is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/mp3"
|
||||
|
||||
# stop server
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dummy_llm):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
# because of that, we need to start the server in a thread
|
||||
# to be able to connect with aiortc
|
||||
# with target french language
|
||||
|
||||
from reflector.settings import settings
|
||||
from reflector.app import app
|
||||
|
||||
settings.DATA_DIR = Path(tmpdir)
|
||||
|
||||
# start server
|
||||
host = "127.0.0.1"
|
||||
port = 1255
|
||||
base_url = f"http://{host}:{port}/v1"
|
||||
config = Config(app=app, host=host, port=port)
|
||||
server = ThreadedUvicorn(config)
|
||||
await server.start()
|
||||
|
||||
# create a transcript
|
||||
ac = AsyncClient(base_url=base_url)
|
||||
response = await ac.post(
|
||||
"/transcripts", json={"name": "Test RTC", "target_language": "fr"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
tid = response.json()["id"]
|
||||
|
||||
# create a websocket connection as a task
|
||||
events = []
|
||||
|
||||
async def websocket_task():
|
||||
print("Test websocket: TASK STARTED")
|
||||
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
|
||||
print("Test websocket: CONNECTED")
|
||||
try:
|
||||
while True:
|
||||
msg = await ws.receive_json()
|
||||
print(f"Test websocket: JSON {msg}")
|
||||
if msg is None:
|
||||
break
|
||||
events.append(msg)
|
||||
except Exception as e:
|
||||
print(f"Test websocket: EXCEPTION {e}")
|
||||
finally:
|
||||
ws.close()
|
||||
print("Test websocket: DISCONNECTED")
|
||||
|
||||
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||
|
||||
# create stream client
|
||||
import argparse
|
||||
from reflector.stream_client import StreamClient
|
||||
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
add_signaling_arguments(parser)
|
||||
args = parser.parse_args(["-s", "tcp-socket"])
|
||||
signaling = create_signaling(args)
|
||||
|
||||
url = f"{base_url}/transcripts/{tid}/record/webrtc"
|
||||
path = Path(__file__).parent / "records" / "test_short.wav"
|
||||
client = StreamClient(signaling, url=url, play_from=path.as_posix())
|
||||
await client.start()
|
||||
|
||||
timeout = 20
|
||||
while not client.is_ended():
|
||||
await asyncio.sleep(1)
|
||||
timeout -= 1
|
||||
if timeout < 0:
|
||||
raise TimeoutError("Timeout while waiting for RTC to end")
|
||||
|
||||
# XXX aiortc is long to close the connection
|
||||
# instead of waiting a long time, we just send a STOP
|
||||
client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||
|
||||
# wait the processing to finish
|
||||
await asyncio.sleep(2)
|
||||
|
||||
await client.stop()
|
||||
|
||||
# wait the processing to finish
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# stop websocket task
|
||||
websocket_task.cancel()
|
||||
|
||||
# check events
|
||||
assert len(events) > 0
|
||||
from pprint import pprint
|
||||
|
||||
pprint(events)
|
||||
|
||||
# get events list
|
||||
eventnames = [e["event"] for e in events]
|
||||
|
||||
# check events
|
||||
assert "TRANSCRIPT" in eventnames
|
||||
ev = events[eventnames.index("TRANSCRIPT")]
|
||||
assert ev["data"]["text"] == "Hello world"
|
||||
assert ev["data"]["translation"] == "Bonjour le monde"
|
||||
|
||||
assert "TOPIC" in eventnames
|
||||
ev = events[eventnames.index("TOPIC")]
|
||||
@@ -186,19 +339,4 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
|
||||
assert events[-1]["data"]["value"] == "ended"
|
||||
|
||||
# stop server
|
||||
# server.stop()
|
||||
|
||||
# check that transcript status in model is updated
|
||||
resp = await ac.get(f"/transcripts/{tid}")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["status"] == "ended"
|
||||
|
||||
# check that audio is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/wav"
|
||||
|
||||
# check that audio/mp3 is available
|
||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers["Content-Type"] == "audio/mp3"
|
||||
server.stop()
|
||||
|
||||
63
server/tests/test_transcripts_translation.py
Normal file
63
server/tests/test_transcripts_translation.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_create_default_translation():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test en"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test en"
|
||||
assert response.json()["source_language"] == "en"
|
||||
assert response.json()["target_language"] == "en"
|
||||
tid = response.json()["id"]
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test en"
|
||||
assert response.json()["source_language"] == "en"
|
||||
assert response.json()["target_language"] == "en"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_create_en_fr_translation():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post(
|
||||
"/transcripts", json={"name": "test en/fr", "target_language": "fr"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test en/fr"
|
||||
assert response.json()["source_language"] == "en"
|
||||
assert response.json()["target_language"] == "fr"
|
||||
tid = response.json()["id"]
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test en/fr"
|
||||
assert response.json()["source_language"] == "en"
|
||||
assert response.json()["target_language"] == "fr"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_create_fr_en_translation():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post(
|
||||
"/transcripts", json={"name": "test fr/en", "source_language": "fr"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test fr/en"
|
||||
assert response.json()["source_language"] == "fr"
|
||||
assert response.json()["target_language"] == "en"
|
||||
tid = response.json()["id"]
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["name"] == "test fr/en"
|
||||
assert response.json()["source_language"] == "fr"
|
||||
assert response.json()["target_language"] == "en"
|
||||
Reference in New Issue
Block a user