mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: pass source and target language from api to pipeline
This commit is contained in:
@@ -0,0 +1,32 @@
|
|||||||
|
"""add source and target language
|
||||||
|
|
||||||
|
Revision ID: b3df9681cae9
|
||||||
|
Revises: 543ed284d69a
|
||||||
|
Create Date: 2023-08-29 10:55:37.690469
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'b3df9681cae9'
|
||||||
|
down_revision: Union[str, None] = '543ed284d69a'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('transcript', sa.Column('source_language', sa.String(), nullable=True))
|
||||||
|
op.add_column('transcript', sa.Column('target_language', sa.String(), nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('transcript', 'target_language')
|
||||||
|
op.drop_column('transcript', 'source_language')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -1,9 +1,8 @@
|
|||||||
import databases
|
import databases
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from reflector.events import subscribers_startup, subscribers_shutdown
|
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
database = databases.Database(settings.DATABASE_URL)
|
database = databases.Database(settings.DATABASE_URL)
|
||||||
metadata = sqlalchemy.MetaData()
|
metadata = sqlalchemy.MetaData()
|
||||||
|
|
||||||
@@ -20,6 +19,8 @@ transcripts = sqlalchemy.Table(
|
|||||||
sqlalchemy.Column("summary", sqlalchemy.String, nullable=True),
|
sqlalchemy.Column("summary", sqlalchemy.String, nullable=True),
|
||||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||||
|
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||||
|
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
||||||
# with user attached, optional
|
# with user attached, optional
|
||||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
from reflector.processors.base import Processor
|
import importlib
|
||||||
|
|
||||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||||
|
from reflector.processors.base import Pipeline, Processor
|
||||||
from reflector.processors.types import AudioFile
|
from reflector.processors.types import AudioFile
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
import importlib
|
|
||||||
|
|
||||||
|
|
||||||
class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||||
@@ -35,6 +36,10 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
|||||||
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
|
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def set_pipeline(self, pipeline: Pipeline):
|
||||||
|
super().set_pipeline(pipeline)
|
||||||
|
self.processor.set_pipeline(pipeline)
|
||||||
|
|
||||||
def connect(self, processor: Processor):
|
def connect(self, processor: Processor):
|
||||||
self.processor.connect(processor)
|
self.processor.connect(processor)
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ API will be a POST request to TRANSCRIPT_URL:
|
|||||||
from time import monotonic
|
from time import monotonic
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||||
from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word
|
from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word
|
||||||
@@ -54,14 +53,10 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
|||||||
"file": (data.name, data.fd),
|
"file": (data.name, data.fd),
|
||||||
}
|
}
|
||||||
|
|
||||||
# TODO: Get the source / target language from the UI preferences dynamically
|
# FIXME this should be a processor after, as each user may want
|
||||||
# Update code here once this is possible.
|
# different languages
|
||||||
# i.e) extract from context/session objects
|
source_language = self.get_pref("audio:source_language", "en")
|
||||||
source_language = "en"
|
target_language = self.get_pref("audio:target_language", "en")
|
||||||
|
|
||||||
# TODO: target lang is set to "fr" for demo purposes
|
|
||||||
# Revert back once language selection is implemented
|
|
||||||
target_language = "fr"
|
|
||||||
languages = TranslationLanguages()
|
languages = TranslationLanguages()
|
||||||
|
|
||||||
# Only way to set the target should be the UI element like dropdown.
|
# Only way to set the target should be the UI element like dropdown.
|
||||||
@@ -87,8 +82,8 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
|||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
# Sanity check for translation status in the result
|
# Sanity check for translation status in the result
|
||||||
translation = ""
|
translation = None
|
||||||
if target_language in result["text"]:
|
if source_language != target_language and target_language in result["text"]:
|
||||||
translation = result["text"][target_language]
|
translation = result["text"][target_language]
|
||||||
text = result["text"][source_language]
|
text = result["text"][source_language]
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
from reflector.logger import logger
|
|
||||||
from uuid import uuid4
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Any
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from reflector.logger import logger
|
||||||
|
|
||||||
|
|
||||||
class Processor:
|
class Processor:
|
||||||
@@ -17,9 +19,11 @@ class Processor:
|
|||||||
self.uid = uuid4().hex
|
self.uid = uuid4().hex
|
||||||
self.flushed = False
|
self.flushed = False
|
||||||
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
||||||
|
self.pipeline = None
|
||||||
|
|
||||||
def set_pipeline(self, pipeline: "Pipeline"):
|
def set_pipeline(self, pipeline: "Pipeline"):
|
||||||
# if pipeline is used, pipeline logger will be used instead
|
# if pipeline is used, pipeline logger will be used instead
|
||||||
|
self.pipeline = pipeline
|
||||||
self.logger = pipeline.logger.bind(processor=self.__class__.__name__)
|
self.logger = pipeline.logger.bind(processor=self.__class__.__name__)
|
||||||
|
|
||||||
def connect(self, processor: "Processor"):
|
def connect(self, processor: "Processor"):
|
||||||
@@ -54,6 +58,14 @@ class Processor:
|
|||||||
"""
|
"""
|
||||||
self._callbacks.remove(callback)
|
self._callbacks.remove(callback)
|
||||||
|
|
||||||
|
def get_pref(self, key: str, default: Any = None):
|
||||||
|
"""
|
||||||
|
Get a preference from the pipeline prefs
|
||||||
|
"""
|
||||||
|
if self.pipeline:
|
||||||
|
return self.pipeline.get_pref(key, default)
|
||||||
|
return default
|
||||||
|
|
||||||
async def emit(self, data):
|
async def emit(self, data):
|
||||||
for callback in self._callbacks:
|
for callback in self._callbacks:
|
||||||
await callback(data)
|
await callback(data)
|
||||||
@@ -191,6 +203,7 @@ class Pipeline(Processor):
|
|||||||
self.logger.info("Pipeline created")
|
self.logger.info("Pipeline created")
|
||||||
|
|
||||||
self.processors = processors
|
self.processors = processors
|
||||||
|
self.prefs = {}
|
||||||
|
|
||||||
for processor in processors:
|
for processor in processors:
|
||||||
processor.set_pipeline(self)
|
processor.set_pipeline(self)
|
||||||
@@ -220,3 +233,17 @@ class Pipeline(Processor):
|
|||||||
for processor in self.processors:
|
for processor in self.processors:
|
||||||
processor.describe(level + 1)
|
processor.describe(level + 1)
|
||||||
logger.info("")
|
logger.info("")
|
||||||
|
|
||||||
|
def set_pref(self, key: str, value: Any):
|
||||||
|
"""
|
||||||
|
Set a preference for this pipeline
|
||||||
|
"""
|
||||||
|
self.prefs[key] = value
|
||||||
|
|
||||||
|
def get_pref(self, key: str, default=None):
|
||||||
|
"""
|
||||||
|
Get a preference for this pipeline
|
||||||
|
"""
|
||||||
|
if key not in self.prefs:
|
||||||
|
self.logger.warning(f"Pref {key} not found, using default")
|
||||||
|
return self.prefs.get(key, default)
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ class Word(BaseModel):
|
|||||||
|
|
||||||
class Transcript(BaseModel):
|
class Transcript(BaseModel):
|
||||||
text: str = ""
|
text: str = ""
|
||||||
translation: str = ""
|
translation: str | None = None
|
||||||
words: list[Word] = None
|
words: list[Word] = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import av
|
|||||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||||
from fastapi import APIRouter, Request
|
from fastapi import APIRouter, Request
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from reflector.events import subscribers_shutdown
|
from reflector.events import subscribers_shutdown
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
@@ -81,6 +80,8 @@ async def rtc_offer_base(
|
|||||||
event_callback=None,
|
event_callback=None,
|
||||||
event_callback_args=None,
|
event_callback_args=None,
|
||||||
audio_filename: Path | None = None,
|
audio_filename: Path | None = None,
|
||||||
|
source_language: str = "en",
|
||||||
|
target_language: str = "en",
|
||||||
):
|
):
|
||||||
# build an rtc session
|
# build an rtc session
|
||||||
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
||||||
@@ -110,7 +111,6 @@ async def rtc_offer_base(
|
|||||||
result = {
|
result = {
|
||||||
"cmd": "SHOW_TRANSCRIPTION",
|
"cmd": "SHOW_TRANSCRIPTION",
|
||||||
"text": transcript.text,
|
"text": transcript.text,
|
||||||
"translation": transcript.translation,
|
|
||||||
}
|
}
|
||||||
ctx.data_channel.send(dumps(result))
|
ctx.data_channel.send(dumps(result))
|
||||||
|
|
||||||
@@ -179,6 +179,8 @@ async def rtc_offer_base(
|
|||||||
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
|
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
|
||||||
]
|
]
|
||||||
ctx.pipeline = Pipeline(*processors)
|
ctx.pipeline = Pipeline(*processors)
|
||||||
|
ctx.pipeline.set_pref("audio:source_language", source_language)
|
||||||
|
ctx.pipeline.set_pref("audio:target_language", target_language)
|
||||||
# FIXME: warmup is not working well yet
|
# FIXME: warmup is not working well yet
|
||||||
# await ctx.pipeline.warmup()
|
# await ctx.pipeline.warmup()
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Annotated, Optional
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import av
|
import av
|
||||||
|
import reflector.auth as auth
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
Depends,
|
Depends,
|
||||||
@@ -17,13 +18,11 @@ from fastapi import (
|
|||||||
)
|
)
|
||||||
from fastapi_pagination import Page, paginate
|
from fastapi_pagination import Page, paginate
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from starlette.concurrency import run_in_threadpool
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
|
||||||
from reflector.db import database, transcripts
|
from reflector.db import database, transcripts
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.utils.audio_waveform import get_audio_waveform
|
from reflector.utils.audio_waveform import get_audio_waveform
|
||||||
|
from starlette.concurrency import run_in_threadpool
|
||||||
|
|
||||||
from ._range_requests_response import range_requests_response
|
from ._range_requests_response import range_requests_response
|
||||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||||
@@ -50,7 +49,7 @@ class AudioWaveform(BaseModel):
|
|||||||
|
|
||||||
class TranscriptText(BaseModel):
|
class TranscriptText(BaseModel):
|
||||||
text: str
|
text: str
|
||||||
translation: str
|
translation: str | None
|
||||||
|
|
||||||
|
|
||||||
class TranscriptTopic(BaseModel):
|
class TranscriptTopic(BaseModel):
|
||||||
@@ -81,6 +80,8 @@ class Transcript(BaseModel):
|
|||||||
summary: str | None = None
|
summary: str | None = None
|
||||||
topics: list[TranscriptTopic] = []
|
topics: list[TranscriptTopic] = []
|
||||||
events: list[TranscriptEvent] = []
|
events: list[TranscriptEvent] = []
|
||||||
|
source_language: str = "en"
|
||||||
|
target_language: str = "en"
|
||||||
|
|
||||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||||
@@ -186,8 +187,19 @@ class TranscriptController:
|
|||||||
return None
|
return None
|
||||||
return Transcript(**result)
|
return Transcript(**result)
|
||||||
|
|
||||||
async def add(self, name: str, user_id: str | None = None):
|
async def add(
|
||||||
transcript = Transcript(name=name, user_id=user_id)
|
self,
|
||||||
|
name: str,
|
||||||
|
source_language: str = "en",
|
||||||
|
target_language: str = "en",
|
||||||
|
user_id: str | None = None,
|
||||||
|
):
|
||||||
|
transcript = Transcript(
|
||||||
|
name=name,
|
||||||
|
source_language=source_language,
|
||||||
|
target_language=target_language,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
query = transcripts.insert().values(**transcript.model_dump())
|
query = transcripts.insert().values(**transcript.model_dump())
|
||||||
await database.execute(query)
|
await database.execute(query)
|
||||||
return transcript
|
return transcript
|
||||||
@@ -231,10 +243,14 @@ class GetTranscript(BaseModel):
|
|||||||
duration: int
|
duration: int
|
||||||
summary: str | None
|
summary: str | None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
|
source_language: str
|
||||||
|
target_language: str
|
||||||
|
|
||||||
|
|
||||||
class CreateTranscript(BaseModel):
|
class CreateTranscript(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
source_language: str = Field("en")
|
||||||
|
target_language: str = Field("en")
|
||||||
|
|
||||||
|
|
||||||
class UpdateTranscript(BaseModel):
|
class UpdateTranscript(BaseModel):
|
||||||
@@ -243,10 +259,6 @@ class UpdateTranscript(BaseModel):
|
|||||||
summary: Optional[str] = Field(None)
|
summary: Optional[str] = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class TranscriptEntryCreate(BaseModel):
|
|
||||||
name: str
|
|
||||||
|
|
||||||
|
|
||||||
class DeletionStatus(BaseModel):
|
class DeletionStatus(BaseModel):
|
||||||
status: str
|
status: str
|
||||||
|
|
||||||
@@ -268,7 +280,12 @@ async def transcripts_create(
|
|||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
return await transcripts_controller.add(info.name, user_id=user_id)
|
return await transcripts_controller.add(
|
||||||
|
info.name,
|
||||||
|
source_language=info.source_language,
|
||||||
|
target_language=info.target_language,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
@@ -573,4 +590,6 @@ async def transcript_record_webrtc(
|
|||||||
event_callback=handle_rtc_event,
|
event_callback=handle_rtc_event,
|
||||||
event_callback_args=transcript_id,
|
event_callback_args=transcript_id,
|
||||||
audio_filename=transcript.audio_filename,
|
audio_filename=transcript.audio_filename,
|
||||||
|
source_language=transcript.source_language,
|
||||||
|
target_language=transcript.target_language,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -39,8 +39,20 @@ async def dummy_transcript():
|
|||||||
|
|
||||||
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
|
class TestAudioTranscriptProcessor(AudioTranscriptProcessor):
|
||||||
async def _transcript(self, data: AudioFile):
|
async def _transcript(self, data: AudioFile):
|
||||||
|
source_language = self.get_pref("audio:source_language", "en")
|
||||||
|
target_language = self.get_pref("audio:target_language", "en")
|
||||||
|
print("transcripting", source_language, target_language)
|
||||||
|
print("pipeline", self.pipeline)
|
||||||
|
print("prefs", self.pipeline.prefs)
|
||||||
|
|
||||||
|
translation = None
|
||||||
|
if source_language != target_language:
|
||||||
|
if target_language == "fr":
|
||||||
|
translation = "Bonjour le monde"
|
||||||
|
|
||||||
return Transcript(
|
return Transcript(
|
||||||
text="Hello world",
|
text="Hello world",
|
||||||
|
translation=translation,
|
||||||
words=[
|
words=[
|
||||||
Word(start=0.0, end=1.0, text="Hello"),
|
Word(start=0.0, end=1.0, text="Hello"),
|
||||||
Word(start=1.0, end=2.0, text="world"),
|
Word(start=1.0, end=2.0, text="world"),
|
||||||
@@ -165,6 +177,147 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
|
|||||||
assert "TRANSCRIPT" in eventnames
|
assert "TRANSCRIPT" in eventnames
|
||||||
ev = events[eventnames.index("TRANSCRIPT")]
|
ev = events[eventnames.index("TRANSCRIPT")]
|
||||||
assert ev["data"]["text"] == "Hello world"
|
assert ev["data"]["text"] == "Hello world"
|
||||||
|
assert ev["data"]["translation"] is None
|
||||||
|
|
||||||
|
assert "TOPIC" in eventnames
|
||||||
|
ev = events[eventnames.index("TOPIC")]
|
||||||
|
assert ev["data"]["id"]
|
||||||
|
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||||
|
assert ev["data"]["transcript"].startswith("Hello world")
|
||||||
|
assert ev["data"]["timestamp"] == 0.0
|
||||||
|
|
||||||
|
assert "FINAL_SUMMARY" in eventnames
|
||||||
|
ev = events[eventnames.index("FINAL_SUMMARY")]
|
||||||
|
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||||
|
|
||||||
|
# check status order
|
||||||
|
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||||
|
assert statuses == ["recording", "processing", "ended"]
|
||||||
|
|
||||||
|
# ensure the last event received is ended
|
||||||
|
assert events[-1]["event"] == "STATUS"
|
||||||
|
assert events[-1]["data"]["value"] == "ended"
|
||||||
|
|
||||||
|
# check that transcript status in model is updated
|
||||||
|
resp = await ac.get(f"/transcripts/{tid}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["status"] == "ended"
|
||||||
|
|
||||||
|
# check that audio is available
|
||||||
|
resp = await ac.get(f"/transcripts/{tid}/audio")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.headers["Content-Type"] == "audio/wav"
|
||||||
|
|
||||||
|
# check that audio/mp3 is available
|
||||||
|
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.headers["Content-Type"] == "audio/mp3"
|
||||||
|
|
||||||
|
# stop server
|
||||||
|
server.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dummy_llm):
|
||||||
|
# goal: start the server, exchange RTC, receive websocket events
|
||||||
|
# because of that, we need to start the server in a thread
|
||||||
|
# to be able to connect with aiortc
|
||||||
|
# with target french language
|
||||||
|
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
settings.DATA_DIR = Path(tmpdir)
|
||||||
|
|
||||||
|
# start server
|
||||||
|
host = "127.0.0.1"
|
||||||
|
port = 1255
|
||||||
|
base_url = f"http://{host}:{port}/v1"
|
||||||
|
config = Config(app=app, host=host, port=port)
|
||||||
|
server = ThreadedUvicorn(config)
|
||||||
|
await server.start()
|
||||||
|
|
||||||
|
# create a transcript
|
||||||
|
ac = AsyncClient(base_url=base_url)
|
||||||
|
response = await ac.post(
|
||||||
|
"/transcripts", json={"name": "Test RTC", "target_language": "fr"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
tid = response.json()["id"]
|
||||||
|
|
||||||
|
# create a websocket connection as a task
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def websocket_task():
|
||||||
|
print("Test websocket: TASK STARTED")
|
||||||
|
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
|
||||||
|
print("Test websocket: CONNECTED")
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
msg = await ws.receive_json()
|
||||||
|
print(f"Test websocket: JSON {msg}")
|
||||||
|
if msg is None:
|
||||||
|
break
|
||||||
|
events.append(msg)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Test websocket: EXCEPTION {e}")
|
||||||
|
finally:
|
||||||
|
ws.close()
|
||||||
|
print("Test websocket: DISCONNECTED")
|
||||||
|
|
||||||
|
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||||
|
|
||||||
|
# create stream client
|
||||||
|
import argparse
|
||||||
|
from reflector.stream_client import StreamClient
|
||||||
|
from aiortc.contrib.signaling import add_signaling_arguments, create_signaling
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
add_signaling_arguments(parser)
|
||||||
|
args = parser.parse_args(["-s", "tcp-socket"])
|
||||||
|
signaling = create_signaling(args)
|
||||||
|
|
||||||
|
url = f"{base_url}/transcripts/{tid}/record/webrtc"
|
||||||
|
path = Path(__file__).parent / "records" / "test_short.wav"
|
||||||
|
client = StreamClient(signaling, url=url, play_from=path.as_posix())
|
||||||
|
await client.start()
|
||||||
|
|
||||||
|
timeout = 20
|
||||||
|
while not client.is_ended():
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
timeout -= 1
|
||||||
|
if timeout < 0:
|
||||||
|
raise TimeoutError("Timeout while waiting for RTC to end")
|
||||||
|
|
||||||
|
# XXX aiortc is long to close the connection
|
||||||
|
# instead of waiting a long time, we just send a STOP
|
||||||
|
client.channel.send(json.dumps({"cmd": "STOP"}))
|
||||||
|
|
||||||
|
# wait the processing to finish
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
await client.stop()
|
||||||
|
|
||||||
|
# wait the processing to finish
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
# stop websocket task
|
||||||
|
websocket_task.cancel()
|
||||||
|
|
||||||
|
# check events
|
||||||
|
assert len(events) > 0
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
pprint(events)
|
||||||
|
|
||||||
|
# get events list
|
||||||
|
eventnames = [e["event"] for e in events]
|
||||||
|
|
||||||
|
# check events
|
||||||
|
assert "TRANSCRIPT" in eventnames
|
||||||
|
ev = events[eventnames.index("TRANSCRIPT")]
|
||||||
|
assert ev["data"]["text"] == "Hello world"
|
||||||
|
assert ev["data"]["translation"] == "Bonjour le monde"
|
||||||
|
|
||||||
assert "TOPIC" in eventnames
|
assert "TOPIC" in eventnames
|
||||||
ev = events[eventnames.index("TOPIC")]
|
ev = events[eventnames.index("TOPIC")]
|
||||||
@@ -186,19 +339,4 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
|
|||||||
assert events[-1]["data"]["value"] == "ended"
|
assert events[-1]["data"]["value"] == "ended"
|
||||||
|
|
||||||
# stop server
|
# stop server
|
||||||
# server.stop()
|
server.stop()
|
||||||
|
|
||||||
# check that transcript status in model is updated
|
|
||||||
resp = await ac.get(f"/transcripts/{tid}")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.json()["status"] == "ended"
|
|
||||||
|
|
||||||
# check that audio is available
|
|
||||||
resp = await ac.get(f"/transcripts/{tid}/audio")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.headers["Content-Type"] == "audio/wav"
|
|
||||||
|
|
||||||
# check that audio/mp3 is available
|
|
||||||
resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
|
|
||||||
assert resp.status_code == 200
|
|
||||||
assert resp.headers["Content-Type"] == "audio/mp3"
|
|
||||||
|
|||||||
63
server/tests/test_transcripts_translation.py
Normal file
63
server/tests/test_transcripts_translation.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_create_default_translation():
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||||
|
response = await ac.post("/transcripts", json={"name": "test en"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test en"
|
||||||
|
assert response.json()["source_language"] == "en"
|
||||||
|
assert response.json()["target_language"] == "en"
|
||||||
|
tid = response.json()["id"]
|
||||||
|
|
||||||
|
response = await ac.get(f"/transcripts/{tid}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test en"
|
||||||
|
assert response.json()["source_language"] == "en"
|
||||||
|
assert response.json()["target_language"] == "en"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_create_en_fr_translation():
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||||
|
response = await ac.post(
|
||||||
|
"/transcripts", json={"name": "test en/fr", "target_language": "fr"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test en/fr"
|
||||||
|
assert response.json()["source_language"] == "en"
|
||||||
|
assert response.json()["target_language"] == "fr"
|
||||||
|
tid = response.json()["id"]
|
||||||
|
|
||||||
|
response = await ac.get(f"/transcripts/{tid}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test en/fr"
|
||||||
|
assert response.json()["source_language"] == "en"
|
||||||
|
assert response.json()["target_language"] == "fr"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_transcript_create_fr_en_translation():
|
||||||
|
from reflector.app import app
|
||||||
|
|
||||||
|
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||||
|
response = await ac.post(
|
||||||
|
"/transcripts", json={"name": "test fr/en", "source_language": "fr"}
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test fr/en"
|
||||||
|
assert response.json()["source_language"] == "fr"
|
||||||
|
assert response.json()["target_language"] == "en"
|
||||||
|
tid = response.json()["id"]
|
||||||
|
|
||||||
|
response = await ac.get(f"/transcripts/{tid}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["name"] == "test fr/en"
|
||||||
|
assert response.json()["source_language"] == "fr"
|
||||||
|
assert response.json()["target_language"] == "en"
|
||||||
Reference in New Issue
Block a user