From 600f2ca370bf61bf75e146e76fb0189fb10409c2 Mon Sep 17 00:00:00 2001 From: Mathieu Virbel Date: Thu, 31 Aug 2023 11:16:27 +0200 Subject: [PATCH] server: add BroadcastProcessor tests --- server/reflector/processors/base.py | 8 +++- server/tests/test_processors_broadcast.py | 46 +++++++++++++++++++++++ 2 files changed, 52 insertions(+), 2 deletions(-) create mode 100644 server/tests/test_processors_broadcast.py diff --git a/server/reflector/processors/base.py b/server/reflector/processors/base.py index 219bc3c4..0f93b14b 100644 --- a/server/reflector/processors/base.py +++ b/server/reflector/processors/base.py @@ -1,6 +1,6 @@ import asyncio from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, Union from uuid import uuid4 from pydantic import BaseModel @@ -211,12 +211,16 @@ class BroadcastProcessor(Processor): This processor does not guarantee that the output is in order. This processor connect all the output of the processors to the input of - the next processor. + the next processor; so the next processor must be able to accept different + types of input. """ def __init__(self, processors: Processor): super().__init__() self.processors = processors + self.INPUT_TYPE = processors[0].INPUT_TYPE + output_types = set([processor.OUTPUT_TYPE for processor in processors]) + self.OUTPUT_TYPE = Union[tuple(output_types)] def set_pipeline(self, pipeline: "Pipeline"): super().set_pipeline(pipeline) diff --git a/server/tests/test_processors_broadcast.py b/server/tests/test_processors_broadcast.py new file mode 100644 index 00000000..fcddf31c --- /dev/null +++ b/server/tests/test_processors_broadcast.py @@ -0,0 +1,46 @@ +import pytest + + +@pytest.mark.asyncio +async def test_processor_broadcast(): + from reflector.processors.base import Processor, BroadcastProcessor, Pipeline + + class TestProcessor(Processor): + INPUT_TYPE = str + OUTPUT_TYPE = str + + def __init__(self, name, **kwargs): + super().__init__(**kwargs) + self.name = name + + async def _push(self, data): + data = data + f":{self.name}" + await self.emit(data) + + processors = [ + TestProcessor("A"), + BroadcastProcessor( + processors=[ + TestProcessor("B"), + TestProcessor("C"), + ], + ), + ] + + events = [] + + async def on_event(event): + events.append(event) + + pipeline = Pipeline(*processors) + pipeline.on(on_event) + await pipeline.push("test") + await pipeline.flush() + + assert len(events) == 3 + assert events[0].processor == "A" + assert events[0].data == "test:A" + assert events[1].processor == "B" + assert events[1].data == "test:A:B" + assert events[2].processor == "C" + assert events[2].data == "test:A:C"