tests: rework tests and fixes bugs along the way

This commit is contained in:
Mathieu Virbel
2023-08-01 16:05:48 +02:00
parent bc55cfdea3
commit 1f8e4200fd
9 changed files with 126 additions and 54 deletions

View File

@@ -39,7 +39,7 @@ jobs:
cd server cd server
poetry run python -m pytest -v tests poetry run python -m pytest -v tests
pep8: formatting:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
@@ -53,6 +53,20 @@ jobs:
cd server cd server
black --check reflector tests black --check reflector tests
linting:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.x
uses: actions/setup-python@v4
with:
python-version: 3.11
- name: Validate formatting
run: |
pip install ruff
cd server
ruff reflector tests
docker: docker:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:

View File

@@ -2,6 +2,7 @@ from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
import asyncio import asyncio
import json import json
import re
class LLM: class LLM:
@@ -48,11 +49,15 @@ class LLM:
def _parse_json(self, result: str) -> dict: def _parse_json(self, result: str) -> dict:
result = result.strip() result = result.strip()
# try detecting code block if exist # try detecting code block if exist
if result.startswith("```json\n") and result.endswith("```"): # starts with ```json\n, ends with ```
result = result[8:-3] # or starts with ```\n, ends with ```
elif result.startswith("```\n") and result.endswith("```"): # or starts with \n```javascript\n, ends with ```
result = result[4:-3]
print(">>>", result) regex = r"```(json|javascript|)?(.*)```"
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
if not matches:
return result
# we have a match, try to parse it
result = matches[0][1]
return json.loads(result.strip()) return json.loads(result.strip())

View File

@@ -1,7 +1,6 @@
from reflector.llm.base import LLM from reflector.llm.base import LLM
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
import json
import httpx import httpx

View File

@@ -13,8 +13,8 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
BACKEND_DEFAULT = "whisper" BACKEND_DEFAULT = "whisper"
def __init__(self, backend=None, **kwargs): def __init__(self, backend=None, **kwargs):
super().__init__(**kwargs)
self.processor = self.BACKENDS[backend or self.BACKEND_DEFAULT]() self.processor = self.BACKENDS[backend or self.BACKEND_DEFAULT]()
super().__init__(**kwargs)
def connect(self, processor: Processor): def connect(self, processor: Processor):
self.processor.connect(processor) self.processor.connect(processor)

View File

@@ -32,7 +32,11 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
transcript.text += segment.text transcript.text += segment.text
for word in segment.words: for word in segment.words:
transcript.words.append( transcript.words.append(
Word(text=word.word, start=ts + word.start, end=ts + word.end) Word(
text=word.word,
start=round(ts + word.start, 3),
end=round(ts + word.end, 3),
)
) )
return transcript return transcript

View File

@@ -28,4 +28,3 @@ class TranscriptSummarizerProcessor(Processor):
return return
self.logger.info(f"Writing to {self.filename}") self.logger.info(f"Writing to {self.filename}")
await self.emit(self.filename) await self.emit(self.filename)

View File

@@ -1,4 +1,3 @@
from pathlib import Path
import av import av
from reflector.logger import logger from reflector.logger import logger
from reflector.processors import ( from reflector.processors import (
@@ -8,46 +7,37 @@ from reflector.processors import (
AudioTranscriptAutoProcessor, AudioTranscriptAutoProcessor,
TranscriptLinerProcessor, TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor, TranscriptTopicDetectorProcessor,
TranscriptSummarizerProcessor, # TranscriptSummarizerProcessor,
) )
import asyncio import asyncio
if __name__ == "__main__": async def process_audio_file(filename, event_callback):
import argparse async def on_transcript(data):
await event_callback("transcript", data)
parser = argparse.ArgumentParser() async def on_topic(data):
parser.add_argument("source", help="Source file (mp3, wav, mp4...)") await event_callback("topic", data)
args = parser.parse_args()
async def main(): async def on_summary(data):
async def on_transcript(transcript): await event_callback("summary", data)
print(f"Transcript: [{transcript.human_timestamp}]: {transcript.text}")
async def on_summary(summary):
print(f"Summary: {summary.title} - {summary.summary}")
async def on_final_summary(path):
print(f"Final Summary: {path}")
# transcription output # transcription output
result_fn = Path(args.source).with_suffix(".jsonl")
pipeline = Pipeline( pipeline = Pipeline(
AudioChunkerProcessor(), AudioChunkerProcessor(),
AudioMergeProcessor(), AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(), AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(callback=on_transcript), TranscriptLinerProcessor(callback=on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary), TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
TranscriptSummarizerProcessor.as_threaded( # TranscriptSummarizerProcessor.as_threaded(
filename=result_fn, callback=on_final_summary # callback=on_summary
), # ),
) )
pipeline.describe() pipeline.describe()
# start processing audio # start processing audio
logger.info(f"Opening {args.source}") logger.info(f"Opening {filename}")
container = av.open(args.source) container = av.open(filename)
try: try:
logger.info("Start pushing audio into the pipeline") logger.info("Start pushing audio into the pipeline")
for frame in container.decode(audio=0): for frame in container.decode(audio=0):
@@ -58,4 +48,20 @@ if __name__ == "__main__":
logger.info("All done !") logger.info("All done !")
asyncio.run(main())
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
args = parser.parse_args()
async def event_callback(event, data):
if event == "transcript":
print(f"Transcript[{data.human_timestamp}]: {data.text}")
elif event == "topic":
print(f"Topic: {data}")
elif event == "summary":
print(f"Summary: {data}")
asyncio.run(process_audio_file(args.source, event_callback))

View File

@@ -73,8 +73,8 @@ async def rtc_offer(params: RtcOffer, request: Request):
ctx.pipeline = Pipeline( ctx.pipeline = Pipeline(
AudioChunkerProcessor(), AudioChunkerProcessor(),
AudioMergeProcessor(), AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(), AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
TranscriptLinerProcessor(callback=on_transcript), TranscriptLinerProcessor(),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary), TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary),
# FinalSummaryProcessor.as_threaded( # FinalSummaryProcessor.as_threaded(
# filename=result_fn, callback=on_final_summary # filename=result_fn, callback=on_final_summary

View File

@@ -0,0 +1,45 @@
import pytest
from unittest.mock import patch
@pytest.mark.asyncio
async def test_basic_process(event_loop):
# goal is to start the server, and send rtc audio to it
# validate the events received
from reflector.tools.process import process_audio_file
from reflector.settings import settings
from reflector.llm.base import LLM
from pathlib import Path
# use an LLM test backend
settings.LLM_BACKEND = "test"
class LLMTest(LLM):
async def _generate(self, prompt: str, **kwargs) -> str:
return {
"title": "TITLE",
"summary": "SUMMARY",
}
LLM.register("test", LLMTest)
# event callback
marks = {
"transcript": 0,
"topic": 0,
# "summary": 0,
}
async def event_callback(event, data):
print(f"{event}: {data}")
marks[event] += 1
# invoke the process and capture events
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
await process_audio_file(path.as_posix(), event_callback)
print(marks)
# validate the events
assert marks["transcript"] == 5
assert marks["topic"] == 4
# assert marks["summary"] == 1