diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7132e09c..2a73b7c2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,3 +29,10 @@ repos: hooks: - id: black files: ^server/(reflector|tests)/ + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + files: ^server/(gpu|evaluate|reflector)/ diff --git a/server/env.example b/server/env.example index a8fd4128..0dc73c22 100644 --- a/server/env.example +++ b/server/env.example @@ -59,8 +59,8 @@ ## llm backend implementation ## ======================================================= -## Use oobagooda (default) -#LLM_BACKEND=oobagooda +## Use oobabooga (default) +#LLM_BACKEND=oobabooga #LLM_URL=http://xxx:7860/api/generate/v1 ## Using serverless modal.com (require reflector-gpu-modal deployed) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index bf6f4cf5..fd8a4aae 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -3,14 +3,15 @@ Reflector GPU backend - LLM =========================== """ - +import json import os -from modal import Image, method, Stub, asgi_app, Secret +from typing import Optional +from modal import Image, Secret, Stub, asgi_app, method # LLM LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" -LLM_LOW_CPU_MEM_USAGE: bool = False +LLM_LOW_CPU_MEM_USAGE: bool = True LLM_TORCH_DTYPE: str = "bfloat16" LLM_MAX_NEW_TOKENS: int = 300 @@ -49,6 +50,8 @@ llm_image = ( "torch", "sentencepiece", "protobuf", + "jsonformer==0.12.0", + "accelerate==0.21.0", "einops==0.6.1", "hf-transfer~=0.1", "huggingface_hub==0.16.4", @@ -81,6 +84,7 @@ class LLM: # generation configuration print("Instance llm generation config") + # JSONFormer doesn't yet support generation configs, but keeping for future usage model.config.max_new_tokens = LLM_MAX_NEW_TOKENS gen_cfg = GenerationConfig.from_model_config(model.config) gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS @@ -107,16 +111,30 @@ class LLM: return {"status": "ok"} @method() - def generate(self, prompt: str): + def generate(self, prompt: str, schema: str = None): print(f"Generate {prompt=}") - # tokenize prompt - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( - self.model.device - ) - output = self.model.generate(input_ids, generation_config=self.gen_cfg) + # If a schema is given, conform to schema + if schema: + print(f"Schema {schema=}") + import jsonformer - # decode output - response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) + jsonformer_llm = jsonformer.Jsonformer(model=self.model, + tokenizer=self.tokenizer, + json_schema=json.loads(schema), + prompt=prompt, + max_string_token_length=self.gen_cfg.max_new_tokens) + response = jsonformer_llm() + else: + # If no schema, perform prompt only generation + + # tokenize prompt + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( + self.model.device + ) + output = self.model.generate(input_ids, generation_config=self.gen_cfg) + + # decode output + response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) print(f"Generated {response=}") return {"text": response} @@ -135,7 +153,7 @@ class LLM: ) @asgi_app() def web(): - from fastapi import FastAPI, HTTPException, status, Depends + from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import OAuth2PasswordBearer from pydantic import BaseModel @@ -154,12 +172,16 @@ def web(): class LLMRequest(BaseModel): prompt: str + schema: Optional[dict] = None @app.post("/llm", dependencies=[Depends(apikey_auth)]) async def llm( req: LLMRequest, ): - func = llmstub.generate.spawn(prompt=req.prompt) + if req.schema: + func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema)) + else: + func = llmstub.generate.spawn(prompt=req.prompt) result = func.get() return result diff --git a/server/reflector/app.py b/server/reflector/app.py index fa148240..14d55d7a 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -1,15 +1,17 @@ +from contextlib import asynccontextmanager + +import reflector.auth # noqa +import reflector.db # noqa from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi_pagination import add_pagination from fastapi.routing import APIRoute -import reflector.db # noqa -import reflector.auth # noqa -from reflector.views.rtc_offer import router as rtc_offer_router -from reflector.views.transcripts import router as transcripts_router -from reflector.events import subscribers_startup, subscribers_shutdown +from fastapi_pagination import add_pagination +from reflector.events import subscribers_shutdown, subscribers_startup from reflector.logger import logger from reflector.settings import settings -from contextlib import asynccontextmanager +from reflector.views.rtc_offer import router as rtc_offer_router +from reflector.views.transcripts import router as transcripts_router +from reflector.views.user import router as user_router try: import sentry_sdk @@ -50,6 +52,7 @@ app.add_middleware( # register views app.include_router(rtc_offer_router) app.include_router(transcripts_router, prefix="/v1") +app.include_router(user_router, prefix="/v1") add_pagination(app) diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index e528a3e6..d046ffe7 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -1,10 +1,11 @@ -from reflector.settings import settings -from reflector.utils.retry import retry -from reflector.logger import logger as reflector_logger -from time import monotonic import importlib import json import re +from time import monotonic + +from reflector.logger import logger as reflector_logger +from reflector.settings import settings +from reflector.utils.retry import retry class LLM: @@ -20,7 +21,7 @@ class LLM: Return an instance depending on the settings. Settings used: - - `LLM_BACKEND`: key of the backend, defaults to `oobagooda` + - `LLM_BACKEND`: key of the backend, defaults to `oobabooga` - `LLM_URL`: url of the backend """ if name is None: @@ -44,10 +45,16 @@ class LLM: async def _warmup(self, logger: reflector_logger): pass - async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict: + async def generate( + self, + prompt: str, + logger: reflector_logger, + schema: dict | None = None, + **kwargs, + ) -> dict: logger.info("LLM generate", prompt=repr(prompt)) try: - result = await retry(self._generate)(prompt=prompt, **kwargs) + result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs) except Exception: logger.exception("Failed to call llm after retrying") raise @@ -59,7 +66,7 @@ class LLM: return result - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: raise NotImplementedError def _parse_json(self, result: str) -> dict: diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index d6a0fa07..56fc0e69 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,7 +1,7 @@ +import httpx from reflector.llm.base import LLM from reflector.settings import settings from reflector.utils.retry import retry -import httpx class BananaLLM(LLM): @@ -13,18 +13,22 @@ class BananaLLM(LLM): "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, } - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): + json_payload = {"prompt": prompt} + if schema: + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await retry(client.post)( settings.LLM_URL, headers=self.headers, - json={"prompt": prompt}, + json=json_payload, timeout=self.timeout, retry_timeout=300, # as per their sdk ) response.raise_for_status() text = response.json()["text"] - text = text[len(prompt) :] # remove prompt + if not schema: + text = text[len(prompt) :] return text diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 7f23aa0d..ce0de02a 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,7 +1,7 @@ +import httpx from reflector.llm.base import LLM from reflector.settings import settings from reflector.utils.retry import retry -import httpx class ModalLLM(LLM): @@ -23,18 +23,22 @@ class ModalLLM(LLM): ) response.raise_for_status() - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): + json_payload = {"prompt": prompt} + if schema: + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await retry(client.post)( self.llm_url, headers=self.headers, - json={"prompt": prompt}, + json=json_payload, timeout=self.timeout, retry_timeout=60 * 5, ) response.raise_for_status() text = response.json()["text"] - text = text[len(prompt) :] # remove prompt + if not schema: + text = text[len(prompt) :] return text @@ -48,6 +52,14 @@ if __name__ == "__main__": result = await llm.generate("Hello, my name is", logger=logger) print(result) + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + + result = await llm.generate("Hello, my name is", schema=schema, logger=logger) + print(result) + import asyncio asyncio.run(main()) diff --git a/server/reflector/llm/llm_oobagooda.py b/server/reflector/llm/llm_oobabooga.py similarity index 56% rename from server/reflector/llm/llm_oobagooda.py rename to server/reflector/llm/llm_oobabooga.py index be7d8133..411014c5 100644 --- a/server/reflector/llm/llm_oobagooda.py +++ b/server/reflector/llm/llm_oobabooga.py @@ -1,18 +1,21 @@ +import httpx from reflector.llm.base import LLM from reflector.settings import settings -import httpx -class OobagoodaLLM(LLM): - async def _generate(self, prompt: str, **kwargs): +class OobaboogaLLM(LLM): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): + json_payload = {"prompt": prompt} + if schema: + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await client.post( settings.LLM_URL, headers={"Content-Type": "application/json"}, - json={"prompt": prompt}, + json=json_payload, ) response.raise_for_status() return response.json() -LLM.register("oobagooda", OobagoodaLLM) +LLM.register("oobabooga", OobaboogaLLM) diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index dd438704..7ed532b7 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -1,7 +1,7 @@ +import httpx from reflector.llm.base import LLM from reflector.logger import logger from reflector.settings import settings -import httpx class OpenAILLM(LLM): @@ -15,7 +15,7 @@ class OpenAILLM(LLM): self.max_tokens = settings.LLM_MAX_TOKENS logger.info(f"LLM use openai backend at {self.openai_url}") - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_key}", diff --git a/server/reflector/processors/transcript_topic_detector.py b/server/reflector/processors/transcript_topic_detector.py index b626e8a2..3d8e3965 100644 --- a/server/reflector/processors/transcript_topic_detector.py +++ b/server/reflector/processors/transcript_topic_detector.py @@ -1,7 +1,7 @@ -from reflector.processors.base import Processor -from reflector.processors.types import Transcript, TitleSummary -from reflector.utils.retry import retry from reflector.llm import LLM +from reflector.processors.base import Processor +from reflector.processors.types import TitleSummary, Transcript +from reflector.utils.retry import retry class TranscriptTopicDetectorProcessor(Processor): @@ -15,9 +15,11 @@ class TranscriptTopicDetectorProcessor(Processor): PROMPT = """ ### Human: Create a JSON object as response.The JSON object must have 2 fields: - i) title and ii) summary.For the title field,generate a short title - for the given text. For the summary field, summarize the given text - in three sentences. + i) title and ii) summary. + + For the title field, generate a short title for the given text. + For the summary field, summarize the given text in a maximum of + three sentences. {input_text} @@ -30,6 +32,13 @@ class TranscriptTopicDetectorProcessor(Processor): self.transcript = None self.min_transcript_length = min_transcript_length self.llm = LLM.get_instance() + self.topic_detector_schema = { + "type": "object", + "properties": { + "title": {"type": "string"}, + "summary": {"type": "string"}, + }, + } async def _warmup(self): await self.llm.warmup(logger=self.logger) @@ -52,7 +61,9 @@ class TranscriptTopicDetectorProcessor(Processor): text = self.transcript.text self.logger.info(f"Topic detector got {len(text)} length transcript") prompt = self.PROMPT.format(input_text=text) - result = await retry(self.llm.generate)(prompt=prompt, logger=self.logger) + result = await retry(self.llm.generate)( + prompt=prompt, schema=self.topic_detector_schema, logger=self.logger + ) summary = TitleSummary( title=result["title"], summary=result["summary"], diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 2add7448..468dab2f 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -55,8 +55,8 @@ class Settings(BaseSettings): TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None # LLM - # available backend: openai, banana, modal, oobagooda - LLM_BACKEND: str = "oobagooda" + # available backend: openai, banana, modal, oobabooga + LLM_BACKEND: str = "oobabooga" # LLM common configuration LLM_URL: str | None = None diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index bfe61473..29db9ec7 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -136,9 +136,7 @@ class TranscriptController: query = transcripts.select() if user_id is not None: query = query.where(transcripts.c.user_id == user_id) - print(query) results = await database.fetch_all(query) - print(results) return results async def get_by_id( diff --git a/server/reflector/views/user.py b/server/reflector/views/user.py new file mode 100644 index 00000000..4952c471 --- /dev/null +++ b/server/reflector/views/user.py @@ -0,0 +1,20 @@ +from typing import Annotated, Optional + +import reflector.auth as auth +from fastapi import APIRouter, Depends +from pydantic import BaseModel + +router = APIRouter() + + +class UserInfo(BaseModel): + sub: str + email: Optional[str] + email_verified: Optional[bool] + + +@router.get("/me") +async def user_me( + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> UserInfo | None: + return user diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py index 95c296de..cc6a8574 100644 --- a/server/tests/test_processors_pipeline.py +++ b/server/tests/test_processors_pipeline.py @@ -15,7 +15,7 @@ async def test_basic_process(event_loop): settings.TRANSCRIPT_BACKEND = "whisper" class LLMTest(LLM): - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: return { "title": "TITLE", "summary": "SUMMARY", diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index 0e764cca..8237d4ab 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -3,16 +3,16 @@ # FIXME test websocket connection after RTC is finished still send the full events # FIXME try with locked session, RTC should not work -import pytest -import json -from unittest.mock import patch -from httpx import AsyncClient - -from uvicorn import Config, Server -import threading import asyncio +import json +import threading from pathlib import Path +from unittest.mock import patch + +import pytest +from httpx import AsyncClient from httpx_ws import aconnect_ws +from uvicorn import Config, Server class ThreadedUvicorn: @@ -60,7 +60,7 @@ async def dummy_llm(): from reflector.llm.base import LLM class TestLLM(LLM): - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) with patch("reflector.llm.base.LLM.get_instance") as mock_llm: