diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 015a4e01..3e06d513 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -1,6 +1,6 @@ name: Deploy to Amazon ECS -on: [deployment, workflow_dispatch] +on: [workflow_dispatch] env: # 384658522150.dkr.ecr.us-east-1.amazonaws.com/reflector diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7132e09c..3dcbe202 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,3 +29,11 @@ repos: hooks: - id: black files: ^server/(reflector|tests)/ + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + files: ^server/(gpu|evaluate|reflector)/ + args: ["--profile", "black", "--filter-files"] diff --git a/server/.gitignore b/server/.gitignore index 7d66d6f0..dbabe979 100644 --- a/server/.gitignore +++ b/server/.gitignore @@ -178,3 +178,4 @@ audio_*.wav # ignore local database reflector.sqlite3 +data/ diff --git a/server/env.example b/server/env.example index 11e0927b..5c91b9d2 100644 --- a/server/env.example +++ b/server/env.example @@ -11,6 +11,29 @@ #DATABASE_URL=postgresql://reflector:reflector@localhost:5432/reflector +## ======================================================= +## User authentication +## ======================================================= + +## No authentication +#AUTH_BACKEND=none + +## Using fief (fief.dev) +#AUTH_BACKEND=fief +#AUTH_FIEF_URL=https://your-fief-instance.... +#AUTH_FIEF_CLIENT_ID=xxx +#AUTH_FIEF_CLIENT_SECRET=xxx + + +## ======================================================= +## Public mode +## ======================================================= +## If set to true, anonymous transcripts will be +## accessible to anybody. + +#PUBLIC_MODE=false + + ## ======================================================= ## Transcription backend ## @@ -45,8 +68,8 @@ ## llm backend implementation ## ======================================================= -## Use oobagooda (default) -#LLM_BACKEND=oobagooda +## Use oobabooga (default) +#LLM_BACKEND=oobabooga #LLM_URL=http://xxx:7860/api/generate/v1 ## Using serverless modal.com (require reflector-gpu-modal deployed) diff --git a/server/gpu/modal/reflector_llm.py b/server/gpu/modal/reflector_llm.py index bf6f4cf5..1a3f77d6 100644 --- a/server/gpu/modal/reflector_llm.py +++ b/server/gpu/modal/reflector_llm.py @@ -3,14 +3,15 @@ Reflector GPU backend - LLM =========================== """ - +import json import os -from modal import Image, method, Stub, asgi_app, Secret +from typing import Optional +from modal import Image, Secret, Stub, asgi_app, method # LLM LLM_MODEL: str = "lmsys/vicuna-13b-v1.5" -LLM_LOW_CPU_MEM_USAGE: bool = False +LLM_LOW_CPU_MEM_USAGE: bool = True LLM_TORCH_DTYPE: str = "bfloat16" LLM_MAX_NEW_TOKENS: int = 300 @@ -49,6 +50,8 @@ llm_image = ( "torch", "sentencepiece", "protobuf", + "jsonformer==0.12.0", + "accelerate==0.21.0", "einops==0.6.1", "hf-transfer~=0.1", "huggingface_hub==0.16.4", @@ -81,6 +84,7 @@ class LLM: # generation configuration print("Instance llm generation config") + # JSONFormer doesn't yet support generation configs, but keeping for future usage model.config.max_new_tokens = LLM_MAX_NEW_TOKENS gen_cfg = GenerationConfig.from_model_config(model.config) gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS @@ -107,16 +111,30 @@ class LLM: return {"status": "ok"} @method() - def generate(self, prompt: str): + def generate(self, prompt: str, schema: str = None): print(f"Generate {prompt=}") - # tokenize prompt - input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( - self.model.device - ) - output = self.model.generate(input_ids, generation_config=self.gen_cfg) + # If a schema is given, conform to schema + if schema: + print(f"Schema {schema=}") + import jsonformer - # decode output - response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) + jsonformer_llm = jsonformer.Jsonformer(model=self.model, + tokenizer=self.tokenizer, + json_schema=json.loads(schema), + prompt=prompt, + max_string_token_length=self.gen_cfg.max_new_tokens) + response = jsonformer_llm() + else: + # If no schema, perform prompt only generation + + # tokenize prompt + input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to( + self.model.device + ) + output = self.model.generate(input_ids, generation_config=self.gen_cfg) + + # decode output + response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True) print(f"Generated {response=}") return {"text": response} @@ -135,9 +153,9 @@ class LLM: ) @asgi_app() def web(): - from fastapi import FastAPI, HTTPException, status, Depends + from fastapi import Depends, FastAPI, HTTPException, status from fastapi.security import OAuth2PasswordBearer - from pydantic import BaseModel + from pydantic import BaseModel, Field llmstub = LLM() @@ -154,12 +172,16 @@ def web(): class LLMRequest(BaseModel): prompt: str + schema_: Optional[dict] = Field(None, alias="schema") @app.post("/llm", dependencies=[Depends(apikey_auth)]) async def llm( req: LLMRequest, ): - func = llmstub.generate.spawn(prompt=req.prompt) + if req.schema_: + func = llmstub.generate.spawn(prompt=req.prompt, schema=json.dumps(req.schema_)) + else: + func = llmstub.generate.spawn(prompt=req.prompt) result = func.get() return result diff --git a/server/gpu/modal/reflector_transcriber.py b/server/gpu/modal/reflector_transcriber.py index 631233cc..55df052b 100644 --- a/server/gpu/modal/reflector_transcriber.py +++ b/server/gpu/modal/reflector_transcriber.py @@ -3,11 +3,11 @@ Reflector GPU backend - transcriber =================================== """ -import tempfile import os -from modal import Image, method, Stub, asgi_app, Secret -from pydantic import BaseModel +import tempfile +from modal import Image, Secret, Stub, asgi_app, method +from pydantic import BaseModel # Whisper WHISPER_MODEL: str = "large-v2" @@ -15,6 +15,9 @@ WHISPER_COMPUTE_TYPE: str = "float16" WHISPER_NUM_WORKERS: int = 1 WHISPER_CACHE_DIR: str = "/cache/whisper" +# Translation Model +TRANSLATION_MODEL = "facebook/m2m100_418M" + stub = Stub(name="reflector-transcriber") @@ -31,6 +34,9 @@ whisper_image = ( "faster-whisper", "requests", "torch", + "transformers", + "sentencepiece", + "protobuf", ) .run_function(download_whisper) .env( @@ -51,17 +57,21 @@ whisper_image = ( ) class Whisper: def __enter__(self): - import torch import faster_whisper + import torch + from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer self.use_gpu = torch.cuda.is_available() - device = "cuda" if self.use_gpu else "cpu" + self.device = "cuda" if self.use_gpu else "cpu" self.model = faster_whisper.WhisperModel( WHISPER_MODEL, - device=device, + device=self.device, compute_type=WHISPER_COMPUTE_TYPE, num_workers=WHISPER_NUM_WORKERS, ) + self.translation_model = M2M100ForConditionalGeneration.from_pretrained(TRANSLATION_MODEL).to(self.device) + self.translation_tokenizer = M2M100Tokenizer.from_pretrained(TRANSLATION_MODEL) + @method() def warmup(self): @@ -72,28 +82,30 @@ class Whisper: self, audio_data: str, audio_suffix: str, - timestamp: float = 0, - language: str = "en", + source_language: str, + target_language: str, + timestamp: float = 0 ): with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp: fp.write(audio_data) segments, _ = self.model.transcribe( fp.name, - language=language, + language=source_language, beam_size=5, word_timestamps=True, vad_filter=True, vad_parameters={"min_silence_duration_ms": 500}, ) - transcript = "" + multilingual_transcript = {} + transcript_source_lang = "" words = [] if segments: segments = list(segments) for segment in segments: - transcript += segment.text + transcript_source_lang += segment.text for word in segment.words: words.append( { @@ -102,9 +114,24 @@ class Whisper: "end": round(timestamp + word.end, 3), } ) + + multilingual_transcript[source_language] = transcript_source_lang + + if target_language != source_language: + self.translation_tokenizer.src_lang = source_language + forced_bos_token_id = self.translation_tokenizer.get_lang_id(target_language) + encoded_transcript = self.translation_tokenizer(transcript_source_lang, return_tensors="pt").to(self.device) + generated_tokens = self.translation_model.generate( + **encoded_transcript, + forced_bos_token_id=forced_bos_token_id + ) + result = self.translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) + translation = result[0].strip() + multilingual_transcript[target_language] = translation + return { - "text": transcript, - "words": words, + "text": multilingual_transcript, + "words": words } @@ -122,7 +149,7 @@ class Whisper: ) @asgi_app() def web(): - from fastapi import FastAPI, UploadFile, Form, Depends, HTTPException, status + from fastapi import Depends, FastAPI, Form, HTTPException, UploadFile, status from fastapi.security import OAuth2PasswordBearer from typing_extensions import Annotated @@ -131,6 +158,7 @@ def web(): app = FastAPI() oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + supported_audio_file_types = ["wav", "mp3", "ogg", "flac"] def apikey_auth(apikey: str = Depends(oauth2_scheme)): if apikey != os.environ["REFLECTOR_GPU_APIKEY"]: @@ -140,28 +168,26 @@ def web(): headers={"WWW-Authenticate": "Bearer"}, ) - class TranscriptionRequest(BaseModel): - timestamp: float = 0 - language: str = "en" - class TranscriptResponse(BaseModel): - result: str + result: dict @app.post("/transcribe", dependencies=[Depends(apikey_auth)]) async def transcribe( file: UploadFile, timestamp: Annotated[float, Form()] = 0, - language: Annotated[str, Form()] = "en", - ): + source_language: Annotated[str, Form()] = "en", + target_language: Annotated[str, Form()] = "en" + ) -> TranscriptResponse: audio_data = await file.read() audio_suffix = file.filename.split(".")[-1] - assert audio_suffix in ["wav", "mp3", "ogg", "flac"] + assert audio_suffix in supported_audio_file_types func = transcriberstub.transcribe_segment.spawn( audio_data=audio_data, audio_suffix=audio_suffix, - language=language, - timestamp=timestamp, + source_language=source_language, + target_language=target_language, + timestamp=timestamp ) result = func.get() return result diff --git a/server/poetry.lock b/server/poetry.lock index 9ad03bcf..b3b122f4 100644 --- a/server/poetry.lock +++ b/server/poetry.lock @@ -930,6 +930,23 @@ mysql = ["aiomysql"] postgresql = ["asyncpg"] sqlite = ["aiosqlite"] +[[package]] +name = "deprecated" +version = "1.2.14" +description = "Python @deprecated decorator to deprecate old python classes, functions or methods." +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "Deprecated-1.2.14-py2.py3-none-any.whl", hash = "sha256:6fac8b097794a90302bdbb17b9b815e732d3c4720583ff1b198499d78470466c"}, + {file = "Deprecated-1.2.14.tar.gz", hash = "sha256:e5323eb936458dccc2582dc6f9c322c852a775a27065ff2b0c4970b9d53d01b3"}, +] + +[package.dependencies] +wrapt = ">=1.10,<2" + +[package.extras] +dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] + [[package]] name = "dnspython" version = "2.4.1" @@ -1022,6 +1039,28 @@ tokenizers = "==0.13.*" conversion = ["transformers[torch] (>=4.23)"] dev = ["black (==23.*)", "flake8 (==6.*)", "isort (==5.*)", "pytest (==7.*)"] +[[package]] +name = "fief-client" +version = "0.17.0" +description = "Fief Client for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "fief_client-0.17.0-py3-none-any.whl", hash = "sha256:ecc8674ecaf58fc7d2926f5a0f49fabd3a1a03e278f030977a97ecb716b8884d"}, + {file = "fief_client-0.17.0.tar.gz", hash = "sha256:f1f9a10c760c29811a8cce2c1d58938090901772826dda973b67dde1bce3bafd"}, +] + +[package.dependencies] +fastapi = {version = "*", optional = true, markers = "extra == \"fastapi\""} +httpx = ">=0.21.3,<0.25.0" +jwcrypto = ">=1.4,<2.0.0" +makefun = {version = ">=1.14.0,<2.0.0", optional = true, markers = "extra == \"fastapi\""} + +[package.extras] +cli = ["halo"] +fastapi = ["fastapi", "makefun (>=1.14.0,<2.0.0)"] +flask = ["flask"] + [[package]] name = "filelock" version = "3.12.2" @@ -1530,6 +1569,20 @@ files = [ {file = "jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe"}, ] +[[package]] +name = "jwcrypto" +version = "1.5.0" +description = "Implementation of JOSE Web standards" +optional = false +python-versions = ">= 3.6" +files = [ + {file = "jwcrypto-1.5.0.tar.gz", hash = "sha256:2c1dc51cf8e38ddf324795dfe9426dee9dd46caf47f535ccbc18781fba810b8d"}, +] + +[package.dependencies] +cryptography = ">=3.4" +deprecated = "*" + [[package]] name = "levenshtein" version = "0.21.1" @@ -1662,6 +1715,17 @@ win32-setctime = {version = ">=1.0.0", markers = "sys_platform == \"win32\""} [package.extras] dev = ["Sphinx (==5.3.0)", "colorama (==0.4.5)", "colorama (==0.4.6)", "freezegun (==1.1.0)", "freezegun (==1.2.2)", "mypy (==v0.910)", "mypy (==v0.971)", "mypy (==v0.990)", "pre-commit (==3.2.1)", "pytest (==6.1.2)", "pytest (==7.2.1)", "pytest-cov (==2.12.1)", "pytest-cov (==4.0.0)", "pytest-mypy-plugins (==1.10.1)", "pytest-mypy-plugins (==1.9.3)", "sphinx-autobuild (==2021.3.14)", "sphinx-rtd-theme (==1.2.0)", "tox (==3.27.1)", "tox (==4.4.6)"] +[[package]] +name = "makefun" +version = "1.15.1" +description = "Small library to dynamically create python functions." +optional = false +python-versions = "*" +files = [ + {file = "makefun-1.15.1-py2.py3-none-any.whl", hash = "sha256:a63cfc7b47a539c76d97bd4fdb833c7d0461e759fd1225f580cb4be6200294d4"}, + {file = "makefun-1.15.1.tar.gz", hash = "sha256:40b0f118b6ded0d8d78c78f1eb679b8b6b2462e3c1b3e05fb1b2da8cd46b48a5"}, +] + [[package]] name = "mpmath" version = "1.3.0" @@ -3234,4 +3298,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "ea523f9b74581a7867097a6249d416d8836f4daaf33fde65ea343e4d3502c71c" +content-hash = "d84edfea8ac7a849340af8eb5db47df9c13a7cc1c640062ebedb2a808be0de4e" diff --git a/server/pyproject.toml b/server/pyproject.toml index e3e75843..895be79d 100644 --- a/server/pyproject.toml +++ b/server/pyproject.toml @@ -25,6 +25,7 @@ httpx = "^0.24.1" fastapi-pagination = "^0.12.6" databases = {extras = ["aiosqlite", "asyncpg"], version = "^0.7.0"} sqlalchemy = "<1.5" +fief-client = {extras = ["fastapi"], version = "^0.17.0"} [tool.poetry.group.dev.dependencies] diff --git a/server/reflector/app.py b/server/reflector/app.py index 8383bf32..14d55d7a 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -1,14 +1,17 @@ +from contextlib import asynccontextmanager + +import reflector.auth # noqa +import reflector.db # noqa from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware -from fastapi_pagination import add_pagination from fastapi.routing import APIRoute -import reflector.db # noqa -from reflector.views.rtc_offer import router as rtc_offer_router -from reflector.views.transcripts import router as transcripts_router -from reflector.events import subscribers_startup, subscribers_shutdown +from fastapi_pagination import add_pagination +from reflector.events import subscribers_shutdown, subscribers_startup from reflector.logger import logger from reflector.settings import settings -from contextlib import asynccontextmanager +from reflector.views.rtc_offer import router as rtc_offer_router +from reflector.views.transcripts import router as transcripts_router +from reflector.views.user import router as user_router try: import sentry_sdk @@ -49,6 +52,7 @@ app.add_middleware( # register views app.include_router(rtc_offer_router) app.include_router(transcripts_router, prefix="/v1") +app.include_router(user_router, prefix="/v1") add_pagination(app) diff --git a/server/reflector/auth/__init__.py b/server/reflector/auth/__init__.py new file mode 100644 index 00000000..65e75d9b --- /dev/null +++ b/server/reflector/auth/__init__.py @@ -0,0 +1,13 @@ +from reflector.settings import settings +from reflector.logger import logger +import importlib + +logger.info(f"User authentication using {settings.AUTH_BACKEND}") +module_name = f"reflector.auth.auth_{settings.AUTH_BACKEND}" +auth_module = importlib.import_module(module_name) + +UserInfo = auth_module.UserInfo +AccessTokenInfo = auth_module.AccessTokenInfo +authenticated = auth_module.authenticated +current_user = auth_module.current_user +current_user_optional = auth_module.current_user_optional diff --git a/server/reflector/auth/auth_fief.py b/server/reflector/auth/auth_fief.py new file mode 100644 index 00000000..0b363fc0 --- /dev/null +++ b/server/reflector/auth/auth_fief.py @@ -0,0 +1,25 @@ +from fastapi.security import OAuth2AuthorizationCodeBearer +from fief_client import FiefAccessTokenInfo, FiefAsync, FiefUserInfo +from fief_client.integrations.fastapi import FiefAuth +from reflector.settings import settings + +fief = FiefAsync( + settings.AUTH_FIEF_URL, + settings.AUTH_FIEF_CLIENT_ID, + settings.AUTH_FIEF_CLIENT_SECRET, +) + +scheme = OAuth2AuthorizationCodeBearer( + f"{settings.AUTH_FIEF_URL}/authorize", + f"{settings.AUTH_FIEF_URL}/api/token", + scopes={"openid": "openid", "offline_access": "offline_access"}, + auto_error=False, +) + +auth = FiefAuth(fief, scheme) + +UserInfo = FiefUserInfo +AccessTokenInfo = FiefAccessTokenInfo +authenticated = auth.authenticated() +current_user = auth.current_user() +current_user_optional = auth.current_user(optional=True) diff --git a/server/reflector/auth/auth_none.py b/server/reflector/auth/auth_none.py new file mode 100644 index 00000000..1c1dd0fd --- /dev/null +++ b/server/reflector/auth/auth_none.py @@ -0,0 +1,26 @@ +from pydantic import BaseModel +from typing import Annotated +from fastapi import Depends +from fastapi.security import OAuth2PasswordBearer + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) + + +class UserInfo(BaseModel): + sub: str + + +class AccessTokenInfo(BaseModel): + pass + + +def authenticated(token: Annotated[str, Depends(oauth2_scheme)]): + return None + + +def current_user(token: Annotated[str, Depends(oauth2_scheme)]): + return None + + +def current_user_optional(token: Annotated[str, Depends(oauth2_scheme)]): + return None diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index e528a3e6..d046ffe7 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -1,10 +1,11 @@ -from reflector.settings import settings -from reflector.utils.retry import retry -from reflector.logger import logger as reflector_logger -from time import monotonic import importlib import json import re +from time import monotonic + +from reflector.logger import logger as reflector_logger +from reflector.settings import settings +from reflector.utils.retry import retry class LLM: @@ -20,7 +21,7 @@ class LLM: Return an instance depending on the settings. Settings used: - - `LLM_BACKEND`: key of the backend, defaults to `oobagooda` + - `LLM_BACKEND`: key of the backend, defaults to `oobabooga` - `LLM_URL`: url of the backend """ if name is None: @@ -44,10 +45,16 @@ class LLM: async def _warmup(self, logger: reflector_logger): pass - async def generate(self, prompt: str, logger: reflector_logger, **kwargs) -> dict: + async def generate( + self, + prompt: str, + logger: reflector_logger, + schema: dict | None = None, + **kwargs, + ) -> dict: logger.info("LLM generate", prompt=repr(prompt)) try: - result = await retry(self._generate)(prompt=prompt, **kwargs) + result = await retry(self._generate)(prompt=prompt, schema=schema, **kwargs) except Exception: logger.exception("Failed to call llm after retrying") raise @@ -59,7 +66,7 @@ class LLM: return result - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: raise NotImplementedError def _parse_json(self, result: str) -> dict: diff --git a/server/reflector/llm/llm_banana.py b/server/reflector/llm/llm_banana.py index d6a0fa07..56fc0e69 100644 --- a/server/reflector/llm/llm_banana.py +++ b/server/reflector/llm/llm_banana.py @@ -1,7 +1,7 @@ +import httpx from reflector.llm.base import LLM from reflector.settings import settings from reflector.utils.retry import retry -import httpx class BananaLLM(LLM): @@ -13,18 +13,22 @@ class BananaLLM(LLM): "X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY, } - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): + json_payload = {"prompt": prompt} + if schema: + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await retry(client.post)( settings.LLM_URL, headers=self.headers, - json={"prompt": prompt}, + json=json_payload, timeout=self.timeout, retry_timeout=300, # as per their sdk ) response.raise_for_status() text = response.json()["text"] - text = text[len(prompt) :] # remove prompt + if not schema: + text = text[len(prompt) :] return text diff --git a/server/reflector/llm/llm_modal.py b/server/reflector/llm/llm_modal.py index 7f23aa0d..ce0de02a 100644 --- a/server/reflector/llm/llm_modal.py +++ b/server/reflector/llm/llm_modal.py @@ -1,7 +1,7 @@ +import httpx from reflector.llm.base import LLM from reflector.settings import settings from reflector.utils.retry import retry -import httpx class ModalLLM(LLM): @@ -23,18 +23,22 @@ class ModalLLM(LLM): ) response.raise_for_status() - async def _generate(self, prompt: str, **kwargs): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): + json_payload = {"prompt": prompt} + if schema: + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await retry(client.post)( self.llm_url, headers=self.headers, - json={"prompt": prompt}, + json=json_payload, timeout=self.timeout, retry_timeout=60 * 5, ) response.raise_for_status() text = response.json()["text"] - text = text[len(prompt) :] # remove prompt + if not schema: + text = text[len(prompt) :] return text @@ -48,6 +52,14 @@ if __name__ == "__main__": result = await llm.generate("Hello, my name is", logger=logger) print(result) + schema = { + "type": "object", + "properties": {"name": {"type": "string"}}, + } + + result = await llm.generate("Hello, my name is", schema=schema, logger=logger) + print(result) + import asyncio asyncio.run(main()) diff --git a/server/reflector/llm/llm_oobagooda.py b/server/reflector/llm/llm_oobabooga.py similarity index 56% rename from server/reflector/llm/llm_oobagooda.py rename to server/reflector/llm/llm_oobabooga.py index be7d8133..411014c5 100644 --- a/server/reflector/llm/llm_oobagooda.py +++ b/server/reflector/llm/llm_oobabooga.py @@ -1,18 +1,21 @@ +import httpx from reflector.llm.base import LLM from reflector.settings import settings -import httpx -class OobagoodaLLM(LLM): - async def _generate(self, prompt: str, **kwargs): +class OobaboogaLLM(LLM): + async def _generate(self, prompt: str, schema: dict | None, **kwargs): + json_payload = {"prompt": prompt} + if schema: + json_payload["schema"] = schema async with httpx.AsyncClient() as client: response = await client.post( settings.LLM_URL, headers={"Content-Type": "application/json"}, - json={"prompt": prompt}, + json=json_payload, ) response.raise_for_status() return response.json() -LLM.register("oobagooda", OobagoodaLLM) +LLM.register("oobabooga", OobaboogaLLM) diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index dd438704..7ed532b7 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -1,7 +1,7 @@ +import httpx from reflector.llm.base import LLM from reflector.logger import logger from reflector.settings import settings -import httpx class OpenAILLM(LLM): @@ -15,7 +15,7 @@ class OpenAILLM(LLM): self.max_tokens = settings.LLM_MAX_TOKENS logger.info(f"LLM use openai backend at {self.openai_url}") - async def _generate(self, prompt: str, **kwargs) -> str: + async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {self.openai_key}", diff --git a/server/reflector/processors/audio_file_writer.py b/server/reflector/processors/audio_file_writer.py index d67db65e..00ab2529 100644 --- a/server/reflector/processors/audio_file_writer.py +++ b/server/reflector/processors/audio_file_writer.py @@ -26,13 +26,13 @@ class AudioFileWriterProcessor(Processor): self.out_stream = self.out_container.add_stream( "pcm_s16le", rate=data.sample_rate ) - for packet in self.out_stream.encode(data): - self.out_container.mux(packet) + for packet in self.out_stream.encode(data): + self.out_container.mux(packet) await self.emit(data) async def _flush(self): if self.out_container: - for packet in self.out_stream.encode(None): + for packet in self.out_stream.encode(): self.out_container.mux(packet) self.out_container.close() self.out_container = None diff --git a/server/reflector/processors/audio_transcript_modal.py b/server/reflector/processors/audio_transcript_modal.py index 1ed727d6..80b6e582 100644 --- a/server/reflector/processors/audio_transcript_modal.py +++ b/server/reflector/processors/audio_transcript_modal.py @@ -5,19 +5,22 @@ API will be a POST request to TRANSCRIPT_URL: ```form "timestamp": 123.456 -"language": "en" +"source_language": "en" +"target_language": "en" "file":