From 33d82bc9af8540fbab5555caf79bd56d3f8f74b2 Mon Sep 17 00:00:00 2001 From: Koper Date: Fri, 13 Oct 2023 10:03:13 +0100 Subject: [PATCH 1/9] Force MP3 download --- server/reflector/views/transcripts.py | 11 +++++++---- www/app/transcripts/recorder.tsx | 3 +++ 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 410839d7..c8fdf6c2 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -13,6 +13,7 @@ from fastapi import ( WebSocket, WebSocketDisconnect, ) +from fastapi.responses import FileResponse from fastapi_pagination import Page, paginate from pydantic import BaseModel, Field from reflector.db import database, transcripts @@ -21,7 +22,6 @@ from reflector.settings import settings from reflector.utils.audio_waveform import get_audio_waveform from starlette.concurrency import run_in_threadpool -from ._range_requests_response import range_requests_response from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base router = APIRouter() @@ -356,10 +356,13 @@ async def transcript_get_audio_mp3( if not transcript.audio_mp3_filename.exists(): raise HTTPException(status_code=404, detail="Audio not found") - return range_requests_response( - request, + truncated_id = str(transcript.id).split("-")[0] + filename = f"recording_{truncated_id}.mp3" + + return FileResponse( transcript.audio_mp3_filename, - content_type="audio/mp3", + headers={"Content-Disposition": f"attachment; filename={filename}"}, + media_type="audio/mpeg", ) diff --git a/www/app/transcripts/recorder.tsx b/www/app/transcripts/recorder.tsx index a550a840..c1ba0774 100644 --- a/www/app/transcripts/recorder.tsx +++ b/www/app/transcripts/recorder.tsx @@ -311,6 +311,9 @@ export default function Recorder(props: RecorderProps) { From abf9dbcaf1a8f99f293aeb3a9df136c5f55043a6 Mon Sep 17 00:00:00 2001 From: Koper Date: Fri, 13 Oct 2023 10:17:36 +0100 Subject: [PATCH 2/9] Keep range_requests_response --- server/reflector/views/transcripts.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index c8fdf6c2..10e30b26 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -13,7 +13,6 @@ from fastapi import ( WebSocket, WebSocketDisconnect, ) -from fastapi.responses import FileResponse from fastapi_pagination import Page, paginate from pydantic import BaseModel, Field from reflector.db import database, transcripts @@ -22,6 +21,7 @@ from reflector.settings import settings from reflector.utils.audio_waveform import get_audio_waveform from starlette.concurrency import run_in_threadpool +from ._range_requests_response import range_requests_response from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base router = APIRouter() @@ -359,10 +359,11 @@ async def transcript_get_audio_mp3( truncated_id = str(transcript.id).split("-")[0] filename = f"recording_{truncated_id}.mp3" - return FileResponse( + return range_requests_response( + request, transcript.audio_mp3_filename, + content_type="audio/mpeg", headers={"Content-Disposition": f"attachment; filename={filename}"}, - media_type="audio/mpeg", ) From 1a7da94cae6f440aa507ea1a24625b0bf26019a9 Mon Sep 17 00:00:00 2001 From: Koper Date: Fri, 13 Oct 2023 10:37:45 +0100 Subject: [PATCH 3/9] Fix MP3 download python error --- server/reflector/views/_range_requests_response.py | 8 +++++++- server/reflector/views/transcripts.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/server/reflector/views/_range_requests_response.py b/server/reflector/views/_range_requests_response.py index 1a584a3c..f0c628e9 100644 --- a/server/reflector/views/_range_requests_response.py +++ b/server/reflector/views/_range_requests_response.py @@ -38,7 +38,9 @@ def _get_range_header(range_header: str, file_size: int) -> tuple[int, int]: return start, end -def range_requests_response(request: Request, file_path: str, content_type: str): +def range_requests_response( + request: Request, file_path: str, content_type: str, content_disposition: str +): """Returns StreamingResponse using Range Requests of a given file""" file_size = os.stat(file_path).st_size @@ -54,6 +56,10 @@ def range_requests_response(request: Request, file_path: str, content_type: str) "content-range, content-encoding" ), } + + if content_disposition: + headers["Content-Disposition"] = content_disposition + start = 0 end = file_size - 1 status_code = status.HTTP_200_OK diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 10e30b26..bf6d967b 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -363,7 +363,7 @@ async def transcript_get_audio_mp3( request, transcript.audio_mp3_filename, content_type="audio/mpeg", - headers={"Content-Disposition": f"attachment; filename={filename}"}, + content_disposition=f"attachment; filename={filename}", ) From 149342f854276a5ef03d7f631acd2b20300782d0 Mon Sep 17 00:00:00 2001 From: Koper Date: Fri, 13 Oct 2023 10:42:52 +0100 Subject: [PATCH 4/9] Fix unit tests --- server/tests/test_transcripts_audio_download.py | 6 +++--- server/tests/test_transcripts_rtc_ws.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py index f37b7a4f..79cb25bf 100644 --- a/server/tests/test_transcripts_audio_download.py +++ b/server/tests/test_transcripts_audio_download.py @@ -35,7 +35,7 @@ async def fake_transcript(tmpdir): @pytest.mark.parametrize( "url_suffix,content_type", [ - ["/mp3", "audio/mp3"], + ["/mp3", "audio/mpeg"], ], ) async def test_transcript_audio_download(fake_transcript, url_suffix, content_type): @@ -51,7 +51,7 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty @pytest.mark.parametrize( "url_suffix,content_type", [ - ["/mp3", "audio/mp3"], + ["/mp3", "audio/mpeg"], ], ) async def test_transcript_audio_download_range( @@ -74,7 +74,7 @@ async def test_transcript_audio_download_range( @pytest.mark.parametrize( "url_suffix,content_type", [ - ["/mp3", "audio/mp3"], + ["/mp3", "audio/mpeg"], ], ) async def test_transcript_audio_download_range_with_seek( diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 2485ca6b..494fcd36 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -167,7 +167,7 @@ async def test_transcript_rtc_and_websocket( # check that audio/mp3 is available resp = await ac.get(f"/transcripts/{tid}/audio/mp3") assert resp.status_code == 200 - assert resp.headers["Content-Type"] == "audio/mp3" + assert resp.headers["Content-Type"] == "audio/mpeg" # stop server server.stop() From 0e4ef90e621d3673e1db2a118ae215f4724b0e43 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 13 Oct 2023 11:46:43 +0200 Subject: [PATCH 5/9] add github cache for docker --- .github/workflows/deploy.yml | 2 ++ .github/workflows/test_server.yml | 2 ++ 2 files changed, 4 insertions(+) diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index fa2ff154..ecac11b4 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -43,3 +43,5 @@ jobs: platforms: linux/amd64,linux/arm64 push: true tags: ${{ steps.login-ecr.outputs.registry }}/${{ env.ECR_REPOSITORY }}:latest + cache-from: type=gha + cache-to: type=gha,mode=max diff --git a/.github/workflows/test_server.yml b/.github/workflows/test_server.yml index 4c08de22..6761b738 100644 --- a/.github/workflows/test_server.yml +++ b/.github/workflows/test_server.yml @@ -81,3 +81,5 @@ jobs: with: context: server platforms: linux/amd64,linux/arm64 + cache-from: type=gha + cache-to: type=gha,mode=max From 4e40cc511abbfb45e6cab8d141ff5c2afcbd5359 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 13 Oct 2023 14:51:57 +0200 Subject: [PATCH 6/9] server: create fixture for starting the server, and always close server even if one test fail --- server/tests/test_transcripts_rtc_ws.py | 61 +++++++++++++------------ 1 file changed, 32 insertions(+), 29 deletions(-) diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 494fcd36..517567d0 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -31,28 +31,43 @@ class ThreadedUvicorn: continue -@pytest.mark.asyncio -async def test_transcript_rtc_and_websocket( - tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing -): - # goal: start the server, exchange RTC, receive websocket events - # because of that, we need to start the server in a thread - # to be able to connect with aiortc - +@pytest.fixture +async def appserver(tmpdir): from reflector.settings import settings from reflector.app import app + DATA_DIR = settings.DATA_DIR settings.DATA_DIR = Path(tmpdir) # start server host = "127.0.0.1" port = 1255 - base_url = f"http://{host}:{port}/v1" config = Config(app=app, host=host, port=port) server = ThreadedUvicorn(config) await server.start() + yield (server, host, port) + + server.stop() + settings.DATA_DIR = DATA_DIR + + +@pytest.mark.asyncio +async def test_transcript_rtc_and_websocket( + tmpdir, + dummy_llm, + dummy_transcript, + dummy_processors, + ensure_casing, + appserver, +): + # goal: start the server, exchange RTC, receive websocket events + # because of that, we need to start the server in a thread + # to be able to connect with aiortc + server, host, port = appserver + # create a transcript + base_url = f"http://{host}:{port}/v1" ac = AsyncClient(base_url=base_url) response = await ac.post("/transcripts", json={"name": "Test RTC"}) assert response.status_code == 200 @@ -169,33 +184,24 @@ async def test_transcript_rtc_and_websocket( assert resp.status_code == 200 assert resp.headers["Content-Type"] == "audio/mpeg" - # stop server - server.stop() - @pytest.mark.asyncio async def test_transcript_rtc_and_websocket_and_fr( - tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing + tmpdir, + dummy_llm, + dummy_transcript, + dummy_processors, + ensure_casing, + appserver, ): # goal: start the server, exchange RTC, receive websocket events # because of that, we need to start the server in a thread # to be able to connect with aiortc # with target french language - - from reflector.settings import settings - from reflector.app import app - - settings.DATA_DIR = Path(tmpdir) - - # start server - host = "127.0.0.1" - port = 1255 - base_url = f"http://{host}:{port}/v1" - config = Config(app=app, host=host, port=port) - server = ThreadedUvicorn(config) - await server.start() + server, host, port = appserver # create a transcript + base_url = f"http://{host}:{port}/v1" ac = AsyncClient(base_url=base_url) response = await ac.post( "/transcripts", json={"name": "Test RTC", "target_language": "fr"} @@ -303,6 +309,3 @@ async def test_transcript_rtc_and_websocket_and_fr( # ensure the last event received is ended assert events[-1]["event"] == "STATUS" assert events[-1]["data"]["value"] == "ended" - - # stop server - server.stop() From c5297be924a598c820ec090739bce60427103154 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 13 Oct 2023 15:29:54 +0200 Subject: [PATCH 7/9] gh: use poetry cache from setup-python and remove old deps (#281) * gh: use poetry cache from setup-python and remove old deps * gh: use pipx and not setup-poetry, as per setup-python example * server: remove pyaudio unused in current reflector --- .github/workflows/test_server.yml | 19 +++++-------------- server/poetry.lock | 25 +------------------------ server/pyproject.toml | 4 ---- 3 files changed, 6 insertions(+), 42 deletions(-) diff --git a/.github/workflows/test_server.yml b/.github/workflows/test_server.yml index 6761b738..1191fe92 100644 --- a/.github/workflows/test_server.yml +++ b/.github/workflows/test_server.yml @@ -13,23 +13,14 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + - name: Install poetry + run: pipx install poetry - name: Set up Python 3.x uses: actions/setup-python@v4 with: - python-version: 3.11 - - uses: Gr1N/setup-poetry@v8 - - name: Cache Python requirements - uses: actions/cache@v2 - id: cache-pip - with: - path: ~/.cache/pypoetry/virtualenvs - key: ${{ runner.os }}-poetry-${{ hashFiles('poetry.lock') }} - restore-keys: | - - ${{ runner.os }}-poetry- - - name: Install tests dependencies - run: | - sudo apt-get update - sudo apt-get install -y portaudio19-dev build-essential + python-version: '3.11' + cache: 'poetry' + cache-dependency-path: 'server/poetry.lock' - name: Install requirements run: | cd server diff --git a/server/poetry.lock b/server/poetry.lock index 20870235..330c23e3 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -2173,29 +2173,6 @@ files = [ {file = "protobuf-4.24.4.tar.gz", hash = "sha256:5a70731910cd9104762161719c3d883c960151eea077134458503723b60e3667"}, ] -[[package]] -name = "pyaudio" -version = "0.2.13" -description = "Cross-platform audio I/O with PortAudio" -optional = false -python-versions = "*" -files = [ - {file = "PyAudio-0.2.13-cp310-cp310-win32.whl", hash = "sha256:48e29537ea22ae2ae323eebe297bfb2683831cee4f20d96964e131f65ab2161d"}, - {file = "PyAudio-0.2.13-cp310-cp310-win_amd64.whl", hash = "sha256:87137cfd0ef8608a2a383be3f6996f59505e322dab9d16531f14cf542fa294f1"}, - {file = "PyAudio-0.2.13-cp311-cp311-win32.whl", hash = "sha256:13915faaa780e6bbbb6d745ef0e761674fd461b1b1b3f9c1f57042a534bfc0c3"}, - {file = "PyAudio-0.2.13-cp311-cp311-win_amd64.whl", hash = "sha256:59cc3cc5211b729c7854e3989058a145872cc58b1a7b46c6d4d88448a343d890"}, - {file = "PyAudio-0.2.13-cp37-cp37m-win32.whl", hash = "sha256:d294e3f85b2238649b1ff49ce3412459a8a312569975a89d14646536362d7576"}, - {file = "PyAudio-0.2.13-cp37-cp37m-win_amd64.whl", hash = "sha256:ff7f5e44ef51fe61da1e09c6f632f0b5808198edd61b363855cc7dd03bf4a8ac"}, - {file = "PyAudio-0.2.13-cp38-cp38-win32.whl", hash = "sha256:c6b302b048c054b7463936d8ba884b73877dc47012f3c94665dba92dd658ae04"}, - {file = "PyAudio-0.2.13-cp38-cp38-win_amd64.whl", hash = "sha256:1505d766ee718df6f5a18b73ac42307ba1cb4d2c0397873159254a34f67515d6"}, - {file = "PyAudio-0.2.13-cp39-cp39-win32.whl", hash = "sha256:eb128e4a6ea9b98d9a31f33c44978885af27dbe8ae53d665f8790cbfe045517e"}, - {file = "PyAudio-0.2.13-cp39-cp39-win_amd64.whl", hash = "sha256:910ef09225cce227adbba92622d4a3e3c8375117f7dd64039f287d9ffc0e02a1"}, - {file = "PyAudio-0.2.13.tar.gz", hash = "sha256:26bccc81e4243d1c0ff5487e6b481de6329fcd65c79365c267cef38f363a2b56"}, -] - -[package.extras] -test = ["numpy"] - [[package]] name = "pycparser" version = "2.21" @@ -3861,4 +3838,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "a85cb09a0e4b68b29c4272d550e618d2e24ace5f16b707f29e8ac4ce915c1fae" +content-hash = "61578467a70980ff9c2dc0cd787b6410b91d7c5fd2bb4c46b6951ec82690ef67" diff --git a/server/pyproject.toml b/server/pyproject.toml index ffe790f2..e3b44774 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -40,10 +40,6 @@ black = "^23.7.0" stamina = "^23.1.0" -[tool.poetry.group.client.dependencies] -pyaudio = "^0.2.13" - - [tool.poetry.group.tests.dependencies] pytest-cov = "^4.1.0" pytest-aiohttp = "^1.0.4" From 10a86fb036409df5ef6ffc51590be59b8887698a Mon Sep 17 00:00:00 2001 From: Koper Date: Fri, 13 Oct 2023 15:12:05 +0100 Subject: [PATCH 8/9] Display Transcript Title --- www/app/api/apis/DefaultApi.ts | 67 --------------------- www/app/transcripts/[transcriptId]/page.tsx | 20 +++--- www/app/transcripts/transcriptTitle.tsx | 13 ++++ 3 files changed, 26 insertions(+), 74 deletions(-) create mode 100644 www/app/transcripts/transcriptTitle.tsx diff --git a/www/app/api/apis/DefaultApi.ts b/www/app/api/apis/DefaultApi.ts index ce094f1e..d51d42ca 100644 --- a/www/app/api/apis/DefaultApi.ts +++ b/www/app/api/apis/DefaultApi.ts @@ -54,10 +54,6 @@ export interface V1TranscriptGetRequest { transcriptId: any; } -export interface V1TranscriptGetAudioRequest { - transcriptId: any; -} - export interface V1TranscriptGetAudioMp3Request { transcriptId: any; } @@ -310,69 +306,6 @@ export class DefaultApi extends runtime.BaseAPI { return await response.value(); } - /** - * Transcript Get Audio - */ - async v1TranscriptGetAudioRaw( - requestParameters: V1TranscriptGetAudioRequest, - initOverrides?: RequestInit | runtime.InitOverrideFunction, - ): Promise> { - if ( - requestParameters.transcriptId === null || - requestParameters.transcriptId === undefined - ) { - throw new runtime.RequiredError( - "transcriptId", - "Required parameter requestParameters.transcriptId was null or undefined when calling v1TranscriptGetAudio.", - ); - } - - const queryParameters: any = {}; - - const headerParameters: runtime.HTTPHeaders = {}; - - if (this.configuration && this.configuration.accessToken) { - // oauth required - headerParameters["Authorization"] = await this.configuration.accessToken( - "OAuth2AuthorizationCodeBearer", - [], - ); - } - - const response = await this.request( - { - path: `/v1/transcripts/{transcript_id}/audio`.replace( - `{${"transcript_id"}}`, - encodeURIComponent(String(requestParameters.transcriptId)), - ), - method: "GET", - headers: headerParameters, - query: queryParameters, - }, - initOverrides, - ); - - if (this.isJsonMime(response.headers.get("content-type"))) { - return new runtime.JSONApiResponse(response); - } else { - return new runtime.TextApiResponse(response) as any; - } - } - - /** - * Transcript Get Audio - */ - async v1TranscriptGetAudio( - requestParameters: V1TranscriptGetAudioRequest, - initOverrides?: RequestInit | runtime.InitOverrideFunction, - ): Promise { - const response = await this.v1TranscriptGetAudioRaw( - requestParameters, - initOverrides, - ); - return await response.value(); - } - /** * Transcript Get Audio Mp3 */ diff --git a/www/app/transcripts/[transcriptId]/page.tsx b/www/app/transcripts/[transcriptId]/page.tsx index 74fede74..d1d985c7 100644 --- a/www/app/transcripts/[transcriptId]/page.tsx +++ b/www/app/transcripts/[transcriptId]/page.tsx @@ -12,6 +12,7 @@ import "../../styles/button.css"; import FinalSummary from "../finalSummary"; import ShareLink from "../shareLink"; import QRCode from "react-qr-code"; +import TranscriptTitle from "../transcriptTitle"; type TranscriptDetails = { params: { @@ -50,13 +51,18 @@ export default function TranscriptDetails(details: TranscriptDetails) { ) : ( <> - +
+ {transcript?.response?.title && ( + + )} + +
{ + return ( +

+ {props.title} +

+ ); +}; + +export default TranscriptTitle; From 628c69f81c69ecb69aa110f42707bc804e6adc4b Mon Sep 17 00:00:00 2001 From: projects-g <63178974+projects-g@users.noreply.github.com> Date: Fri, 13 Oct 2023 22:01:21 +0530 Subject: [PATCH 9/9] Separate out transcription and translation into own Modal deployments (#268) * abstract transcript/translate into separate GPU apps * update app names * update transformers library version * update env.example file --- server/env.example | 1 + server/gpu/modal/reflector_transcriber.py | 138 +--------- server/gpu/modal/reflector_translator.py | 237 ++++++++++++++++++ .../processors/transcript_translator.py | 6 +- server/reflector/settings.py | 4 + 5 files changed, 246 insertions(+), 140 deletions(-) create mode 100644 server/gpu/modal/reflector_translator.py diff --git a/server/env.example b/server/env.example index 4e9b7311..8c4dcdab 100644 --- a/server/env.example +++ b/server/env.example @@ -48,6 +48,7 @@ ## Using serverless modal.com (require reflector-gpu-modal deployed) #TRANSCRIPT_BACKEND=modal #TRANSCRIPT_URL=https://xxxxx--reflector-transcriber-web.modal.run +#TRANSLATE_URL=https://xxxxx--reflector-translator-web.modal.run #TRANSCRIPT_MODAL_API_KEY=xxxxx ## Using serverless banana.dev (require reflector-gpu-banana deployed) diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index 3404cfe4..69558c8e 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -14,34 +14,12 @@ WHISPER_MODEL: str = "large-v2" WHISPER_COMPUTE_TYPE: str = "float16" WHISPER_NUM_WORKERS: int = 1 -# Seamless M4T -SEAMLESSM4T_MODEL_SIZE: str = "medium" -SEAMLESSM4T_MODEL_CARD_NAME: str = f"seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}" -SEAMLESSM4T_VOCODER_CARD_NAME: str = "vocoder_36langs" - -HF_SEAMLESS_M4TEPO: str = f"facebook/seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}" -HF_SEAMLESS_M4T_VOCODEREPO: str = "facebook/seamless-m4t-vocoder" - -SEAMLESS_GITEPO: str = "https://github.com/facebookresearch/seamless_communication.git" -SEAMLESS_MODEL_DIR: str = "m4t" WHISPER_MODEL_DIR = "/root/transcription_models" stub = Stub(name="reflector-transcriber") -def install_seamless_communication(): - import os - import subprocess - initial_dir = os.getcwd() - subprocess.run(["ssh-keyscan", "-t", "rsa", "github.com", ">>", "~/.ssh/known_hosts"]) - subprocess.run(["rm", "-rf", "seamless_communication"]) - subprocess.run(["git", "clone", SEAMLESS_GITEPO, "." + "/seamless_communication"]) - os.chdir("seamless_communication") - subprocess.run(["pip", "install", "-e", "."]) - os.chdir(initial_dir) - - def download_whisper(): from faster_whisper.utils import download_model @@ -50,18 +28,6 @@ def download_whisper(): print("Whisper model downloaded") -def download_seamlessm4t_model(): - from huggingface_hub import snapshot_download - - print("Downloading Transcriber model & tokenizer") - snapshot_download(HF_SEAMLESS_M4TEPO, cache_dir=SEAMLESS_MODEL_DIR) - print("Transcriber model & tokenizer downloaded") - - print("Downloading vocoder weights") - snapshot_download(HF_SEAMLESS_M4T_VOCODEREPO, cache_dir=SEAMLESS_MODEL_DIR) - print("Vocoder weights downloaded") - - def migrate_cache_llm(): """ XXX The cache for model files in Transformers v4.22.0 has been updated. @@ -76,52 +42,6 @@ def migrate_cache_llm(): print("LLM cache moved") -def configure_seamless_m4t(): - import os - - import yaml - - ASSETS_DIR: str = "./seamless_communication/src/seamless_communication/assets/cards" - - with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'r') as file: - model_yaml_data = yaml.load(file, Loader=yaml.FullLoader) - with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'r') as file: - vocoder_yaml_data = yaml.load(file, Loader=yaml.FullLoader) - with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'r') as file: - unity_100_yaml_data = yaml.load(file, Loader=yaml.FullLoader) - with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'r') as file: - unity_200_yaml_data = yaml.load(file, Loader=yaml.FullLoader) - - model_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}/snapshots" - available_model_versions = os.listdir(model_dir) - latest_model_version = sorted(available_model_versions)[-1] - model_name = f"multitask_unity_{SEAMLESSM4T_MODEL_SIZE}.pt" - model_path = os.path.join(os.getcwd(), model_dir, latest_model_version, model_name) - - vocoder_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-vocoder/snapshots" - available_vocoder_versions = os.listdir(vocoder_dir) - latest_vocoder_version = sorted(available_vocoder_versions)[-1] - vocoder_name = "vocoder_36langs.pt" - vocoder_path = os.path.join(os.getcwd(), vocoder_dir, latest_vocoder_version, vocoder_name) - - tokenizer_name = "tokenizer.model" - tokenizer_path = os.path.join(os.getcwd(), model_dir, latest_model_version, tokenizer_name) - - model_yaml_data['checkpoint'] = f"file:/{model_path}" - vocoder_yaml_data['checkpoint'] = f"file:/{vocoder_path}" - unity_100_yaml_data['tokenizer'] = f"file:/{tokenizer_path}" - unity_200_yaml_data['tokenizer'] = f"file:/{tokenizer_path}" - - with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'w') as file: - yaml.dump(model_yaml_data, file) - with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'w') as file: - yaml.dump(vocoder_yaml_data, file) - with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'w') as file: - yaml.dump(unity_100_yaml_data, file) - with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'w') as file: - yaml.dump(unity_200_yaml_data, file) - - transcriber_image = ( Image.debian_slim(python_version="3.10.8") .apt_install("git") @@ -131,7 +51,7 @@ transcriber_image = ( "faster-whisper", "requests", "torch", - "transformers", + "transformers==4.34.0", "sentencepiece", "protobuf", "huggingface_hub==0.16.4", @@ -141,9 +61,6 @@ transcriber_image = ( "pyyaml", "hf-transfer~=0.1" ) - .run_function(install_seamless_communication) - .run_function(download_seamlessm4t_model) - .run_function(configure_seamless_m4t) .run_function(download_whisper) .run_function(migrate_cache_llm) .env( @@ -167,7 +84,6 @@ class Transcriber: def __enter__(self): import faster_whisper import torch - from seamless_communication.models.inference.translator import Translator self.use_gpu = torch.cuda.is_available() self.device = "cuda" if self.use_gpu else "cpu" @@ -178,12 +94,6 @@ class Transcriber: num_workers=WHISPER_NUM_WORKERS, download_root=WHISPER_MODEL_DIR ) - self.translator = Translator( - SEAMLESSM4T_MODEL_CARD_NAME, - SEAMLESSM4T_VOCODER_CARD_NAME, - torch.device(self.device), - dtype=torch.float32 - ) @method() def transcribe_segment( @@ -229,38 +139,6 @@ class Transcriber: "words": words } - def get_seamless_lang_code(self, lang_code: str): - """ - The codes for SeamlessM4T is different from regular standards. - For ex, French is "fra" and not "fr". - """ - # TODO: Enhance with complete list of lang codes - seamless_lang_code = { - "en": "eng", - "fr": "fra" - } - return seamless_lang_code.get(lang_code, "eng") - - @method() - def translate_text( - self, - text: str, - source_language: str, - target_language: str - ): - translated_text, _, _ = self.translator.predict( - text, - "t2tt", - src_lang=self.get_seamless_lang_code(source_language), - tgt_lang=self.get_seamless_lang_code(target_language), - ngram_filtering=True - ) - return { - "text": { - source_language: text, - target_language: str(translated_text) - } - } # ------------------------------------------------------------------- # Web API # ------------------------------------------------------------------- @@ -316,18 +194,4 @@ def web(): result = func.get() return result - @app.post("/translate", dependencies=[Depends(apikey_auth)]) - async def translate( - text: str, - source_language: Annotated[str, Body(...)] = "en", - target_language: Annotated[str, Body(...)] = "fr", - ) -> TranscriptResponse: - func = transcriberstub.translate_text.spawn( - text=text, - source_language=source_language, - target_language=target_language, - ) - result = func.get() - return result - return app diff --git a/server/gpu/modal/reflector_translator.py b/server/gpu/modal/reflector_translator.py new file mode 100644 index 00000000..69ea719a --- /dev/null +++ b/server/gpu/modal/reflector_translator.py @@ -0,0 +1,237 @@ +""" +Reflector GPU backend - transcriber +=================================== +""" + +import os +import tempfile + +from modal import Image, Secret, Stub, asgi_app, method +from pydantic import BaseModel + +# Seamless M4T +SEAMLESSM4T_MODEL_SIZE: str = "medium" +SEAMLESSM4T_MODEL_CARD_NAME: str = f"seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}" +SEAMLESSM4T_VOCODER_CARD_NAME: str = "vocoder_36langs" + +HF_SEAMLESS_M4TEPO: str = f"facebook/seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}" +HF_SEAMLESS_M4T_VOCODEREPO: str = "facebook/seamless-m4t-vocoder" + +SEAMLESS_GITEPO: str = "https://github.com/facebookresearch/seamless_communication.git" +SEAMLESS_MODEL_DIR: str = "m4t" + +stub = Stub(name="reflector-translator") + + +def install_seamless_communication(): + import os + import subprocess + initial_dir = os.getcwd() + subprocess.run(["ssh-keyscan", "-t", "rsa", "github.com", ">>", "~/.ssh/known_hosts"]) + subprocess.run(["rm", "-rf", "seamless_communication"]) + subprocess.run(["git", "clone", SEAMLESS_GITEPO, "." + "/seamless_communication"]) + os.chdir("seamless_communication") + subprocess.run(["pip", "install", "-e", "."]) + os.chdir(initial_dir) + + +def download_seamlessm4t_model(): + from huggingface_hub import snapshot_download + + print("Downloading Transcriber model & tokenizer") + snapshot_download(HF_SEAMLESS_M4TEPO, cache_dir=SEAMLESS_MODEL_DIR) + print("Transcriber model & tokenizer downloaded") + + print("Downloading vocoder weights") + snapshot_download(HF_SEAMLESS_M4T_VOCODEREPO, cache_dir=SEAMLESS_MODEL_DIR) + print("Vocoder weights downloaded") + + +def configure_seamless_m4t(): + import os + + import yaml + + ASSETS_DIR: str = "./seamless_communication/src/seamless_communication/assets/cards" + + with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'r') as file: + model_yaml_data = yaml.load(file, Loader=yaml.FullLoader) + with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'r') as file: + vocoder_yaml_data = yaml.load(file, Loader=yaml.FullLoader) + with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'r') as file: + unity_100_yaml_data = yaml.load(file, Loader=yaml.FullLoader) + with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'r') as file: + unity_200_yaml_data = yaml.load(file, Loader=yaml.FullLoader) + + model_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-{SEAMLESSM4T_MODEL_SIZE}/snapshots" + available_model_versions = os.listdir(model_dir) + latest_model_version = sorted(available_model_versions)[-1] + model_name = f"multitask_unity_{SEAMLESSM4T_MODEL_SIZE}.pt" + model_path = os.path.join(os.getcwd(), model_dir, latest_model_version, model_name) + + vocoder_dir = f"{SEAMLESS_MODEL_DIR}/models--facebook--seamless-m4t-vocoder/snapshots" + available_vocoder_versions = os.listdir(vocoder_dir) + latest_vocoder_version = sorted(available_vocoder_versions)[-1] + vocoder_name = "vocoder_36langs.pt" + vocoder_path = os.path.join(os.getcwd(), vocoder_dir, latest_vocoder_version, vocoder_name) + + tokenizer_name = "tokenizer.model" + tokenizer_path = os.path.join(os.getcwd(), model_dir, latest_model_version, tokenizer_name) + + model_yaml_data['checkpoint'] = f"file:/{model_path}" + vocoder_yaml_data['checkpoint'] = f"file:/{vocoder_path}" + unity_100_yaml_data['tokenizer'] = f"file:/{tokenizer_path}" + unity_200_yaml_data['tokenizer'] = f"file:/{tokenizer_path}" + + with open(f'{ASSETS_DIR}/seamlessM4T_{SEAMLESSM4T_MODEL_SIZE}.yaml', 'w') as file: + yaml.dump(model_yaml_data, file) + with open(f'{ASSETS_DIR}/vocoder_36langs.yaml', 'w') as file: + yaml.dump(vocoder_yaml_data, file) + with open(f'{ASSETS_DIR}/unity_nllb-100.yaml', 'w') as file: + yaml.dump(unity_100_yaml_data, file) + with open(f'{ASSETS_DIR}/unity_nllb-200.yaml', 'w') as file: + yaml.dump(unity_200_yaml_data, file) + + +transcriber_image = ( + Image.debian_slim(python_version="3.10.8") + .apt_install("git") + .apt_install("wget") + .apt_install("libsndfile-dev") + .pip_install( + "requests", + "torch", + "transformers==4.34.0", + "sentencepiece", + "protobuf", + "huggingface_hub==0.16.4", + "gitpython", + "torchaudio", + "fairseq2", + "pyyaml", + "hf-transfer~=0.1" + ) + .run_function(install_seamless_communication) + .run_function(download_seamlessm4t_model) + .run_function(configure_seamless_m4t) + .env( + { + "LD_LIBRARY_PATH": ( + "/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:" + "/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/" + ) + } + ) +) + + +@stub.cls( + gpu="A10G", + timeout=60 * 5, + container_idle_timeout=60 * 5, + image=transcriber_image, +) +class Translator: + def __enter__(self): + import torch + from seamless_communication.models.inference.translator import Translator + + self.use_gpu = torch.cuda.is_available() + self.device = "cuda" if self.use_gpu else "cpu" + self.translator = Translator( + SEAMLESSM4T_MODEL_CARD_NAME, + SEAMLESSM4T_VOCODER_CARD_NAME, + torch.device(self.device), + dtype=torch.float32 + ) + + @method() + def warmup(self): + return {"status": "ok"} + + def get_seamless_lang_code(self, lang_code: str): + """ + The codes for SeamlessM4T is different from regular standards. + For ex, French is "fra" and not "fr". + """ + # TODO: Enhance with complete list of lang codes + seamless_lang_code = { + "en": "eng", + "fr": "fra" + } + return seamless_lang_code.get(lang_code, "eng") + + @method() + def translate_text( + self, + text: str, + source_language: str, + target_language: str + ): + translated_text, _, _ = self.translator.predict( + text, + "t2tt", + src_lang=self.get_seamless_lang_code(source_language), + tgt_lang=self.get_seamless_lang_code(target_language), + ngram_filtering=True + ) + return { + "text": { + source_language: text, + target_language: str(translated_text) + } + } +# ------------------------------------------------------------------- +# Web API +# ------------------------------------------------------------------- + + +@stub.function( + container_idle_timeout=60, + timeout=60, + secrets=[ + Secret.from_name("reflector-gpu"), + ], +) +@asgi_app() +def web(): + from fastapi import Body, Depends, FastAPI, HTTPException, status + from fastapi.security import OAuth2PasswordBearer + from typing_extensions import Annotated + + translatorstub = Translator() + + app = FastAPI() + + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + def apikey_auth(apikey: str = Depends(oauth2_scheme)): + if apikey != os.environ["REFLECTOR_GPU_APIKEY"]: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers={"WWW-Authenticate": "Bearer"}, + ) + + class TranslateResponse(BaseModel): + result: dict + + @app.post("/translate", dependencies=[Depends(apikey_auth)]) + async def translate( + text: str, + source_language: Annotated[str, Body(...)] = "en", + target_language: Annotated[str, Body(...)] = "fr", + ) -> TranslateResponse: + func = translatorstub.translate_text.spawn( + text=text, + source_language=source_language, + target_language=target_language, + ) + result = func.get() + return result + + @app.post("/warmup", dependencies=[Depends(apikey_auth)]) + async def warmup(): + return translatorstub.warmup.spawn().get() + + return app diff --git a/server/reflector/processors/transcript_translator.py b/server/reflector/processors/transcript_translator.py index ae2c68e1..77b8f5be 100644 --- a/server/reflector/processors/transcript_translator.py +++ b/server/reflector/processors/transcript_translator.py @@ -16,8 +16,8 @@ class TranscriptTranslatorProcessor(Processor): def __init__(self, **kwargs): super().__init__(**kwargs) - self.transcript_url = settings.TRANSCRIPT_URL - self.timeout = settings.TRANSCRIPT_TIMEOUT + self.translate_url = settings.TRANSLATE_URL + self.timeout = settings.TRANSLATE_TIMEOUT self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"} async def _push(self, data: Transcript): @@ -46,7 +46,7 @@ class TranscriptTranslatorProcessor(Processor): async with httpx.AsyncClient() as client: response = await retry(client.post)( - settings.TRANSCRIPT_URL + "/translate", + self.translate_url + "/translate", headers=self.headers, params=json_payload, timeout=self.timeout, diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 4cec6b96..fa2d1296 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -38,6 +38,10 @@ class Settings(BaseSettings): TRANSCRIPT_URL: str | None = None TRANSCRIPT_TIMEOUT: int = 90 + # Translate into the target language + TRANSLATE_URL: str | None = None + TRANSLATE_TIMEOUT: int = 90 + # Audio transcription banana.dev configuration TRANSCRIPT_BANANA_API_KEY: str | None = None TRANSCRIPT_BANANA_MODEL_KEY: str | None = None