Merge branch 'feat-user-auth-fief' into feat-user-auth-www

This commit is contained in:
Koper
2023-08-18 17:33:22 +07:00
15 changed files with 152 additions and 65 deletions

View File

@@ -29,3 +29,10 @@ repos:
hooks: hooks:
- id: black - id: black
files: ^server/(reflector|tests)/ files: ^server/(reflector|tests)/
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort (python)
files: ^server/(gpu|evaluate|reflector)/

View File

@@ -59,8 +59,8 @@
## llm backend implementation ## llm backend implementation
## ======================================================= ## =======================================================
## Use oobagooda (default) ## Use oobabooga (default)
#LLM_BACKEND=oobagooda #LLM_BACKEND=oobabooga
#LLM_URL=http://xxx:7860/api/generate/v1 #LLM_URL=http://xxx:7860/api/generate/v1
## Using serverless modal.com (require reflector-gpu-modal deployed) ## Using serverless modal.com (require reflector-gpu-modal deployed)

View File

@@ -3,14 +3,15 @@ Reflector GPU backend - LLM
=========================== ===========================
""" """
import json
import os import os
from modal import Image, method, Stub, asgi_app, Secret from typing import Optional
from modal import Image, Secret, Stub, asgi_app, method
# LLM # LLM
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
LLM_LOW_CPU_MEM_USAGE: bool = False LLM_LOW_CPU_MEM_USAGE: bool = True
LLM_TORCH_DTYPE: str = "bfloat16" LLM_TORCH_DTYPE: str = "bfloat16"
LLM_MAX_NEW_TOKENS: int = 300 LLM_MAX_NEW_TOKENS: int = 300
@@ -49,6 +50,8 @@ llm_image = (
"torch", "torch",
"sentencepiece", "sentencepiece",
"protobuf", "protobuf",
"jsonformer==0.12.0",
"accelerate==0.21.0",
"einops==0.6.1", "einops==0.6.1",
"hf-transfer~=0.1", "hf-transfer~=0.1",
"huggingface_hub==0.16.4", "huggingface_hub==0.16.4",
@@ -81,6 +84,7 @@ class LLM:
# generation configuration # generation configuration
print("Instance llm generation config") print("Instance llm generation config")
# JSONFormer doesn't yet support generation configs, but keeping for future usage
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
gen_cfg = GenerationConfig.from_model_config(model.config) gen_cfg = GenerationConfig.from_model_config(model.config)
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
@@ -107,16 +111,30 @@ class LLM:
return {"status": "ok"} return {"status": "ok"}
@method() @method()
def generate(self, prompt: str): def generate(self, prompt: str, schema: str = None):
print(f"Generate {prompt=}") print(f"Generate {prompt=}")
# tokenize prompt # If a schema is given, conform to schema
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( if schema:
self.model.device print(f"Schema {schema=}")
) import jsonformer
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
# decode output jsonformer_llm = jsonformer.Jsonformer(model=self.model,
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) tokenizer=self.tokenizer,
json_schema=json.loads(schema),
prompt=prompt,
max_string_token_length=self.gen_cfg.max_new_tokens)
response = jsonformer_llm()
else:
# If no schema, perform prompt only generation
# tokenize prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
# decode output
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
print(f"Generated {response=}") print(f"Generated {response=}")
return {"text": response} return {"text": response}
@@ -135,7 +153,7 @@ class LLM:
) )
@asgi_app() @asgi_app()
def web(): def web():
from fastapi import FastAPI, HTTPException, status, Depends from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel from pydantic import BaseModel
@@ -154,12 +172,16 @@ def web():
class LLMRequest(BaseModel): class LLMRequest(BaseModel):
prompt: str prompt: str
schema: Optional[dict] = None
@app.post("/llm", dependencies=[Depends(apikey_auth)]) @app.post("/llm", dependencies=[Depends(apikey_auth)])
async def llm( async def llm(
req: LLMRequest, req: LLMRequest,
): ):
func = llmstub.generate.spawn(prompt=req.prompt) if req.schema:
func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema))
else:
func = llmstub.generate.spawn(prompt=req.prompt)
result = func.get() result = func.get()
return result return result

View File

@@ -1,15 +1,17 @@
from contextlib import asynccontextmanager
import reflector.auth # noqa
import reflector.db # noqa
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi_pagination import add_pagination
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
import reflector.db # noqa from fastapi_pagination import add_pagination
import reflector.auth # noqa from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router
from reflector.events import subscribers_startup, subscribers_shutdown
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
from contextlib import asynccontextmanager from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router
from reflector.views.user import router as user_router
try: try:
import sentry_sdk import sentry_sdk
@@ -50,6 +52,7 @@ app.add_middleware(
# register views # register views
app.include_router(rtc_offer_router) app.include_router(rtc_offer_router)
app.include_router(transcripts_router, prefix="/v1") app.include_router(transcripts_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")
add_pagination(app) add_pagination(app)

View File

@@ -1,10 +1,11 @@
from reflector.settings import settings
from reflector.utils.retry import retry
from reflector.logger import logger as reflector_logger
from time import monotonic
import importlib import importlib
import json import json
import re import re
from time import monotonic
from reflector.logger import logger as reflector_logger
from reflector.settings import settings
from reflector.utils.retry import retry
class LLM: class LLM:
@@ -20,7 +21,7 @@ class LLM:
Return an instance depending on the settings. Return an instance depending on the settings.
Settings used: Settings used:
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda` - `LLM_BACKEND`: key of the backend, defaults to `oobabooga`
- `LLM_URL`: url of the backend - `LLM_URL`: url of the backend
""" """
if name is None: if name is None:
@@ -44,10 +45,16 @@ class LLM:
async def _warmup(self, logger: reflector_logger): async def _warmup(self, logger: reflector_logger):
pass pass
async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict: async def generate(
self,
prompt: str,
logger: reflector_logger,
schema: dict | None = None,
**kwargs,
) -> dict:
logger.info("LLM generate", prompt=repr(prompt)) logger.info("LLM generate", prompt=repr(prompt))
try: try:
result = await retry(self._generate)(prompt=prompt, **kwargs) result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs)
except Exception: except Exception:
logger.exception("Failed to call llm after retrying") logger.exception("Failed to call llm after retrying")
raise raise
@@ -59,7 +66,7 @@ class LLM:
return result return result
async def _generate(self, prompt: str, **kwargs) -> str: async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
raise NotImplementedError raise NotImplementedError
def _parse_json(self, result: str) -> dict: def _parse_json(self, result: str) -> dict:

View File

@@ -1,7 +1,7 @@
import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.retry import retry from reflector.utils.retry import retry
import httpx
class BananaLLM(LLM): class BananaLLM(LLM):
@@ -13,18 +13,22 @@ class BananaLLM(LLM):
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
} }
async def _generate(self, prompt: str, **kwargs): async def _generate(self, prompt: str, schema: dict | None, **kwargs):
json_payload = {"prompt": prompt}
if schema:
json_payload["schema"] = schema
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await retry(client.post)( response = await retry(client.post)(
settings.LLM_URL, settings.LLM_URL,
headers=self.headers, headers=self.headers,
json={"prompt": prompt}, json=json_payload,
timeout=self.timeout, timeout=self.timeout,
retry_timeout=300, # as per their sdk retry_timeout=300, # as per their sdk
) )
response.raise_for_status() response.raise_for_status()
text = response.json()["text"] text = response.json()["text"]
text = text[len(prompt) :] # remove prompt if not schema:
text = text[len(prompt) :]
return text return text

View File

@@ -1,7 +1,7 @@
import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.retry import retry from reflector.utils.retry import retry
import httpx
class ModalLLM(LLM): class ModalLLM(LLM):
@@ -23,18 +23,22 @@ class ModalLLM(LLM):
) )
response.raise_for_status() response.raise_for_status()
async def _generate(self, prompt: str, **kwargs): async def _generate(self, prompt: str, schema: dict | None, **kwargs):
json_payload = {"prompt": prompt}
if schema:
json_payload["schema"] = schema
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await retry(client.post)( response = await retry(client.post)(
self.llm_url, self.llm_url,
headers=self.headers, headers=self.headers,
json={"prompt": prompt}, json=json_payload,
timeout=self.timeout, timeout=self.timeout,
retry_timeout=60 * 5, retry_timeout=60 * 5,
) )
response.raise_for_status() response.raise_for_status()
text = response.json()["text"] text = response.json()["text"]
text = text[len(prompt) :] # remove prompt if not schema:
text = text[len(prompt) :]
return text return text
@@ -48,6 +52,14 @@ if __name__ == "__main__":
result = await llm.generate("Hello, my name is", logger=logger) result = await llm.generate("Hello, my name is", logger=logger)
print(result) print(result)
schema = {
"type": "object",
"properties": {"name": {"type": "string"}},
}
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
print(result)
import asyncio import asyncio
asyncio.run(main()) asyncio.run(main())

View File

@@ -1,18 +1,21 @@
import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.settings import settings from reflector.settings import settings
import httpx
class OobagoodaLLM(LLM): class OobaboogaLLM(LLM):
async def _generate(self, prompt: str, **kwargs): async def _generate(self, prompt: str, schema: dict | None, **kwargs):
json_payload = {"prompt": prompt}
if schema:
json_payload["schema"] = schema
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
settings.LLM_URL, settings.LLM_URL,
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
json={"prompt": prompt}, json=json_payload,
) )
response.raise_for_status() response.raise_for_status()
return response.json() return response.json()
LLM.register("oobagooda", OobagoodaLLM) LLM.register("oobabooga", OobaboogaLLM)

View File

@@ -1,7 +1,7 @@
import httpx
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
import httpx
class OpenAILLM(LLM): class OpenAILLM(LLM):
@@ -15,7 +15,7 @@ class OpenAILLM(LLM):
self.max_tokens = settings.LLM_MAX_TOKENS self.max_tokens = settings.LLM_MAX_TOKENS
logger.info(f"LLM use openai backend at {self.openai_url}") logger.info(f"LLM use openai backend at {self.openai_url}")
async def _generate(self, prompt: str, **kwargs) -> str: async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {self.openai_key}", "Authorization": f"Bearer {self.openai_key}",

View File

@@ -1,7 +1,7 @@
from reflector.processors.base import Processor
from reflector.processors.types import Transcript, TitleSummary
from reflector.utils.retry import retry
from reflector.llm import LLM from reflector.llm import LLM
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary, Transcript
from reflector.utils.retry import retry
class TranscriptTopicDetectorProcessor(Processor): class TranscriptTopicDetectorProcessor(Processor):
@@ -15,9 +15,11 @@ class TranscriptTopicDetectorProcessor(Processor):
PROMPT = """ PROMPT = """
### Human: ### Human:
Create a JSON object as response.The JSON object must have 2 fields: Create a JSON object as response.The JSON object must have 2 fields:
i) title and ii) summary.For the title field,generate a short title i) title and ii) summary.
for the given text. For the summary field, summarize the given text
in three sentences. For the title field, generate a short title for the given text.
For the summary field, summarize the given text in a maximum of
three sentences.
{input_text} {input_text}
@@ -30,6 +32,13 @@ class TranscriptTopicDetectorProcessor(Processor):
self.transcript = None self.transcript = None
self.min_transcript_length = min_transcript_length self.min_transcript_length = min_transcript_length
self.llm = LLM.get_instance() self.llm = LLM.get_instance()
self.topic_detector_schema = {
"type": "object",
"properties": {
"title": {"type": "string"},
"summary": {"type": "string"},
},
}
async def _warmup(self): async def _warmup(self):
await self.llm.warmup(logger=self.logger) await self.llm.warmup(logger=self.logger)
@@ -52,7 +61,9 @@ class TranscriptTopicDetectorProcessor(Processor):
text = self.transcript.text text = self.transcript.text
self.logger.info(f"Topic detector got {len(text)} length transcript") self.logger.info(f"Topic detector got {len(text)} length transcript")
prompt = self.PROMPT.format(input_text=text) prompt = self.PROMPT.format(input_text=text)
result = await retry(self.llm.generate)(prompt=prompt, logger=self.logger) result = await retry(self.llm.generate)(
prompt=prompt, schema=self.topic_detector_schema, logger=self.logger
)
summary = TitleSummary( summary = TitleSummary(
title=result["title"], title=result["title"],
summary=result["summary"], summary=result["summary"],

View File

@@ -55,8 +55,8 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# LLM # LLM
# available backend: openai, banana, modal, oobagooda # available backend: openai, banana, modal, oobabooga
LLM_BACKEND: str = "oobagooda" LLM_BACKEND: str = "oobabooga"
# LLM common configuration # LLM common configuration
LLM_URL: str | None = None LLM_URL: str | None = None

View File

@@ -136,9 +136,7 @@ class TranscriptController:
query = transcripts.select() query = transcripts.select()
if user_id is not None: if user_id is not None:
query = query.where(transcripts.c.user_id == user_id) query = query.where(transcripts.c.user_id == user_id)
print(query)
results = await database.fetch_all(query) results = await database.fetch_all(query)
print(results)
return results return results
async def get_by_id( async def get_by_id(

View File

@@ -0,0 +1,20 @@
from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import APIRouter, Depends
from pydantic import BaseModel
router = APIRouter()
class UserInfo(BaseModel):
sub: str
email: Optional[str]
email_verified: Optional[bool]
@router.get("/me")
async def user_me(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> UserInfo | None:
return user

View File

@@ -15,7 +15,7 @@ async def test_basic_process(event_loop):
settings.TRANSCRIPT_BACKEND = "whisper" settings.TRANSCRIPT_BACKEND = "whisper"
class LLMTest(LLM): class LLMTest(LLM):
async def _generate(self, prompt: str, **kwargs) -> str: async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
return { return {
"title": "TITLE", "title": "TITLE",
"summary": "SUMMARY", "summary": "SUMMARY",

View File

@@ -3,16 +3,16 @@
# FIXME test websocket connection after RTC is finished still send the full events # FIXME test websocket connection after RTC is finished still send the full events
# FIXME try with locked session, RTC should not work # FIXME try with locked session, RTC should not work
import pytest
import json
from unittest.mock import patch
from httpx import AsyncClient
from uvicorn import Config, Server
import threading
import asyncio import asyncio
import json
import threading
from pathlib import Path from pathlib import Path
from unittest.mock import patch
import pytest
from httpx import AsyncClient
from httpx_ws import aconnect_ws from httpx_ws import aconnect_ws
from uvicorn import Config, Server
class ThreadedUvicorn: class ThreadedUvicorn:
@@ -60,7 +60,7 @@ async def dummy_llm():
from reflector.llm.base import LLM from reflector.llm.base import LLM
class TestLLM(LLM): class TestLLM(LLM):
async def _generate(self, prompt: str, **kwargs): async def _generate(self, prompt: str, schema: dict | None, **kwargs):
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"}) return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
with patch("reflector.llm.base.LLM.get_instance") as mock_llm: with patch("reflector.llm.base.LLM.get_instance") as mock_llm: