processors: implement Pipeline, simplify usage

This commit is contained in:
2023-07-29 00:59:09 +02:00
parent 509840cb4c
commit 6f61863136

View File

@@ -8,6 +8,7 @@ from reflector.settings import settings
from reflector.logger import logger
import httpx
import asyncio
import json
from concurrent.futures import ThreadPoolExecutor
@@ -19,6 +20,9 @@ class AudioFile:
sample_width: int
timestamp: float = 0.0
def release(self):
self.path.unlink()
@dataclass
class Word:
@@ -69,11 +73,16 @@ class Processor:
INPUT_TYPE: type = None
OUTPUT_TYPE: type = None
def __init__(self):
def __init__(self, callback=None):
self._processors = []
self._callbacks = []
if callback:
self.on(callback)
def connect(self, processor: "Processor"):
"""
Connect this processor output to another processor
"""
if processor.INPUT_TYPE != self.OUTPUT_TYPE:
raise ValueError(
f"Processor {processor} input type {processor.INPUT_TYPE} "
@@ -82,27 +91,41 @@ class Processor:
self._processors.append(processor)
def disconnect(self, processor: "Processor"):
"""
Disconnect this processor data from another processor
"""
self._processors.remove(processor)
def on(self, callback):
"""
Register a callback to be called when data is emitted
"""
self._callbacks.append(callback)
def off(self, callback):
"""
Unregister a callback to be called when data is emitted
"""
self._callbacks.remove(callback)
async def emit(self, data):
for callback in self._callbacks:
if isinstance(data, AudioFile):
import pdb; pdb.set_trace()
await callback(data)
for processor in self._processors:
await processor.push(data)
async def push(self, data):
"""
Push data to this processor. `data` must be of type `INPUT_TYPE`
The function returns the output of type `OUTPUT_TYPE`
"""
# logger.debug(f"{self.__class__.__name__} push")
return await self._push(data)
async def flush(self):
"""
Flush data to this processor
"""
# logger.debug(f"{self.__class__.__name__} flush")
return await self._flush()
@@ -114,12 +137,24 @@ class Processor:
@classmethod
def as_threaded(cls, *args, **kwargs):
"""
Return a single threaded processor where output is guaranteed
to be in order
"""
return ThreadedProcessor(cls(*args, **kwargs), max_workers=1)
class ThreadedProcessor(Processor):
"""
A processor that runs in a separate thread
"""
def __init__(self, processor: Processor, max_workers=1):
super().__init__()
# FIXME: This is a hack to make sure that the processor is single threaded
# but if it is more than 1, then we need to make sure that the processor
# is emiting data in order
assert max_workers == 1
self.processor = processor
self.INPUT_TYPE = processor.INPUT_TYPE
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
@@ -231,9 +266,12 @@ class AudioTranscriptProcessor(Processor):
OUTPUT_TYPE = Transcript
async def _push(self, data: AudioFile):
result = await self._transcript(data)
if result:
await self.emit(result)
try:
result = await self._transcript(data)
if result:
await self.emit(result)
finally:
data.release()
async def _transcript(self, data: AudioFile):
raise NotImplementedError
@@ -282,8 +320,8 @@ class TranscriptLineProcessor(Processor):
INPUT_TYPE = Transcript
OUTPUT_TYPE = Transcript
def __init__(self, max_text=1000):
super().__init__()
def __init__(self, max_text=1000, **kwargs):
super().__init__(**kwargs)
self.transcript = Transcript(words=[])
self.max_text = max_text
@@ -340,9 +378,58 @@ class TitleSummaryProcessor(Processor):
summary = TitleSummary(title=result.title, summary=result.description)
await self.emit(summary)
except httpx.ConnectError as e:
logger.error(f"Failed to call llm: {e}")
except Exception:
logger.exception("Failed to call llm")
return
class Pipeline(Processor):
"""
A pipeline of processors
"""
INPUT_TYPE = None
OUTPUT_TYPE = None
def __init__(self, *processors):
super().__init__()
self.processors = processors
for i in range(len(processors) - 1):
processors[i].connect(processors[i + 1])
self.INPUT_TYPE = processors[0].INPUT_TYPE
self.OUTPUT_TYPE = processors[-1].OUTPUT_TYPE
async def _push(self, data):
await self.processors[0].push(data)
async def _flush(self):
for processor in self.processors:
await processor.flush()
class FinalSummaryProcessor(Processor):
"""
Assemble all summary into a line-based json
"""
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = Path
def __init__(self, filename: Path, **kwargs):
super().__init__(**kwargs)
self.filename = filename
async def _push(self, data: TitleSummary):
with open(self.filename, "a", encoding="utf8") as fd:
fd.write(json.dumps(data))
async def _flush(self):
logger.info(f"Writing to {self.filename}")
await self.emit(self.filename)
if __name__ == "__main__":
@@ -353,44 +440,40 @@ if __name__ == "__main__":
args = parser.parse_args()
async def main():
chunker = AudioChunkerProcessor()
# merge audio
merger = AudioMergeProcessor.as_threaded()
chunker.connect(merger)
# transcript audio
transcripter = AudioWhisperTranscriptProcessor()
merger.connect(transcripter)
# merge transcript and output lines
line_processor = TranscriptLineProcessor()
transcripter.connect(line_processor)
async def on_transcript(transcript):
print(f"Transcript: [{transcript.human_timestamp}]: {transcript.text}")
line_processor.on(on_transcript)
async def on_summary(summary):
print(f"Summary: {summary.title} - {summary.summary}")
# # title and summary
# title_summary = TitleSummaryProcessor.as_threaded()
# line_processor.connect(title_summary)
#
# async def on_summary(summary):
# print(f"Summary: title={summary.title} summary={summary.summary}")
#
# title_summary.on(on_summary)
async def on_final_summary(path):
print(f"Final Summary: {path}")
# transcription output
result_fn = Path(args.source).with_suffix(".jsonl")
pipeline = Pipeline(
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioWhisperTranscriptProcessor().as_threaded(),
TranscriptLineProcessor(callback=on_transcript),
TitleSummaryProcessor.as_threaded(callback=on_summary),
FinalSummaryProcessor.as_threaded(
filename=result_fn, callback=on_final_summary
),
)
# start processing audio
logger.info(f"Opening{args.source}")
container = av.open(args.source)
for frame in container.decode(audio=0):
await chunker.push(frame)
try:
logger.info("Start pushing audio into the pipeline")
for frame in container.decode(audio=0):
await pipeline.push(frame)
finally:
logger.info("Flushing the pipeline")
await pipeline.flush()
# audio done, flush everything
await chunker.flush()
await merger.flush()
await transcripter.flush()
await line_processor.flush()
# await title_summary.flush()
logger.info("All done !")
asyncio.run(main())