diff --git a/.github/workflows/test_server.yml b/.github/workflows/test_server.yml index 2029a77c..2651f798 100644 --- a/.github/workflows/test_server.yml +++ b/.github/workflows/test_server.yml @@ -39,7 +39,7 @@ jobs: cd server poetry run python -m pytest -v tests - pep8: + formatting: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 @@ -53,6 +53,20 @@ jobs: cd server black --check reflector tests + linting: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.x + uses: actions/setup-python@v4 + with: + python-version: 3.11 + - name: Validate formatting + run: | + pip install ruff + cd server + ruff reflector tests + docker: runs-on: ubuntu-latest steps: diff --git a/server/reflector/llm/base.py b/server/reflector/llm/base.py index 55c0de5f..6b5dbdc9 100644 --- a/server/reflector/llm/base.py +++ b/server/reflector/llm/base.py @@ -2,6 +2,7 @@ from reflector.logger import logger from reflector.settings import settings import asyncio import json +import re class LLM: @@ -48,11 +49,15 @@ class LLM: def _parse_json(self, result: str) -> dict: result = result.strip() # try detecting code block if exist - if result.startswith("```json\n") and result.endswith("```"): - result = result[8:-3] - elif result.startswith("```\n") and result.endswith("```"): - result = result[4:-3] - print(">>>", result) + # starts with ```json\n, ends with ``` + # or starts with ```\n, ends with ``` + # or starts with \n```javascript\n, ends with ``` + + regex = r"```(json|javascript|)?(.*)```" + matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL) + if not matches: + return result + + # we have a match, try to parse it + result = matches[0][1] return json.loads(result.strip()) - - diff --git a/server/reflector/llm/llm_openai.py b/server/reflector/llm/llm_openai.py index d4c565d6..517902e9 100644 --- a/server/reflector/llm/llm_openai.py +++ b/server/reflector/llm/llm_openai.py @@ -1,7 +1,6 @@ from reflector.llm.base import LLM from reflector.logger import logger from reflector.settings import settings -import json import httpx diff --git a/server/reflector/processors/audio_transcript_auto.py b/server/reflector/processors/audio_transcript_auto.py index 0ece84f3..9b792009 100644 --- a/server/reflector/processors/audio_transcript_auto.py +++ b/server/reflector/processors/audio_transcript_auto.py @@ -13,8 +13,8 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor): BACKEND_DEFAULT = "whisper" def __init__(self, backend=None, **kwargs): - super().__init__(**kwargs) self.processor = self.BACKENDS[backend or self.BACKEND_DEFAULT]() + super().__init__(**kwargs) def connect(self, processor: Processor): self.processor.connect(processor) diff --git a/server/reflector/processors/audio_transcript_whisper.py b/server/reflector/processors/audio_transcript_whisper.py index 9a85e9bf..0b768543 100644 --- a/server/reflector/processors/audio_transcript_whisper.py +++ b/server/reflector/processors/audio_transcript_whisper.py @@ -32,7 +32,11 @@ class AudioTranscriptWhisperProcessor(AudioTranscriptProcessor): transcript.text += segment.text for word in segment.words: transcript.words.append( - Word(text=word.word, start=ts + word.start, end=ts + word.end) + Word( + text=word.word, + start=round(ts + word.start, 3), + end=round(ts + word.end, 3), + ) ) return transcript diff --git a/server/reflector/processors/transcript_summarizer.py b/server/reflector/processors/transcript_summarizer.py index 4e149602..e4e55e9e 100644 --- a/server/reflector/processors/transcript_summarizer.py +++ b/server/reflector/processors/transcript_summarizer.py @@ -28,4 +28,3 @@ class TranscriptSummarizerProcessor(Processor): return self.logger.info(f"Writing to {self.filename}") await self.emit(self.filename) - diff --git a/server/reflector/tools/process.py b/server/reflector/tools/process.py index aefbc153..0c8611d8 100644 --- a/server/reflector/tools/process.py +++ b/server/reflector/tools/process.py @@ -1,4 +1,3 @@ -from pathlib import Path import av from reflector.logger import logger from reflector.processors import ( @@ -8,11 +7,48 @@ from reflector.processors import ( AudioTranscriptAutoProcessor, TranscriptLinerProcessor, TranscriptTopicDetectorProcessor, - TranscriptSummarizerProcessor, + # TranscriptSummarizerProcessor, ) import asyncio +async def process_audio_file(filename, event_callback): + async def on_transcript(data): + await event_callback("transcript", data) + + async def on_topic(data): + await event_callback("topic", data) + + async def on_summary(data): + await event_callback("summary", data) + + # transcription output + pipeline = Pipeline( + AudioChunkerProcessor(), + AudioMergeProcessor(), + AudioTranscriptAutoProcessor.as_threaded(), + TranscriptLinerProcessor(callback=on_transcript), + TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), + # TranscriptSummarizerProcessor.as_threaded( + # callback=on_summary + # ), + ) + pipeline.describe() + + # start processing audio + logger.info(f"Opening {filename}") + container = av.open(filename) + try: + logger.info("Start pushing audio into the pipeline") + for frame in container.decode(audio=0): + await pipeline.push(frame) + finally: + logger.info("Flushing the pipeline") + await pipeline.flush() + + logger.info("All done !") + + if __name__ == "__main__": import argparse @@ -20,42 +56,12 @@ if __name__ == "__main__": parser.add_argument("source", help="Source file (mp3, wav, mp4...)") args = parser.parse_args() - async def main(): - async def on_transcript(transcript): - print(f"Transcript: [{transcript.human_timestamp}]: {transcript.text}") + async def event_callback(event, data): + if event == "transcript": + print(f"Transcript[{data.human_timestamp}]: {data.text}") + elif event == "topic": + print(f"Topic: {data}") + elif event == "summary": + print(f"Summary: {data}") - async def on_summary(summary): - print(f"Summary: {summary.title} - {summary.summary}") - - async def on_final_summary(path): - print(f"Final Summary: {path}") - - # transcription output - result_fn = Path(args.source).with_suffix(".jsonl") - - pipeline = Pipeline( - AudioChunkerProcessor(), - AudioMergeProcessor(), - AudioTranscriptAutoProcessor.as_threaded(), - TranscriptLinerProcessor(callback=on_transcript), - TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary), - TranscriptSummarizerProcessor.as_threaded( - filename=result_fn, callback=on_final_summary - ), - ) - pipeline.describe() - - # start processing audio - logger.info(f"Opening {args.source}") - container = av.open(args.source) - try: - logger.info("Start pushing audio into the pipeline") - for frame in container.decode(audio=0): - await pipeline.push(frame) - finally: - logger.info("Flushing the pipeline") - await pipeline.flush() - - logger.info("All done !") - - asyncio.run(main()) + asyncio.run(process_audio_file(args.source, event_callback)) diff --git a/server/reflector/views/rtc_offer.py b/server/reflector/views/rtc_offer.py index 03f9decf..f462a37a 100644 --- a/server/reflector/views/rtc_offer.py +++ b/server/reflector/views/rtc_offer.py @@ -73,8 +73,8 @@ async def rtc_offer(params: RtcOffer, request: Request): ctx.pipeline = Pipeline( AudioChunkerProcessor(), AudioMergeProcessor(), - AudioTranscriptAutoProcessor.as_threaded(), - TranscriptLinerProcessor(callback=on_transcript), + AudioTranscriptAutoProcessor.as_threaded(callback=on_transcript), + TranscriptLinerProcessor(), TranscriptTopicDetectorProcessor.as_threaded(callback=on_summary), # FinalSummaryProcessor.as_threaded( # filename=result_fn, callback=on_final_summary diff --git a/server/tests/test_processors_pipeline.py b/server/tests/test_processors_pipeline.py new file mode 100644 index 00000000..312f76d6 --- /dev/null +++ b/server/tests/test_processors_pipeline.py @@ -0,0 +1,45 @@ +import pytest +from unittest.mock import patch + + +@pytest.mark.asyncio +async def test_basic_process(event_loop): + # goal is to start the server, and send rtc audio to it + # validate the events received + from reflector.tools.process import process_audio_file + from reflector.settings import settings + from reflector.llm.base import LLM + from pathlib import Path + + # use an LLM test backend + settings.LLM_BACKEND = "test" + + class LLMTest(LLM): + async def _generate(self, prompt: str, **kwargs) -> str: + return { + "title": "TITLE", + "summary": "SUMMARY", + } + + LLM.register("test", LLMTest) + + # event callback + marks = { + "transcript": 0, + "topic": 0, + # "summary": 0, + } + + async def event_callback(event, data): + print(f"{event}: {data}") + marks[event] += 1 + + # invoke the process and capture events + path = Path(__file__).parent / "records" / "test_mathieu_hello.wav" + await process_audio_file(path.as_posix(), event_callback) + print(marks) + + # validate the events + assert marks["transcript"] == 5 + assert marks["topic"] == 4 + # assert marks["summary"] == 1