diff --git a/server/migrations/versions/b3df9681cae9_add_source_and_target_language.py b/server/migrations/versions/b3df9681cae9_add_source_and_target_language.py new file mode 100644 index 00000000..ed8a85b2 --- /dev/null +++ b/server/migrations/versions/b3df9681cae9_add_source_and_target_language.py @@ -0,0 +1,32 @@ +"""add source and target language + +Revision ID: b3df9681cae9 +Revises: 543ed284d69a +Create Date: 2023-08-29 10:55:37.690469 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b3df9681cae9' +down_revision: Union[str, None] = '543ed284d69a' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('transcript', sa.Column('source_language', sa.String(), nullable=True)) + op.add_column('transcript', sa.Column('target_language', sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('transcript', 'target_language') + op.drop_column('transcript', 'source_language') + # ### end Alembic commands ### diff --git a/server/reflector/db/__init__.py b/server/reflector/db/__init__.py index 3864b13a..2ac68029 100644 --- a/server/reflector/db/__init__.py +++ b/server/reflector/db/__init__.py @@ -1,9 +1,8 @@ import databases import sqlalchemy -from reflector.events import subscribers_startup, subscribers_shutdown +from reflector.events import subscribers_shutdown, subscribers_startup from reflector.settings import settings - database = databases.Database(settings.DATABASE_URL) metadata = sqlalchemy.MetaData() @@ -20,6 +19,8 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("summary", sqlalchemy.String, nullable=True), sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON), + sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), + sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), # with user attached, optional sqlalchemy.Column("user_id", sqlalchemy.String), ) diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py index fdae7663..3bc10102 100644 --- a/server/reflector/processors/audio_transcript_auto.py +++ b/server/reflector/processors/audio_transcript_auto.py @@ -1,8 +1,9 @@ -from reflector.processors.base import Processor +import importlib + from reflector.processors.audio_transcript import AudioTranscriptProcessor +from reflector.processors.base import Pipeline, Processor from reflector.processors.types import AudioFile from reflector.settings import settings -import importlib class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): @@ -35,6 +36,10 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND) super().__init__(**kwargs) + def set_pipeline(self, pipeline: Pipeline): + super().set_pipeline(pipeline) + self.processor.set_pipeline(pipeline) + def connect(self, processor: Processor): self.processor.connect(processor) diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 335d1f0f..2ecdc2ec 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -15,7 +15,6 @@ API will be a POST request to TRANSCRIPT_URL: from time import monotonic import httpx - from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word @@ -54,14 +53,10 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): "file": (data.name, data.fd), } - # TODO: Get the source / target language from the UI preferences dynamically - # Update code here once this is possible. - # i.e) extract from context/session objects - source_language = "en" - - # TODO: target lang is set to "fr" for demo purposes - # Revert back once language selection is implemented - target_language = "fr" + # FIXME this should be a processor after, as each user may want + # different languages + source_language = self.get_pref("audio:source_language", "en") + target_language = self.get_pref("audio:target_language", "en") languages = TranslationLanguages() # Only way to set the target should be the UI element like dropdown. @@ -87,8 +82,8 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): result = response.json() # Sanity check for translation status in the result - translation = "" - if target_language in result["text"]: + translation = None + if source_language != target_language and target_language in result["text"]: translation = result["text"][target_language] text = result["text"][source_language] diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py index 4a7f2bc2..35e836bc 100644 --- a/server/reflector/processors/base.py +++ b/server/reflector/processors/base.py @@ -1,7 +1,9 @@ -from reflector.logger import logger -from uuid import uuid4 -from concurrent.futures import ThreadPoolExecutor import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Any +from uuid import uuid4 + +from reflector.logger import logger class Processor: @@ -17,9 +19,11 @@ class Processor: self.uid = uuid4().hex self.flushed = False self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__) + self.pipeline = None def set_pipeline(self, pipeline: "Pipeline"): # if pipeline is used, pipeline logger will be used instead + self.pipeline = pipeline self.logger = pipeline.logger.bind(processor=self.__class__.__name__) def connect(self, processor: "Processor"): @@ -54,6 +58,14 @@ class Processor: """ self._callbacks.remove(callback) + def get_pref(self, key: str, default: Any = None): + """ + Get a preference from the pipeline prefs + """ + if self.pipeline: + return self.pipeline.get_pref(key, default) + return default + async def emit(self, data): for callback in self._callbacks: await callback(data) @@ -191,6 +203,7 @@ class Pipeline(Processor): self.logger.info("Pipeline created") self.processors = processors + self.prefs = {} for processor in processors: processor.set_pipeline(self) @@ -220,3 +233,17 @@ class Pipeline(Processor): for processor in self.processors: processor.describe(level + 1) logger.info("") + + def set_pref(self, key: str, value: Any): + """ + Set a preference for this pipeline + """ + self.prefs[key] = value + + def get_pref(self, key: str, default=None): + """ + Get a preference for this pipeline + """ + if key not in self.prefs: + self.logger.warning(f"Pref {key} not found, using default") + return self.prefs.get(key, default) diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 537de415..8aab2a0d 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -47,7 +47,7 @@ class Word(BaseModel): class Transcript(BaseModel): text: str = "" - translation: str = "" + translation: str | None = None words: list[Word] = None @property diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index f909cc9c..90f44434 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -7,7 +7,6 @@ import av from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from fastapi import APIRouter, Request from pydantic import BaseModel - from reflector.events import subscribers_shutdown from reflector.logger import logger from reflector.processors import ( @@ -81,6 +80,8 @@ async def rtc_offer_base( event_callback=None, event_callback_args=None, audio_filename: Path | None = None, + source_language: str = "en", + target_language: str = "en", ): # build an rtc session offer = RTCSessionDescription(sdp=params.sdp, type=params.type) @@ -110,7 +111,6 @@ async def rtc_offer_base( result = { "cmd": "SHOW_TRANSCRIPTION", "text": transcript.text, - "translation": transcript.translation, } ctx.data_channel.send(dumps(result)) @@ -179,6 +179,8 @@ async def rtc_offer_base( TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary), ] ctx.pipeline = Pipeline(*processors) + ctx.pipeline.set_pref("audio:source_language", source_language) + ctx.pipeline.set_pref("audio:target_language", target_language) # FIXME: warmup is not working well yet # await ctx.pipeline.warmup() diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index b153765a..5aed7141 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -7,6 +7,7 @@ from typing import Annotated, Optional from uuid import uuid4 import av +import reflector.auth as auth from fastapi import ( APIRouter, Depends, @@ -17,13 +18,11 @@ from fastapi import ( ) from fastapi_pagination import Page, paginate from pydantic import BaseModel, Field -from starlette.concurrency import run_in_threadpool - -import reflector.auth as auth from reflector.db import database, transcripts from reflector.logger import logger from reflector.settings import settings from reflector.utils.audio_waveform import get_audio_waveform +from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base @@ -50,7 +49,7 @@ class AudioWaveform(BaseModel): class TranscriptText(BaseModel): text: str - translation: str + translation: str | None class TranscriptTopic(BaseModel): @@ -81,6 +80,8 @@ class Transcript(BaseModel): summary: str | None = None topics: list[TranscriptTopic] = [] events: list[TranscriptEvent] = [] + source_language: str = "en" + target_language: str = "en" def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: ev = TranscriptEvent(event=event, data=data.model_dump()) @@ -186,8 +187,19 @@ class TranscriptController: return None return Transcript(**result) - async def add(self, name: str, user_id: str | None = None): - transcript = Transcript(name=name, user_id=user_id) + async def add( + self, + name: str, + source_language: str = "en", + target_language: str = "en", + user_id: str | None = None, + ): + transcript = Transcript( + name=name, + source_language=source_language, + target_language=target_language, + user_id=user_id, + ) query = transcripts.insert().values(**transcript.model_dump()) await database.execute(query) return transcript @@ -231,10 +243,14 @@ class GetTranscript(BaseModel): duration: int summary: str | None created_at: datetime + source_language: str + target_language: str class CreateTranscript(BaseModel): name: str + source_language: str = Field("en") + target_language: str = Field("en") class UpdateTranscript(BaseModel): @@ -243,10 +259,6 @@ class UpdateTranscript(BaseModel): summary: Optional[str] = Field(None) -class TranscriptEntryCreate(BaseModel): - name: str - - class DeletionStatus(BaseModel): status: str @@ -268,7 +280,12 @@ async def transcripts_create( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): user_id = user["sub"] if user else None - return await transcripts_controller.add(info.name, user_id=user_id) + return await transcripts_controller.add( + info.name, + source_language=info.source_language, + target_language=info.target_language, + user_id=user_id, + ) # ============================================================== @@ -573,4 +590,6 @@ async def transcript_record_webrtc( event_callback=handle_rtc_event, event_callback_args=transcript_id, audio_filename=transcript.audio_filename, + source_language=transcript.source_language, + target_language=transcript.target_language, ) diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 8237d4ab..c6adf320 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -39,8 +39,20 @@ async def dummy_transcript(): class TestAudioTranscriptProcessor(AudioTranscriptProcessor): async def _transcript(self, data: AudioFile): + source_language = self.get_pref("audio:source_language", "en") + target_language = self.get_pref("audio:target_language", "en") + print("transcripting", source_language, target_language) + print("pipeline", self.pipeline) + print("prefs", self.pipeline.prefs) + + translation = None + if source_language != target_language: + if target_language == "fr": + translation = "Bonjour le monde" + return Transcript( text="Hello world", + translation=translation, words=[ Word(start=0.0, end=1.0, text="Hello"), Word(start=1.0, end=2.0, text="world"), @@ -165,6 +177,147 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm) assert "TRANSCRIPT" in eventnames ev = events[eventnames.index("TRANSCRIPT")] assert ev["data"]["text"] == "Hello world" + assert ev["data"]["translation"] is None + + assert "TOPIC" in eventnames + ev = events[eventnames.index("TOPIC")] + assert ev["data"]["id"] + assert ev["data"]["summary"] == "LLM SUMMARY" + assert ev["data"]["transcript"].startswith("Hello world") + assert ev["data"]["timestamp"] == 0.0 + + assert "FINAL_SUMMARY" in eventnames + ev = events[eventnames.index("FINAL_SUMMARY")] + assert ev["data"]["summary"] == "LLM SUMMARY" + + # check status order + statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] + assert statuses == ["recording", "processing", "ended"] + + # ensure the last event received is ended + assert events[-1]["event"] == "STATUS" + assert events[-1]["data"]["value"] == "ended" + + # check that transcript status in model is updated + resp = await ac.get(f"/transcripts/{tid}") + assert resp.status_code == 200 + assert resp.json()["status"] == "ended" + + # check that audio is available + resp = await ac.get(f"/transcripts/{tid}/audio") + assert resp.status_code == 200 + assert resp.headers["Content-Type"] == "audio/wav" + + # check that audio/mp3 is available + resp = await ac.get(f"/transcripts/{tid}/audio/mp3") + assert resp.status_code == 200 + assert resp.headers["Content-Type"] == "audio/mp3" + + # stop server + server.stop() + + +@pytest.mark.asyncio +async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dummy_llm): + # goal: start the server, exchange RTC, receive websocket events + # because of that, we need to start the server in a thread + # to be able to connect with aiortc + # with target french language + + from reflector.settings import settings + from reflector.app import app + + settings.DATA_DIR = Path(tmpdir) + + # start server + host = "127.0.0.1" + port = 1255 + base_url = f"http://{host}:{port}/v1" + config = Config(app=app, host=host, port=port) + server = ThreadedUvicorn(config) + await server.start() + + # create a transcript + ac = AsyncClient(base_url=base_url) + response = await ac.post( + "/transcripts", json={"name": "Test RTC", "target_language": "fr"} + ) + assert response.status_code == 200 + tid = response.json()["id"] + + # create a websocket connection as a task + events = [] + + async def websocket_task(): + print("Test websocket: TASK STARTED") + async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws: + print("Test websocket: CONNECTED") + try: + while True: + msg = await ws.receive_json() + print(f"Test websocket: JSON {msg}") + if msg is None: + break + events.append(msg) + except Exception as e: + print(f"Test websocket: EXCEPTION {e}") + finally: + ws.close() + print("Test websocket: DISCONNECTED") + + websocket_task = asyncio.get_event_loop().create_task(websocket_task()) + + # create stream client + import argparse + from reflector.stream_client import StreamClient + from aiortc.contrib.signaling import add_signaling_arguments, create_signaling + + parser = argparse.ArgumentParser() + add_signaling_arguments(parser) + args = parser.parse_args(["-s", "tcp-socket"]) + signaling = create_signaling(args) + + url = f"{base_url}/transcripts/{tid}/record/webrtc" + path = Path(__file__).parent / "records" / "test_short.wav" + client = StreamClient(signaling, url=url, play_from=path.as_posix()) + await client.start() + + timeout = 20 + while not client.is_ended(): + await asyncio.sleep(1) + timeout -= 1 + if timeout < 0: + raise TimeoutError("Timeout while waiting for RTC to end") + + # XXX aiortc is long to close the connection + # instead of waiting a long time, we just send a STOP + client.channel.send(json.dumps({"cmd": "STOP"})) + + # wait the processing to finish + await asyncio.sleep(2) + + await client.stop() + + # wait the processing to finish + await asyncio.sleep(2) + + # stop websocket task + websocket_task.cancel() + + # check events + assert len(events) > 0 + from pprint import pprint + + pprint(events) + + # get events list + eventnames = [e["event"] for e in events] + + # check events + assert "TRANSCRIPT" in eventnames + ev = events[eventnames.index("TRANSCRIPT")] + assert ev["data"]["text"] == "Hello world" + assert ev["data"]["translation"] == "Bonjour le monde" assert "TOPIC" in eventnames ev = events[eventnames.index("TOPIC")] @@ -186,19 +339,4 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm) assert events[-1]["data"]["value"] == "ended" # stop server - # server.stop() - - # check that transcript status in model is updated - resp = await ac.get(f"/transcripts/{tid}") - assert resp.status_code == 200 - assert resp.json()["status"] == "ended" - - # check that audio is available - resp = await ac.get(f"/transcripts/{tid}/audio") - assert resp.status_code == 200 - assert resp.headers["Content-Type"] == "audio/wav" - - # check that audio/mp3 is available - resp = await ac.get(f"/transcripts/{tid}/audio/mp3") - assert resp.status_code == 200 - assert resp.headers["Content-Type"] == "audio/mp3" + server.stop() diff --git a/server/tests/test_transcripts_translation.py b/server/tests/test_transcripts_translation.py new file mode 100644 index 00000000..adae55e9 --- /dev/null +++ b/server/tests/test_transcripts_translation.py @@ -0,0 +1,63 @@ +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_transcript_create_default_translation(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test en"}) + assert response.status_code == 200 + assert response.json()["name"] == "test en" + assert response.json()["source_language"] == "en" + assert response.json()["target_language"] == "en" + tid = response.json()["id"] + + response = await ac.get(f"/transcripts/{tid}") + assert response.status_code == 200 + assert response.json()["name"] == "test en" + assert response.json()["source_language"] == "en" + assert response.json()["target_language"] == "en" + + +@pytest.mark.asyncio +async def test_transcript_create_en_fr_translation(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post( + "/transcripts", json={"name": "test en/fr", "target_language": "fr"} + ) + assert response.status_code == 200 + assert response.json()["name"] == "test en/fr" + assert response.json()["source_language"] == "en" + assert response.json()["target_language"] == "fr" + tid = response.json()["id"] + + response = await ac.get(f"/transcripts/{tid}") + assert response.status_code == 200 + assert response.json()["name"] == "test en/fr" + assert response.json()["source_language"] == "en" + assert response.json()["target_language"] == "fr" + + +@pytest.mark.asyncio +async def test_transcript_create_fr_en_translation(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post( + "/transcripts", json={"name": "test fr/en", "source_language": "fr"} + ) + assert response.status_code == 200 + assert response.json()["name"] == "test fr/en" + assert response.json()["source_language"] == "fr" + assert response.json()["target_language"] == "en" + tid = response.json()["id"] + + response = await ac.get(f"/transcripts/{tid}") + assert response.status_code == 200 + assert response.json()["name"] == "test fr/en" + assert response.json()["source_language"] == "fr" + assert response.json()["target_language"] == "en"