Fix model download

This commit is contained in:
2024-12-27 14:23:03 +01:00
parent f2193c7175
commit 7ff201f3ff

View File

@@ -30,16 +30,18 @@ def migrate_cache_llm():
def download_pyannote_audio():
from pyannote.audio import Pipeline
Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
PYANNOTE_MODEL_NAME,
cache_dir=MODEL_DIR,
use_auth_token=os.environ["HF_TOKEN"],
)
diarizer_image = (
Image.debian_slim(python_version="3.10.8")
.pip_install(
"pyannote.audio",
"pyannote.audio==3.1.0",
"requests",
"onnx",
"torchaudio",
@@ -50,10 +52,12 @@ diarizer_image = (
"protobuf",
"numpy",
"huggingface_hub",
"hf-transfer"
"hf-transfer",
)
.run_function(
download_pyannote_audio, secrets=[Secret.from_name("my-huggingface-secret")]
)
.run_function(migrate_cache_llm)
.run_function(download_pyannote_audio)
.env(
{
"LD_LIBRARY_PATH": (
@@ -64,6 +68,7 @@ diarizer_image = (
)
)
@app.cls(
gpu=modal.gpu.A100(size="40GB"),
timeout=60 * 30,
@@ -80,18 +85,12 @@ class Diarizer:
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
self.diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
cache_dir=MODEL_DIR
PYANNOTE_MODEL_NAME, cache_dir=MODEL_DIR
)
self.diarization_pipeline.to(torch.device(self.device))
@method()
def diarize(
self,
audio_data: str,
audio_suffix: str,
timestamp: float
):
def diarize(self, audio_data: str, audio_suffix: str, timestamp: float):
import tempfile
import torchaudio
@@ -101,21 +100,24 @@ class Diarizer:
print("Diarizing audio")
waveform, sample_rate = torchaudio.load(fp.name)
diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate})
diarization = self.diarization_pipeline(
{"waveform": waveform, "sample_rate": sample_rate}
)
words = []
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
for diarization_segment, _, speaker in diarization.itertracks(
yield_label=True
):
words.append(
{
"start": round(timestamp + diarization_segment.start, 3),
"end": round(timestamp + diarization_segment.end, 3),
"speaker": int(speaker[-2:])
"speaker": int(speaker[-2:]),
}
)
print("Diarization complete")
return {
"diarization": words
}
return {"diarization": words}
# -------------------------------------------------------------------
# Web API
@@ -129,7 +131,7 @@ class Diarizer:
secrets=[
Secret.from_name("reflector-gpu"),
],
image=diarizer_image
image=diarizer_image,
)
@asgi_app()
def web():
@@ -157,16 +159,17 @@ def web():
if response.status_code == 404:
raise HTTPException(
status_code=response.status_code,
detail="The audio file does not exist."
detail="The audio file does not exist.",
)
class DiarizationResponse(BaseModel):
result: dict
@app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)])
@app.post(
"/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]
)
def diarize(
audio_file_url: str,
timestamp: float = 0.0
audio_file_url: str, timestamp: float = 0.0
) -> HTTPException | DiarizationResponse:
# Currently the uploaded files are in mp3 format
audio_suffix = "mp3"
@@ -176,9 +179,7 @@ def web():
print("Audio file downloaded successfully")
func = diarizerstub.diarize.spawn(
audio_data=response.content,
audio_suffix=audio_suffix,
timestamp=timestamp
audio_data=response.content, audio_suffix=audio_suffix, timestamp=timestamp
)
result = func.get()
return result