return both en and fr in transcriptio

This commit is contained in:
Gokul Mohanarangan
2023-08-28 14:25:44 +05:30
parent 3878c98357
commit 49d6e2d1dc
6 changed files with 45 additions and 26 deletions

View File

@@ -6,6 +6,7 @@ Reflector GPU backend - transcriber
import os
import tempfile
from fastapi import File
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
@@ -18,7 +19,7 @@ WHISPER_CACHE_DIR: str = "/cache/whisper"
# Translation Model
TRANSLATION_MODEL = "facebook/m2m100_418M"
stub = Stub(name="reflector-transcriber")
stub = Stub(name="reflector-lang")
def download_whisper():
@@ -129,6 +130,8 @@ class Whisper:
translation = result[0].strip()
multilingual_transcript[target_language] = translation
print(multilingual_transcript)
return {
"text": multilingual_transcript,
"words": words
@@ -149,7 +152,9 @@ class Whisper:
)
@asgi_app()
def web():
from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status
from typing import List
from fastapi import Body, Depends, FastAPI, Form, HTTPException, UploadFile, status
from fastapi.security import OAuth2PasswordBearer
from typing_extensions import Annotated
@@ -174,9 +179,9 @@ def web():
@app.post("/transcribe", dependencies=[Depends(apikey_auth)])
async def transcribe(
file: UploadFile,
timestamp: Annotated[float, Form()] = 0,
source_language: Annotated[str, Form()] = "en",
target_language: Annotated[str, Form()] = "en"
target_language: Annotated[str, Form()] = "fr",
timestamp: Annotated[float, Form()] = 0.0
) -> TranscriptResponse:
audio_data = await file.read()
audio_suffix = file.filename.split(".")[-1]

View File

@@ -58,7 +58,10 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
# Update code here once this is possible.
# i.e) extract from context/session objects
source_language = "en"
target_language = "en"
# TODO: target lang is set to "fr" for demo purposes
# Revert back once language selection is implemented
target_language = "fr"
languages = TranslationLanguages()
# Only way to set the target should be the UI element like dropdown.
@@ -74,7 +77,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
files=files,
timeout=self.timeout,
headers=self.headers,
json=json_payload,
data=json_payload,
)
self.logger.debug(
@@ -84,12 +87,14 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
result = response.json()
# Sanity check for translation status in the result
translation = ""
if target_language in result["text"]:
text = result["text"][target_language]
else:
translation = result["text"][target_language]
text = result["text"][source_language]
transcript = Transcript(
text=text,
translation=translation,
words=[
Word(
text=word["text"],

View File

@@ -34,12 +34,12 @@ class TranscriptLinerProcessor(Processor):
if "." not in word.text:
continue
partial.translation = self.transcript.translation
# emit line
await self.emit(partial)
# create new transcript
partial = Transcript(words=[])
self.transcript = partial
async def _flush(self):

View File

@@ -47,6 +47,7 @@ class Word(BaseModel):
class Transcript(BaseModel):
text: str = ""
translation: str = ""
words: list[Word] = None
@property
@@ -84,7 +85,7 @@ class Transcript(BaseModel):
words = [
Word(text=word.text, start=word.start, end=word.end) for word in self.words
]
return Transcript(text=self.text, words=words)
return Transcript(text=self.text, translation=self.translation, words=words)
class TitleSummary(BaseModel):

View File

@@ -1,25 +1,27 @@
import asyncio
from fastapi import Request, APIRouter
from reflector.events import subscribers_shutdown
from pydantic import BaseModel
from reflector.logger import logger
from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack
from json import loads, dumps
from enum import StrEnum
from json import dumps, loads
from pathlib import Path
import av
from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
from fastapi import APIRouter, Request
from pydantic import BaseModel
from reflector.events import subscribers_shutdown
from reflector.logger import logger
from reflector.processors import (
Pipeline,
AudioChunkerProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
AudioFileWriterProcessor,
FinalSummary,
Pipeline,
TitleSummary,
Transcript,
TranscriptFinalSummaryProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptFinalSummaryProcessor,
Transcript,
TitleSummary,
FinalSummary,
)
sessions = []
@@ -108,6 +110,7 @@ async def rtc_offer_base(
result = {
"cmd": "SHOW_TRANSCRIPTION",
"text": transcript.text,
"translation": transcript.translation,
}
ctx.data_channel.send(dumps(result))

View File

@@ -7,7 +7,6 @@ from typing import Annotated, Optional
from uuid import uuid4
import av
import reflector.auth as auth
from fastapi import (
APIRouter,
Depends,
@@ -18,11 +17,13 @@ from fastapi import (
)
from fastapi_pagination import Page, paginate
from pydantic import BaseModel, Field
from starlette.concurrency import run_in_threadpool
import reflector.auth as auth
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
@@ -49,6 +50,7 @@ class AudioWaveform(BaseModel):
class TranscriptText(BaseModel):
text: str
translation: str
class TranscriptTopic(BaseModel):
@@ -491,7 +493,10 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
# FIXME don't do copy
if event == PipelineEvent.TRANSCRIPT:
resp = transcript.add_event(event=event, data=TranscriptText(text=data.text))
resp = transcript.add_event(
event=event,
data=TranscriptText(text=data.text, translation=data.translation),
)
await transcripts_controller.update(
transcript,
{