mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-22 05:09:05 +00:00
Feature additions (#210)
* initial * add LLM features * update LLM logic * update llm functions: change control flow * add generation config * update return types * update processors and tests * update rtc_offer * revert new title processor change * fix unit tests * add comments and fix HTTP 500 * adjust prompt * test with reflector app * revert new event for final title * update * move onus onto processors * move onus onto processors * stash * add provision for gen config * dynamically pack the LLM input using context length * tune final summary params * update consolidated class structures * update consolidated class structures * update precommit * add broadcast processors * working baseline * Organize LLMParams * minor fixes * minor fixes * minor fixes * fix unit tests * fix unit tests * fix unit tests * update tests * update tests * edit pipeline response events * update summary return types * configure tests * alembic db migration * change LLM response flow * edit main llm functions * edit main llm functions * change llm name and gen cf * Update transcript_topic_detector.py * PR review comments * checkpoint before db event migration * update DB migration of past events * update DB migration of past events * edit LLM classes * Delete unwanted file * remove List typing * remove List typing * update oobabooga API call * topic enhancements * update UI event handling * move ensure_casing to llm base * update tests * update tests
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import reflector.auth # noqa
|
||||
import reflector.db # noqa
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.routing import APIRoute
|
||||
from fastapi_pagination import add_pagination
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
|
||||
import reflector.auth # noqa
|
||||
import reflector.db # noqa
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.logger import logger
|
||||
from reflector.metrics import metrics_init
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import databases
|
||||
import sqlalchemy
|
||||
|
||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||
from reflector.settings import settings
|
||||
|
||||
@@ -16,7 +17,9 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||
sqlalchemy.Column("duration", sqlalchemy.Integer),
|
||||
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
|
||||
sqlalchemy.Column("summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("title", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("short_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
|
||||
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
from .base import LLM # noqa: F401
|
||||
from .llm_params import LLMTaskParams # noqa: F401
|
||||
|
||||
@@ -2,12 +2,19 @@ import importlib
|
||||
import json
|
||||
import re
|
||||
from time import monotonic
|
||||
from typing import TypeVar
|
||||
|
||||
import nltk
|
||||
from prometheus_client import Counter, Histogram
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from reflector.llm.llm_params import TaskParams
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
T = TypeVar("T", bound="LLM")
|
||||
|
||||
|
||||
class LLM:
|
||||
_registry = {}
|
||||
@@ -32,12 +39,25 @@ class LLM:
|
||||
["backend"],
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
self.ensure_nltk()
|
||||
|
||||
@classmethod
|
||||
def ensure_nltk(cls):
|
||||
"""
|
||||
Make sure NLTK package is installed. Searches in the cache and
|
||||
downloads only if needed.
|
||||
"""
|
||||
nltk.download("punkt", download_dir=settings.CACHE_DIR)
|
||||
# For POS tagging
|
||||
nltk.download("averaged_perceptron_tagger", download_dir=settings.CACHE_DIR)
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, klass):
|
||||
cls._registry[name] = klass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name=None):
|
||||
def get_instance(cls, model_name: str | None = None, name: str = None) -> T:
|
||||
"""
|
||||
Return an instance depending on the settings.
|
||||
Settings used:
|
||||
@@ -50,7 +70,39 @@ class LLM:
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.llm.llm_{name}"
|
||||
importlib.import_module(module_name)
|
||||
return cls._registry[name]()
|
||||
return cls._registry[name](model_name)
|
||||
|
||||
def get_model_name(self) -> str:
|
||||
"""
|
||||
Get the currently set model name
|
||||
"""
|
||||
return self._get_model_name()
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
pass
|
||||
|
||||
def set_model_name(self, model_name: str) -> bool:
|
||||
"""
|
||||
Update the model name with the provided model name
|
||||
"""
|
||||
return self._set_model_name(model_name)
|
||||
|
||||
def _set_model_name(self, model_name: str) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
"""
|
||||
Return the LLM Prompt template
|
||||
"""
|
||||
return """
|
||||
### Human:
|
||||
{instruct}
|
||||
|
||||
{text}
|
||||
|
||||
### Assistant:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
name = self.__class__.__name__
|
||||
@@ -73,21 +125,39 @@ class LLM:
|
||||
async def _warmup(self, logger: reflector_logger):
|
||||
pass
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
"""
|
||||
Return the tokenizer instance used by LLM
|
||||
"""
|
||||
return self._get_tokenizer()
|
||||
|
||||
def _get_tokenizer(self):
|
||||
pass
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: str,
|
||||
logger: reflector_logger,
|
||||
schema: dict | None = None,
|
||||
gen_schema: dict | None = None,
|
||||
gen_cfg: GenerationConfig | None = None,
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
logger.info("LLM generate", prompt=repr(prompt))
|
||||
|
||||
if gen_cfg:
|
||||
gen_cfg = gen_cfg.to_dict()
|
||||
self.m_generate_call.inc()
|
||||
try:
|
||||
with self.m_generate.time():
|
||||
result = await retry(self._generate)(
|
||||
prompt=prompt, schema=schema, **kwargs
|
||||
prompt=prompt,
|
||||
gen_schema=gen_schema,
|
||||
gen_cfg=gen_cfg,
|
||||
**kwargs,
|
||||
)
|
||||
self.m_generate_success.inc()
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm after retrying")
|
||||
self.m_generate_failure.inc()
|
||||
@@ -100,7 +170,60 @@ class LLM:
|
||||
|
||||
return result
|
||||
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
def ensure_casing(self, title: str) -> str:
|
||||
"""
|
||||
LLM takes care of word casing, but in rare cases this
|
||||
can falter. This is a fallback to ensure the casing of
|
||||
topics is in a proper format.
|
||||
|
||||
We select nouns, verbs and adjectives and check if camel
|
||||
casing is present and fix it, if not. Will not perform
|
||||
any other changes.
|
||||
"""
|
||||
tokens = nltk.word_tokenize(title)
|
||||
pos_tags = nltk.pos_tag(tokens)
|
||||
camel_cased = []
|
||||
|
||||
whitelisted_pos_tags = [
|
||||
"NN",
|
||||
"NNS",
|
||||
"NNP",
|
||||
"NNPS", # Noun POS
|
||||
"VB",
|
||||
"VBD",
|
||||
"VBG",
|
||||
"VBN",
|
||||
"VBP",
|
||||
"VBZ", # Verb POS
|
||||
"JJ",
|
||||
"JJR",
|
||||
"JJS", # Adjective POS
|
||||
]
|
||||
|
||||
# If at all there is an exception, do not block other reflector
|
||||
# processes. Return the LLM generated title, at the least.
|
||||
try:
|
||||
for word, pos in pos_tags:
|
||||
if pos in whitelisted_pos_tags and word[0].islower():
|
||||
camel_cased.append(word[0].upper() + word[1:])
|
||||
else:
|
||||
camel_cased.append(word)
|
||||
modified_title = " ".join(camel_cased)
|
||||
|
||||
# The result can have words in braces with additional space.
|
||||
# Change ( ABC ), [ ABC ], etc. ==> (ABC), [ABC], etc.
|
||||
pattern = r"(?<=[\[\{\(])\s+|\s+(?=[\]\}\)])"
|
||||
title = re.sub(pattern, "", modified_title)
|
||||
except Exception as e:
|
||||
reflector_logger.info(
|
||||
f"Failed to ensure casing on {title=} " f"with exception : {str(e)}"
|
||||
)
|
||||
|
||||
return title
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def _parse_json(self, result: str) -> dict:
|
||||
@@ -122,3 +245,62 @@ class LLM:
|
||||
result = result[:-3]
|
||||
|
||||
return json.loads(result.strip())
|
||||
|
||||
def text_token_threshold(self, task_params: TaskParams | None) -> int:
|
||||
"""
|
||||
Choose the token size to set as the threshold to pack the LLM calls
|
||||
"""
|
||||
buffer_token_size = 25
|
||||
default_output_tokens = 1000
|
||||
context_window = self.tokenizer.model_max_length
|
||||
tokens = self.tokenizer.tokenize(
|
||||
self.create_prompt(instruct=task_params.instruct, text="")
|
||||
)
|
||||
threshold = context_window - len(tokens) - buffer_token_size
|
||||
if task_params.gen_cfg:
|
||||
threshold -= task_params.gen_cfg.max_new_tokens
|
||||
else:
|
||||
threshold -= default_output_tokens
|
||||
return threshold
|
||||
|
||||
def split_corpus(
|
||||
self,
|
||||
corpus: str,
|
||||
task_params: TaskParams,
|
||||
token_threshold: int | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Split the input to the LLM due to CUDA memory limitations and LLM context window
|
||||
restrictions.
|
||||
|
||||
Accumulate tokens from full sentences till threshold and yield accumulated
|
||||
tokens. Reset accumulation when threshold is reached and repeat process.
|
||||
"""
|
||||
if not token_threshold:
|
||||
token_threshold = self.text_token_threshold(task_params=task_params)
|
||||
|
||||
accumulated_tokens = []
|
||||
accumulated_sentences = []
|
||||
accumulated_token_count = 0
|
||||
corpus_sentences = nltk.sent_tokenize(corpus)
|
||||
|
||||
for sentence in corpus_sentences:
|
||||
tokens = self.tokenizer.tokenize(sentence)
|
||||
if accumulated_token_count + len(tokens) <= token_threshold:
|
||||
accumulated_token_count += len(tokens)
|
||||
accumulated_tokens.extend(tokens)
|
||||
accumulated_sentences.append(sentence)
|
||||
else:
|
||||
yield "".join(accumulated_sentences)
|
||||
accumulated_token_count = len(tokens)
|
||||
accumulated_tokens = tokens
|
||||
accumulated_sentences = [sentence]
|
||||
|
||||
if accumulated_tokens:
|
||||
yield " ".join(accumulated_sentences)
|
||||
|
||||
def create_prompt(self, instruct: str, text: str) -> str:
|
||||
"""
|
||||
Create a consumable prompt based on the prompt template
|
||||
"""
|
||||
return self.template.format(instruct=instruct, text=text)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import httpx
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
@@ -13,10 +14,14 @@ class BananaLLM(LLM):
|
||||
"X-Banana-Model-Key": settings.LLM_BANANA_MODEL_KEY,
|
||||
}
|
||||
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload["gen_cfg"] = gen_cfg
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
settings.LLM_URL,
|
||||
@@ -27,18 +32,21 @@ class BananaLLM(LLM):
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
if not schema:
|
||||
text = text[len(prompt) :]
|
||||
return text
|
||||
|
||||
|
||||
LLM.register("banana", BananaLLM)
|
||||
|
||||
if __name__ == "__main__":
|
||||
from reflector.logger import logger
|
||||
|
||||
async def main():
|
||||
llm = BananaLLM()
|
||||
result = await llm.generate("Hello, my name is")
|
||||
prompt = llm.create_prompt(
|
||||
instruct="Complete the following task",
|
||||
text="Tell me a joke about programming.",
|
||||
)
|
||||
result = await llm.generate(prompt=prompt, logger=logger)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import httpx
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger as reflector_logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class ModalLLM(LLM):
|
||||
def __init__(self):
|
||||
def __init__(self, model_name: str | None = None):
|
||||
super().__init__()
|
||||
self.timeout = settings.LLM_TIMEOUT
|
||||
self.llm_url = settings.LLM_URL + "/llm"
|
||||
@@ -13,6 +16,16 @@ class ModalLLM(LLM):
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
||||
}
|
||||
self._set_model_name(model_name if model_name else settings.DEFAULT_LLM)
|
||||
|
||||
@property
|
||||
def supported_models(self):
|
||||
"""
|
||||
List of currently supported models on this GPU platform
|
||||
"""
|
||||
# TODO: Query the specific GPU platform
|
||||
# Replace this with a HTTP call
|
||||
return ["lmsys/vicuna-13b-v1.5"]
|
||||
|
||||
async def _warmup(self, logger):
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -23,10 +36,14 @@ class ModalLLM(LLM):
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload["gen_cfg"] = gen_cfg
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await retry(client.post)(
|
||||
self.llm_url,
|
||||
@@ -37,10 +54,43 @@ class ModalLLM(LLM):
|
||||
)
|
||||
response.raise_for_status()
|
||||
text = response.json()["text"]
|
||||
if not schema:
|
||||
text = text[len(prompt) :]
|
||||
return text
|
||||
|
||||
def _set_model_name(self, model_name: str) -> bool:
|
||||
"""
|
||||
Set the model name
|
||||
"""
|
||||
# Abort, if the model is not supported
|
||||
if model_name not in self.supported_models:
|
||||
reflector_logger.info(
|
||||
f"Attempted to change {model_name=}, but is not supported."
|
||||
f"Setting model and tokenizer failed !"
|
||||
)
|
||||
return False
|
||||
# Abort, if the model is already set
|
||||
elif hasattr(self, "model_name") and model_name == self._get_model_name():
|
||||
reflector_logger.info("No change in model. Setting model skipped.")
|
||||
return False
|
||||
# Update model name and tokenizer
|
||||
self.model_name = model_name
|
||||
self.llm_tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name, cache_dir=settings.CACHE_DIR
|
||||
)
|
||||
reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.")
|
||||
return True
|
||||
|
||||
def _get_tokenizer(self) -> AutoTokenizer:
|
||||
"""
|
||||
Return the currently used LLM tokenizer
|
||||
"""
|
||||
return self.llm_tokenizer
|
||||
|
||||
def _get_model_name(self) -> str:
|
||||
"""
|
||||
Return the current model name from the instance details
|
||||
"""
|
||||
return self.model_name
|
||||
|
||||
|
||||
LLM.register("modal", ModalLLM)
|
||||
|
||||
@@ -49,15 +99,25 @@ if __name__ == "__main__":
|
||||
|
||||
async def main():
|
||||
llm = ModalLLM()
|
||||
result = await llm.generate("Hello, my name is", logger=logger)
|
||||
prompt = llm.create_prompt(
|
||||
instruct="Complete the following task",
|
||||
text="Tell me a joke about programming.",
|
||||
)
|
||||
result = await llm.generate(prompt=prompt, logger=logger)
|
||||
print(result)
|
||||
|
||||
schema = {
|
||||
gen_schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"properties": {"response": {"type": "string"}},
|
||||
}
|
||||
|
||||
result = await llm.generate("Hello, my name is", schema=schema, logger=logger)
|
||||
result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger)
|
||||
print(result)
|
||||
|
||||
gen_cfg = GenerationConfig(max_new_tokens=150)
|
||||
result = await llm.generate(
|
||||
prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger
|
||||
)
|
||||
print(result)
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -1,13 +1,21 @@
|
||||
import httpx
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class OobaboogaLLM(LLM):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
def __init__(self, model_name: str | None = None):
|
||||
super().__init__()
|
||||
|
||||
async def _generate(
|
||||
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
|
||||
):
|
||||
json_payload = {"prompt": prompt}
|
||||
if schema:
|
||||
json_payload["schema"] = schema
|
||||
if gen_schema:
|
||||
json_payload["gen_schema"] = gen_schema
|
||||
if gen_cfg:
|
||||
json_payload.update(gen_cfg)
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
settings.LLM_URL,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import httpx
|
||||
from transformers import GenerationConfig
|
||||
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class OpenAILLM(LLM):
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, model_name: str | None = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.openai_key = settings.LLM_OPENAI_KEY
|
||||
self.openai_url = settings.LLM_URL
|
||||
@@ -15,7 +17,13 @@ class OpenAILLM(LLM):
|
||||
self.max_tokens = settings.LLM_MAX_TOKENS
|
||||
logger.info(f"LLM use openai backend at {self.openai_url}")
|
||||
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
async def _generate(
|
||||
self,
|
||||
prompt: str,
|
||||
gen_schema: dict | None,
|
||||
gen_cfg: GenerationConfig | None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.openai_key}",
|
||||
|
||||
150
server/reflector/llm/llm_params.py
Normal file
150
server/reflector/llm/llm_params.py
Normal file
@@ -0,0 +1,150 @@
|
||||
from typing import Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from transformers import GenerationConfig
|
||||
|
||||
|
||||
class TaskParams(BaseModel, arbitrary_types_allowed=True):
|
||||
instruct: str
|
||||
gen_cfg: Optional[GenerationConfig] = None
|
||||
gen_schema: Optional[dict] = None
|
||||
|
||||
|
||||
T = TypeVar("T", bound="LLMTaskParams")
|
||||
|
||||
|
||||
class LLMTaskParams:
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, task, klass) -> None:
|
||||
cls._registry[task] = klass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, task: str) -> T:
|
||||
return cls._registry[task]()
|
||||
|
||||
@property
|
||||
def task_params(self) -> TaskParams | None:
|
||||
"""
|
||||
Fetch the task related parameters
|
||||
"""
|
||||
return self._get_task_params()
|
||||
|
||||
def _get_task_params(self) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class FinalLongSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3
|
||||
)
|
||||
self._instruct = """
|
||||
Take the key ideas and takeaways from the text and create a short
|
||||
summary. Be sure to keep the length of the response to a minimum.
|
||||
Do not include trivial information in the summary.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"long_summary": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""gen_schema
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class FinalShortSummaryParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=1300, num_beams=3, do_sample=True, temperature=0.3
|
||||
)
|
||||
self._instruct = """
|
||||
Take the key ideas and takeaways from the text and create a short
|
||||
summary. Be sure to keep the length of the response to a minimum.
|
||||
Do not include trivial information in the summary.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"short_summary": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class FinalTitleParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5
|
||||
)
|
||||
self._instruct = """
|
||||
Combine the following individual titles into one single short title that
|
||||
condenses the essence of all titles.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {"title": {"type": "string"}},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
class TopicParams(LLMTaskParams):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._gen_cfg = GenerationConfig(
|
||||
max_new_tokens=550, num_beams=6, do_sample=True, temperature=0.9
|
||||
)
|
||||
self._instruct = """
|
||||
Create a JSON object as response.The JSON object must have 2 fields:
|
||||
i) title and ii) summary.
|
||||
For the title field, generate a very detailed and self-explanatory
|
||||
title for the given text. Let the title be as descriptive as possible.
|
||||
For the summary field, summarize the given text in a maximum of
|
||||
three sentences.
|
||||
"""
|
||||
self._schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"summary": {"type": "string"},
|
||||
},
|
||||
}
|
||||
self._task_params = TaskParams(
|
||||
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
|
||||
)
|
||||
|
||||
def _get_task_params(self) -> TaskParams:
|
||||
"""
|
||||
Return the parameters associated with a specific LLM task
|
||||
"""
|
||||
return self._task_params
|
||||
|
||||
|
||||
LLMTaskParams.register("topic", TopicParams)
|
||||
LLMTaskParams.register("final_title", FinalTitleParams)
|
||||
LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
|
||||
LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)
|
||||
@@ -4,7 +4,20 @@ from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
|
||||
from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
|
||||
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
|
||||
from .transcript_final_long_summary import ( # noqa: F401
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
)
|
||||
from .transcript_final_short_summary import ( # noqa: F401
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
)
|
||||
from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
|
||||
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
|
||||
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
|
||||
from .types import AudioFile, FinalSummary, TitleSummary, Transcript, Word # noqa: F401
|
||||
from .types import ( # noqa: F401
|
||||
AudioFile,
|
||||
FinalLongSummary,
|
||||
FinalShortSummary,
|
||||
TitleSummary,
|
||||
Transcript,
|
||||
Word,
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
|
||||
from prometheus_client import Counter, Gauge, Histogram
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.logger import logger
|
||||
|
||||
|
||||
@@ -296,7 +297,7 @@ class BroadcastProcessor(Processor):
|
||||
types of input.
|
||||
"""
|
||||
|
||||
def __init__(self, processors: Processor):
|
||||
def __init__(self, processors: list[Processor]):
|
||||
super().__init__()
|
||||
self.processors = processors
|
||||
self.INPUT_TYPE = processors[0].INPUT_TYPE
|
||||
|
||||
59
server/reflector/processors/transcript_final_long_summary.py
Normal file
59
server/reflector/processors/transcript_final_long_summary.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import FinalLongSummary, TitleSummary
|
||||
|
||||
|
||||
class TranscriptFinalLongSummaryProcessor(Processor):
|
||||
"""
|
||||
Get the final long summary
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalLongSummary
|
||||
TASK = "final_long_summary"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = LLM.get_instance()
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def get_long_summary(self, text: str) -> str:
|
||||
"""
|
||||
Generate a long version of the final summary
|
||||
"""
|
||||
self.logger.info(f"Smoothing out {len(text)} length summary to a long summary")
|
||||
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
|
||||
|
||||
accumulated_summaries = ""
|
||||
for chunk in chunks:
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
|
||||
summary_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
accumulated_summaries += summary_result["long_summary"]
|
||||
|
||||
return accumulated_summaries
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks])
|
||||
long_summary = await self.get_long_summary(accumulated_summaries)
|
||||
|
||||
last_chunk = self.chunks[-1]
|
||||
duration = last_chunk.timestamp + last_chunk.duration
|
||||
|
||||
final_long_summary = FinalLongSummary(
|
||||
long_summary=long_summary,
|
||||
duration=duration,
|
||||
)
|
||||
await self.emit(final_long_summary)
|
||||
@@ -0,0 +1,72 @@
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import FinalShortSummary, TitleSummary
|
||||
|
||||
|
||||
class TranscriptFinalShortSummaryProcessor(Processor):
|
||||
"""
|
||||
Get the final summary using a tree summarizer
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalShortSummary
|
||||
TASK = "final_short_summary"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = LLM.get_instance()
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def get_short_summary(self, text: str) -> dict:
|
||||
"""
|
||||
Generata a short summary using tree summarizer
|
||||
"""
|
||||
self.logger.info(f"Smoothing out {len(text)} length summary to a short summary")
|
||||
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
|
||||
|
||||
if len(chunks) == 1:
|
||||
chunk = chunks[0]
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
|
||||
summary_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
return summary_result
|
||||
else:
|
||||
accumulated_summaries = ""
|
||||
for chunk in chunks:
|
||||
prompt = self.llm.create_prompt(
|
||||
instruct=self.params.instruct, text=chunk
|
||||
)
|
||||
summary_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
accumulated_summaries += summary_result["short_summary"]
|
||||
|
||||
return await self.get_short_summary(accumulated_summaries)
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks])
|
||||
short_summary_result = await self.get_short_summary(accumulated_summaries)
|
||||
|
||||
last_chunk = self.chunks[-1]
|
||||
duration = last_chunk.timestamp + last_chunk.duration
|
||||
|
||||
final_summary = FinalShortSummary(
|
||||
short_summary=short_summary_result["short_summary"],
|
||||
duration=duration,
|
||||
)
|
||||
await self.emit(final_summary)
|
||||
@@ -1,30 +0,0 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary, FinalSummary
|
||||
|
||||
|
||||
class TranscriptFinalSummaryProcessor(Processor):
|
||||
"""
|
||||
Assemble all summary into a line-based json
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
# FIXME improve final summary
|
||||
result = "\n".join([chunk.summary for chunk in self.chunks])
|
||||
last_chunk = self.chunks[-1]
|
||||
duration = last_chunk.timestamp + last_chunk.duration
|
||||
|
||||
await self.emit(FinalSummary(summary=result, duration=duration))
|
||||
65
server/reflector/processors/transcript_final_title.py
Normal file
65
server/reflector/processors/transcript_final_title.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import FinalTitle, TitleSummary
|
||||
|
||||
|
||||
class TranscriptFinalTitleProcessor(Processor):
|
||||
"""
|
||||
Assemble all summary into a line-based json
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = FinalTitle
|
||||
TASK = "final_title"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.chunks: list[TitleSummary] = []
|
||||
self.llm = LLM.get_instance()
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
self.chunks.append(data)
|
||||
|
||||
async def get_title(self, text: str) -> dict:
|
||||
"""
|
||||
Generate a title for the whole recording
|
||||
"""
|
||||
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
|
||||
|
||||
if len(chunks) == 1:
|
||||
chunk = chunks[0]
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
|
||||
title_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
return title_result
|
||||
else:
|
||||
accumulated_titles = ""
|
||||
for chunk in chunks:
|
||||
prompt = self.llm.create_prompt(
|
||||
instruct=self.params.instruct, text=chunk
|
||||
)
|
||||
title_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
accumulated_titles += title_result["summary"]
|
||||
|
||||
return await self.get_title(accumulated_titles)
|
||||
|
||||
async def _flush(self):
|
||||
if not self.chunks:
|
||||
self.logger.warning("No summary to output")
|
||||
return
|
||||
|
||||
accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
|
||||
title_result = await self.get_title(accumulated_titles)
|
||||
|
||||
final_title = FinalTitle(title=title_result["title"])
|
||||
await self.emit(final_title)
|
||||
@@ -1,7 +1,6 @@
|
||||
from reflector.llm import LLM
|
||||
from reflector.llm import LLM, LLMTaskParams
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import TitleSummary, Transcript
|
||||
from reflector.utils.retry import retry
|
||||
|
||||
|
||||
class TranscriptTopicDetectorProcessor(Processor):
|
||||
@@ -11,34 +10,14 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
|
||||
INPUT_TYPE = Transcript
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
TASK = "topic"
|
||||
|
||||
PROMPT = """
|
||||
### Human:
|
||||
Create a JSON object as response.The JSON object must have 2 fields:
|
||||
i) title and ii) summary.
|
||||
|
||||
For the title field, generate a short title for the given text.
|
||||
For the summary field, summarize the given text in a maximum of
|
||||
three sentences.
|
||||
|
||||
{input_text}
|
||||
|
||||
### Assistant:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, min_transcript_length=750, **kwargs):
|
||||
def __init__(self, min_transcript_length: int = 750, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = None
|
||||
self.min_transcript_length = min_transcript_length
|
||||
self.llm = LLM.get_instance()
|
||||
self.topic_detector_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"summary": {"type": "string"},
|
||||
},
|
||||
}
|
||||
self.params = LLMTaskParams.get_instance(self.TASK).task_params
|
||||
|
||||
async def _warmup(self):
|
||||
await self.llm.warmup(logger=self.logger)
|
||||
@@ -55,18 +34,30 @@ class TranscriptTopicDetectorProcessor(Processor):
|
||||
return
|
||||
await self.flush()
|
||||
|
||||
async def get_topic(self, text: str) -> dict:
|
||||
"""
|
||||
Generate a topic and description for a transcription excerpt
|
||||
"""
|
||||
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=text)
|
||||
topic_result = await self.llm.generate(
|
||||
prompt=prompt,
|
||||
gen_schema=self.params.gen_schema,
|
||||
gen_cfg=self.params.gen_cfg,
|
||||
logger=self.logger,
|
||||
)
|
||||
return topic_result
|
||||
|
||||
async def _flush(self):
|
||||
if not self.transcript:
|
||||
return
|
||||
|
||||
text = self.transcript.text
|
||||
self.logger.info(f"Topic detector got {len(text)} length transcript")
|
||||
prompt = self.PROMPT.format(input_text=text)
|
||||
result = await retry(self.llm.generate)(
|
||||
prompt=prompt, schema=self.topic_detector_schema, logger=self.logger
|
||||
)
|
||||
topic_result = await self.get_topic(text=text)
|
||||
|
||||
summary = TitleSummary(
|
||||
title=result["title"],
|
||||
summary=result["summary"],
|
||||
title=self.llm.ensure_casing(topic_result["title"]),
|
||||
summary=topic_result["summary"],
|
||||
timestamp=self.transcript.timestamp,
|
||||
duration=self.transcript.duration,
|
||||
transcript=self.transcript,
|
||||
|
||||
@@ -103,11 +103,20 @@ class TitleSummary(BaseModel):
|
||||
return f"{minutes:02d}:{seconds:02d}.{milliseconds:03d}"
|
||||
|
||||
|
||||
class FinalSummary(BaseModel):
|
||||
summary: str
|
||||
class FinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
duration: float
|
||||
|
||||
|
||||
class FinalShortSummary(BaseModel):
|
||||
short_summary: str
|
||||
duration: float
|
||||
|
||||
|
||||
class FinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranslationLanguages(BaseModel):
|
||||
language_to_id_mapping: dict = {
|
||||
"Afrikaans": "af",
|
||||
|
||||
@@ -91,5 +91,11 @@ class Settings(BaseSettings):
|
||||
# if set, all anonymous record will be public
|
||||
PUBLIC_MODE: bool = False
|
||||
|
||||
# Default LLM model name
|
||||
DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5"
|
||||
|
||||
# Cache directory for all model storage
|
||||
CACHE_DIR: str = "data"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
|
||||
import av
|
||||
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
@@ -8,10 +9,13 @@ from reflector.processors import (
|
||||
AudioTranscriptAutoProcessor,
|
||||
Pipeline,
|
||||
PipelineEvent,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
|
||||
|
||||
async def process_audio_file(
|
||||
@@ -31,7 +35,13 @@ async def process_audio_file(
|
||||
if not only_transcript:
|
||||
processors += [
|
||||
TranscriptTopicDetectorProcessor.as_threaded(),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(),
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
# transcription output
|
||||
|
||||
@@ -8,6 +8,7 @@ from aiortc import MediaStreamTrack, RTCPeerConnection, RTCSessionDescription
|
||||
from fastapi import APIRouter, Request
|
||||
from prometheus_client import Gauge
|
||||
from pydantic import BaseModel
|
||||
|
||||
from reflector.events import subscribers_shutdown
|
||||
from reflector.logger import logger
|
||||
from reflector.processors import (
|
||||
@@ -15,14 +16,19 @@ from reflector.processors import (
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
FinalSummary,
|
||||
FinalLongSummary,
|
||||
FinalShortSummary,
|
||||
Pipeline,
|
||||
TitleSummary,
|
||||
Transcript,
|
||||
TranscriptFinalSummaryProcessor,
|
||||
TranscriptFinalLongSummaryProcessor,
|
||||
TranscriptFinalShortSummaryProcessor,
|
||||
TranscriptFinalTitleProcessor,
|
||||
TranscriptLinerProcessor,
|
||||
TranscriptTopicDetectorProcessor,
|
||||
)
|
||||
from reflector.processors.base import BroadcastProcessor
|
||||
from reflector.processors.types import FinalTitle
|
||||
|
||||
sessions = []
|
||||
router = APIRouter()
|
||||
@@ -72,8 +78,10 @@ class StrValue(BaseModel):
|
||||
class PipelineEvent(StrEnum):
|
||||
TRANSCRIPT = "TRANSCRIPT"
|
||||
TOPIC = "TOPIC"
|
||||
FINAL_SUMMARY = "FINAL_SUMMARY"
|
||||
FINAL_LONG_SUMMARY = "FINAL_LONG_SUMMARY"
|
||||
STATUS = "STATUS"
|
||||
FINAL_SHORT_SUMMARY = "FINAL_SHORT_SUMMARY"
|
||||
FINAL_TITLE = "FINAL_TITLE"
|
||||
|
||||
|
||||
async def rtc_offer_base(
|
||||
@@ -124,15 +132,15 @@ async def rtc_offer_base(
|
||||
data=transcript,
|
||||
)
|
||||
|
||||
async def on_topic(summary: TitleSummary):
|
||||
async def on_topic(topic: TitleSummary):
|
||||
# FIXME: make it incremental with the frontend, not send everything
|
||||
ctx.logger.info("Summary", summary=summary)
|
||||
ctx.logger.info("Topic", topic=topic)
|
||||
ctx.topics.append(
|
||||
{
|
||||
"title": summary.title,
|
||||
"timestamp": summary.timestamp,
|
||||
"transcript": summary.transcript.text,
|
||||
"desc": summary.summary,
|
||||
"title": topic.title,
|
||||
"timestamp": topic.timestamp,
|
||||
"transcript": topic.transcript.text,
|
||||
"desc": topic.summary,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -144,17 +152,17 @@ async def rtc_offer_base(
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=summary
|
||||
event=PipelineEvent.TOPIC, args=event_callback_args, data=topic
|
||||
)
|
||||
|
||||
async def on_final_summary(summary: FinalSummary):
|
||||
ctx.logger.info("FinalSummary", final_summary=summary)
|
||||
async def on_final_short_summary(summary: FinalShortSummary):
|
||||
ctx.logger.info("FinalShortSummary", final_short_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_SUMMARY",
|
||||
"summary": summary.summary,
|
||||
"cmd": "DISPLAY_FINAL_SHORT_SUMMARY",
|
||||
"summary": summary.short_summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
@@ -162,11 +170,47 @@ async def rtc_offer_base(
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_SUMMARY,
|
||||
event=PipelineEvent.FINAL_SHORT_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
async def on_final_long_summary(summary: FinalLongSummary):
|
||||
ctx.logger.info("FinalLongSummary", final_summary=summary)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {
|
||||
"cmd": "DISPLAY_FINAL_LONG_SUMMARY",
|
||||
"summary": summary.long_summary,
|
||||
"duration": summary.duration,
|
||||
}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_LONG_SUMMARY,
|
||||
args=event_callback_args,
|
||||
data=summary,
|
||||
)
|
||||
|
||||
async def on_final_title(title: FinalTitle):
|
||||
ctx.logger.info("FinalTitle", final_title=title)
|
||||
|
||||
# send to RTC
|
||||
if ctx.data_channel.readyState == "open":
|
||||
result = {"cmd": "DISPLAY_FINAL_TITLE", "title": title.title}
|
||||
ctx.data_channel.send(dumps(result))
|
||||
|
||||
# send to callback (eg. websocket)
|
||||
if event_callback:
|
||||
await event_callback(
|
||||
event=PipelineEvent.FINAL_TITLE,
|
||||
args=event_callback_args,
|
||||
data=title,
|
||||
)
|
||||
|
||||
# create a context for the whole rtc transaction
|
||||
# add a customised logger to the context
|
||||
processors = []
|
||||
@@ -178,7 +222,17 @@ async def rtc_offer_base(
|
||||
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||
TranscriptFinalSummaryProcessor.as_threaded(callback=on_final_summary),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=on_final_title),
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=on_final_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=on_final_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
ctx.pipeline = Pipeline(*processors)
|
||||
ctx.pipeline.set_pref("audio:source_language", source_language)
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Annotated, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import av
|
||||
import reflector.auth as auth
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
Depends,
|
||||
@@ -18,11 +17,13 @@ from fastapi import (
|
||||
)
|
||||
from fastapi_pagination import Page, paginate
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
import reflector.auth as auth
|
||||
from reflector.db import database, transcripts
|
||||
from reflector.logger import logger
|
||||
from reflector.settings import settings
|
||||
from reflector.utils.audio_waveform import get_audio_waveform
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from ._range_requests_response import range_requests_response
|
||||
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||
@@ -60,8 +61,16 @@ class TranscriptTopic(BaseModel):
|
||||
timestamp: float
|
||||
|
||||
|
||||
class TranscriptFinalSummary(BaseModel):
|
||||
summary: str
|
||||
class TranscriptFinalShortSummary(BaseModel):
|
||||
short_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalLongSummary(BaseModel):
|
||||
long_summary: str
|
||||
|
||||
|
||||
class TranscriptFinalTitle(BaseModel):
|
||||
title: str
|
||||
|
||||
|
||||
class TranscriptEvent(BaseModel):
|
||||
@@ -77,7 +86,9 @@ class Transcript(BaseModel):
|
||||
locked: bool = False
|
||||
duration: float = 0
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
summary: str | None = None
|
||||
title: str | None = None
|
||||
short_summary: str | None = None
|
||||
long_summary: str | None = None
|
||||
topics: list[TranscriptTopic] = []
|
||||
events: list[TranscriptEvent] = []
|
||||
source_language: str = "en"
|
||||
@@ -241,7 +252,9 @@ class GetTranscript(BaseModel):
|
||||
status: str
|
||||
locked: bool
|
||||
duration: int
|
||||
summary: str | None
|
||||
title: str | None
|
||||
short_summary: str | None
|
||||
long_summary: str | None
|
||||
created_at: datetime
|
||||
source_language: str
|
||||
target_language: str
|
||||
@@ -256,7 +269,9 @@ class CreateTranscript(BaseModel):
|
||||
class UpdateTranscript(BaseModel):
|
||||
name: Optional[str] = Field(None)
|
||||
locked: Optional[bool] = Field(None)
|
||||
summary: Optional[str] = Field(None)
|
||||
title: Optional[str] = Field(None)
|
||||
short_summary: Optional[str] = Field(None)
|
||||
long_summary: Optional[str] = Field(None)
|
||||
|
||||
|
||||
class DeletionStatus(BaseModel):
|
||||
@@ -315,20 +330,32 @@ async def transcript_update(
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
values = {}
|
||||
values = {"events": []}
|
||||
if info.name is not None:
|
||||
values["name"] = info.name
|
||||
if info.locked is not None:
|
||||
values["locked"] = info.locked
|
||||
if info.summary is not None:
|
||||
values["summary"] = info.summary
|
||||
# also find FINAL_SUMMARY event and patch it
|
||||
for te in transcript.events:
|
||||
if te["event"] == PipelineEvent.FINAL_SUMMARY:
|
||||
te["summary"] = info.summary
|
||||
if info.long_summary is not None:
|
||||
values["long_summary"] = info.long_summary
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_LONG_SUMMARY:
|
||||
transcript_event["long_summary"] = info.long_summary
|
||||
break
|
||||
values["events"] = transcript.events
|
||||
|
||||
values["events"].extend(transcript.events)
|
||||
if info.short_summary is not None:
|
||||
values["short_summary"] = info.short_summary
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_SHORT_SUMMARY:
|
||||
transcript_event["short_summary"] = info.short_summary
|
||||
break
|
||||
values["events"].extend(transcript.events)
|
||||
if info.title is not None:
|
||||
values["title"] = info.title
|
||||
for transcript_event in transcript.events:
|
||||
if transcript_event["event"] == PipelineEvent.FINAL_TITLE:
|
||||
transcript_event["title"] = info.title
|
||||
break
|
||||
values["events"].extend(transcript.events)
|
||||
await transcripts_controller.update(transcript, values)
|
||||
return transcript
|
||||
|
||||
@@ -539,14 +566,38 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_SUMMARY:
|
||||
final_summary = TranscriptFinalSummary(summary=data.summary)
|
||||
resp = transcript.add_event(event=event, data=final_summary)
|
||||
elif event == PipelineEvent.FINAL_TITLE:
|
||||
final_title = TranscriptFinalTitle(title=data.title)
|
||||
resp = transcript.add_event(event=event, data=final_title)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"summary": final_summary.summary,
|
||||
"title": final_title.title,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_LONG_SUMMARY:
|
||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||
resp = transcript.add_event(event=event, data=final_long_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"long_summary": final_long_summary.long_summary,
|
||||
},
|
||||
)
|
||||
|
||||
elif event == PipelineEvent.FINAL_SHORT_SUMMARY:
|
||||
final_short_summary = TranscriptFinalShortSummary(
|
||||
short_summary=data.short_summary
|
||||
)
|
||||
resp = transcript.add_event(event=event, data=final_short_summary)
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"events": transcript.events_dump(),
|
||||
"short_summary": final_short_summary.short_summary,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user