mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
feat: use llamaindex everywhere (#525)
* feat: use llamaindex for transcript final title too * refactor: removed llm backend, replaced with one single class+llamaindex * refactor: self-review * fix: typing * fix: tests * refactor: extract clean_title and add tests * test: fix * test: remove ensure_casing/nltk * fix: tiny mistake
This commit is contained in:
@@ -37,8 +37,12 @@ def dummy_processors():
|
||||
"reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation"
|
||||
) as mock_translate,
|
||||
):
|
||||
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
|
||||
mock_title.return_value = {"title": "LLM TITLE"}
|
||||
from reflector.processors.transcript_topic_detector import TopicResponse
|
||||
|
||||
mock_topic.return_value = TopicResponse(
|
||||
title="LLM TITLE", summary="LLM SUMMARY"
|
||||
)
|
||||
mock_title.return_value = "LLM Title"
|
||||
mock_long_summary.return_value = "LLM LONG SUMMARY"
|
||||
mock_short_summary.return_value = "LLM SHORT SUMMARY"
|
||||
mock_translate.return_value = "Bonjour le monde"
|
||||
@@ -103,14 +107,15 @@ async def dummy_diarization():
|
||||
|
||||
@pytest.fixture
|
||||
async def dummy_llm():
|
||||
from reflector.llm.base import LLM
|
||||
from reflector.llm import LLM
|
||||
|
||||
class TestLLM(LLM):
|
||||
def __init__(self):
|
||||
self.model_name = "DUMMY MODEL"
|
||||
self.llm_tokenizer = "DUMMY TOKENIZER"
|
||||
|
||||
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
|
||||
# LLM doesn't have get_instance anymore, mocking constructor instead
|
||||
with patch("reflector.llm.LLM") as mock_llm:
|
||||
mock_llm.return_value = TestLLM()
|
||||
yield
|
||||
|
||||
@@ -129,22 +134,19 @@ async def dummy_storage():
|
||||
async def _get_file_url(self, *args, **kwargs):
|
||||
return "http://fake_server/audio.mp3"
|
||||
|
||||
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
|
||||
mock_storage.return_value = DummyStorage()
|
||||
yield
|
||||
async def _get_file(self, *args, **kwargs):
|
||||
from pathlib import Path
|
||||
|
||||
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||
return test_mp3.read_bytes()
|
||||
|
||||
@pytest.fixture
|
||||
def nltk():
|
||||
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
|
||||
mock_nltk.return_value = "NLTK PACKAGE"
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ensure_casing():
|
||||
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
|
||||
mock_casing.return_value = "LLM TITLE"
|
||||
dummy = DummyStorage()
|
||||
with (
|
||||
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
|
||||
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
|
||||
):
|
||||
mock_storage.return_value = dummy
|
||||
mock_get_transcripts.return_value = dummy
|
||||
yield
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_processor_broadcast(nltk):
|
||||
async def test_processor_broadcast():
|
||||
from reflector.processors.base import BroadcastProcessor, Pipeline, Processor
|
||||
|
||||
class TestProcessor(Processor):
|
||||
|
||||
@@ -3,11 +3,9 @@ import pytest
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_process(
|
||||
nltk,
|
||||
dummy_transcript,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
ensure_casing,
|
||||
):
|
||||
# goal is to start the server, and send rtc audio to it
|
||||
# validate the events received
|
||||
@@ -16,8 +14,8 @@ async def test_basic_process(
|
||||
from reflector.settings import settings
|
||||
from reflector.tools.process import process_audio_file
|
||||
|
||||
# use an LLM test backend
|
||||
settings.LLM_BACKEND = "test"
|
||||
# LLM_BACKEND no longer exists in settings
|
||||
# settings.LLM_BACKEND = "test"
|
||||
settings.TRANSCRIPT_BACKEND = "whisper"
|
||||
|
||||
# event callback
|
||||
|
||||
@@ -10,7 +10,6 @@ from httpx import AsyncClient
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_process(
|
||||
tmpdir,
|
||||
ensure_casing,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
@@ -69,7 +68,7 @@ async def test_transcript_process(
|
||||
transcript = resp.json()
|
||||
assert transcript["status"] == "ended"
|
||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||
assert transcript["title"] == "LLM TITLE"
|
||||
assert transcript["title"] == "Llm Title"
|
||||
|
||||
# check topics and transcript
|
||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||
|
||||
@@ -69,8 +69,6 @@ async def test_transcript_rtc_and_websocket(
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
fake_mp3_upload,
|
||||
ensure_casing,
|
||||
nltk,
|
||||
appserver,
|
||||
):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
@@ -185,7 +183,7 @@ async def test_transcript_rtc_and_websocket(
|
||||
|
||||
assert "FINAL_TITLE" in eventnames
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM TITLE"
|
||||
assert ev["data"]["title"] == "Llm Title"
|
||||
|
||||
assert "WAVEFORM" in eventnames
|
||||
ev = events[eventnames.index("WAVEFORM")]
|
||||
@@ -228,8 +226,6 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
dummy_diarization,
|
||||
dummy_storage,
|
||||
fake_mp3_upload,
|
||||
ensure_casing,
|
||||
nltk,
|
||||
appserver,
|
||||
):
|
||||
# goal: start the server, exchange RTC, receive websocket events
|
||||
@@ -353,7 +349,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
|
||||
assert "FINAL_TITLE" in eventnames
|
||||
ev = events[eventnames.index("FINAL_TITLE")]
|
||||
assert ev["data"]["title"] == "LLM TITLE"
|
||||
assert ev["data"]["title"] == "Llm Title"
|
||||
|
||||
# check status order
|
||||
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
|
||||
|
||||
@@ -10,7 +10,6 @@ from httpx import AsyncClient
|
||||
@pytest.mark.asyncio
|
||||
async def test_transcript_upload_file(
|
||||
tmpdir,
|
||||
ensure_casing,
|
||||
dummy_llm,
|
||||
dummy_processors,
|
||||
dummy_diarization,
|
||||
@@ -53,7 +52,7 @@ async def test_transcript_upload_file(
|
||||
transcript = resp.json()
|
||||
assert transcript["status"] == "ended"
|
||||
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
|
||||
assert transcript["title"] == "LLM TITLE"
|
||||
assert transcript["title"] == "Llm Title"
|
||||
|
||||
# check topics and transcript
|
||||
response = await ac.get(f"/transcripts/{tid}/topics")
|
||||
|
||||
21
server/tests/test_utils_text.py
Normal file
21
server/tests/test_utils_text.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import pytest
|
||||
|
||||
from reflector.utils.text import clean_title
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_title,expected",
|
||||
[
|
||||
("hello world", "Hello World"),
|
||||
("HELLO WORLD", "Hello World"),
|
||||
("hello WORLD", "Hello World"),
|
||||
("the quick brown fox", "The Quick Brown fox"),
|
||||
("discussion about API design", "Discussion About api Design"),
|
||||
("Q1 2024 budget review", "Q1 2024 Budget Review"),
|
||||
("'Title with quotes'", "Title With Quotes"),
|
||||
("'title with quotes'", "Title With Quotes"),
|
||||
("MiXeD CaSe WoRdS", "Mixed Case Words"),
|
||||
],
|
||||
)
|
||||
def test_clean_title(input_title, expected):
|
||||
assert clean_title(input_title) == expected
|
||||
Reference in New Issue
Block a user