From 49d6e2d1dcf97578bbf99ed2cf27d95a0898ae11 Mon Sep 17 00:00:00 2001 From: Gokul Mohanarangan Date: Mon, 28 Aug 2023 14:25:44 +0530 Subject: [PATCH] return both en and fr in transcriptio --- server/gpu/modal/reflector_transcriber.py | 13 ++++++--- .../processors/audio_transcript_modal.py | 15 +++++++---- .../reflector/processors/transcript_liner.py | 2 +- server/reflector/processors/types.py | 3 ++- server/reflector/views/rtc_offer.py | 27 ++++++++++--------- server/reflector/views/transcripts.py | 11 +++++--- 6 files changed, 45 insertions(+), 26 deletions(-) diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index 55df052b..e1fde227 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -6,6 +6,7 @@ Reflector GPU backend - transcriber import os import tempfile +from fastapi import File from modal import Image, Secret, Stub, asgi_app, method from pydantic import BaseModel @@ -18,7 +19,7 @@ WHISPER_CACHE_DIR: str = "/cache/whisper" # Translation Model TRANSLATION_MODEL = "facebook/m2m100_418M" -stub = Stub(name="reflector-transcriber") +stub = Stub(name="reflector-lang") def download_whisper(): @@ -129,6 +130,8 @@ class Whisper: translation = result[0].strip() multilingual_transcript[target_language] = translation + print(multilingual_transcript) + return { "text": multilingual_transcript, "words": words @@ -149,7 +152,9 @@ class Whisper: ) @asgi_app() def web(): - from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status + from typing import List + + from fastapi import Body, Depends, FastAPI, Form, HTTPException, UploadFile, status from fastapi.security import OAuth2PasswordBearer from typing_extensions import Annotated @@ -174,9 +179,9 @@ def web(): @app.post("/transcribe", dependencies=[Depends(apikey_auth)]) async def transcribe( file: UploadFile, - timestamp: Annotated[float, Form()] = 0, source_language: Annotated[str, Form()] = "en", - target_language: Annotated[str, Form()] = "en" + target_language: Annotated[str, Form()] = "fr", + timestamp: Annotated[float, Form()] = 0.0 ) -> TranscriptResponse: audio_data = await file.read() audio_suffix = file.filename.split(".")[-1] diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 80b6e582..a65dd278 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -58,7 +58,10 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): # Update code here once this is possible. # i.e) extract from context/session objects source_language = "en" - target_language = "en" + + # TODO: target lang is set to "fr" for demo purposes + # Revert back once language selection is implemented + target_language = "fr" languages = TranslationLanguages() # Only way to set the target should be the UI element like dropdown. @@ -74,7 +77,7 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): files=files, timeout=self.timeout, headers=self.headers, - json=json_payload, + data=json_payload, ) self.logger.debug( @@ -84,12 +87,14 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): result = response.json() # Sanity check for translation status in the result + translation = "" if target_language in result["text"]: - text = result["text"][target_language] - else: - text = result["text"][source_language] + translation = result["text"][target_language] + text = result["text"][source_language] + transcript = Transcript( text=text, + translation=translation, words=[ Word( text=word["text"], diff --git a/server/reflector/processors/transcript_liner.py b/server/reflector/processors/transcript_liner.py index cca5e6a2..5e9d6683 100644 --- a/server/reflector/processors/transcript_liner.py +++ b/server/reflector/processors/transcript_liner.py @@ -34,12 +34,12 @@ class TranscriptLinerProcessor(Processor): if "." not in word.text: continue + partial.translation = self.transcript.translation # emit line await self.emit(partial) # create new transcript partial = Transcript(words=[]) - self.transcript = partial async def _flush(self): diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index 1e5c84f2..537de415 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -47,6 +47,7 @@ class Word(BaseModel): class Transcript(BaseModel): text: str = "" + translation: str = "" words: list[Word] = None @property @@ -84,7 +85,7 @@ class Transcript(BaseModel): words = [ Word(text=word.text, start=word.start, end=word.end) for word in self.words ] - return Transcript(text=self.text, words=words) + return Transcript(text=self.text, translation=self.translation, words=words) class TitleSummary(BaseModel): diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index f28eb021..f909cc9c 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -1,25 +1,27 @@ import asyncio -from fastapi import Request, APIRouter -from reflector.events import subscribers_shutdown -from pydantic import BaseModel -from reflector.logger import logger -from aiortc import RTCPeerConnection, RTCSessionDescription, MediaStreamTrack -from json import loads, dumps from enum import StrEnum +from json import dumps, loads from pathlib import Path + import av +from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription +from fastapi import APIRouter, Request +from pydantic import BaseModel + +from reflector.events import subscribers_shutdown +from reflector.logger import logger from reflector.processors import ( - Pipeline, AudioChunkerProcessor, + AudioFileWriterProcessor, AudioMergeProcessor, AudioTranscriptAutoProcessor, - AudioFileWriterProcessor, + FinalSummary, + Pipeline, + TitleSummary, + Transcript, + TranscriptFinalSummaryProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, - TranscriptFinalSummaryProcessor, - Transcript, - TitleSummary, - FinalSummary, ) sessions = [] @@ -108,6 +110,7 @@ async def rtc_offer_base( result = { "cmd": "SHOW_TRANSCRIPTION", "text": transcript.text, + "translation": transcript.translation, } ctx.data_channel.send(dumps(result)) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index c92079a6..b153765a 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -7,7 +7,6 @@ from typing import Annotated, Optional from uuid import uuid4 import av -import reflector.auth as auth from fastapi import ( APIRouter, Depends, @@ -18,11 +17,13 @@ from fastapi import ( ) from fastapi_pagination import Page, paginate from pydantic import BaseModel, Field +from starlette.concurrency import run_in_threadpool + +import reflector.auth as auth from reflector.db import database, transcripts from reflector.logger import logger from reflector.settings import settings from reflector.utils.audio_waveform import get_audio_waveform -from starlette.concurrency import run_in_threadpool from ._range_requests_response import range_requests_response from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base @@ -49,6 +50,7 @@ class AudioWaveform(BaseModel): class TranscriptText(BaseModel): text: str + translation: str class TranscriptTopic(BaseModel): @@ -491,7 +493,10 @@ async def handle_rtc_event(event: PipelineEvent, args, data): # FIXME don't do copy if event == PipelineEvent.TRANSCRIPT: - resp = transcript.add_event(event=event, data=TranscriptText(text=data.text)) + resp = transcript.add_event( + event=event, + data=TranscriptText(text=data.text, translation=data.translation), + ) await transcripts_controller.update( transcript, {