mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
Feature additions (#210)
* initial * add LLM features * update LLM logic * update llm functions: change control flow * add generation config * update return types * update processors and tests * update rtc_offer * revert new title processor change * fix unit tests * add comments and fix HTTP 500 * adjust prompt * test with reflector app * revert new event for final title * update * move onus onto processors * move onus onto processors * stash * add provision for gen config * dynamically pack the LLM input using context length * tune final summary params * update consolidated class structures * update consolidated class structures * update precommit * add broadcast processors * working baseline * Organize LLMParams * minor fixes * minor fixes * minor fixes * fix unit tests * fix unit tests * fix unit tests * update tests * update tests * edit pipeline response events * update summary return types * configure tests * alembic db migration * change LLM response flow * edit main llm functions * edit main llm functions * change llm name and gen cf * Update transcript_topic_detector.py * PR review comments * checkpoint before db event migration * update DB migration of past events * update DB migration of past events * edit LLM classes * Delete unwanted file * remove List typing * remove List typing * update oobabooga API call * topic enhancements * update UI event handling * move ensure_casing to llm base * update tests * update tests
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -14,3 +16,50 @@ async def setup_database():
|
||||
metadata.create_all(bind=engine)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_processors():
|
||||
with patch(
|
||||
"reflector.processors.transcript_topic_detector.TranscriptTopicDetectorProcessor.get_topic"
|
||||
) as mock_topic, patch(
|
||||
"reflector.processors.transcript_final_title.TranscriptFinalTitleProcessor.get_title"
|
||||
) as mock_title, patch(
|
||||
"reflector.processors.transcript_final_long_summary.TranscriptFinalLongSummaryProcessor.get_long_summary"
|
||||
) as mock_long_summary, patch(
|
||||
"reflector.processors.transcript_final_short_summary.TranscriptFinalShortSummaryProcessor.get_short_summary"
|
||||
) as mock_short_summary:
|
||||
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
|
||||
mock_title.return_value = {"title": "LLM FINAL TITLE"}
|
||||
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
||||
mock_short_summary.return_value = {"short_summary": "LLM SHORT SUMMARY"}
|
||||
|
||||
yield mock_topic, mock_title, mock_long_summary, mock_short_summary
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_llm():
|
||||
from reflector.llm.base import LLM
|
||||
|
||||
class TestLLM(LLM):
|
||||
def __init__(self):
|
||||
self.model_name = "DUMMY MODEL"
|
||||
self.llm_tokenizer = "DUMMY TOKENIZER"
|
||||
|
||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||
mock_llm.return_value = TestLLM()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nltk():
|
||||
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
|
||||
mock_nltk.return_value = "NLTK PACKAGE"
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ensure_casing():
|
||||
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
|
||||
mock_casing.return_value = "LLM TITLE"
|
||||
yield
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processor_broadcast():
|
||||
async def test_processor_broadcast(nltk):
|
||||
from reflector.processors.base import Processor, BroadcastProcessor, Pipeline
|
||||
|
||||
class TestProcessor(Processor):
|
||||
|
||||
@@ -2,27 +2,19 @@ import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_process(event_loop):
|
||||
async def test_basic_process(
|
||||
event_loop, nltk, dummy_llm, dummy_processors, ensure_casing
|
||||
):
|
||||
# goal is to start the server, and send rtc audio to it
|
||||
# validate the events received
|
||||
from reflector.tools.process import process_audio_file
|
||||
from reflector.settings import settings
|
||||
from reflector.llm.base import LLM
|
||||
from pathlib import Path
|
||||
|
||||
# use an LLM test backend
|
||||
settings.LLM_BACKEND = "test"
|
||||
settings.TRANSCRIPT_BACKEND = "whisper"
|
||||
|
||||
class LLMTest(LLM):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs) -> str:
|
||||
return {
|
||||
"title": "TITLE",
|
||||
"summary": "SUMMARY",
|
||||
}
|
||||
|
||||
LLM.register("test", LLMTest)
|
||||
|
||||
# event callback
|
||||
marks = {}
|
||||
|
||||
@@ -39,4 +31,6 @@ async def test_basic_process(event_loop):
|
||||
# validate the events
|
||||
assert marks["TranscriptLinerProcessor"] == 5
|
||||
assert marks["TranscriptTopicDetectorProcessor"] == 1
|
||||
assert marks["TranscriptFinalSummaryProcessor"] == 1
|
||||
assert marks["TranscriptFinalLongSummaryProcessor"] == 1
|
||||
assert marks["TranscriptFinalShortSummaryProcessor"] == 1
|
||||
assert marks["TranscriptFinalTitleProcessor"] == 1
|
||||
|
||||
@@ -75,21 +75,52 @@ async def test_transcript_get_update_summary():
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["summary"] is None
|
||||
assert response.json()["long_summary"] is None
|
||||
assert response.json()["short_summary"] is None
|
||||
|
||||
tid = response.json()["id"]
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["summary"] is None
|
||||
assert response.json()["long_summary"] is None
|
||||
assert response.json()["short_summary"] is None
|
||||
|
||||
response = await ac.patch(f"/transcripts/{tid}", json={"summary": "test"})
|
||||
response = await ac.patch(
|
||||
f"/transcripts/{tid}",
|
||||
json={"long_summary": "test_long", "short_summary": "test_short"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json()["summary"] == "test"
|
||||
assert response.json()["long_summary"] == "test_long"
|
||||
assert response.json()["short_summary"] == "test_short"
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["summary"] == "test"
|
||||
assert response.json()["long_summary"] == "test_long"
|
||||
assert response.json()["short_summary"] == "test_short"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_get_update_title():
|
||||
from reflector.app import app
|
||||
|
||||
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
|
||||
response = await ac.post("/transcripts", json={"name": "test"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["title"] is None
|
||||
|
||||
tid = response.json()["id"]
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["title"] is None
|
||||
|
||||
response = await ac.patch(f"/transcripts/{tid}", json={"title": "test_title"})
|
||||
assert response.status_code == 200
|
||||
assert response.json()["title"] == "test_title"
|
||||
|
||||
response = await ac.get(f"/transcripts/{tid}")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["title"] == "test_title"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -67,21 +67,10 @@ async def dummy_transcript():
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_llm():
|
||||
from reflector.llm.base import LLM
|
||||
|
||||
class TestLLM(LLM):
|
||||
async def _generate(self, prompt: str, schema: dict | None, **kwargs):
|
||||
return json.dumps({"title": "LLM TITLE", "summary": "LLM SUMMARY"})
|
||||
|
||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||
mock_llm.return_value = TestLLM()
|
||||
yield
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm):
|
||||
async def test_transcript_rtc_and_websocket(
|
||||
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
|
||||
):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
# because of that, we need to start the server in a thread
|
||||
# to be able to connect with aiortc
|
||||
@@ -186,9 +175,17 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
|
||||
assert ev["data"]["transcript"].startswith("Hello world")
|
||||
assert ev["data"]["timestamp"] == 0.0
|
||||
|
||||
assert "FINAL_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SUMMARY")]
|
||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||
assert "FINAL_LONG_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
|
||||
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
|
||||
|
||||
assert "FINAL_SHORT_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
|
||||
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
|
||||
|
||||
assert "FINAL_TITLE" in eventnames
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM FINAL TITLE"
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
@@ -218,7 +215,9 @@ async def test_transcript_rtc_and_websocket(tmpdir, dummy_transcript, dummy_llm)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dummy_llm):
|
||||
async def test_transcript_rtc_and_websocket_and_fr(
|
||||
tmpdir, dummy_llm, dummy_transcript, dummy_processors, ensure_casing
|
||||
):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
# because of that, we need to start the server in a thread
|
||||
# to be able to connect with aiortc
|
||||
@@ -326,9 +325,17 @@ async def test_transcript_rtc_and_websocket_and_fr(tmpdir, dummy_transcript, dum
|
||||
assert ev["data"]["transcript"].startswith("Hello world")
|
||||
assert ev["data"]["timestamp"] == 0.0
|
||||
|
||||
assert "FINAL_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SUMMARY")]
|
||||
assert ev["data"]["summary"] == "LLM SUMMARY"
|
||||
assert "FINAL_LONG_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_LONG_SUMMARY")]
|
||||
assert ev["data"]["long_summary"] == "LLM LONG SUMMARY"
|
||||
|
||||
assert "FINAL_SHORT_SUMMARY" in eventnames
|
||||
ev = events[eventnames.index("FINAL_SHORT_SUMMARY")]
|
||||
assert ev["data"]["short_summary"] == "LLM SHORT SUMMARY"
|
||||
|
||||
assert "FINAL_TITLE" in eventnames
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM FINAL TITLE"
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
|
||||
Reference in New Issue
Block a user