server: implement warmup event for llm and transcription

This commit is contained in:
Mathieu Virbel
2023-08-11 15:32:41 +02:00
parent a2518df3bd
commit 38a5ee0da2
8 changed files with 85 additions and 5 deletions

View File

@@ -1,6 +1,7 @@
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.retry import retry from reflector.utils.retry import retry
from reflector.logger import logger as reflector_logger from reflector.logger import logger as reflector_logger
from time import monotonic
import importlib import importlib
import json import json
import re import re
@@ -29,6 +30,21 @@ class LLM:
importlib.import_module(module_name) importlib.import_module(module_name)
return cls._registry[name]() return cls._registry[name]()
async def warmup(self, logger: reflector_logger):
start = monotonic()
name = self.__class__.__name__
logger.info(f"LLM[{name}] warming up...")
try:
await retry(self._warmup)(logger=logger)
duration = monotonic() - start
logger.info(f"LLM[{name}] warmup took {duration:.2f} seconds")
except Exception:
logger.exception(f"LLM[{name}] warmup failed")
raise
async def _warmup(self, logger: reflector_logger):
pass
async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict: async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict:
logger.info("LLM generate", prompt=repr(prompt)) logger.info("LLM generate", prompt=repr(prompt))
try: try:

View File

@@ -9,10 +9,20 @@ class ModalLLM(LLM):
super().__init__() super().__init__()
self.timeout = settings.LLM_TIMEOUT self.timeout = settings.LLM_TIMEOUT
self.llm_url = settings.LLM_URL + "/llm" self.llm_url = settings.LLM_URL + "/llm"
self.llm_warmup_url = settings.LLM_URL + "/warmup"
self.headers = { self.headers = {
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}", "Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
} }
async def _warmup(self, logger):
async with httpx.AsyncClient() as client:
response = await client.post(
self.llm_warmup_url,
headers=self.headers,
timeout=self.timeout,
)
response.raise_for_status()
async def _generate(self, prompt: str, **kwargs): async def _generate(self, prompt: str, **kwargs):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await retry(client.post)( response = await retry(client.post)(

View File

@@ -47,6 +47,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
def off(self, callback): def off(self, callback):
self.processor.off(callback) self.processor.off(callback)
async def _warmup(self):
return await self.processor._warmup()
async def _push(self, data: AudioFile): async def _push(self, data: AudioFile):
return await self.processor._push(data) return await self.processor._push(data)

View File

@@ -16,6 +16,7 @@ from reflector.processors.audio_transcript_auto import AudioTranscriptAutoProces
from reflector.processors.types import AudioFile, Transcript, Word from reflector.processors.types import AudioFile, Transcript, Word
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.retry import retry from reflector.utils.retry import retry
from time import monotonic
import httpx import httpx
@@ -23,24 +24,37 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
def __init__(self, modal_api_key: str): def __init__(self, modal_api_key: str):
super().__init__() super().__init__()
self.transcript_url = settings.TRANSCRIPT_URL + "/transcribe" self.transcript_url = settings.TRANSCRIPT_URL + "/transcribe"
self.warmup_url = settings.TRANSCRIPT_URL + "/warmup"
self.timeout = settings.TRANSCRIPT_TIMEOUT self.timeout = settings.TRANSCRIPT_TIMEOUT
self.headers = { self.headers = {
"Authorization": f"Bearer {modal_api_key}", "Authorization": f"Bearer {modal_api_key}",
} }
async def _warmup(self):
try:
async with httpx.AsyncClient() as client:
start = monotonic()
self.logger.debug("Transcribe modal: warming up...")
response = await client.post(
self.warmup_url,
headers=self.headers,
timeout=self.timeout,
)
response.raise_for_status()
duration = monotonic() - start
self.logger.debug(f"Transcribe modal: warmup took {duration:.2f}s")
except Exception:
self.logger.exception("Transcribe modal: warmup failed")
async def _transcript(self, data: AudioFile): async def _transcript(self, data: AudioFile):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
print(f"Try to transcribe audio {data.path.name}") print(f"Try to transcribe audio {data.path.name}")
files = { files = {
"file": (data.path.name, data.path.open("rb")), "file": (data.path.name, data.path.open("rb")),
} }
form = {
"timestamp": float(round(data.timestamp, 2)),
}
response = await retry(client.post)( response = await retry(client.post)(
self.transcript_url, self.transcript_url,
files=files, files=files,
data=form,
timeout=self.timeout, timeout=self.timeout,
headers=self.headers, headers=self.headers,
) )
@@ -51,10 +65,15 @@ class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
transcript = Transcript( transcript = Transcript(
text=result["text"], text=result["text"],
words=[ words=[
Word(text=word["text"], start=word["start"], end=word["end"]) Word(
text=word["text"],
start=word["start"],
end=word["end"],
)
for word in result["words"] for word in result["words"]
], ],
) )
transcript.add_offset(data.timestamp)
return transcript return transcript

View File

@@ -7,6 +7,7 @@ import asyncio
class Processor: class Processor:
INPUT_TYPE: type = None INPUT_TYPE: type = None
OUTPUT_TYPE: type = None OUTPUT_TYPE: type = None
WARMUP_EVENT: str = "WARMUP_EVENT"
def __init__(self, callback=None, custom_logger=None): def __init__(self, callback=None, custom_logger=None):
self._processors = [] self._processors = []
@@ -85,12 +86,21 @@ class Processor:
def describe(self, level=0): def describe(self, level=0):
logger.info(" " * level + self.__class__.__name__) logger.info(" " * level + self.__class__.__name__)
async def warmup(self):
"""
Warmup the processor
"""
await self._warmup()
async def _push(self, data): async def _push(self, data):
raise NotImplementedError raise NotImplementedError
async def _flush(self): async def _flush(self):
pass pass
async def _warmup(self):
pass
@classmethod @classmethod
def as_threaded(cls, *args, **kwargs): def as_threaded(cls, *args, **kwargs):
""" """
@@ -129,10 +139,17 @@ class ThreadedProcessor(Processor):
if data is None: if data is None:
await self.processor.flush() await self.processor.flush()
break break
if data == self.WARMUP_EVENT:
self.logger.debug(f"Warming up {self.processor.__class__.__name__}")
await self.processor.warmup()
continue
await self.processor.push(data) await self.processor.push(data)
finally: finally:
self.queue.task_done() self.queue.task_done()
async def _warmup(self):
await self.queue.put(self.WARMUP_EVENT)
async def _push(self, data): async def _push(self, data):
await self.queue.put(data) await self.queue.put(data)
@@ -163,6 +180,7 @@ class Pipeline(Processor):
OUTPUT_TYPE = None OUTPUT_TYPE = None
def __init__(self, *processors: Processor): def __init__(self, *processors: Processor):
self._warmed_up = False
super().__init__() super().__init__()
self.logger = logger.bind(pipeline=self.uid) self.logger = logger.bind(pipeline=self.uid)
self.logger.info("Pipeline created") self.logger.info("Pipeline created")
@@ -178,6 +196,11 @@ class Pipeline(Processor):
self.INPUT_TYPE = processors[0].INPUT_TYPE self.INPUT_TYPE = processors[0].INPUT_TYPE
self.OUTPUT_TYPE = processors[-1].OUTPUT_TYPE self.OUTPUT_TYPE = processors[-1].OUTPUT_TYPE
async def _warmup(self):
for processor in self.processors:
self.logger.debug(f"Warming up {processor.__class__.__name__}")
await processor.warmup()
async def _push(self, data): async def _push(self, data):
await self.processors[0].push(data) await self.processors[0].push(data)

View File

@@ -31,6 +31,9 @@ class TranscriptTopicDetectorProcessor(Processor):
self.min_transcript_length = min_transcript_length self.min_transcript_length = min_transcript_length
self.llm = LLM.get_instance() self.llm = LLM.get_instance()
async def _warmup(self):
await self.llm.warmup(logger=self.logger)
async def _push(self, data: Transcript): async def _push(self, data: Transcript):
if self.transcript is None: if self.transcript is None:
self.transcript = data self.transcript = data

View File

@@ -49,6 +49,11 @@ class Transcript(BaseModel):
self.words.extend(other.words) self.words.extend(other.words)
self.text += other.text self.text += other.text
def add_offset(self, offset: float):
for word in self.words:
word.start += offset
word.end += offset
def clone(self): def clone(self):
words = [ words = [
Word(text=word.text, start=word.start, end=word.end) for word in self.words Word(text=word.text, start=word.start, end=word.end) for word in self.words

View File

@@ -159,6 +159,7 @@ async def rtc_offer_base(
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary), TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
) )
await ctx.pipeline.warmup()
# handle RTC peer connection # handle RTC peer connection
pc = RTCPeerConnection() pc = RTCPeerConnection()