mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
Merge branch 'feat-user-auth-fief' into feat-user-auth-www
This commit is contained in:
@@ -29,3 +29,10 @@ repos:
|
||||
hooks:
|
||||
- id: black
|
||||
files: ^server/(reflector|tests)/
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
files: ^server/(gpu|evaluate|reflector)/
|
||||
|
||||
@@ -59,8 +59,8 @@
|
||||
## llm backend implementation
|
||||
## =======================================================
|
||||
|
||||
## Use oobagooda (default)
|
||||
#LLM_BACKEND=oobagooda
|
||||
## Use oobabooga (default)
|
||||
#LLM_BACKEND=oobabooga
|
||||
#LLM_URL=http://xxx:7860/api/generate/v1
|
||||
|
||||
## Using serverless modal.com (require reflector-gpu-modal deployed)
|
||||
|
||||
@@ -3,14 +3,15 @@ Reflector GPU backend - LLM
|
||||
===========================
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
from modal import Image, method, Stub, asgi_app, Secret
|
||||
from typing import Optional
|
||||
|
||||
from modal import Image, Secret, Stub, asgi_app, method
|
||||
|
||||
# LLM
|
||||
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
|
||||
LLM_LOW_CPU_MEM_USAGE: bool = False
|
||||
LLM_LOW_CPU_MEM_USAGE: bool = True
|
||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||
LLM_MAX_NEW_TOKENS: int = 300
|
||||
|
||||
@@ -49,6 +50,8 @@ llm_image = (
|
||||
"torch",
|
||||
"sentencepiece",
|
||||
"protobuf",
|
||||
"jsonformer==0.12.0",
|
||||
"accelerate==0.21.0",
|
||||
"einops==0.6.1",
|
||||
"hf-transfer~=0.1",
|
||||
"huggingface_hub==0.16.4",
|
||||
@@ -81,6 +84,7 @@ class LLM:
|
||||
|
||||
# generation configuration
|
||||
print("Instance llm generation config")
|
||||
# JSONFormer doesn't yet support generation configs, but keeping for future usage
|
||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||
@@ -107,16 +111,30 @@ class LLM:
|
||||
return {"status": "ok"}
|
||||
|
||||
@method()
|
||||
def generate(self, prompt: str):
|
||||
def generate(self, prompt: str, schema: str = None):
|
||||
print(f"Generate {prompt=}")
|
||||
# tokenize prompt
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||
self.model.device
|
||||
)
|
||||
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
|
||||
# If a schema is given, conform to schema
|
||||
if schema:
|
||||
print(f"Schema {schema=}")
|
||||
import jsonformer
|
||||
|
||||
# decode output
|
||||
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
||||
jsonformer_llm = jsonformer.Jsonformer(model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
json_schema=json.loads(schema),
|
||||
prompt=prompt,
|
||||
max_string_token_length=self.gen_cfg.max_new_tokens)
|
||||
response = jsonformer_llm()
|
||||
else:
|
||||
# If no schema, perform prompt only generation
|
||||
|
||||
# tokenize prompt
|
||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||
self.model.device
|
||||
)
|
||||
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
|
||||
|
||||
# decode output
|
||||
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
||||
print(f"Generated {response=}")
|
||||
return {"text": response}
|
||||
|
||||
@@ -135,7 +153,7 @@ class LLM:
|
||||
)
|
||||
@asgi_app()
|
||||
def web():
|
||||
from fastapi import FastAPI, HTTPException, status, Depends
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -154,12 +172,16 @@ def web():
|
||||
|
||||
class LLMRequest(BaseModel):
|
||||
prompt: str
|
||||
schema: Optional[dict] = None
|
||||
|
||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
||||
async def llm(
|
||||
req: LLMRequest,
|
||||
):
|
||||
func = llmstub.generate.spawn(prompt=req.prompt)
|
||||
if req.schema:
|
||||
func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema))
|
||||
else:
|
||||
func = llmstub.generate.spawn(prompt=req.prompt)
|
||||
result = func.get()
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import reflector.auth # noqa
|
||||
import reflector.db # noqa
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi_pagination import add_pagination
|
||||
from fastapi.routing import APIRoute
|
||||
import reflector.db # noqa
|
||||
import reflector.auth # noqa
|
||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||
from reflector.views.transcripts import router as transcripts_router
|
||||
from reflector.events import subscribers_startup, subscribers_shutdown
|
||||
from fastapi_pagination import add_pagination
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
from contextlib import asynccontextmanager
|
||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||
from reflector.views.transcripts import router as transcripts_router
|
||||
from reflector.views.user import router as user_router
|
||||
|
||||
try:
|
||||
import sentry_sdk
|
||||
@@ -50,6 +52,7 @@ app.add_middleware(
|
||||
# register views
|
||||
app.include_router(rtc_offer_router)
|
||||
app.include_router(transcripts_router, prefix="/v1")
|
||||
app.include_router(user_router, prefix="/v1")
|
||||
add_pagination(app)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from time import monotonic
|
||||
import importlib
|
||||
import json
|
||||
import re
|
||||
from time import monotonic
|
||||
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class LLM:
|
||||
@@ -20,7 +21,7 @@ class LLM:
|
||||
Return an instance depending on the settings.
|
||||
Settings used:
|
||||
|
||||
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
|
||||
- `LLM_BACKEND`: key of the backend, defaults to `oobabooga`
|
||||
- `LLM_URL`: url of the backend
|
||||
"""
|
||||
if name is None:
|
||||
@@ -44,10 +45,16 @@ class LLM:
|
||||
async def _warmup(self, logger: reflector_logger):
|
||||
pass
|
||||
|
||||
async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict:
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
logger: reflector_logger,
|
||||
schema: dict | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
logger.info("LLM generate", prompt=repr(prompt))
|
||||
try:
|
||||
result = await retry(self._generate)(prompt=prompt, **kwargs)
|
||||
result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs)
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
raise
|
||||
@@ -59,7 +66,7 @@ class LLM:
|
||||
|
||||
return result
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs) -> str:
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_json(self, result: str) -> dict:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import httpx
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
import httpx
|
||||
|
||||
|
||||
class BananaLLM(LLM):
|
||||
@@ -13,18 +13,22 @@ class BananaLLM(LLM):
|
||||
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
||||
}
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
settings.LLM_URL,
|
||||
headers=self.headers,
|
||||
json={"prompt": prompt},
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=300, # as per their sdk
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
text = text[len(prompt) :] # remove prompt
|
||||
if not schema:
|
||||
text = text[len(prompt) :]
|
||||
return text
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import httpx
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
import httpx
|
||||
|
||||
|
||||
class ModalLLM(LLM):
|
||||
@@ -23,18 +23,22 @@ class ModalLLM(LLM):
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
self.llm_url,
|
||||
headers=self.headers,
|
||||
json={"prompt": prompt},
|
||||
json=json_payload,
|
||||
timeout=self.timeout,
|
||||
retry_timeout=60 * 5,
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
text = text[len(prompt) :] # remove prompt
|
||||
if not schema:
|
||||
text = text[len(prompt) :]
|
||||
return text
|
||||
|
||||
|
||||
@@ -48,6 +52,14 @@ if __name__ == "__main__":
|
||||
result = await llm.generate("Hello, my name is", logger=logger)
|
||||
print(result)
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
|
||||
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,18 +1,21 @@
|
||||
import httpx
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
import httpx
|
||||
|
||||
|
||||
class OobagoodaLLM(LLM):
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
class OobaboogaLLM(LLM):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
settings.LLM_URL,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json={"prompt": prompt},
|
||||
json=json_payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
LLM.register("oobagooda", OobagoodaLLM)
|
||||
LLM.register("oobabooga", OobaboogaLLM)
|
||||
@@ -1,7 +1,7 @@
|
||||
import httpx
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
import httpx
|
||||
|
||||
|
||||
class OpenAILLM(LLM):
|
||||
@@ -15,7 +15,7 @@ class OpenAILLM(LLM):
|
||||
self.max_tokens = settings.LLM_MAX_TOKENS
|
||||
logger.info(f"LLM use openai backend at {self.openai_url}")
|
||||
|
||||
async def _generate(self, prompt: str, **kwargs) -> str:
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.openai_key}",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import Transcript, TitleSummary
|
||||
from reflector.utils.retry import retry
|
||||
from reflector.llm import LLM
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary, Transcript
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class TranscriptTopicDetectorProcessor(Processor):
|
||||
@@ -15,9 +15,11 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
PROMPT = """
|
||||
### Human:
|
||||
Create a JSON object as response.The JSON object must have 2 fields:
|
||||
i) title and ii) summary.For the title field,generate a short title
|
||||
for the given text. For the summary field, summarize the given text
|
||||
in three sentences.
|
||||
i) title and ii) summary.
|
||||
|
||||
For the title field, generate a short title for the given text.
|
||||
For the summary field, summarize the given text in a maximum of
|
||||
three sentences.
|
||||
|
||||
{input_text}
|
||||
|
||||
@@ -30,6 +32,13 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
self.transcript = None
|
||||
self.min_transcript_length = min_transcript_length
|
||||
self.llm = LLM.get_instance()
|
||||
self.topic_detector_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"summary": {"type": "string"},
|
||||
},
|
||||
}
|
||||
|
||||
async def _warmup(self):
|
||||
await self.llm.warmup(logger=self.logger)
|
||||
@@ -52,7 +61,9 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
text = self.transcript.text
|
||||
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
||||
prompt = self.PROMPT.format(input_text=text)
|
||||
result = await retry(self.llm.generate)(prompt=prompt, logger=self.logger)
|
||||
result = await retry(self.llm.generate)(
|
||||
prompt=prompt, schema=self.topic_detector_schema, logger=self.logger
|
||||
)
|
||||
summary = TitleSummary(
|
||||
title=result["title"],
|
||||
summary=result["summary"],
|
||||
|
||||
@@ -55,8 +55,8 @@ class Settings(BaseSettings):
|
||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||
|
||||
# LLM
|
||||
# available backend: openai, banana, modal, oobagooda
|
||||
LLM_BACKEND: str = "oobagooda"
|
||||
# available backend: openai, banana, modal, oobabooga
|
||||
LLM_BACKEND: str = "oobabooga"
|
||||
|
||||
# LLM common configuration
|
||||
LLM_URL: str | None = None
|
||||
|
||||
@@ -136,9 +136,7 @@ class TranscriptController:
|
||||
query = transcripts.select()
|
||||
if user_id is not None:
|
||||
query = query.where(transcripts.c.user_id == user_id)
|
||||
print(query)
|
||||
results = await database.fetch_all(query)
|
||||
print(results)
|
||||
return results
|
||||
|
||||
async def get_by_id(
|
||||
|
||||
20
server/reflector/views/user.py
Normal file
20
server/reflector/views/user.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import reflector.auth as auth
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
sub: str
|
||||
email: Optional[str]
|
||||
email_verified: Optional[bool]
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def user_me(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> UserInfo | None:
|
||||
return user
|
||||
@@ -15,7 +15,7 @@ async def test_basic_process(event_loop):
|
||||
settings.TRANSCRIPT_BACKEND = "whisper"
|
||||
|
||||
class LLMTest(LLM):
|
||||
async def _generate(self, prompt: str, **kwargs) -> str:
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
return {
|
||||
"title": "TITLE",
|
||||
"summary": "SUMMARY",
|
||||
|
||||
@@ -3,16 +3,16 @@
|
||||
# FIXME test websocket connection after RTC is finished still send the full events
|
||||
# FIXME try with locked session, RTC should not work
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from httpx import AsyncClient
|
||||
|
||||
from uvicorn import Config, Server
|
||||
import threading
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from httpx_ws import aconnect_ws
|
||||
from uvicorn import Config, Server
|
||||
|
||||
|
||||
class ThreadedUvicorn:
|
||||
@@ -60,7 +60,7 @@ async def dummy_llm():
|
||||
from reflector.llm.base import LLM
|
||||
|
||||
class TestLLM(LLM):
|
||||
async def _generate(self, prompt: str, **kwargs):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
|
||||
|
||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||
|
||||
Reference in New Issue
Block a user