Feature additions (#210)

* initial

* add LLM features

* update LLM logic

* update llm functions: change control flow

* add generation config

* update return types

* update processors and tests

* update rtc_offer

* revert new title processor change

* fix unit tests

* add comments and fix HTTP 500

* adjust prompt

* test with reflector app

* revert new event for final title

* update

* move onus onto processors

* move onus onto processors

* stash

* add provision for gen config

* dynamically pack the LLM input using context length

* tune final summary params

* update consolidated class structures

* update consolidated class structures

* update precommit

* add broadcast processors

* working baseline

* Organize LLMParams

* minor fixes

* minor fixes

* minor fixes

* fix unit tests

* fix unit tests

* fix unit tests

* update tests

* update tests

* edit pipeline response events

* update summary return types

* configure tests

* alembic db migration

* change LLM response flow

* edit main llm functions

* edit main llm functions

* change llm name and gen cf

* Update transcript_topic_detector.py

* PR review comments

* checkpoint before db event migration

* update DB migration of past events

* update DB migration of past events

* edit LLM classes

* Delete unwanted file

* remove List typing

* remove List typing

* update oobabooga API call

* topic enhancements

* update UI event handling

* move ensure_casing to llm base

* update tests

* update tests
This commit is contained in:
projects-g
2023-09-13 11:26:08 +05:30
committed by GitHub
parent 762d7bfc3c
commit 9fe261406c
33 changed files with 1334 additions and 202 deletions

View File

@@ -1,3 +1,5 @@
from unittest.mock import patch
import pytest
@@ -14,3 +16,50 @@ async def setup_database():
metadata.create_all(bind=engine)
yield
@pytest.fixture
def dummy_processors():
with patch(
"reflector.processors.transcript_topic_detector.TranscriptTopicDetectorProcessor.get_topic"
) as mock_topic, patch(
"reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title"
) as mock_title, patch(
"reflector.processors.transcript_final_long_summary.TranscriptFinalLongSummaryProcessor.get_long_summary"
) as mock_long_summary, patch(
"reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary"
) as mock_short_summary:
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
mock_title.return_value = {"title": "LLM FINAL TITLE"}
mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
yield mock_topic, mock_title, mock_long_summary, mock_short_summary
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
class TestLLM(LLM):
def __init__(self):
self.model_name = "DUMMY MODEL"
self.llm_tokenizer = "DUMMY TOKENIZER"
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@pytest.fixture
def nltk():
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
mock_nltk.return_value = "NLTK PACKAGE"
yield
@pytest.fixture
def ensure_casing():
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
mock_casing.return_value = "LLM TITLE"
yield

View File

@@ -2,7 +2,7 @@ import pytest
@pytest.mark.asyncio
async def test_processor_broadcast():
async def test_processor_broadcast(nltk):
from reflector.processors.base import Processor, BroadcastProcessor, Pipeline
class TestProcessor(Processor):

View File

@@ -2,27 +2,19 @@ import pytest
@pytest.mark.asyncio
async def test_basic_process(event_loop):
async def test_basic_process(
event_loop, nltk, dummy_llm, dummy_processors, ensure_casing
):
# goal is to start the server, and send rtc audio to it
# validate the events received
from reflector.tools.process import process_audio_file
from reflector.settings import settings
from reflector.llm.base import LLM
from pathlib import Path
# use an LLM test backend
settings.LLM_BACKEND = "test"
settings.TRANSCRIPT_BACKEND = "whisper"
class LLMTest(LLM):
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
return {
"title": "TITLE",
"summary": "SUMMARY",
}
LLM.register("test", LLMTest)
# event callback
marks = {}
@@ -39,4 +31,6 @@ async def test_basic_process(event_loop):
# validate the events
assert marks["TranscriptLinerProcessor"] == 5
assert marks["TranscriptTopicDetectorProcessor"] == 1
assert marks["TranscriptFinalSummaryProcessor"] == 1
assert marks["TranscriptFinalLongSummaryProcessor"] == 1
assert marks["TranscriptFinalShortSummaryProcessor"] == 1
assert marks["TranscriptFinalTitleProcessor"] == 1

View File

@@ -75,21 +75,52 @@ async def test_transcript_get_update_summary():
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["summary"] is None
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["summary"] is None
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
response = await ac.patch(f"/transcripts/{tid}", json={"summary": "test"})
response = await ac.patch(
f"/transcripts/{tid}",
json={"long_summary": "test_long", "short_summary": "test_short"},
)
assert response.status_code == 200
assert response.json()["summary"] == "test"
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["summary"] == "test"
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
@pytest.mark.asyncio
async def test_transcript_get_update_title():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["title"] is None
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] is None
response = await ac.patch(f"/transcripts/{tid}", json={"title": "test_title"})
assert response.status_code == 200
assert response.json()["title"] == "test_title"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] == "test_title"
@pytest.mark.asyncio

View File

@@ -67,21 +67,10 @@ async def dummy_transcript():
yield
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
class TestLLM(LLM):
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm):
async def test_transcript_rtc_and_websocket(
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
@@ -186,9 +175,17 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SUMMARY")]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert "FINAL_LONG_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM FINAL TITLE"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
@@ -218,7 +215,9 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dummy_llm):
async def test_transcript_rtc_and_websocket_and_fr(
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
# to be able to connect with aiortc
@@ -326,9 +325,17 @@ async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dum
assert ev["data"]["transcript"].startswith("Hello world")
assert ev["data"]["timestamp"] == 0.0
assert "FINAL_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SUMMARY")]
assert ev["data"]["summary"] == "LLM SUMMARY"
assert "FINAL_LONG_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
assert "FINAL_SHORT_SUMMARY" in eventnames
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM FINAL TITLE"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]