From 7ff201f3ff3b903560d82ca0c79a7d8fbdcf677c Mon Sep 17 00:00:00 2001 From: Sergey Mankovsky Date: Fri, 27 Dec 2024 14:23:03 +0100 Subject: [PATCH] Fix model download --- .../modal_deployments/reflector_diarizer.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/server/gpu/modal_deployments/reflector_diarizer.py b/server/gpu/modal_deployments/reflector_diarizer.py index bbb75b4c..0f9178f4 100644 --- a/server/gpu/modal_deployments/reflector_diarizer.py +++ b/server/gpu/modal_deployments/reflector_diarizer.py @@ -30,16 +30,18 @@ def migrate_cache_llm(): def download_pyannote_audio(): from pyannote.audio import Pipeline + Pipeline.from_pretrained( - "pyannote/speaker-diarization-3.0", + PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR, + use_auth_token=os.environ["HF_TOKEN"], ) diarizer_image = ( Image.debian_slim(python_version="3.10.8") .pip_install( - "pyannote.audio", + "pyannote.audio==3.1.0", "requests", "onnx", "torchaudio", @@ -50,10 +52,12 @@ diarizer_image = ( "protobuf", "numpy", "huggingface_hub", - "hf-transfer" + "hf-transfer", + ) + .run_function( + download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")] ) .run_function(migrate_cache_llm) - .run_function(download_pyannote_audio) .env( { "LD_LIBRARY_PATH": ( @@ -64,6 +68,7 @@ diarizer_image = ( ) ) + @app.cls( gpu=modal.gpu.A100(size="40GB"), timeout=60 * 30, @@ -80,18 +85,12 @@ class Diarizer: self.use_gpu = torch.cuda.is_available() self.device = "cuda" if self.use_gpu else "cpu" self.diarization_pipeline = Pipeline.from_pretrained( - "pyannote/speaker-diarization-3.0", - cache_dir=MODEL_DIR + PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR ) self.diarization_pipeline.to(torch.device(self.device)) @method() - def diarize( - self, - audio_data: str, - audio_suffix: str, - timestamp: float - ): + def diarize(self, audio_data: str, audio_suffix: str, timestamp: float): import tempfile import torchaudio @@ -101,21 +100,24 @@ class Diarizer: print("Diarizing audio") waveform, sample_rate = torchaudio.load(fp.name) - diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate}) + diarization = self.diarization_pipeline( + {"waveform": waveform, "sample_rate": sample_rate} + ) words = [] - for diarization_segment, _, speaker in diarization.itertracks(yield_label=True): + for diarization_segment, _, speaker in diarization.itertracks( + yield_label=True + ): words.append( { "start": round(timestamp + diarization_segment.start, 3), "end": round(timestamp + diarization_segment.end, 3), - "speaker": int(speaker[-2:]) + "speaker": int(speaker[-2:]), } ) print("Diarization complete") - return { - "diarization": words - } + return {"diarization": words} + # ------------------------------------------------------------------- # Web API @@ -129,7 +131,7 @@ class Diarizer: secrets=[ Secret.from_name("reflector-gpu"), ], - image=diarizer_image + image=diarizer_image, ) @asgi_app() def web(): @@ -157,16 +159,17 @@ def web(): if response.status_code == 404: raise HTTPException( status_code=response.status_code, - detail="The audio file does not exist." + detail="The audio file does not exist.", ) class DiarizationResponse(BaseModel): result: dict - @app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]) + @app.post( + "/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)] + ) def diarize( - audio_file_url: str, - timestamp: float = 0.0 + audio_file_url: str, timestamp: float = 0.0 ) -> HTTPException | DiarizationResponse: # Currently the uploaded files are in mp3 format audio_suffix = "mp3" @@ -176,9 +179,7 @@ def web(): print("Audio file downloaded successfully") func = diarizerstub.diarize.spawn( - audio_data=response.content, - audio_suffix=audio_suffix, - timestamp=timestamp + audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp ) result = func.get() return result