server: pass source and target language from api to pipeline

This commit is contained in:
2023-08-29 11:16:23 +02:00
parent cce8a9137a
commit 68dce235ec
10 changed files with 330 additions and 48 deletions

View File

@@ -7,7 +7,6 @@ import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from fastapi import APIRouter, Request
from pydantic import BaseModel
from reflector.events import subscribers_shutdown
from reflector.logger import logger
from reflector.processors import (
@@ -81,6 +80,8 @@ async def rtc_offer_base(
event_callback=None,
event_callback_args=None,
audio_filename: Path | None = None,
source_language: str = "en",
target_language: str = "en",
):
# build an rtc session
offer = RTCSessionDescription(sdp=params.sdp, type=params.type)
@@ -110,7 +111,6 @@ async def rtc_offer_base(
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
"translation": transcript.translation,
}
ctx.data_channel.send(dumps(result))
@@ -179,6 +179,8 @@ async def rtc_offer_base(
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
]
ctx.pipeline = Pipeline(*processors)
ctx.pipeline.set_pref("audio:source_language", source_language)
ctx.pipeline.set_pref("audio:target_language", target_language)
# FIXME: warmup is not working well yet
# await ctx.pipeline.warmup()

View File

@@ -7,6 +7,7 @@ from typing import Annotated, Optional
from uuid import uuid4
import av
import reflector.auth as auth
from fastapi import (
APIRouter,
Depends,
@@ -17,13 +18,11 @@ from fastapi import (
)
from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
from starlette.concurrency import run_in_threadpool
import reflector.auth as auth
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
@@ -50,7 +49,7 @@ class AudioWaveform(BaseModel):
class TranscriptText(BaseModel):
text: str
translation: str
translation: str | None
class TranscriptTopic(BaseModel):
@@ -81,6 +80,8 @@ class Transcript(BaseModel):
summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
source_language: str = "en"
target_language: str = "en"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -186,8 +187,19 @@ class TranscriptController:
return None
return Transcript(**result)
async def add(self, name: str, user_id: str | None = None):
transcript = Transcript(name=name, user_id=user_id)
async def add(
self,
name: str,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
):
transcript = Transcript(
name=name,
source_language=source_language,
target_language=target_language,
user_id=user_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
return transcript
@@ -231,10 +243,14 @@ class GetTranscript(BaseModel):
duration: int
summary: str | None
created_at: datetime
source_language: str
target_language: str
class CreateTranscript(BaseModel):
name: str
source_language: str = Field("en")
target_language: str = Field("en")
class UpdateTranscript(BaseModel):
@@ -243,10 +259,6 @@ class UpdateTranscript(BaseModel):
summary: Optional[str] = Field(None)
class TranscriptEntryCreate(BaseModel):
name: str
class DeletionStatus(BaseModel):
status: str
@@ -268,7 +280,12 @@ async def transcripts_create(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
return await transcripts_controller.add(info.name, user_id=user_id)
return await transcripts_controller.add(
info.name,
source_language=info.source_language,
target_language=info.target_language,
user_id=user_id,
)
# ==============================================================
@@ -573,4 +590,6 @@ async def transcript_record_webrtc(
event_callback=handle_rtc_event,
event_callback_args=transcript_id,
audio_filename=transcript.audio_filename,
source_language=transcript.source_language,
target_language=transcript.target_language,
)