feat: use llamaindex everywhere (#525)

* feat: use llamaindex for transcript final title too

* refactor: removed llm backend, replaced with one single class+llamaindex

* refactor: self-review

* fix: typing

* fix: tests

* refactor: extract clean_title and add tests

* test: fix

* test: remove ensure_casing/nltk

* fix: tiny mistake
This commit is contained in:
2025-08-01 12:13:00 -06:00
committed by GitHub
parent 1878834ce6
commit 28ac031ff6
25 changed files with 284 additions and 1539 deletions

View File

@@ -37,8 +37,12 @@ def dummy_processors():
"reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation"
) as mock_translate,
):
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
mock_title.return_value = {"title": "LLM TITLE"}
from reflector.processors.transcript_topic_detector import TopicResponse
mock_topic.return_value = TopicResponse(
title="LLM TITLE", summary="LLM SUMMARY"
)
mock_title.return_value = "LLM Title"
mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = "LLM SHORT SUMMARY"
mock_translate.return_value = "Bonjour le monde"
@@ -103,14 +107,15 @@ async def dummy_diarization():
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
from reflector.llm import LLM
class TestLLM(LLM):
def __init__(self):
self.model_name = "DUMMY MODEL"
self.llm_tokenizer = "DUMMY TOKENIZER"
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
# LLM doesn't have get_instance anymore, mocking constructor instead
with patch("reflector.llm.LLM") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@@ -129,22 +134,19 @@ async def dummy_storage():
async def _get_file_url(self, *args, **kwargs):
return "http://fake_server/audio.mp3"
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
mock_storage.return_value = DummyStorage()
yield
async def _get_file(self, *args, **kwargs):
from pathlib import Path
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
return test_mp3.read_bytes()
@pytest.fixture
def nltk():
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
mock_nltk.return_value = "NLTK PACKAGE"
yield
@pytest.fixture
def ensure_casing():
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
mock_casing.return_value = "LLM TITLE"
dummy = DummyStorage()
with (
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
):
mock_storage.return_value = dummy
mock_get_transcripts.return_value = dummy
yield

View File

@@ -2,7 +2,7 @@ import pytest
@pytest.mark.asyncio
async def test_processor_broadcast(nltk):
async def test_processor_broadcast():
from reflector.processors.base import BroadcastProcessor, Pipeline, Processor
class TestProcessor(Processor):

View File

@@ -3,11 +3,9 @@ import pytest
@pytest.mark.asyncio
async def test_basic_process(
nltk,
dummy_transcript,
dummy_llm,
dummy_processors,
ensure_casing,
):
# goal is to start the server, and send rtc audio to it
# validate the events received
@@ -16,8 +14,8 @@ async def test_basic_process(
from reflector.settings import settings
from reflector.tools.process import process_audio_file
# use an LLM test backend
settings.LLM_BACKEND = "test"
# LLM_BACKEND no longer exists in settings
# settings.LLM_BACKEND = "test"
settings.TRANSCRIPT_BACKEND = "whisper"
# event callback

View File

@@ -10,7 +10,6 @@ from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_process(
tmpdir,
ensure_casing,
dummy_llm,
dummy_processors,
dummy_diarization,
@@ -69,7 +68,7 @@ async def test_transcript_process(
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
assert transcript["title"] == "LLM TITLE"
assert transcript["title"] == "Llm Title"
# check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics")

View File

@@ -69,8 +69,6 @@ async def test_transcript_rtc_and_websocket(
dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing,
nltk,
appserver,
):
# goal: start the server, exchange RTC, receive websocket events
@@ -185,7 +183,7 @@ async def test_transcript_rtc_and_websocket(
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
assert ev["data"]["title"] == "Llm Title"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("WAVEFORM")]
@@ -228,8 +226,6 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing,
nltk,
appserver,
):
# goal: start the server, exchange RTC, receive websocket events
@@ -353,7 +349,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
assert ev["data"]["title"] == "Llm Title"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]

View File

@@ -10,7 +10,6 @@ from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_upload_file(
tmpdir,
ensure_casing,
dummy_llm,
dummy_processors,
dummy_diarization,
@@ -53,7 +52,7 @@ async def test_transcript_upload_file(
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
assert transcript["title"] == "LLM TITLE"
assert transcript["title"] == "Llm Title"
# check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics")

View File

@@ -0,0 +1,21 @@
import pytest
from reflector.utils.text import clean_title
@pytest.mark.parametrize(
"input_title,expected",
[
("hello world", "Hello World"),
("HELLO WORLD", "Hello World"),
("hello WORLD", "Hello World"),
("the quick brown fox", "The Quick Brown fox"),
("discussion about API design", "Discussion About api Design"),
("Q1 2024 budget review", "Q1 2024 Budget Review"),
("'Title with quotes'", "Title With Quotes"),
("'title with quotes'", "Title With Quotes"),
("MiXeD CaSe WoRdS", "Mixed Case Words"),
],
)
def test_clean_title(input_title, expected):
assert clean_title(input_title) == expected