diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index e1fde227..a56dba7a 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -6,7 +6,6 @@ Reflector GPU backend - transcriber import os import tempfile -from fastapi import File from modal import Image, Secret, Stub, asgi_app, method from pydantic import BaseModel @@ -152,9 +151,7 @@ class Whisper: ) @asgi_app() def web(): - from typing import List - - from fastapi import Body, Depends, FastAPI, Form, HTTPException, UploadFile, status + from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status from fastapi.security import OAuth2PasswordBearer from typing_extensions import Annotated @@ -176,23 +173,26 @@ def web(): class TranscriptResponse(BaseModel): result: dict + class TranscriptRequest(BaseModel): + file: UploadFile + timestamp: Annotated[float, Form()] = 0 + source_language: Annotated[str, Form()] = "en" + target_language: Annotated[str, Form()] = "en" + @app.post("/transcribe", dependencies=[Depends(apikey_auth)]) async def transcribe( - file: UploadFile, - source_language: Annotated[str, Form()] = "en", - target_language: Annotated[str, Form()] = "fr", - timestamp: Annotated[float, Form()] = 0.0 + req: TranscriptRequest ) -> TranscriptResponse: - audio_data = await file.read() - audio_suffix = file.filename.split(".")[-1] + audio_data = await req.file.read() + audio_suffix = req.file.filename.split(".")[-1] assert audio_suffix in supported_audio_file_types func = transcriberstub.transcribe_segment.spawn( audio_data=audio_data, audio_suffix=audio_suffix, - source_language=source_language, - target_language=target_language, - timestamp=timestamp + source_language=req.source_language, + target_language=req.target_language, + timestamp=req.timestamp ) result = func.get() return result diff --git a/server/reflector/processors/transcript_liner.py b/server/reflector/processors/transcript_liner.py index 5e9d6683..c7ec2f64 100644 --- a/server/reflector/processors/transcript_liner.py +++ b/server/reflector/processors/transcript_liner.py @@ -27,19 +27,19 @@ class TranscriptLinerProcessor(Processor): return # cut to the next . - partial = Transcript(words=[]) + partial = Transcript(translation=self.transcript.translation, words=[]) for word in self.transcript.words[:]: partial.text += word.text partial.words.append(word) if "." not in word.text: continue - partial.translation = self.transcript.translation # emit line await self.emit(partial) # create new transcript - partial = Transcript(words=[]) + partial = Transcript(translation=self.transcript.translation, words=[]) + self.transcript = partial async def _flush(self):