New summary (#283)

* handover final summary to Zephyr deployment

* fix display error

* push new summary feature

* fix failing test case

* Added markdown support for final summary

* update UI render issue

* retain sentence tokenizer call

---------

Co-authored-by: Koper <andreas@monadical.com>
This commit is contained in:
projects-g
2023-10-13 22:53:29 +05:30
committed by GitHub
parent 38cd0385b4
commit 1d92d43fe0
13 changed files with 933 additions and 23 deletions

View File

@@ -0,0 +1,208 @@
"""
Reflector GPU backend - LLM
===========================
"""
import json
import os
from typing import Optional
import modal
from modal import Image, Secret, Stub, asgi_app, method
# LLM
LLM_MODEL: str = "HuggingFaceH4/zephyr-7b-alpha"
LLM_LOW_CPU_MEM_USAGE: bool = True
LLM_TORCH_DTYPE: str = "bfloat16"
LLM_MAX_NEW_TOKENS: int = 300
IMAGE_MODEL_DIR = "/root/llm_models"
stub = Stub(name="reflector-llm-zephyr")
def download_llm():
from huggingface_hub import snapshot_download
print("Downloading LLM model")
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
print("LLM model downloaded")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
print("LLM cache moved")
llm_image = (
Image.debian_slim(python_version="3.10.8")
.apt_install("git")
.pip_install(
"transformers==4.34.0",
"torch",
"sentencepiece",
"protobuf",
"jsonformer==0.12.0",
"accelerate==0.21.0",
"einops==0.6.1",
"hf-transfer~=0.1",
"huggingface_hub==0.16.4"
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(download_llm)
.run_function(migrate_cache_llm)
)
@stub.cls(
gpu="A10G",
timeout=60 * 5,
container_idle_timeout=60 * 5,
concurrency_limit=2,
image=llm_image,
)
class LLM:
def __enter__(self):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
print("Instance llm model")
model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL,
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
cache_dir=IMAGE_MODEL_DIR
)
# JSONFormer doesn't yet support generation configs
print("Instance llm generation config")
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
# generation configuration
gen_cfg = GenerationConfig.from_model_config(model.config)
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
# load tokenizer
print("Instance llm tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL,
cache_dir=IMAGE_MODEL_DIR
)
gen_cfg.pad_token_id = tokenizer.eos_token_id
gen_cfg.eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
# move model to gpu
print("Move llm model to GPU")
model = model.cuda()
print("Warmup llm done")
self.model = model
self.tokenizer = tokenizer
self.gen_cfg = gen_cfg
self.GenerationConfig = GenerationConfig
def __exit__(self, *args):
print("Exit llm")
@method()
def generate(self, prompt: str, gen_schema: str | None, gen_cfg: str | None) -> dict:
"""
Perform a generation action using the LLM
"""
print(f"Generate {prompt=}")
if gen_cfg:
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
gen_cfg.pad_token_id = self.tokenizer.eos_token_id
gen_cfg.eos_token_id = self.tokenizer.eos_token_id
else:
gen_cfg = self.gen_cfg
# If a gen_schema is given, conform to gen_schema
if gen_schema:
import jsonformer
print(f"Schema {gen_schema=}")
jsonformer_llm = jsonformer.Jsonformer(
model=self.model,
tokenizer=self.tokenizer,
json_schema=json.loads(gen_schema),
prompt=prompt,
max_string_token_length=gen_cfg.max_new_tokens
)
response = jsonformer_llm()
else:
# If no gen_schema, perform prompt only generation
# tokenize prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
output = self.model.generate(input_ids, generation_config=gen_cfg)
# decode output
response = self.tokenizer.decode(output[0].cpu(), skip_special_tokens=True)
response = response[len(prompt):]
response = {
"long_summary": response
}
print(f"Generated {response=}")
return {"text": response}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@stub.function(
container_idle_timeout=60 * 10,
timeout=60 * 5,
secrets=[
Secret.from_name("reflector-gpu"),
],
)
@asgi_app()
def web():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
llmstub = LLM()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class LLMRequest(BaseModel):
prompt: str
gen_schema: Optional[dict] = None
gen_cfg: Optional[dict] = None
@app.post("/llm", dependencies=[Depends(apikey_auth)])
async def llm(
req: LLMRequest,
):
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
func = llmstub.generate.spawn(prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg)
result = func.get()
return result
return app

View File

@@ -258,7 +258,7 @@ class LLM:
"""
Choose the token size to set as the threshold to pack the LLM calls
"""
buffer_token_size = 25
buffer_token_size = 100
default_output_tokens = 1000
context_window = self.tokenizer.model_max_length
tokens = self.tokenizer.tokenize(

View File

@@ -23,7 +23,7 @@ class ModalLLM(LLM):
"""
# TODO: Query the specific GPU platform
# Replace this with a HTTP call
return ["lmsys/vicuna-13b-v1.5"]
return ["lmsys/vicuna-13b-v1.5", "HuggingFaceH4/zephyr-7b-alpha"]
async def _generate(
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
@@ -33,6 +33,13 @@ class ModalLLM(LLM):
json_payload["gen_schema"] = gen_schema
if gen_cfg:
json_payload["gen_cfg"] = gen_cfg
# Handing over generation of the final summary to Zephyr model
# but replacing the Vicuna model will happen after more testing
# TODO: Create a mapping of model names and cloud deployments
if self.model_name == "HuggingFaceH4/zephyr-7b-alpha":
self.llm_url = settings.ZEPHYR_LLM_URL + "/llm"
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,

View File

@@ -144,7 +144,76 @@ class TopicParams(LLMTaskParams):
return self._task_params
class BulletedSummaryParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=800,
num_beams=1,
do_sample=True,
temperature=0.2,
early_stopping=True,
)
self._instruct = """
Given a meeting transcript, extract the key things discussed in the
form of a list.
While generating the response, follow the constraints mentioned below.
Summary constraints:
i) Do not add new content, except to fix spelling or punctuation.
ii) Do not add any prefixes or numbering in the response.
iii) The summarization should be as information dense as possible.
iv) Do not add any additional sections like Note, Conclusion, etc. in
the response.
Response format:
i) The response should be in the form of a bulleted list.
ii) Iteratively merge all the relevant paragraphs together to keep the
number of paragraphs to a minimum.
iii) Remove any unfinished sentences from the final response.
iv) Do not include narrative or reporting clauses.
v) Use "*" as the bullet icon.
"""
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""gen_schema
Return the parameters associated with a specific LLM task
"""
return self._task_params
class MergedSummaryParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=600,
num_beams=1,
do_sample=True,
temperature=0.2,
early_stopping=True,
)
self._instruct = """
Given the key points of a meeting, summarize the points to describe the
meeting in the form of paragraphs.
"""
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""gen_schema
Return the parameters associated with a specific LLM task
"""
return self._task_params
LLMTaskParams.register("topic", TopicParams)
LLMTaskParams.register("final_title", FinalTitleParams)
LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)
LLMTaskParams.register("bullet_summary", BulletedSummaryParams)
LLMTaskParams.register("merged_summary", MergedSummaryParams)

View File

@@ -1,3 +1,4 @@
import nltk
from reflector.llm import LLM, LLMTaskParams
from reflector.processors.base import Processor
from reflector.processors.types import FinalLongSummary, TitleSummary
@@ -10,36 +11,58 @@ class TranscriptFinalLongSummaryProcessor(Processor):
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = FinalLongSummary
TASK = "final_long_summary"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.chunks: list[TitleSummary] = []
self.llm = LLM.get_instance()
self.params = LLMTaskParams.get_instance(self.TASK).task_params
self.llm = LLM.get_instance(model_name="HuggingFaceH4/zephyr-7b-alpha")
async def _push(self, data: TitleSummary):
self.chunks.append(data)
async def get_bullet_summary(self, text: str) -> str:
params = LLMTaskParams.get_instance("bullet_summary").task_params
chunks = list(self.llm.split_corpus(corpus=text, task_params=params))
bullet_summary = ""
for chunk in chunks:
prompt = self.llm.create_prompt(instruct=params.instruct, text=chunk)
summary_result = await self.llm.generate(
prompt=prompt,
gen_schema=params.gen_schema,
gen_cfg=params.gen_cfg,
logger=self.logger,
)
bullet_summary += summary_result["long_summary"]
return bullet_summary
async def get_merged_summary(self, text: str) -> str:
params = LLMTaskParams.get_instance("merged_summary").task_params
chunks = list(self.llm.split_corpus(corpus=text, task_params=params))
merged_summary = ""
for chunk in chunks:
prompt = self.llm.create_prompt(instruct=params.instruct, text=chunk)
summary_result = await self.llm.generate(
prompt=prompt,
gen_schema=params.gen_schema,
gen_cfg=params.gen_cfg,
logger=self.logger,
)
merged_summary += summary_result["long_summary"]
return merged_summary
async def get_long_summary(self, text: str) -> str:
"""
Generate a long version of the final summary
"""
self.logger.info(f"Smoothing out {len(text)} length summary to a long summary")
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
bullet_summary = await self.get_bullet_summary(text)
merged_summary = await self.get_merged_summary(bullet_summary)
accumulated_summaries = ""
for chunk in chunks:
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
summary_result = await self.llm.generate(
prompt=prompt,
gen_schema=self.params.gen_schema,
gen_cfg=self.params.gen_cfg,
logger=self.logger,
)
accumulated_summaries += summary_result["long_summary"]
return merged_summary
return accumulated_summaries
def sentence_tokenize(self, text: str) -> [str]:
return nltk.sent_tokenize(text)
async def _flush(self):
if not self.chunks:
@@ -49,11 +72,25 @@ class TranscriptFinalLongSummaryProcessor(Processor):
accumulated_summaries = " ".join([chunk.summary for chunk in self.chunks])
long_summary = await self.get_long_summary(accumulated_summaries)
# Format the output as much as possible to be handled
# by front-end for displaying
summary_sentences = []
for sentence in self.sentence_tokenize(long_summary):
sentence = str(sentence).strip()
if sentence.startswith("- "):
sentence.replace("- ", "* ")
else:
sentence = "* " + sentence
sentence += " \n"
summary_sentences.append(sentence)
formatted_long_summary = "".join(summary_sentences)
last_chunk = self.chunks[-1]
duration = last_chunk.timestamp + last_chunk.duration
final_long_summary = FinalLongSummary(
long_summary=long_summary,
long_summary=formatted_long_summary,
duration=duration,
)
await self.emit(final_long_summary)

View File

@@ -72,6 +72,7 @@ class Settings(BaseSettings):
LLM_TIMEOUT: int = 60 * 5 # take cold start into account
LLM_MAX_TOKENS: int = 1024
LLM_TEMPERATURE: float = 0.7
ZEPHYR_LLM_URL: str | None = None
# LLM Banana configuration
LLM_BANANA_API_KEY: str | None = None

View File

@@ -93,3 +93,12 @@ def ensure_casing():
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
mock_casing.return_value = "LLM TITLE"
yield
@pytest.fixture
def sentence_tokenize():
with patch(
"reflector.processors.TranscriptFinalLongSummaryProcessor" ".sentence_tokenize"
) as mock_sent_tokenize:
mock_sent_tokenize.return_value = ["LLM LONG SUMMARY"]
yield

View File

@@ -9,6 +9,7 @@ async def test_basic_process(
dummy_llm,
dummy_processors,
ensure_casing,
sentence_tokenize,
):
# goal is to start the server, and send rtc audio to it
# validate the events received

View File

@@ -60,6 +60,7 @@ async def test_transcript_rtc_and_websocket(
dummy_processors,
ensure_casing,
appserver,
sentence_tokenize,
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
@@ -156,7 +157,7 @@ async def test_transcript_rtc_and_websocket(
assert "FINAL_LONG_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert ev["data"]["long_summary"] == "* LLM LONG SUMMARY \n"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
@@ -193,6 +194,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_processors,
ensure_casing,
appserver,
sentence_tokenize,
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
@@ -292,7 +294,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
assert "FINAL_LONG_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert ev["data"]["long_summary"] == "* LLM LONG SUMMARY \n"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]