mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-22 05:09:05 +00:00
server: pass source and target language from api to pipeline
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
import databases
|
||||
import sqlalchemy
|
||||
from reflector.events import subscribers_startup, subscribers_shutdown
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
database = databases.Database(settings.DATABASE_URL)
|
||||
metadata = sqlalchemy.MetaData()
|
||||
|
||||
@@ -20,6 +19,8 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Column("summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
||||
# with user attached, optional
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||
)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from reflector.processors.base import Processor
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.base import Pipeline, Processor
|
||||
from reflector.processors.types import AudioFile
|
||||
from reflector.settings import settings
|
||||
import importlib
|
||||
|
||||
|
||||
class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||
@@ -35,6 +36,10 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_pipeline(self, pipeline: Pipeline):
|
||||
super().set_pipeline(pipeline)
|
||||
self.processor.set_pipeline(pipeline)
|
||||
|
||||
def connect(self, processor: Processor):
|
||||
self.processor.connect(processor)
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ API will be a POST request to TRANSCRIPT_URL:
|
||||
from time import monotonic
|
||||
|
||||
import httpx
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProcessor
|
||||
from reflector.processors.types import AudioFile, Transcript, TranslationLanguages, Word
|
||||
@@ -54,14 +53,10 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
"file": (data.name, data.fd),
|
||||
}
|
||||
|
||||
# TODO: Get the source / target language from the UI preferences dynamically
|
||||
# Update code here once this is possible.
|
||||
# i.e) extract from context/session objects
|
||||
source_language = "en"
|
||||
|
||||
# TODO: target lang is set to "fr" for demo purposes
|
||||
# Revert back once language selection is implemented
|
||||
target_language = "fr"
|
||||
# FIXME this should be a processor after, as each user may want
|
||||
# different languages
|
||||
source_language = self.get_pref("audio:source_language", "en")
|
||||
target_language = self.get_pref("audio:target_language", "en")
|
||||
languages = TranslationLanguages()
|
||||
|
||||
# Only way to set the target should be the UI element like dropdown.
|
||||
@@ -87,8 +82,8 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
|
||||
result = response.json()
|
||||
|
||||
# Sanity check for translation status in the result
|
||||
translation = ""
|
||||
if target_language in result["text"]:
|
||||
translation = None
|
||||
if source_language != target_language and target_language in result["text"]:
|
||||
translation = result["text"][target_language]
|
||||
text = result["text"][source_language]
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from reflector.logger import logger
|
||||
from uuid import uuid4
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
class Processor:
|
||||
@@ -17,9 +19,11 @@ class Processor:
|
||||
self.uid = uuid4().hex
|
||||
self.flushed = False
|
||||
self.logger = (custom_logger or logger).bind(processor=self.__class__.__name__)
|
||||
self.pipeline = None
|
||||
|
||||
def set_pipeline(self, pipeline: "Pipeline"):
|
||||
# if pipeline is used, pipeline logger will be used instead
|
||||
self.pipeline = pipeline
|
||||
self.logger = pipeline.logger.bind(processor=self.__class__.__name__)
|
||||
|
||||
def connect(self, processor: "Processor"):
|
||||
@@ -54,6 +58,14 @@ class Processor:
|
||||
"""
|
||||
self._callbacks.remove(callback)
|
||||
|
||||
def get_pref(self, key: str, default: Any = None):
|
||||
"""
|
||||
Get a preference from the pipeline prefs
|
||||
"""
|
||||
if self.pipeline:
|
||||
return self.pipeline.get_pref(key, default)
|
||||
return default
|
||||
|
||||
async def emit(self, data):
|
||||
for callback in self._callbacks:
|
||||
await callback(data)
|
||||
@@ -191,6 +203,7 @@ class Pipeline(Processor):
|
||||
self.logger.info("Pipeline created")
|
||||
|
||||
self.processors = processors
|
||||
self.prefs = {}
|
||||
|
||||
for processor in processors:
|
||||
processor.set_pipeline(self)
|
||||
@@ -220,3 +233,17 @@ class Pipeline(Processor):
|
||||
for processor in self.processors:
|
||||
processor.describe(level + 1)
|
||||
logger.info("")
|
||||
|
||||
def set_pref(self, key: str, value: Any):
|
||||
"""
|
||||
Set a preference for this pipeline
|
||||
"""
|
||||
self.prefs[key] = value
|
||||
|
||||
def get_pref(self, key: str, default=None):
|
||||
"""
|
||||
Get a preference for this pipeline
|
||||
"""
|
||||
if key not in self.prefs:
|
||||
self.logger.warning(f"Pref {key} not found, using default")
|
||||
return self.prefs.get(key, default)
|
||||
|
||||
@@ -47,7 +47,7 @@ class Word(BaseModel):
|
||||
|
||||
class Transcript(BaseModel):
|
||||
text: str = ""
|
||||
translation: str = ""
|
||||
translation: str | None = None
|
||||
words: list[Word] = None
|
||||
|
||||
@property
|
||||
|
||||
@@ -7,7 +7,6 @@ import av
|
||||
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||
from fastapi import APIRouter, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.events import subscribers_shutdown
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
@@ -81,6 +80,8 @@ async def rtc_offer_base(
|
||||
event_callback=None,
|
||||
event_callback_args=None,
|
||||
audio_filename: Path | None = None,
|
||||
source_language: str = "en",
|
||||
target_language: str = "en",
|
||||
):
|
||||
# build an rtc session
|
||||
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
|
||||
@@ -110,7 +111,6 @@ async def rtc_offer_base(
|
||||
result = {
|
||||
"cmd": "SHOW_TRANSCRIPTION",
|
||||
"text": transcript.text,
|
||||
"translation": transcript.translation,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
@@ -179,6 +179,8 @@ async def rtc_offer_base(
|
||||
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
|
||||
]
|
||||
ctx.pipeline = Pipeline(*processors)
|
||||
ctx.pipeline.set_pref("audio:source_language", source_language)
|
||||
ctx.pipeline.set_pref("audio:target_language", target_language)
|
||||
# FIXME: warmup is not working well yet
|
||||
# await ctx.pipeline.warmup()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Annotated, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import av
|
||||
import reflector.auth as auth
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
@@ -17,13 +18,11 @@ from fastapi import (
|
||||
)
|
||||
from fastapi_pagination import Page, paginate
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||
@@ -50,7 +49,7 @@ class AudioWaveform(BaseModel):
|
||||
|
||||
class TranscriptText(BaseModel):
|
||||
text: str
|
||||
translation: str
|
||||
translation: str | None
|
||||
|
||||
|
||||
class TranscriptTopic(BaseModel):
|
||||
@@ -81,6 +80,8 @@ class Transcript(BaseModel):
|
||||
summary: str | None = None
|
||||
topics: list[TranscriptTopic] = []
|
||||
events: list[TranscriptEvent] = []
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
@@ -186,8 +187,19 @@ class TranscriptController:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
async def add(self, name: str, user_id: str | None = None):
|
||||
transcript = Transcript(name=name, user_id=user_id)
|
||||
async def add(
|
||||
self,
|
||||
name: str,
|
||||
source_language: str = "en",
|
||||
target_language: str = "en",
|
||||
user_id: str | None = None,
|
||||
):
|
||||
transcript = Transcript(
|
||||
name=name,
|
||||
source_language=source_language,
|
||||
target_language=target_language,
|
||||
user_id=user_id,
|
||||
)
|
||||
query = transcripts.insert().values(**transcript.model_dump())
|
||||
await database.execute(query)
|
||||
return transcript
|
||||
@@ -231,10 +243,14 @@ class GetTranscript(BaseModel):
|
||||
duration: int
|
||||
summary: str | None
|
||||
created_at: datetime
|
||||
source_language: str
|
||||
target_language: str
|
||||
|
||||
|
||||
class CreateTranscript(BaseModel):
|
||||
name: str
|
||||
source_language: str = Field("en")
|
||||
target_language: str = Field("en")
|
||||
|
||||
|
||||
class UpdateTranscript(BaseModel):
|
||||
@@ -243,10 +259,6 @@ class UpdateTranscript(BaseModel):
|
||||
summary: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class TranscriptEntryCreate(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class DeletionStatus(BaseModel):
|
||||
status: str
|
||||
|
||||
@@ -268,7 +280,12 @@ async def transcripts_create(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
return await transcripts_controller.add(info.name, user_id=user_id)
|
||||
return await transcripts_controller.add(
|
||||
info.name,
|
||||
source_language=info.source_language,
|
||||
target_language=info.target_language,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
|
||||
# ==============================================================
|
||||
@@ -573,4 +590,6 @@ async def transcript_record_webrtc(
|
||||
event_callback=handle_rtc_event,
|
||||
event_callback_args=transcript_id,
|
||||
audio_filename=transcript.audio_filename,
|
||||
source_language=transcript.source_language,
|
||||
target_language=transcript.target_language,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user