server: pass source and target language from api to pipeline

This commit is contained in:
2023-08-29 11:16:23 +02:00
parent cce8a9137a
commit 68dce235ec
10 changed files with 330 additions and 48 deletions

View File

@@ -0,0 +1,32 @@
"""add source and target language
Revision ID: b3df9681cae9
Revises: 543ed284d69a
Create Date: 2023-08-29 10:55:37.690469
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b3df9681cae9'
down_revision: Union[str, None] = '543ed284d69a'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('transcript', sa.Column('source_language', sa.String(), nullable=True))
op.add_column('transcript', sa.Column('target_language', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('transcript', 'target_language')
op.drop_column('transcript', 'source_language')
# ### end Alembic commands ###

View File

@@ -1,9 +1,8 @@
import databases import databases
import sqlalchemy import sqlalchemy
from reflector.events import subscribers_startup, subscribers_shutdown from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings from reflector.settings import settings
database = databases.Database(settings.DATABASE_URL) database = databases.Database(settings.DATABASE_URL)
metadata = sqlalchemy.MetaData() metadata = sqlalchemy.MetaData()
@@ -20,6 +19,8 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("summary", sqlalchemy.String, nullable=True), sqlalchemy.Column("summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
# with user attached, optional # with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String), sqlalchemy.Column("user_id", sqlalchemy.String),
) )

View File

@@ -1,8 +1,9 @@
from reflector.processors.base import Processor import importlib
from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.base import Pipeline, Processor
from reflector.processors.types import AudioFile from reflector.processors.types import AudioFile
from reflector.settings import settings from reflector.settings import settings
import importlib
class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
@@ -35,6 +36,10 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND) self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
super().__init__(**kwargs) super().__init__(**kwargs)
def set_pipeline(self, pipeline: Pipeline):
super().set_pipeline(pipeline)
self.processor.set_pipeline(pipeline)
def connect(self, processor: Processor): def connect(self, processor: Processor):
self.processor.connect(processor) self.processor.connect(processor)

View File

@@ -15,7 +15,6 @@ API will be a POST request to TRANSCRIPT_URL:
from time import monotonic from time import monotonic
import httpx import httpx
from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word
@@ -54,14 +53,10 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
"file": (data.name, data.fd), "file": (data.name, data.fd),
} }
# TODO: Get the source / target language from the UI preferences dynamically # FIXME this should be a processor after, as each user may want
# Update code here once this is possible. # different languages
# i.e) extract from context/session objects source_language = self.get_pref("audio:source_language", "en")
source_language = "en" target_language = self.get_pref("audio:target_language", "en")
# TODO: target lang is set to "fr" for demo purposes
# Revert back once language selection is implemented
target_language = "fr"
languages = TranslationLanguages() languages = TranslationLanguages()
# Only way to set the target should be the UI element like dropdown. # Only way to set the target should be the UI element like dropdown.
@@ -87,8 +82,8 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
result = response.json() result = response.json()
# Sanity check for translation status in the result # Sanity check for translation status in the result
translation = "" translation = None
if target_language in result["text"]: if source_language != target_language and target_language in result["text"]:
translation = result["text"][target_language] translation = result["text"][target_language]
text = result["text"][source_language] text = result["text"][source_language]

View File

@@ -1,7 +1,9 @@
from reflector.logger import logger
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from uuid import uuid4
from reflector.logger import logger
class Processor: class Processor:
@@ -17,9 +19,11 @@ class Processor:
self.uid = uuid4().hex self.uid = uuid4().hex
self.flushed = False self.flushed = False
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__) self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
self.pipeline = None
def set_pipeline(self, pipeline: "Pipeline"): def set_pipeline(self, pipeline: "Pipeline"):
# if pipeline is used, pipeline logger will be used instead # if pipeline is used, pipeline logger will be used instead
self.pipeline = pipeline
self.logger = pipeline.logger.bind(processor=self.__class__.__name__) self.logger = pipeline.logger.bind(processor=self.__class__.__name__)
def connect(self, processor: "Processor"): def connect(self, processor: "Processor"):
@@ -54,6 +58,14 @@ class Processor:
""" """
self._callbacks.remove(callback) self._callbacks.remove(callback)
def get_pref(self, key: str, default: Any = None):
"""
Get a preference from the pipeline prefs
"""
if self.pipeline:
return self.pipeline.get_pref(key, default)
return default
async def emit(self, data): async def emit(self, data):
for callback in self._callbacks: for callback in self._callbacks:
await callback(data) await callback(data)
@@ -191,6 +203,7 @@ class Pipeline(Processor):
self.logger.info("Pipeline created") self.logger.info("Pipeline created")
self.processors = processors self.processors = processors
self.prefs = {}
for processor in processors: for processor in processors:
processor.set_pipeline(self) processor.set_pipeline(self)
@@ -220,3 +233,17 @@ class Pipeline(Processor):
for processor in self.processors: for processor in self.processors:
processor.describe(level + 1) processor.describe(level + 1)
logger.info("") logger.info("")
def set_pref(self, key: str, value: Any):
"""
Set a preference for this pipeline
"""
self.prefs[key] = value
def get_pref(self, key: str, default=None):
"""
Get a preference for this pipeline
"""
if key not in self.prefs:
self.logger.warning(f"Pref {key} not found, using default")
return self.prefs.get(key, default)

View File

@@ -47,7 +47,7 @@ class Word(BaseModel):
class Transcript(BaseModel): class Transcript(BaseModel):
text: str = "" text: str = ""
translation: str = "" translation: str | None = None
words: list[Word] = None words: list[Word] = None
@property @property

View File

@@ -7,7 +7,6 @@ import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from fastapi import APIRouter, Request from fastapi import APIRouter, Request
from pydantic import BaseModel from pydantic import BaseModel
from reflector.events import subscribers_shutdown from reflector.events import subscribers_shutdown
from reflector.logger import logger from reflector.logger import logger
from reflector.processors import ( from reflector.processors import (
@@ -81,6 +80,8 @@ async def rtc_offer_base(
event_callback=None, event_callback=None,
event_callback_args=None, event_callback_args=None,
audio_filename: Path | None = None, audio_filename: Path | None = None,
source_language: str = "en",
target_language: str = "en",
): ):
# build an rtc session # build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type) offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -110,7 +111,6 @@ async def rtc_offer_base(
result = { result = {
"cmd": "SHOW_TRANSCRIPTION", "cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text, "text": transcript.text,
"translation": transcript.translation,
} }
ctx.data_channel.send(dumps(result)) ctx.data_channel.send(dumps(result))
@@ -179,6 +179,8 @@ async def rtc_offer_base(
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary), TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
] ]
ctx.pipeline = Pipeline(*processors) ctx.pipeline = Pipeline(*processors)
ctx.pipeline.set_pref("audio:source_language", source_language)
ctx.pipeline.set_pref("audio:target_language", target_language)
# FIXME: warmup is not working well yet # FIXME: warmup is not working well yet
# await ctx.pipeline.warmup() # await ctx.pipeline.warmup()

View File

@@ -7,6 +7,7 @@ from typing import Annotated, Optional
from uuid import uuid4 from uuid import uuid4
import av import av
import reflector.auth as auth
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
Depends, Depends,
@@ -17,13 +18,11 @@ from fastapi import (
) )
from fastapi_pagination import Page, paginate from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from starlette.concurrency import run_in_threadpool
import reflector.auth as auth
from reflector.db import database, transcripts from reflector.db import database, transcripts
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
@@ -50,7 +49,7 @@ class AudioWaveform(BaseModel):
class TranscriptText(BaseModel): class TranscriptText(BaseModel):
text: str text: str
translation: str translation: str | None
class TranscriptTopic(BaseModel): class TranscriptTopic(BaseModel):
@@ -81,6 +80,8 @@ class Transcript(BaseModel):
summary: str | None = None summary: str | None = None
topics: list[TranscriptTopic] = [] topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = [] events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump()) ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -186,8 +187,19 @@ class TranscriptController:
return None return None
return Transcript(**result) return Transcript(**result)
async def add(self, name: str, user_id: str | None = None): async def add(
transcript = Transcript(name=name, user_id=user_id) self,
name: str,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
):
transcript = Transcript(
name=name,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
query = transcripts.insert().values(**transcript.model_dump()) query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query) await database.execute(query)
return transcript return transcript
@@ -231,10 +243,14 @@ class GetTranscript(BaseModel):
duration: int duration: int
summary: str | None summary: str | None
created_at: datetime created_at: datetime
source_language: str
target_language: str
class CreateTranscript(BaseModel): class CreateTranscript(BaseModel):
name: str name: str
source_language: str = Field("en")
target_language: str = Field("en")
class UpdateTranscript(BaseModel): class UpdateTranscript(BaseModel):
@@ -243,10 +259,6 @@ class UpdateTranscript(BaseModel):
summary: Optional[str] = Field(None) summary: Optional[str] = Field(None)
class TranscriptEntryCreate(BaseModel):
name: str
class DeletionStatus(BaseModel): class DeletionStatus(BaseModel):
status: str status: str
@@ -268,7 +280,12 @@ async def transcripts_create(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
return await transcripts_controller.add(info.name, user_id=user_id) return await transcripts_controller.add(
info.name,
source_language=info.source_language,
target_language=info.target_language,
user_id=user_id,
)
# ============================================================== # ==============================================================
@@ -573,4 +590,6 @@ async def transcript_record_webrtc(
event_callback=handle_rtc_event, event_callback=handle_rtc_event,
event_callback_args=transcript_id, event_callback_args=transcript_id,
audio_filename=transcript.audio_filename, audio_filename=transcript.audio_filename,
source_language=transcript.source_language,
target_language=transcript.target_language,
) )

View File

@@ -39,8 +39,20 @@ async def dummy_transcript():
class TestAudioTranscriptProcessor(AudioTranscriptProcessor): class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile): async def _transcript(self, data: AudioFile):
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
print("transcripting", source_language, target_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
translation = None
if source_language != target_language:
if target_language == "fr":
translation = "Bonjour le monde"
return Transcript( return Transcript(
text="Hello world", text="Hello world",
translation=translation,
words=[ words=[
Word(start=0.0, end=1.0, text="Hello"), Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text="world"), Word(start=1.0, end=2.0, text="world"),
@@ -165,6 +177,147 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
assert "TRANSCRIPT" in eventnames assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")] ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world" assert ev["data"]["text"] == "Hello world"
assert ev["data"]["translation"] is None
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SUMMARY")]
assert ev["data"]["summary"] == "LLM SUMMARY"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# check that transcript status in model is updated
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"
# check that audio is available
resp = await ac.get(f"/transcripts/{tid}/audio")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/wav"
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mp3"
# stop server
server.stop()
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dummy_llm):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
# with target french language
from reflector.settings import settings
from reflector.app import app
settings.DATA_DIR = Path(tmpdir)
# start server
host = "127.0.0.1"
port = 1255
base_url = f"http://{host}:{port}/v1"
config = Config(app=app, host=host, port=port)
server = ThreadedUvicorn(config)
await server.start()
# create a transcript
ac = AsyncClient(base_url=base_url)
response = await ac.post(
"/transcripts", json={"name": "Test RTC", "target_language": "fr"}
)
assert response.status_code == 200
tid = response.json()["id"]
# create a websocket connection as a task
events = []
async def websocket_task():
print("Test websocket: TASK STARTED")
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED")
try:
while True:
msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}")
if msg is None:
break
events.append(msg)
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
ws.close()
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
# create stream client
import argparse
from reflector.stream_client import StreamClient
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
parser = argparse.ArgumentParser()
add_signaling_arguments(parser)
args = parser.parse_args(["-s", "tcp-socket"])
signaling = create_signaling(args)
url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
timeout = 20
while not client.is_ended():
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for RTC to end")
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
# wait the processing to finish
await asyncio.sleep(2)
# stop websocket task
websocket_task.cancel()
# check events
assert len(events) > 0
from pprint import pprint
pprint(events)
# get events list
eventnames = [e["event"] for e in events]
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world"
assert ev["data"]["translation"] == "Bonjour le monde"
assert "TOPIC" in eventnames assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")] ev = events[eventnames.index("TOPIC")]
@@ -186,19 +339,4 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
assert events[-1]["data"]["value"] == "ended" assert events[-1]["data"]["value"] == "ended"
# stop server # stop server
# server.stop() server.stop()
# check that transcript status in model is updated
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"
# check that audio is available
resp = await ac.get(f"/transcripts/{tid}/audio")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/wav"
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mp3"

View File

@@ -0,0 +1,63 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_create_default_translation():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test en"})
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
@pytest.mark.asyncio
async def test_transcript_create_en_fr_translation():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post(
"/transcripts", json={"name": "test en/fr", "target_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
@pytest.mark.asyncio
async def test_transcript_create_fr_en_translation():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post(
"/transcripts", json={"name": "test fr/en", "source_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"