From 38a5ee0da2c9737e5d28074c755bb963d43a7af3 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Fri, 11 Aug 2023 15:32:41 +0200 Subject: [PATCH] server: implement warmup event for llm and transcription --- server/reflector/llm/base.py | 16 ++++++++++ server/reflector/llm/llm_modal.py | 10 +++++++ .../processors/audio_transcript_auto.py | 3 ++ .../processors/audio_transcript_modal.py | 29 +++++++++++++++---- server/reflector/processors/base.py | 23 +++++++++++++++ .../processors/transcript_topic_detector.py | 3 ++ server/reflector/processors/types.py | 5 ++++ server/reflector/views/rtc_offer.py | 1 + 8 files changed, 85 insertions(+), 5 deletions(-) diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 7f65a4bc..accd738a 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -1,6 +1,7 @@ from reflector.settings import settings from reflector.utils.retry import retry from reflector.logger import logger as reflector_logger +from time import monotonic import importlib import json import re @@ -29,6 +30,21 @@ class LLM: importlib.import_module(module_name) return cls._registry[name]() + async def warmup(self, logger: reflector_logger): + start = monotonic() + name = self.__class__.__name__ + logger.info(f"LLM[{name}] warming up...") + try: + await retry(self._warmup)(logger=logger) + duration = monotonic() - start + logger.info(f"LLM[{name}] warmup took {duration:.2f} seconds") + except Exception: + logger.exception(f"LLM[{name}] warmup failed") + raise + + async def _warmup(self, logger: reflector_logger): + pass + async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict: logger.info("LLM generate", prompt=repr(prompt)) try: diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index b971153b..63ebc67c 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -9,10 +9,20 @@ class ModalLLM(LLM): super().__init__() self.timeout = settings.LLM_TIMEOUT self.llm_url = settings.LLM_URL + "/llm" + self.llm_warmup_url = settings.LLM_URL + "/warmup" self.headers = { "Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}", } + async def _warmup(self, logger): + async with httpx.AsyncClient() as client: + response = await client.post( + self.llm_warmup_url, + headers=self.headers, + timeout=self.timeout, + ) + response.raise_for_status() + async def _generate(self, prompt: str, **kwargs): async with httpx.AsyncClient() as client: response = await retry(client.post)( diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py index 339e5633..fdae7663 100644 --- a/server/reflector/processors/audio_transcript_auto.py +++ b/server/reflector/processors/audio_transcript_auto.py @@ -47,6 +47,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): def off(self, callback): self.processor.off(callback) + async def _warmup(self): + return await self.processor._warmup() + async def _push(self, data: AudioFile): return await self.processor._push(data) diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 71cbdadb..a058df77 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -16,6 +16,7 @@ from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProces from reflector.processors.types import AudioFile, Transcript, Word from reflector.settings import settings from reflector.utils.retry import retry +from time import monotonic import httpx @@ -23,24 +24,37 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): def __init__(self, modal_api_key: str): super().__init__() self.transcript_url = settings.TRANSCRIPT_URL + "/transcribe" + self.warmup_url = settings.TRANSCRIPT_URL + "/warmup" self.timeout = settings.TRANSCRIPT_TIMEOUT self.headers = { "Authorization": f"Bearer {modal_api_key}", } + async def _warmup(self): + try: + async with httpx.AsyncClient() as client: + start = monotonic() + self.logger.debug("Transcribe modal: warming up...") + response = await client.post( + self.warmup_url, + headers=self.headers, + timeout=self.timeout, + ) + response.raise_for_status() + duration = monotonic() - start + self.logger.debug(f"Transcribe modal: warmup took {duration:.2f}s") + except Exception: + self.logger.exception("Transcribe modal: warmup failed") + async def _transcript(self, data: AudioFile): async with httpx.AsyncClient() as client: print(f"Try to transcribe audio {data.path.name}") files = { "file": (data.path.name, data.path.open("rb")), } - form = { - "timestamp": float(round(data.timestamp, 2)), - } response = await retry(client.post)( self.transcript_url, files=files, - data=form, timeout=self.timeout, headers=self.headers, ) @@ -51,10 +65,15 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor): transcript = Transcript( text=result["text"], words=[ - Word(text=word["text"], start=word["start"], end=word["end"]) + Word( + text=word["text"], + start=word["start"], + end=word["end"], + ) for word in result["words"] ], ) + transcript.add_offset(data.timestamp) return transcript diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py index 692a490b..85cbc3fd 100644 --- a/server/reflector/processors/base.py +++ b/server/reflector/processors/base.py @@ -7,6 +7,7 @@ import asyncio class Processor: INPUT_TYPE: type = None OUTPUT_TYPE: type = None + WARMUP_EVENT: str = "WARMUP_EVENT" def __init__(self, callback=None, custom_logger=None): self._processors = [] @@ -85,12 +86,21 @@ class Processor: def describe(self, level=0): logger.info(" " * level + self.__class__.__name__) + async def warmup(self): + """ + Warmup the processor + """ + await self._warmup() + async def _push(self, data): raise NotImplementedError async def _flush(self): pass + async def _warmup(self): + pass + @classmethod def as_threaded(cls, *args, **kwargs): """ @@ -129,10 +139,17 @@ class ThreadedProcessor(Processor): if data is None: await self.processor.flush() break + if data == self.WARMUP_EVENT: + self.logger.debug(f"Warming up {self.processor.__class__.__name__}") + await self.processor.warmup() + continue await self.processor.push(data) finally: self.queue.task_done() + async def _warmup(self): + await self.queue.put(self.WARMUP_EVENT) + async def _push(self, data): await self.queue.put(data) @@ -163,6 +180,7 @@ class Pipeline(Processor): OUTPUT_TYPE = None def __init__(self, *processors: Processor): + self._warmed_up = False super().__init__() self.logger = logger.bind(pipeline=self.uid) self.logger.info("Pipeline created") @@ -178,6 +196,11 @@ class Pipeline(Processor): self.INPUT_TYPE = processors[0].INPUT_TYPE self.OUTPUT_TYPE = processors[-1].OUTPUT_TYPE + async def _warmup(self): + for processor in self.processors: + self.logger.debug(f"Warming up {processor.__class__.__name__}") + await processor.warmup() + async def _push(self, data): await self.processors[0].push(data) diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index 6bcc2497..f4a9286a 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -31,6 +31,9 @@ class TranscriptTopicDetectorProcessor(Processor): self.min_transcript_length = min_transcript_length self.llm = LLM.get_instance() + async def _warmup(self): + await self.llm.warmup(logger=self.logger) + async def _push(self, data: Transcript): if self.transcript is None: self.transcript = data diff --git a/server/reflector/processors/types.py b/server/reflector/processors/types.py index bdf98b7a..6b193882 100644 --- a/server/reflector/processors/types.py +++ b/server/reflector/processors/types.py @@ -49,6 +49,11 @@ class Transcript(BaseModel): self.words.extend(other.words) self.text += other.text + def add_offset(self, offset: float): + for word in self.words: + word.start += offset + word.end += offset + def clone(self): words = [ Word(text=word.text, start=word.start, end=word.end) for word in self.words diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index cbc0a4dc..2e9ed1b6 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -159,6 +159,7 @@ async def rtc_offer_base( TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary), ) + await ctx.pipeline.warmup() # handle RTC peer connection pc = RTCPeerConnection()