Fix model download

This commit is contained in:
2024-12-27 14:23:03 +01:00
parent f2193c7175
commit 7ff201f3ff

View File

@@ -30,16 +30,18 @@ def migrate_cache_llm():
def download_pyannote_audio(): def download_pyannote_audio():
from pyannote.audio import Pipeline from pyannote.audio import Pipeline
Pipeline.from_pretrained( Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0", PYANNOTE_MODEL_NAME,
cache_dir=MODEL_DIR, cache_dir=MODEL_DIR,
use_auth_token=os.environ["HF_TOKEN"],
) )
diarizer_image = ( diarizer_image = (
Image.debian_slim(python_version="3.10.8") Image.debian_slim(python_version="3.10.8")
.pip_install( .pip_install(
"pyannote.audio", "pyannote.audio==3.1.0",
"requests", "requests",
"onnx", "onnx",
"torchaudio", "torchaudio",
@@ -50,10 +52,12 @@ diarizer_image = (
"protobuf", "protobuf",
"numpy", "numpy",
"huggingface_hub", "huggingface_hub",
"hf-transfer" "hf-transfer",
)
.run_function(
download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")]
) )
.run_function(migrate_cache_llm) .run_function(migrate_cache_llm)
.run_function(download_pyannote_audio)
.env( .env(
{ {
"LD_LIBRARY_PATH": ( "LD_LIBRARY_PATH": (
@@ -64,6 +68,7 @@ diarizer_image = (
) )
) )
@app.cls( @app.cls(
gpu=modal.gpu.A100(size="40GB"), gpu=modal.gpu.A100(size="40GB"),
timeout=60 * 30, timeout=60 * 30,
@@ -80,18 +85,12 @@ class Diarizer:
self.use_gpu = torch.cuda.is_available() self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu" self.device = "cuda" if self.use_gpu else "cpu"
self.diarization_pipeline = Pipeline.from_pretrained( self.diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0", PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR
cache_dir=MODEL_DIR
) )
self.diarization_pipeline.to(torch.device(self.device)) self.diarization_pipeline.to(torch.device(self.device))
@method() @method()
def diarize( def diarize(self, audio_data: str, audio_suffix: str, timestamp: float):
self,
audio_data: str,
audio_suffix: str,
timestamp: float
):
import tempfile import tempfile
import torchaudio import torchaudio
@@ -101,21 +100,24 @@ class Diarizer:
print("Diarizing audio") print("Diarizing audio")
waveform, sample_rate = torchaudio.load(fp.name) waveform, sample_rate = torchaudio.load(fp.name)
diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate}) diarization = self.diarization_pipeline(
{"waveform": waveform, "sample_rate": sample_rate}
)
words = [] words = []
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True): for diarization_segment, _, speaker in diarization.itertracks(
yield_label=True
):
words.append( words.append(
{ {
"start": round(timestamp + diarization_segment.start, 3), "start": round(timestamp + diarization_segment.start, 3),
"end": round(timestamp + diarization_segment.end, 3), "end": round(timestamp + diarization_segment.end, 3),
"speaker": int(speaker[-2:]) "speaker": int(speaker[-2:]),
} }
) )
print("Diarization complete") print("Diarization complete")
return { return {"diarization": words}
"diarization": words
}
# ------------------------------------------------------------------- # -------------------------------------------------------------------
# Web API # Web API
@@ -129,7 +131,7 @@ class Diarizer:
secrets=[ secrets=[
Secret.from_name("reflector-gpu"), Secret.from_name("reflector-gpu"),
], ],
image=diarizer_image image=diarizer_image,
) )
@asgi_app() @asgi_app()
def web(): def web():
@@ -157,16 +159,17 @@ def web():
if response.status_code == 404: if response.status_code == 404:
raise HTTPException( raise HTTPException(
status_code=response.status_code, status_code=response.status_code,
detail="The audio file does not exist." detail="The audio file does not exist.",
) )
class DiarizationResponse(BaseModel): class DiarizationResponse(BaseModel):
result: dict result: dict
@app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]) @app.post(
"/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]
)
def diarize( def diarize(
audio_file_url: str, audio_file_url: str, timestamp: float = 0.0
timestamp: float = 0.0
) -> HTTPException | DiarizationResponse: ) -> HTTPException | DiarizationResponse:
# Currently the uploaded files are in mp3 format # Currently the uploaded files are in mp3 format
audio_suffix = "mp3" audio_suffix = "mp3"
@@ -176,9 +179,7 @@ def web():
print("Audio file downloaded successfully") print("Audio file downloaded successfully")
func = diarizerstub.diarize.spawn( func = diarizerstub.diarize.spawn(
audio_data=response.content, audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp
audio_suffix=audio_suffix,
timestamp=timestamp
) )
result = func.get() result = func.get()
return result return result