mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
processors: implement Pipeline, simplify usage
This commit is contained in:
@@ -8,6 +8,7 @@ from reflector.settings import settings
|
||||
from reflector.logger import logger
|
||||
import httpx
|
||||
import asyncio
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
|
||||
@@ -19,6 +20,9 @@ class AudioFile:
|
||||
sample_width: int
|
||||
timestamp: float = 0.0
|
||||
|
||||
def release(self):
|
||||
self.path.unlink()
|
||||
|
||||
|
||||
@dataclass
|
||||
class Word:
|
||||
@@ -69,11 +73,16 @@ class Processor:
|
||||
INPUT_TYPE: type = None
|
||||
OUTPUT_TYPE: type = None
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, callback=None):
|
||||
self._processors = []
|
||||
self._callbacks = []
|
||||
if callback:
|
||||
self.on(callback)
|
||||
|
||||
def connect(self, processor: "Processor"):
|
||||
"""
|
||||
Connect this processor output to another processor
|
||||
"""
|
||||
if processor.INPUT_TYPE != self.OUTPUT_TYPE:
|
||||
raise ValueError(
|
||||
f"Processor {processor} input type {processor.INPUT_TYPE} "
|
||||
@@ -82,27 +91,41 @@ class Processor:
|
||||
self._processors.append(processor)
|
||||
|
||||
def disconnect(self, processor: "Processor"):
|
||||
"""
|
||||
Disconnect this processor data from another processor
|
||||
"""
|
||||
self._processors.remove(processor)
|
||||
|
||||
def on(self, callback):
|
||||
"""
|
||||
Register a callback to be called when data is emitted
|
||||
"""
|
||||
self._callbacks.append(callback)
|
||||
|
||||
def off(self, callback):
|
||||
"""
|
||||
Unregister a callback to be called when data is emitted
|
||||
"""
|
||||
self._callbacks.remove(callback)
|
||||
|
||||
async def emit(self, data):
|
||||
for callback in self._callbacks:
|
||||
if isinstance(data, AudioFile):
|
||||
import pdb; pdb.set_trace()
|
||||
await callback(data)
|
||||
for processor in self._processors:
|
||||
await processor.push(data)
|
||||
|
||||
async def push(self, data):
|
||||
"""
|
||||
Push data to this processor. `data` must be of type `INPUT_TYPE`
|
||||
The function returns the output of type `OUTPUT_TYPE`
|
||||
"""
|
||||
# logger.debug(f"{self.__class__.__name__} push")
|
||||
return await self._push(data)
|
||||
|
||||
async def flush(self):
|
||||
"""
|
||||
Flush data to this processor
|
||||
"""
|
||||
# logger.debug(f"{self.__class__.__name__} flush")
|
||||
return await self._flush()
|
||||
|
||||
@@ -114,12 +137,24 @@ class Processor:
|
||||
|
||||
@classmethod
|
||||
def as_threaded(cls, *args, **kwargs):
|
||||
"""
|
||||
Return a single threaded processor where output is guaranteed
|
||||
to be in order
|
||||
"""
|
||||
return ThreadedProcessor(cls(*args, **kwargs), max_workers=1)
|
||||
|
||||
|
||||
class ThreadedProcessor(Processor):
|
||||
"""
|
||||
A processor that runs in a separate thread
|
||||
"""
|
||||
|
||||
def __init__(self, processor: Processor, max_workers=1):
|
||||
super().__init__()
|
||||
# FIXME: This is a hack to make sure that the processor is single threaded
|
||||
# but if it is more than 1, then we need to make sure that the processor
|
||||
# is emiting data in order
|
||||
assert max_workers == 1
|
||||
self.processor = processor
|
||||
self.INPUT_TYPE = processor.INPUT_TYPE
|
||||
self.OUTPUT_TYPE = processor.OUTPUT_TYPE
|
||||
@@ -231,9 +266,12 @@ class AudioTranscriptProcessor(Processor):
|
||||
OUTPUT_TYPE = Transcript
|
||||
|
||||
async def _push(self, data: AudioFile):
|
||||
result = await self._transcript(data)
|
||||
if result:
|
||||
await self.emit(result)
|
||||
try:
|
||||
result = await self._transcript(data)
|
||||
if result:
|
||||
await self.emit(result)
|
||||
finally:
|
||||
data.release()
|
||||
|
||||
async def _transcript(self, data: AudioFile):
|
||||
raise NotImplementedError
|
||||
@@ -282,8 +320,8 @@ class TranscriptLineProcessor(Processor):
|
||||
INPUT_TYPE = Transcript
|
||||
OUTPUT_TYPE = Transcript
|
||||
|
||||
def __init__(self, max_text=1000):
|
||||
super().__init__()
|
||||
def __init__(self, max_text=1000, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.transcript = Transcript(words=[])
|
||||
self.max_text = max_text
|
||||
|
||||
@@ -340,9 +378,58 @@ class TitleSummaryProcessor(Processor):
|
||||
summary = TitleSummary(title=result.title, summary=result.description)
|
||||
await self.emit(summary)
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
logger.error(f"Failed to call llm: {e}")
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to call llm")
|
||||
return
|
||||
|
||||
|
||||
class Pipeline(Processor):
|
||||
"""
|
||||
A pipeline of processors
|
||||
"""
|
||||
|
||||
INPUT_TYPE = None
|
||||
OUTPUT_TYPE = None
|
||||
|
||||
def __init__(self, *processors):
|
||||
super().__init__()
|
||||
self.processors = processors
|
||||
|
||||
for i in range(len(processors) - 1):
|
||||
processors[i].connect(processors[i + 1])
|
||||
|
||||
self.INPUT_TYPE = processors[0].INPUT_TYPE
|
||||
self.OUTPUT_TYPE = processors[-1].OUTPUT_TYPE
|
||||
|
||||
async def _push(self, data):
|
||||
await self.processors[0].push(data)
|
||||
|
||||
async def _flush(self):
|
||||
for processor in self.processors:
|
||||
await processor.flush()
|
||||
|
||||
|
||||
class FinalSummaryProcessor(Processor):
|
||||
"""
|
||||
Assemble all summary into a line-based json
|
||||
"""
|
||||
|
||||
INPUT_TYPE = TitleSummary
|
||||
OUTPUT_TYPE = Path
|
||||
|
||||
def __init__(self, filename: Path, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.filename = filename
|
||||
|
||||
async def _push(self, data: TitleSummary):
|
||||
with open(self.filename, "a", encoding="utf8") as fd:
|
||||
fd.write(json.dumps(data))
|
||||
|
||||
async def _flush(self):
|
||||
logger.info(f"Writing to {self.filename}")
|
||||
await self.emit(self.filename)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -353,44 +440,40 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
async def main():
|
||||
chunker = AudioChunkerProcessor()
|
||||
|
||||
# merge audio
|
||||
merger = AudioMergeProcessor.as_threaded()
|
||||
chunker.connect(merger)
|
||||
|
||||
# transcript audio
|
||||
transcripter = AudioWhisperTranscriptProcessor()
|
||||
merger.connect(transcripter)
|
||||
|
||||
# merge transcript and output lines
|
||||
line_processor = TranscriptLineProcessor()
|
||||
transcripter.connect(line_processor)
|
||||
|
||||
async def on_transcript(transcript):
|
||||
print(f"Transcript: [{transcript.human_timestamp}]: {transcript.text}")
|
||||
|
||||
line_processor.on(on_transcript)
|
||||
async def on_summary(summary):
|
||||
print(f"Summary: {summary.title} - {summary.summary}")
|
||||
|
||||
# # title and summary
|
||||
# title_summary = TitleSummaryProcessor.as_threaded()
|
||||
# line_processor.connect(title_summary)
|
||||
#
|
||||
# async def on_summary(summary):
|
||||
# print(f"Summary: title={summary.title} summary={summary.summary}")
|
||||
#
|
||||
# title_summary.on(on_summary)
|
||||
async def on_final_summary(path):
|
||||
print(f"Final Summary: {path}")
|
||||
|
||||
# transcription output
|
||||
result_fn = Path(args.source).with_suffix(".jsonl")
|
||||
|
||||
pipeline = Pipeline(
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioWhisperTranscriptProcessor().as_threaded(),
|
||||
TranscriptLineProcessor(callback=on_transcript),
|
||||
TitleSummaryProcessor.as_threaded(callback=on_summary),
|
||||
FinalSummaryProcessor.as_threaded(
|
||||
filename=result_fn, callback=on_final_summary
|
||||
),
|
||||
)
|
||||
|
||||
# start processing audio
|
||||
logger.info(f"Opening{args.source}")
|
||||
container = av.open(args.source)
|
||||
for frame in container.decode(audio=0):
|
||||
await chunker.push(frame)
|
||||
try:
|
||||
logger.info("Start pushing audio into the pipeline")
|
||||
for frame in container.decode(audio=0):
|
||||
await pipeline.push(frame)
|
||||
finally:
|
||||
logger.info("Flushing the pipeline")
|
||||
await pipeline.flush()
|
||||
|
||||
# audio done, flush everything
|
||||
await chunker.flush()
|
||||
await merger.flush()
|
||||
await transcripter.flush()
|
||||
await line_processor.flush()
|
||||
# await title_summary.flush()
|
||||
logger.info("All done !")
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Reference in New Issue
Block a user