mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Merge branch 'feat-user-auth-fief' into feat-user-auth-www
This commit is contained in:
@@ -29,3 +29,10 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: black
|
- id: black
|
||||||
files: ^server/(reflector|tests)/
|
files: ^server/(reflector|tests)/
|
||||||
|
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 5.12.0
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
name: isort (python)
|
||||||
|
files: ^server/(gpu|evaluate|reflector)/
|
||||||
|
|||||||
@@ -59,8 +59,8 @@
|
|||||||
## llm backend implementation
|
## llm backend implementation
|
||||||
## =======================================================
|
## =======================================================
|
||||||
|
|
||||||
## Use oobagooda (default)
|
## Use oobabooga (default)
|
||||||
#LLM_BACKEND=oobagooda
|
#LLM_BACKEND=oobabooga
|
||||||
#LLM_URL=http://xxx:7860/api/generate/v1
|
#LLM_URL=http://xxx:7860/api/generate/v1
|
||||||
|
|
||||||
## Using serverless modal.com (require reflector-gpu-modal deployed)
|
## Using serverless modal.com (require reflector-gpu-modal deployed)
|
||||||
|
|||||||
@@ -3,14 +3,15 @@ Reflector GPU backend - LLM
|
|||||||
===========================
|
===========================
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
from modal import Image, method, Stub, asgi_app, Secret
|
from typing import Optional
|
||||||
|
|
||||||
|
from modal import Image, Secret, Stub, asgi_app, method
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
|
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
|
||||||
LLM_LOW_CPU_MEM_USAGE: bool = False
|
LLM_LOW_CPU_MEM_USAGE: bool = True
|
||||||
LLM_TORCH_DTYPE: str = "bfloat16"
|
LLM_TORCH_DTYPE: str = "bfloat16"
|
||||||
LLM_MAX_NEW_TOKENS: int = 300
|
LLM_MAX_NEW_TOKENS: int = 300
|
||||||
|
|
||||||
@@ -49,6 +50,8 @@ llm_image = (
|
|||||||
"torch",
|
"torch",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"protobuf",
|
"protobuf",
|
||||||
|
"jsonformer==0.12.0",
|
||||||
|
"accelerate==0.21.0",
|
||||||
"einops==0.6.1",
|
"einops==0.6.1",
|
||||||
"hf-transfer~=0.1",
|
"hf-transfer~=0.1",
|
||||||
"huggingface_hub==0.16.4",
|
"huggingface_hub==0.16.4",
|
||||||
@@ -81,6 +84,7 @@ class LLM:
|
|||||||
|
|
||||||
# generation configuration
|
# generation configuration
|
||||||
print("Instance llm generation config")
|
print("Instance llm generation config")
|
||||||
|
# JSONFormer doesn't yet support generation configs, but keeping for future usage
|
||||||
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||||
gen_cfg = GenerationConfig.from_model_config(model.config)
|
gen_cfg = GenerationConfig.from_model_config(model.config)
|
||||||
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
|
||||||
@@ -107,16 +111,30 @@ class LLM:
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
@method()
|
@method()
|
||||||
def generate(self, prompt: str):
|
def generate(self, prompt: str, schema: str = None):
|
||||||
print(f"Generate {prompt=}")
|
print(f"Generate {prompt=}")
|
||||||
# tokenize prompt
|
# If a schema is given, conform to schema
|
||||||
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
if schema:
|
||||||
self.model.device
|
print(f"Schema {schema=}")
|
||||||
)
|
import jsonformer
|
||||||
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
|
|
||||||
|
|
||||||
# decode output
|
jsonformer_llm = jsonformer.Jsonformer(model=self.model,
|
||||||
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
tokenizer=self.tokenizer,
|
||||||
|
json_schema=json.loads(schema),
|
||||||
|
prompt=prompt,
|
||||||
|
max_string_token_length=self.gen_cfg.max_new_tokens)
|
||||||
|
response = jsonformer_llm()
|
||||||
|
else:
|
||||||
|
# If no schema, perform prompt only generation
|
||||||
|
|
||||||
|
# tokenize prompt
|
||||||
|
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
|
||||||
|
self.model.device
|
||||||
|
)
|
||||||
|
output = self.model.generate(input_ids, generation_config=self.gen_cfg)
|
||||||
|
|
||||||
|
# decode output
|
||||||
|
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
|
||||||
print(f"Generated {response=}")
|
print(f"Generated {response=}")
|
||||||
return {"text": response}
|
return {"text": response}
|
||||||
|
|
||||||
@@ -135,7 +153,7 @@ class LLM:
|
|||||||
)
|
)
|
||||||
@asgi_app()
|
@asgi_app()
|
||||||
def web():
|
def web():
|
||||||
from fastapi import FastAPI, HTTPException, status, Depends
|
from fastapi import Depends, FastAPI, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
@@ -154,12 +172,16 @@ def web():
|
|||||||
|
|
||||||
class LLMRequest(BaseModel):
|
class LLMRequest(BaseModel):
|
||||||
prompt: str
|
prompt: str
|
||||||
|
schema: Optional[dict] = None
|
||||||
|
|
||||||
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
@app.post("/llm", dependencies=[Depends(apikey_auth)])
|
||||||
async def llm(
|
async def llm(
|
||||||
req: LLMRequest,
|
req: LLMRequest,
|
||||||
):
|
):
|
||||||
func = llmstub.generate.spawn(prompt=req.prompt)
|
if req.schema:
|
||||||
|
func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema))
|
||||||
|
else:
|
||||||
|
func = llmstub.generate.spawn(prompt=req.prompt)
|
||||||
result = func.get()
|
result = func.get()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,17 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import reflector.auth # noqa
|
||||||
|
import reflector.db # noqa
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi_pagination import add_pagination
|
|
||||||
from fastapi.routing import APIRoute
|
from fastapi.routing import APIRoute
|
||||||
import reflector.db # noqa
|
from fastapi_pagination import add_pagination
|
||||||
import reflector.auth # noqa
|
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||||
from reflector.views.rtc_offer import router as rtc_offer_router
|
|
||||||
from reflector.views.transcripts import router as transcripts_router
|
|
||||||
from reflector.events import subscribers_startup, subscribers_shutdown
|
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from contextlib import asynccontextmanager
|
from reflector.views.rtc_offer import router as rtc_offer_router
|
||||||
|
from reflector.views.transcripts import router as transcripts_router
|
||||||
|
from reflector.views.user import router as user_router
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import sentry_sdk
|
import sentry_sdk
|
||||||
@@ -50,6 +52,7 @@ app.add_middleware(
|
|||||||
# register views
|
# register views
|
||||||
app.include_router(rtc_offer_router)
|
app.include_router(rtc_offer_router)
|
||||||
app.include_router(transcripts_router, prefix="/v1")
|
app.include_router(transcripts_router, prefix="/v1")
|
||||||
|
app.include_router(user_router, prefix="/v1")
|
||||||
add_pagination(app)
|
add_pagination(app)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from reflector.settings import settings
|
|
||||||
from reflector.utils.retry import retry
|
|
||||||
from reflector.logger import logger as reflector_logger
|
|
||||||
from time import monotonic
|
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
from time import monotonic
|
||||||
|
|
||||||
|
from reflector.logger import logger as reflector_logger
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.utils.retry import retry
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
@@ -20,7 +21,7 @@ class LLM:
|
|||||||
Return an instance depending on the settings.
|
Return an instance depending on the settings.
|
||||||
Settings used:
|
Settings used:
|
||||||
|
|
||||||
- `LLM_BACKEND`: key of the backend, defaults to `oobagooda`
|
- `LLM_BACKEND`: key of the backend, defaults to `oobabooga`
|
||||||
- `LLM_URL`: url of the backend
|
- `LLM_URL`: url of the backend
|
||||||
"""
|
"""
|
||||||
if name is None:
|
if name is None:
|
||||||
@@ -44,10 +45,16 @@ class LLM:
|
|||||||
async def _warmup(self, logger: reflector_logger):
|
async def _warmup(self, logger: reflector_logger):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict:
|
async def generate(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
logger: reflector_logger,
|
||||||
|
schema: dict | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> dict:
|
||||||
logger.info("LLM generate", prompt=repr(prompt))
|
logger.info("LLM generate", prompt=repr(prompt))
|
||||||
try:
|
try:
|
||||||
result = await retry(self._generate)(prompt=prompt, **kwargs)
|
result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to call llm after retrying")
|
logger.exception("Failed to call llm after retrying")
|
||||||
raise
|
raise
|
||||||
@@ -59,7 +66,7 @@ class LLM:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _generate(self, prompt: str, **kwargs) -> str:
|
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _parse_json(self, result: str) -> dict:
|
def _parse_json(self, result: str) -> dict:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
import httpx
|
||||||
from reflector.llm.base import LLM
|
from reflector.llm.base import LLM
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.utils.retry import retry
|
from reflector.utils.retry import retry
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
class BananaLLM(LLM):
|
class BananaLLM(LLM):
|
||||||
@@ -13,18 +13,22 @@ class BananaLLM(LLM):
|
|||||||
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _generate(self, prompt: str, **kwargs):
|
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||||
|
json_payload = {"prompt": prompt}
|
||||||
|
if schema:
|
||||||
|
json_payload["schema"] = schema
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await retry(client.post)(
|
response = await retry(client.post)(
|
||||||
settings.LLM_URL,
|
settings.LLM_URL,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json={"prompt": prompt},
|
json=json_payload,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
retry_timeout=300, # as per their sdk
|
retry_timeout=300, # as per their sdk
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
text = response.json()["text"]
|
text = response.json()["text"]
|
||||||
text = text[len(prompt) :] # remove prompt
|
if not schema:
|
||||||
|
text = text[len(prompt) :]
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
|
import httpx
|
||||||
from reflector.llm.base import LLM
|
from reflector.llm.base import LLM
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.utils.retry import retry
|
from reflector.utils.retry import retry
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
class ModalLLM(LLM):
|
class ModalLLM(LLM):
|
||||||
@@ -23,18 +23,22 @@ class ModalLLM(LLM):
|
|||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
async def _generate(self, prompt: str, **kwargs):
|
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||||
|
json_payload = {"prompt": prompt}
|
||||||
|
if schema:
|
||||||
|
json_payload["schema"] = schema
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await retry(client.post)(
|
response = await retry(client.post)(
|
||||||
self.llm_url,
|
self.llm_url,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
json={"prompt": prompt},
|
json=json_payload,
|
||||||
timeout=self.timeout,
|
timeout=self.timeout,
|
||||||
retry_timeout=60 * 5,
|
retry_timeout=60 * 5,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
text = response.json()["text"]
|
text = response.json()["text"]
|
||||||
text = text[len(prompt) :] # remove prompt
|
if not schema:
|
||||||
|
text = text[len(prompt) :]
|
||||||
return text
|
return text
|
||||||
|
|
||||||
|
|
||||||
@@ -48,6 +52,14 @@ if __name__ == "__main__":
|
|||||||
result = await llm.generate("Hello, my name is", logger=logger)
|
result = await llm.generate("Hello, my name is", logger=logger)
|
||||||
print(result)
|
print(result)
|
||||||
|
|
||||||
|
schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"name": {"type": "string"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
|
||||||
|
print(result)
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|||||||
@@ -1,18 +1,21 @@
|
|||||||
|
import httpx
|
||||||
from reflector.llm.base import LLM
|
from reflector.llm.base import LLM
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
class OobagoodaLLM(LLM):
|
class OobaboogaLLM(LLM):
|
||||||
async def _generate(self, prompt: str, **kwargs):
|
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||||
|
json_payload = {"prompt": prompt}
|
||||||
|
if schema:
|
||||||
|
json_payload["schema"] = schema
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
settings.LLM_URL,
|
settings.LLM_URL,
|
||||||
headers={"Content-Type": "application/json"},
|
headers={"Content-Type": "application/json"},
|
||||||
json={"prompt": prompt},
|
json=json_payload,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
LLM.register("oobagooda", OobagoodaLLM)
|
LLM.register("oobabooga", OobaboogaLLM)
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
|
import httpx
|
||||||
from reflector.llm.base import LLM
|
from reflector.llm.base import LLM
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAILLM(LLM):
|
class OpenAILLM(LLM):
|
||||||
@@ -15,7 +15,7 @@ class OpenAILLM(LLM):
|
|||||||
self.max_tokens = settings.LLM_MAX_TOKENS
|
self.max_tokens = settings.LLM_MAX_TOKENS
|
||||||
logger.info(f"LLM use openai backend at {self.openai_url}")
|
logger.info(f"LLM use openai backend at {self.openai_url}")
|
||||||
|
|
||||||
async def _generate(self, prompt: str, **kwargs) -> str:
|
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {self.openai_key}",
|
"Authorization": f"Bearer {self.openai_key}",
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from reflector.processors.base import Processor
|
|
||||||
from reflector.processors.types import Transcript, TitleSummary
|
|
||||||
from reflector.utils.retry import retry
|
|
||||||
from reflector.llm import LLM
|
from reflector.llm import LLM
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import TitleSummary, Transcript
|
||||||
|
from reflector.utils.retry import retry
|
||||||
|
|
||||||
|
|
||||||
class TranscriptTopicDetectorProcessor(Processor):
|
class TranscriptTopicDetectorProcessor(Processor):
|
||||||
@@ -15,9 +15,11 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
PROMPT = """
|
PROMPT = """
|
||||||
### Human:
|
### Human:
|
||||||
Create a JSON object as response.The JSON object must have 2 fields:
|
Create a JSON object as response.The JSON object must have 2 fields:
|
||||||
i) title and ii) summary.For the title field,generate a short title
|
i) title and ii) summary.
|
||||||
for the given text. For the summary field, summarize the given text
|
|
||||||
in three sentences.
|
For the title field, generate a short title for the given text.
|
||||||
|
For the summary field, summarize the given text in a maximum of
|
||||||
|
three sentences.
|
||||||
|
|
||||||
{input_text}
|
{input_text}
|
||||||
|
|
||||||
@@ -30,6 +32,13 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
self.transcript = None
|
self.transcript = None
|
||||||
self.min_transcript_length = min_transcript_length
|
self.min_transcript_length = min_transcript_length
|
||||||
self.llm = LLM.get_instance()
|
self.llm = LLM.get_instance()
|
||||||
|
self.topic_detector_schema = {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": {"type": "string"},
|
||||||
|
"summary": {"type": "string"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
async def _warmup(self):
|
async def _warmup(self):
|
||||||
await self.llm.warmup(logger=self.logger)
|
await self.llm.warmup(logger=self.logger)
|
||||||
@@ -52,7 +61,9 @@ class TranscriptTopicDetectorProcessor(Processor):
|
|||||||
text = self.transcript.text
|
text = self.transcript.text
|
||||||
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
||||||
prompt = self.PROMPT.format(input_text=text)
|
prompt = self.PROMPT.format(input_text=text)
|
||||||
result = await retry(self.llm.generate)(prompt=prompt, logger=self.logger)
|
result = await retry(self.llm.generate)(
|
||||||
|
prompt=prompt, schema=self.topic_detector_schema, logger=self.logger
|
||||||
|
)
|
||||||
summary = TitleSummary(
|
summary = TitleSummary(
|
||||||
title=result["title"],
|
title=result["title"],
|
||||||
summary=result["summary"],
|
summary=result["summary"],
|
||||||
|
|||||||
@@ -55,8 +55,8 @@ class Settings(BaseSettings):
|
|||||||
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
|
||||||
|
|
||||||
# LLM
|
# LLM
|
||||||
# available backend: openai, banana, modal, oobagooda
|
# available backend: openai, banana, modal, oobabooga
|
||||||
LLM_BACKEND: str = "oobagooda"
|
LLM_BACKEND: str = "oobabooga"
|
||||||
|
|
||||||
# LLM common configuration
|
# LLM common configuration
|
||||||
LLM_URL: str | None = None
|
LLM_URL: str | None = None
|
||||||
|
|||||||
@@ -136,9 +136,7 @@ class TranscriptController:
|
|||||||
query = transcripts.select()
|
query = transcripts.select()
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
query = query.where(transcripts.c.user_id == user_id)
|
query = query.where(transcripts.c.user_id == user_id)
|
||||||
print(query)
|
|
||||||
results = await database.fetch_all(query)
|
results = await database.fetch_all(query)
|
||||||
print(results)
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(
|
||||||
|
|||||||
20
server/reflector/views/user.py
Normal file
20
server/reflector/views/user.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
|
import reflector.auth as auth
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class UserInfo(BaseModel):
|
||||||
|
sub: str
|
||||||
|
email: Optional[str]
|
||||||
|
email_verified: Optional[bool]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me")
|
||||||
|
async def user_me(
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
) -> UserInfo | None:
|
||||||
|
return user
|
||||||
@@ -15,7 +15,7 @@ async def test_basic_process(event_loop):
|
|||||||
settings.TRANSCRIPT_BACKEND = "whisper"
|
settings.TRANSCRIPT_BACKEND = "whisper"
|
||||||
|
|
||||||
class LLMTest(LLM):
|
class LLMTest(LLM):
|
||||||
async def _generate(self, prompt: str, **kwargs) -> str:
|
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||||
return {
|
return {
|
||||||
"title": "TITLE",
|
"title": "TITLE",
|
||||||
"summary": "SUMMARY",
|
"summary": "SUMMARY",
|
||||||
|
|||||||
@@ -3,16 +3,16 @@
|
|||||||
# FIXME test websocket connection after RTC is finished still send the full events
|
# FIXME test websocket connection after RTC is finished still send the full events
|
||||||
# FIXME try with locked session, RTC should not work
|
# FIXME try with locked session, RTC should not work
|
||||||
|
|
||||||
import pytest
|
|
||||||
import json
|
|
||||||
from unittest.mock import patch
|
|
||||||
from httpx import AsyncClient
|
|
||||||
|
|
||||||
from uvicorn import Config, Server
|
|
||||||
import threading
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
from httpx_ws import aconnect_ws
|
from httpx_ws import aconnect_ws
|
||||||
|
from uvicorn import Config, Server
|
||||||
|
|
||||||
|
|
||||||
class ThreadedUvicorn:
|
class ThreadedUvicorn:
|
||||||
@@ -60,7 +60,7 @@ async def dummy_llm():
|
|||||||
from reflector.llm.base import LLM
|
from reflector.llm.base import LLM
|
||||||
|
|
||||||
class TestLLM(LLM):
|
class TestLLM(LLM):
|
||||||
async def _generate(self, prompt: str, **kwargs):
|
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||||
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
|
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
|
||||||
|
|
||||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||||
|
|||||||
Reference in New Issue
Block a user