update fixes

This commit is contained in:
Gokul Mohanarangan
2023-08-28 14:32:21 +05:30
parent 49d6e2d1dc
commit ebbe01f282
2 changed files with 16 additions and 16 deletions

View File

@@ -6,7 +6,6 @@ Reflector GPU backend - transcriber
import os import os
import tempfile import tempfile
from fastapi import File
from modal import Image, Secret, Stub, asgi_app, method from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel from pydantic import BaseModel
@@ -152,9 +151,7 @@ class Whisper:
) )
@asgi_app() @asgi_app()
def web(): def web():
from typing import List from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status
from fastapi import Body, Depends, FastAPI, Form, HTTPException, UploadFile, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from typing_extensions import Annotated from typing_extensions import Annotated
@@ -176,23 +173,26 @@ def web():
class TranscriptResponse(BaseModel): class TranscriptResponse(BaseModel):
result: dict result: dict
class TranscriptRequest(BaseModel):
file: UploadFile
timestamp: Annotated[float, Form()] = 0
source_language: Annotated[str, Form()] = "en"
target_language: Annotated[str, Form()] = "en"
@app.post("/transcribe", dependencies=[Depends(apikey_auth)]) @app.post("/transcribe", dependencies=[Depends(apikey_auth)])
async def transcribe( async def transcribe(
file: UploadFile, req: TranscriptRequest
source_language: Annotated[str, Form()] = "en",
target_language: Annotated[str, Form()] = "fr",
timestamp: Annotated[float, Form()] = 0.0
) -> TranscriptResponse: ) -> TranscriptResponse:
audio_data = await file.read() audio_data = await req.file.read()
audio_suffix = file.filename.split(".")[-1] audio_suffix = req.file.filename.split(".")[-1]
assert audio_suffix in supported_audio_file_types assert audio_suffix in supported_audio_file_types
func = transcriberstub.transcribe_segment.spawn( func = transcriberstub.transcribe_segment.spawn(
audio_data=audio_data, audio_data=audio_data,
audio_suffix=audio_suffix, audio_suffix=audio_suffix,
source_language=source_language, source_language=req.source_language,
target_language=target_language, target_language=req.target_language,
timestamp=timestamp timestamp=req.timestamp
) )
result = func.get() result = func.get()
return result return result

View File

@@ -27,19 +27,19 @@ class TranscriptLinerProcessor(Processor):
return return
# cut to the next . # cut to the next .
partial = Transcript(words=[]) partial = Transcript(translation=self.transcript.translation, words=[])
for word in self.transcript.words[:]: for word in self.transcript.words[:]:
partial.text += word.text partial.text += word.text
partial.words.append(word) partial.words.append(word)
if "." not in word.text: if "." not in word.text:
continue continue
partial.translation = self.transcript.translation
# emit line # emit line
await self.emit(partial) await self.emit(partial)
# create new transcript # create new transcript
partial = Transcript(words=[]) partial = Transcript(translation=self.transcript.translation, words=[])
self.transcript = partial self.transcript = partial
async def _flush(self): async def _flush(self):