server: pass source and target language from api to pipeline

This commit is contained in:
2023-08-29 11:16:23 +02:00
parent cce8a9137a
commit 68dce235ec
10 changed files with 330 additions and 48 deletions

View File

@@ -0,0 +1,32 @@
"""add source and target language
Revision ID: b3df9681cae9
Revises: 543ed284d69a
Create Date: 2023-08-29 10:55:37.690469
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'b3df9681cae9'
down_revision: Union[str, None] = '543ed284d69a'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('transcript', sa.Column('source_language', sa.String(), nullable=True))
op.add_column('transcript', sa.Column('target_language', sa.String(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('transcript', 'target_language')
op.drop_column('transcript', 'source_language')
# ### end Alembic commands ###

View File

@@ -1,9 +1,8 @@
import databases
import sqlalchemy
from reflector.events import subscribers_startup, subscribers_shutdown
from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings
database = databases.Database(settings.DATABASE_URL)
metadata = sqlalchemy.MetaData()
@@ -20,6 +19,8 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
)

View File

@@ -1,8 +1,9 @@
from reflector.processors.base import Processor
import importlib
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.base import Pipeline, Processor
from reflector.processors.types import AudioFile
from reflector.settings import settings
import importlib
class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
@@ -35,6 +36,10 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
super().__init__(**kwargs)
def set_pipeline(self, pipeline: Pipeline):
super().set_pipeline(pipeline)
self.processor.set_pipeline(pipeline)
def connect(self, processor: Processor):
self.processor.connect(processor)

View File

@@ -15,7 +15,6 @@ API will be a POST request to TRANSCRIPT_URL:
from time import monotonic
import httpx
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word
@@ -54,14 +53,10 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
"file": (data.name, data.fd),
}
# TODO: Get the source / target language from the UI preferences dynamically
# Update code here once this is possible.
# i.e) extract from context/session objects
source_language = "en"
# TODO: target lang is set to "fr" for demo purposes
# Revert back once language selection is implemented
target_language = "fr"
# FIXME this should be a processor after, as each user may want
# different languages
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
languages = TranslationLanguages()
# Only way to set the target should be the UI element like dropdown.
@@ -87,8 +82,8 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
result = response.json()
# Sanity check for translation status in the result
translation = ""
if target_language in result["text"]:
translation = None
if source_language != target_language and target_language in result["text"]:
translation = result["text"][target_language]
text = result["text"][source_language]

View File

@@ -1,7 +1,9 @@
from reflector.logger import logger
from uuid import uuid4
from concurrent.futures import ThreadPoolExecutor
import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from uuid import uuid4
from reflector.logger import logger
class Processor:
@@ -17,9 +19,11 @@ class Processor:
self.uid = uuid4().hex
self.flushed = False
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
self.pipeline = None
def set_pipeline(self, pipeline: "Pipeline"):
# if pipeline is used, pipeline logger will be used instead
self.pipeline = pipeline
self.logger = pipeline.logger.bind(processor=self.__class__.__name__)
def connect(self, processor: "Processor"):
@@ -54,6 +58,14 @@ class Processor:
"""
self._callbacks.remove(callback)
def get_pref(self, key: str, default: Any = None):
"""
Get a preference from the pipeline prefs
"""
if self.pipeline:
return self.pipeline.get_pref(key, default)
return default
async def emit(self, data):
for callback in self._callbacks:
await callback(data)
@@ -191,6 +203,7 @@ class Pipeline(Processor):
self.logger.info("Pipeline created")
self.processors = processors
self.prefs = {}
for processor in processors:
processor.set_pipeline(self)
@@ -220,3 +233,17 @@ class Pipeline(Processor):
for processor in self.processors:
processor.describe(level + 1)
logger.info("")
def set_pref(self, key: str, value: Any):
"""
Set a preference for this pipeline
"""
self.prefs[key] = value
def get_pref(self, key: str, default=None):
"""
Get a preference for this pipeline
"""
if key not in self.prefs:
self.logger.warning(f"Pref {key} not found, using default")
return self.prefs.get(key, default)

View File

@@ -47,7 +47,7 @@ class Word(BaseModel):
class Transcript(BaseModel):
text: str = ""
translation: str = ""
translation: str | None = None
words: list[Word] = None
@property

View File

@@ -7,7 +7,6 @@ import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from fastapi import APIRouter, Request
from pydantic import BaseModel
from reflector.events import subscribers_shutdown
from reflector.logger import logger
from reflector.processors import (
@@ -81,6 +80,8 @@ async def rtc_offer_base(
event_callback=None,
event_callback_args=None,
audio_filename: Path | None = None,
source_language: str = "en",
target_language: str = "en",
):
# build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -110,7 +111,6 @@ async def rtc_offer_base(
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
"translation": transcript.translation,
}
ctx.data_channel.send(dumps(result))
@@ -179,6 +179,8 @@ async def rtc_offer_base(
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
]
ctx.pipeline = Pipeline(*processors)
ctx.pipeline.set_pref("audio:source_language", source_language)
ctx.pipeline.set_pref("audio:target_language", target_language)
# FIXME: warmup is not working well yet
# await ctx.pipeline.warmup()

View File

@@ -7,6 +7,7 @@ from typing import Annotated, Optional
from uuid import uuid4
import av
import reflector.auth as auth
from fastapi import (
APIRouter,
Depends,
@@ -17,13 +18,11 @@ from fastapi import (
)
from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
from starlette.concurrency import run_in_threadpool
import reflector.auth as auth
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
@@ -50,7 +49,7 @@ class AudioWaveform(BaseModel):
class TranscriptText(BaseModel):
text: str
translation: str
translation: str | None
class TranscriptTopic(BaseModel):
@@ -81,6 +80,8 @@ class Transcript(BaseModel):
summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -186,8 +187,19 @@ class TranscriptController:
return None
return Transcript(**result)
async def add(self, name: str, user_id: str | None = None):
transcript = Transcript(name=name, user_id=user_id)
async def add(
self,
name: str,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
):
transcript = Transcript(
name=name,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
return transcript
@@ -231,10 +243,14 @@ class GetTranscript(BaseModel):
duration: int
summary: str | None
created_at: datetime
source_language: str
target_language: str
class CreateTranscript(BaseModel):
name: str
source_language: str = Field("en")
target_language: str = Field("en")
class UpdateTranscript(BaseModel):
@@ -243,10 +259,6 @@ class UpdateTranscript(BaseModel):
summary: Optional[str] = Field(None)
class TranscriptEntryCreate(BaseModel):
name: str
class DeletionStatus(BaseModel):
status: str
@@ -268,7 +280,12 @@ async def transcripts_create(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
return await transcripts_controller.add(info.name, user_id=user_id)
return await transcripts_controller.add(
info.name,
source_language=info.source_language,
target_language=info.target_language,
user_id=user_id,
)
# ==============================================================
@@ -573,4 +590,6 @@ async def transcript_record_webrtc(
event_callback=handle_rtc_event,
event_callback_args=transcript_id,
audio_filename=transcript.audio_filename,
source_language=transcript.source_language,
target_language=transcript.target_language,
)

View File

@@ -39,8 +39,20 @@ async def dummy_transcript():
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
async def _transcript(self, data: AudioFile):
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
print("transcripting", source_language, target_language)
print("pipeline", self.pipeline)
print("prefs", self.pipeline.prefs)
translation = None
if source_language != target_language:
if target_language == "fr":
translation = "Bonjour le monde"
return Transcript(
text="Hello world",
translation=translation,
words=[
Word(start=0.0, end=1.0, text="Hello"),
Word(start=1.0, end=2.0, text="world"),
@@ -165,6 +177,147 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world"
assert ev["data"]["translation"] is None
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
assert ev["data"]["id"]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SUMMARY")]
assert ev["data"]["summary"] == "LLM SUMMARY"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses == ["recording", "processing", "ended"]
# ensure the last event received is ended
assert events[-1]["event"] == "STATUS"
assert events[-1]["data"]["value"] == "ended"
# check that transcript status in model is updated
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"
# check that audio is available
resp = await ac.get(f"/transcripts/{tid}/audio")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/wav"
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mp3"
# stop server
server.stop()
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dummy_llm):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
# with target french language
from reflector.settings import settings
from reflector.app import app
settings.DATA_DIR = Path(tmpdir)
# start server
host = "127.0.0.1"
port = 1255
base_url = f"http://{host}:{port}/v1"
config = Config(app=app, host=host, port=port)
server = ThreadedUvicorn(config)
await server.start()
# create a transcript
ac = AsyncClient(base_url=base_url)
response = await ac.post(
"/transcripts", json={"name": "Test RTC", "target_language": "fr"}
)
assert response.status_code == 200
tid = response.json()["id"]
# create a websocket connection as a task
events = []
async def websocket_task():
print("Test websocket: TASK STARTED")
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED")
try:
while True:
msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}")
if msg is None:
break
events.append(msg)
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
ws.close()
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
# create stream client
import argparse
from reflector.stream_client import StreamClient
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
parser = argparse.ArgumentParser()
add_signaling_arguments(parser)
args = parser.parse_args(["-s", "tcp-socket"])
signaling = create_signaling(args)
url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
timeout = 20
while not client.is_ended():
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for RTC to end")
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
# wait the processing to finish
await asyncio.sleep(2)
# stop websocket task
websocket_task.cancel()
# check events
assert len(events) > 0
from pprint import pprint
pprint(events)
# get events list
eventnames = [e["event"] for e in events]
# check events
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"] == "Hello world"
assert ev["data"]["translation"] == "Bonjour le monde"
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
@@ -186,19 +339,4 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
assert events[-1]["data"]["value"] == "ended"
# stop server
# server.stop()
# check that transcript status in model is updated
resp = await ac.get(f"/transcripts/{tid}")
assert resp.status_code == 200
assert resp.json()["status"] == "ended"
# check that audio is available
resp = await ac.get(f"/transcripts/{tid}/audio")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/wav"
# check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mp3"
server.stop()

View File

@@ -0,0 +1,63 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_create_default_translation():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test en"})
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
@pytest.mark.asyncio
async def test_transcript_create_en_fr_translation():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post(
"/transcripts", json={"name": "test en/fr", "target_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
@pytest.mark.asyncio
async def test_transcript_create_fr_en_translation():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post(
"/transcripts", json={"name": "test fr/en", "source_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"