server: add BroadcastProcessor tests

This commit is contained in:
2023-08-31 11:16:27 +02:00
committed by Mathieu Virbel
parent 9ed26030a5
commit 600f2ca370
2 changed files with 52 additions and 2 deletions

View File

@@ -1,6 +1,6 @@
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any, Union
from uuid import uuid4 from uuid import uuid4
from pydantic import BaseModel from pydantic import BaseModel
@@ -211,12 +211,16 @@ class BroadcastProcessor(Processor):
This processor does not guarantee that the output is in order. This processor does not guarantee that the output is in order.
This processor connect all the output of the processors to the input of This processor connect all the output of the processors to the input of
the next processor. the next processor; so the next processor must be able to accept different
types of input.
""" """
def __init__(self, processors: Processor): def __init__(self, processors: Processor):
super().__init__() super().__init__()
self.processors = processors self.processors = processors
self.INPUT_TYPE = processors[0].INPUT_TYPE
output_types = set([processor.OUTPUT_TYPE for processor in processors])
self.OUTPUT_TYPE = Union[tuple(output_types)]
def set_pipeline(self, pipeline: "Pipeline"): def set_pipeline(self, pipeline: "Pipeline"):
super().set_pipeline(pipeline) super().set_pipeline(pipeline)

View File

@@ -0,0 +1,46 @@
import pytest
@pytest.mark.asyncio
async def test_processor_broadcast():
from reflector.processors.base import Processor, BroadcastProcessor, Pipeline
class TestProcessor(Processor):
INPUT_TYPE = str
OUTPUT_TYPE = str
def __init__(self, name, **kwargs):
super().__init__(**kwargs)
self.name = name
async def _push(self, data):
data = data + f":{self.name}"
await self.emit(data)
processors = [
TestProcessor("A"),
BroadcastProcessor(
processors=[
TestProcessor("B"),
TestProcessor("C"),
],
),
]
events = []
async def on_event(event):
events.append(event)
pipeline = Pipeline(*processors)
pipeline.on(on_event)
await pipeline.push("test")
await pipeline.flush()
assert len(events) == 3
assert events[0].processor == "A"
assert events[0].data == "test:A"
assert events[1].processor == "B"
assert events[1].data == "test:A:B"
assert events[2].processor == "C"
assert events[2].data == "test:A:C"