tests: rework tests and fixes bugs along the way

This commit is contained in:
Mathieu Virbel
2023-08-01 16:05:48 +02:00
parent bc55cfdea3
commit 1f8e4200fd
9 changed files with 126 additions and 54 deletions

View File

@@ -2,6 +2,7 @@ from reflector.logger import logger
from reflector.settings import settings
import asyncio
import json
import re
class LLM:
@@ -48,11 +49,15 @@ class LLM:
def _parse_json(self, result: str) -> dict:
result = result.strip()
# try detecting code block if exist
if result.startswith("```json\n") and result.endswith("```"):
result = result[8:-3]
elif result.startswith("```\n") and result.endswith("```"):
result = result[4:-3]
print(">>>", result)
# starts with ```json\n, ends with ```
# or starts with ```\n, ends with ```
# or starts with \n```javascript\n, ends with ```
regex = r"```(json|javascript|)?(.*)```"
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
if not matches:
return result
# we have a match, try to parse it
result = matches[0][1]
return json.loads(result.strip())

View File

@@ -1,7 +1,6 @@
from reflector.llm.base import LLM
from reflector.logger import logger
from reflector.settings import settings
import json
import httpx

View File

@@ -13,8 +13,8 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
BACKEND_DEFAULT = "whisper"
def __init__(self, backend=None, **kwargs):
super().__init__(**kwargs)
self.processor = self.BACKENDS[backend or self.BACKEND_DEFAULT]()
super().__init__(**kwargs)
def connect(self, processor: Processor):
self.processor.connect(processor)

View File

@@ -32,7 +32,11 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor):
transcript.text += segment.text
for word in segment.words:
transcript.words.append(
Word(text=word.word, start=ts + word.start, end=ts + word.end)
Word(
text=word.word,
start=round(ts + word.start, 3),
end=round(ts + word.end, 3),
)
)
return transcript

View File

@@ -28,4 +28,3 @@ class TranscriptSummarizerProcessor(Processor):
return
self.logger.info(f"Writing to {self.filename}")
await self.emit(self.filename)

View File

@@ -1,4 +1,3 @@
from pathlib import Path
import av
from reflector.logger import logger
from reflector.processors import (
@@ -8,11 +7,48 @@ from reflector.processors import (
AudioTranscriptAutoProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptSummarizerProcessor,
# TranscriptSummarizerProcessor,
)
import asyncio
async def process_audio_file(filename, event_callback):
async def on_transcript(data):
await event_callback("transcript", data)
async def on_topic(data):
await event_callback("topic", data)
async def on_summary(data):
await event_callback("summary", data)
# transcription output
pipeline = Pipeline(
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(callback=on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic),
# TranscriptSummarizerProcessor.as_threaded(
# callback=on_summary
# ),
)
pipeline.describe()
# start processing audio
logger.info(f"Opening {filename}")
container = av.open(filename)
try:
logger.info("Start pushing audio into the pipeline")
for frame in container.decode(audio=0):
await pipeline.push(frame)
finally:
logger.info("Flushing the pipeline")
await pipeline.flush()
logger.info("All done !")
if __name__ == "__main__":
import argparse
@@ -20,42 +56,12 @@ if __name__ == "__main__":
parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
args = parser.parse_args()
async def main():
async def on_transcript(transcript):
print(f"Transcript: [{transcript.human_timestamp}]: {transcript.text}")
async def event_callback(event, data):
if event == "transcript":
print(f"Transcript[{data.human_timestamp}]: {data.text}")
elif event == "topic":
print(f"Topic: {data}")
elif event == "summary":
print(f"Summary: {data}")
async def on_summary(summary):
print(f"Summary: {summary.title} - {summary.summary}")
async def on_final_summary(path):
print(f"Final Summary: {path}")
# transcription output
result_fn = Path(args.source).with_suffix(".jsonl")
pipeline = Pipeline(
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(callback=on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary),
TranscriptSummarizerProcessor.as_threaded(
filename=result_fn, callback=on_final_summary
),
)
pipeline.describe()
# start processing audio
logger.info(f"Opening {args.source}")
container = av.open(args.source)
try:
logger.info("Start pushing audio into the pipeline")
for frame in container.decode(audio=0):
await pipeline.push(frame)
finally:
logger.info("Flushing the pipeline")
await pipeline.flush()
logger.info("All done !")
asyncio.run(main())
asyncio.run(process_audio_file(args.source, event_callback))

View File

@@ -73,8 +73,8 @@ async def rtc_offer(params: RtcOffer, request: Request):
ctx.pipeline = Pipeline(
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(callback=on_transcript),
AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript),
TranscriptLinerProcessor(),
TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary),
# FinalSummaryProcessor.as_threaded(
# filename=result_fn, callback=on_final_summary

View File

@@ -0,0 +1,45 @@
import pytest
from unittest.mock import patch
@pytest.mark.asyncio
async def test_basic_process(event_loop):
# goal is to start the server, and send rtc audio to it
# validate the events received
from reflector.tools.process import process_audio_file
from reflector.settings import settings
from reflector.llm.base import LLM
from pathlib import Path
# use an LLM test backend
settings.LLM_BACKEND = "test"
class LLMTest(LLM):
async def _generate(self, prompt: str, **kwargs) -> str:
return {
"title": "TITLE",
"summary": "SUMMARY",
}
LLM.register("test", LLMTest)
# event callback
marks = {
"transcript": 0,
"topic": 0,
# "summary": 0,
}
async def event_callback(event, data):
print(f"{event}: {data}")
marks[event] += 1
# invoke the process and capture events
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
await process_audio_file(path.as_posix(), event_callback)
print(marks)
# validate the events
assert marks["transcript"] == 5
assert marks["topic"] == 4
# assert marks["summary"] == 1