mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
tests: rework tests and fixes bugs along the way
This commit is contained in:
16
.github/workflows/test_server.yml
vendored
16
.github/workflows/test_server.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
|||||||
cd server
|
cd server
|
||||||
poetry run python -m pytest -v tests
|
poetry run python -m pytest -v tests
|
||||||
|
|
||||||
pep8:
|
formatting:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v3
|
- uses: actions/checkout@v3
|
||||||
@@ -53,6 +53,20 @@ jobs:
|
|||||||
cd server
|
cd server
|
||||||
black --check reflector tests
|
black --check reflector tests
|
||||||
|
|
||||||
|
linting:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v3
|
||||||
|
- name: Set up Python 3.x
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: 3.11
|
||||||
|
- name: Validate formatting
|
||||||
|
run: |
|
||||||
|
pip install ruff
|
||||||
|
cd server
|
||||||
|
ruff reflector tests
|
||||||
|
|
||||||
docker:
|
docker:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from reflector.logger import logger
|
|||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM:
|
||||||
@@ -48,11 +49,15 @@ class LLM:
|
|||||||
def _parse_json(self, result: str) -> dict:
|
def _parse_json(self, result: str) -> dict:
|
||||||
result = result.strip()
|
result = result.strip()
|
||||||
# try detecting code block if exist
|
# try detecting code block if exist
|
||||||
if result.startswith("```json\n") and result.endswith("```"):
|
# starts with ```json\n, ends with ```
|
||||||
result = result[8:-3]
|
# or starts with ```\n, ends with ```
|
||||||
elif result.startswith("```\n") and result.endswith("```"):
|
# or starts with \n```javascript\n, ends with ```
|
||||||
result = result[4:-3]
|
|
||||||
print(">>>", result)
|
regex = r"```(json|javascript|)?(.*)```"
|
||||||
|
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
|
||||||
|
if not matches:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# we have a match, try to parse it
|
||||||
|
result = matches[0][1]
|
||||||
return json.loads(result.strip())
|
return json.loads(result.strip())
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
from reflector.llm.base import LLM
|
from reflector.llm.base import LLM
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
import json
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
|||||||
BACKEND_DEFAULT = "whisper"
|
BACKEND_DEFAULT = "whisper"
|
||||||
|
|
||||||
def __init__(self, backend=None, **kwargs):
|
def __init__(self, backend=None, **kwargs):
|
||||||
super().__init__(**kwargs)
|
|
||||||
self.processor = self.BACKENDS[backend or self.BACKEND_DEFAULT]()
|
self.processor = self.BACKENDS[backend or self.BACKEND_DEFAULT]()
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
def connect(self, processor: Processor):
|
def connect(self, processor: Processor):
|
||||||
self.processor.connect(processor)
|
self.processor.connect(processor)
|
||||||
|
|||||||
@@ -32,7 +32,11 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
|
|||||||
transcript.text += segment.text
|
transcript.text += segment.text
|
||||||
for word in segment.words:
|
for word in segment.words:
|
||||||
transcript.words.append(
|
transcript.words.append(
|
||||||
Word(text=word.word, start=ts + word.start, end=ts + word.end)
|
Word(
|
||||||
|
text=word.word,
|
||||||
|
start=round(ts + word.start, 3),
|
||||||
|
end=round(ts + word.end, 3),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return transcript
|
return transcript
|
||||||
|
|||||||
@@ -28,4 +28,3 @@ class TranscriptSummarizerProcessor(Processor):
|
|||||||
return
|
return
|
||||||
self.logger.info(f"Writing to {self.filename}")
|
self.logger.info(f"Writing to {self.filename}")
|
||||||
await self.emit(self.filename)
|
await self.emit(self.filename)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from pathlib import Path
|
|
||||||
import av
|
import av
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
@@ -8,11 +7,48 @@ from reflector.processors import (
|
|||||||
AudioTranscriptAutoProcessor,
|
AudioTranscriptAutoProcessor,
|
||||||
TranscriptLinerProcessor,
|
TranscriptLinerProcessor,
|
||||||
TranscriptTopicDetectorProcessor,
|
TranscriptTopicDetectorProcessor,
|
||||||
TranscriptSummarizerProcessor,
|
# TranscriptSummarizerProcessor,
|
||||||
)
|
)
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
|
async def process_audio_file(filename, event_callback):
|
||||||
|
async def on_transcript(data):
|
||||||
|
await event_callback("transcript", data)
|
||||||
|
|
||||||
|
async def on_topic(data):
|
||||||
|
await event_callback("topic", data)
|
||||||
|
|
||||||
|
async def on_summary(data):
|
||||||
|
await event_callback("summary", data)
|
||||||
|
|
||||||
|
# transcription output
|
||||||
|
pipeline = Pipeline(
|
||||||
|
AudioChunkerProcessor(),
|
||||||
|
AudioMergeProcessor(),
|
||||||
|
AudioTranscriptAutoProcessor.as_threaded(),
|
||||||
|
TranscriptLinerProcessor(callback=on_transcript),
|
||||||
|
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
|
||||||
|
# TranscriptSummarizerProcessor.as_threaded(
|
||||||
|
# callback=on_summary
|
||||||
|
# ),
|
||||||
|
)
|
||||||
|
pipeline.describe()
|
||||||
|
|
||||||
|
# start processing audio
|
||||||
|
logger.info(f"Opening {filename}")
|
||||||
|
container = av.open(filename)
|
||||||
|
try:
|
||||||
|
logger.info("Start pushing audio into the pipeline")
|
||||||
|
for frame in container.decode(audio=0):
|
||||||
|
await pipeline.push(frame)
|
||||||
|
finally:
|
||||||
|
logger.info("Flushing the pipeline")
|
||||||
|
await pipeline.flush()
|
||||||
|
|
||||||
|
logger.info("All done !")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
@@ -20,42 +56,12 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
async def main():
|
async def event_callback(event, data):
|
||||||
async def on_transcript(transcript):
|
if event == "transcript":
|
||||||
print(f"Transcript: [{transcript.human_timestamp}]: {transcript.text}")
|
print(f"Transcript[{data.human_timestamp}]: {data.text}")
|
||||||
|
elif event == "topic":
|
||||||
|
print(f"Topic: {data}")
|
||||||
|
elif event == "summary":
|
||||||
|
print(f"Summary: {data}")
|
||||||
|
|
||||||
async def on_summary(summary):
|
asyncio.run(process_audio_file(args.source, event_callback))
|
||||||
print(f"Summary: {summary.title} - {summary.summary}")
|
|
||||||
|
|
||||||
async def on_final_summary(path):
|
|
||||||
print(f"Final Summary: {path}")
|
|
||||||
|
|
||||||
# transcription output
|
|
||||||
result_fn = Path(args.source).with_suffix(".jsonl")
|
|
||||||
|
|
||||||
pipeline = Pipeline(
|
|
||||||
AudioChunkerProcessor(),
|
|
||||||
AudioMergeProcessor(),
|
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
|
||||||
TranscriptLinerProcessor(callback=on_transcript),
|
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary),
|
|
||||||
TranscriptSummarizerProcessor.as_threaded(
|
|
||||||
filename=result_fn, callback=on_final_summary
|
|
||||||
),
|
|
||||||
)
|
|
||||||
pipeline.describe()
|
|
||||||
|
|
||||||
# start processing audio
|
|
||||||
logger.info(f"Opening {args.source}")
|
|
||||||
container = av.open(args.source)
|
|
||||||
try:
|
|
||||||
logger.info("Start pushing audio into the pipeline")
|
|
||||||
for frame in container.decode(audio=0):
|
|
||||||
await pipeline.push(frame)
|
|
||||||
finally:
|
|
||||||
logger.info("Flushing the pipeline")
|
|
||||||
await pipeline.flush()
|
|
||||||
|
|
||||||
logger.info("All done !")
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|||||||
@@ -73,8 +73,8 @@ async def rtc_offer(params: RtcOffer, request: Request):
|
|||||||
ctx.pipeline = Pipeline(
|
ctx.pipeline = Pipeline(
|
||||||
AudioChunkerProcessor(),
|
AudioChunkerProcessor(),
|
||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
|
||||||
TranscriptLinerProcessor(callback=on_transcript),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary),
|
TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary),
|
||||||
# FinalSummaryProcessor.as_threaded(
|
# FinalSummaryProcessor.as_threaded(
|
||||||
# filename=result_fn, callback=on_final_summary
|
# filename=result_fn, callback=on_final_summary
|
||||||
|
|||||||
45
server/tests/test_processors_pipeline.py
Normal file
45
server/tests/test_processors_pipeline.py
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_process(event_loop):
|
||||||
|
# goal is to start the server, and send rtc audio to it
|
||||||
|
# validate the events received
|
||||||
|
from reflector.tools.process import process_audio_file
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.llm.base import LLM
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# use an LLM test backend
|
||||||
|
settings.LLM_BACKEND = "test"
|
||||||
|
|
||||||
|
class LLMTest(LLM):
|
||||||
|
async def _generate(self, prompt: str, **kwargs) -> str:
|
||||||
|
return {
|
||||||
|
"title": "TITLE",
|
||||||
|
"summary": "SUMMARY",
|
||||||
|
}
|
||||||
|
|
||||||
|
LLM.register("test", LLMTest)
|
||||||
|
|
||||||
|
# event callback
|
||||||
|
marks = {
|
||||||
|
"transcript": 0,
|
||||||
|
"topic": 0,
|
||||||
|
# "summary": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def event_callback(event, data):
|
||||||
|
print(f"{event}: {data}")
|
||||||
|
marks[event] += 1
|
||||||
|
|
||||||
|
# invoke the process and capture events
|
||||||
|
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
|
||||||
|
await process_audio_file(path.as_posix(), event_callback)
|
||||||
|
print(marks)
|
||||||
|
|
||||||
|
# validate the events
|
||||||
|
assert marks["transcript"] == 5
|
||||||
|
assert marks["topic"] == 4
|
||||||
|
# assert marks["summary"] == 1
|
||||||
Reference in New Issue
Block a user