Merge branch 'main' of github.com:Monadical-SAS/reflector into llm-modal

This commit is contained in:
Gokul Mohanarangan
2023-09-05 12:39:59 +05:30
12 changed files with 429 additions and 88 deletions

View File

@@ -1,10 +1,10 @@
from .base import Processor, ThreadedProcessor, Pipeline # noqa: F401
from .types import AudioFile, Transcript, Word, TitleSummary, FinalSummary # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_chunker import AudioChunkerProcessor # noqa: F401 from .audio_chunker import AudioChunkerProcessor # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401
from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401 from .audio_transcript_auto import AudioTranscriptAutoProcessor # noqa: F401
from .base import Pipeline, PipelineEvent, Processor, ThreadedProcessor # noqa: F401
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401
from .transcript_liner import TranscriptLinerProcessor # noqa: F401 from .transcript_liner import TranscriptLinerProcessor # noqa: F401
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401 from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
from .transcript_final_summary import TranscriptFinalSummaryProcessor # noqa: F401 from .types import AudioFile, FinalSummary, TitleSummary, Transcript, Word # noqa: F401

View File

@@ -1,17 +1,25 @@
import asyncio import asyncio
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any, Union
from uuid import uuid4 from uuid import uuid4
from pydantic import BaseModel
from reflector.logger import logger from reflector.logger import logger
class PipelineEvent(BaseModel):
processor: str
uid: str
data: Any
class Processor: class Processor:
INPUT_TYPE: type = None INPUT_TYPE: type = None
OUTPUT_TYPE: type = None OUTPUT_TYPE: type = None
WARMUP_EVENT: str = "WARMUP_EVENT" WARMUP_EVENT: str = "WARMUP_EVENT"
def __init__(self, callback=None, custom_logger=None): def __init__(self, callback=None, custom_logger=None):
self.name = self.__class__.__name__
self._processors = [] self._processors = []
self._callbacks = [] self._callbacks = []
if callback: if callback:
@@ -67,6 +75,10 @@ class Processor:
return default return default
async def emit(self, data): async def emit(self, data):
if self.pipeline:
await self.pipeline.emit(
PipelineEvent(processor=self.name, uid=self.uid, data=data)
)
for callback in self._callbacks: for callback in self._callbacks:
await callback(data) await callback(data)
for processor in self._processors: for processor in self._processors:
@@ -183,11 +195,72 @@ class ThreadedProcessor(Processor):
def on(self, callback): def on(self, callback):
self.processor.on(callback) self.processor.on(callback)
def off(self, callback):
self.processor.off(callback)
def describe(self, level=0): def describe(self, level=0):
super().describe(level) super().describe(level)
self.processor.describe(level + 1) self.processor.describe(level + 1)
class BroadcastProcessor(Processor):
"""
A processor that broadcasts data to multiple processors, in the order
they were passed to the constructor
This processor does not guarantee that the output is in order.
This processor connect all the output of the processors to the input of
the next processor; so the next processor must be able to accept different
types of input.
"""
def __init__(self, processors: Processor):
super().__init__()
self.processors = processors
self.INPUT_TYPE = processors[0].INPUT_TYPE
output_types = set([processor.OUTPUT_TYPE for processor in processors])
self.OUTPUT_TYPE = Union[tuple(output_types)]
def set_pipeline(self, pipeline: "Pipeline"):
super().set_pipeline(pipeline)
for processor in self.processors:
processor.set_pipeline(pipeline)
async def _warmup(self):
for processor in self.processors:
await processor.warmup()
async def _push(self, data):
for processor in self.processors:
await processor.push(data)
async def _flush(self):
for processor in self.processors:
await processor.flush()
def connect(self, processor: Processor):
for processor in self.processors:
processor.connect(processor)
def disconnect(self, processor: Processor):
for processor in self.processors:
processor.disconnect(processor)
def on(self, callback):
for processor in self.processors:
processor.on(callback)
def off(self, callback):
for processor in self.processors:
processor.off(callback)
def describe(self, level=0):
super().describe(level)
for processor in self.processors:
processor.describe(level + 1)
class Pipeline(Processor): class Pipeline(Processor):
""" """
A pipeline of processors A pipeline of processors

View File

@@ -1,43 +1,45 @@
import asyncio
import av import av
from reflector.logger import logger from reflector.logger import logger
from reflector.processors import ( from reflector.processors import (
Pipeline,
AudioChunkerProcessor, AudioChunkerProcessor,
AudioMergeProcessor, AudioMergeProcessor,
AudioTranscriptAutoProcessor, AudioTranscriptAutoProcessor,
Pipeline,
PipelineEvent,
TranscriptFinalSummaryProcessor,
TranscriptLinerProcessor, TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor, TranscriptTopicDetectorProcessor,
TranscriptFinalSummaryProcessor,
) )
import asyncio
async def process_audio_file(filename, event_callback, only_transcript=False): async def process_audio_file(
async def on_transcript(data): filename,
await event_callback("transcript", data) event_callback,
only_transcript=False,
async def on_topic(data): source_language="en",
await event_callback("topic", data) target_language="en",
):
async def on_summary(data):
await event_callback("summary", data)
# build pipeline for audio processing # build pipeline for audio processing
processors = [ processors = [
AudioChunkerProcessor(), AudioChunkerProcessor(),
AudioMergeProcessor(), AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(), AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(callback=on_transcript), TranscriptLinerProcessor(),
] ]
if not only_transcript: if not only_transcript:
processors += [ processors += [
TranscriptTopicDetectorProcessor.as_threaded(callback=on_topic), TranscriptTopicDetectorProcessor.as_threaded(),
TranscriptFinalSummaryProcessor.as_threaded(callback=on_summary), TranscriptFinalSummaryProcessor.as_threaded(),
] ]
# transcription output # transcription output
pipeline = Pipeline(*processors) pipeline = Pipeline(*processors)
pipeline.set_pref("audio:source_language", source_language)
pipeline.set_pref("audio:target_language", target_language)
pipeline.describe() pipeline.describe()
pipeline.on(event_callback)
# start processing audio # start processing audio
logger.info(f"Opening {filename}") logger.info(f"Opening {filename}")
@@ -59,20 +61,35 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("source", help="Source file (mp3, wav, mp4...)") parser.add_argument("source", help="Source file (mp3, wav, mp4...)")
parser.add_argument("--only-transcript", "-t", action="store_true") parser.add_argument("--only-transcript", "-t", action="store_true")
parser.add_argument("--source-language", default="en")
parser.add_argument("--target-language", default="en")
parser.add_argument("--output", "-o", help="Output file (output.jsonl)")
args = parser.parse_args() args = parser.parse_args()
async def event_callback(event, data): output_fd = None
if event == "transcript": if args.output:
print(f"Transcript[{data.human_timestamp}]: {data.text}") output_fd = open(args.output, "w")
elif event == "topic":
print(f"Topic[{data.human_timestamp}]: title={data.title}") async def event_callback(event: PipelineEvent):
print(f"Topic[{data.human_timestamp}]: summary={data.summary}") processor = event.processor
elif event == "summary": # ignore some processor
print(f"Summary: duration={data.duration}") if processor in ("AudioChunkerProcessor", "AudioMergeProcessor"):
print(f"Summary: summary={data.summary}") return
logger.info(f"Event: {event}")
if output_fd:
output_fd.write(event.model_dump_json())
output_fd.write("\n")
asyncio.run( asyncio.run(
process_audio_file( process_audio_file(
args.source, event_callback, only_transcript=args.only_transcript args.source,
event_callback,
only_transcript=args.only_transcript,
source_language=args.source_language,
target_language=args.target_language,
) )
) )
if output_fd:
output_fd.close()
logger.info(f"Output written to {args.output}")

View File

@@ -0,0 +1,110 @@
"""
# Run a pipeline of processor
This tools help to either create a pipeline from command line,
or read a yaml description of a pipeline and run it.
"""
import json
from reflector.logger import logger
from reflector.processors import Pipeline, PipelineEvent
def camel_to_snake(s):
return "".join(["_" + c.lower() if c.isupper() else c for c in s]).lstrip("_")
def snake_to_camel(s):
return "".join([c.capitalize() for c in s.split("_")])
def get_jsonl(filename, filter_processor_name=None):
logger.info(f"Opening {args.input}")
if filter_processor_name is not None:
filter_processor_name = snake_to_camel(filter_processor_name) + "Processor"
logger.info(f"Filtering on {filter_processor_name}")
with open(filename, encoding="utf8") as f:
for line in f:
data = json.loads(line)
if (
filter_processor_name is not None
and data["processor"] != filter_processor_name
):
continue
yield data
def get_processor(name):
import importlib
module_name = f"reflector.processors.{name}"
class_name = snake_to_camel(name) + "Processor"
module = importlib.import_module(module_name)
return getattr(module, class_name)
async def run_single_processor(args):
output_fd = None
if args.output:
output_fd = open(args.output, "w")
async def event_callback(event: PipelineEvent):
processor = event.processor
# ignore some processor
if processor in ("AudioChunkerProcessor", "AudioMergeProcessor"):
return
print(f"Event: {event}")
if output_fd:
output_fd.write(event.model_dump_json())
output_fd.write("\n")
processor = get_processor(args.processor)()
pipeline = Pipeline(processor)
pipeline.on(event_callback)
input_type = pipeline.INPUT_TYPE
logger.info(f"Converting to {input_type.__name__} type")
for data in get_jsonl(args.input, filter_processor_name=args.input_processor):
obj = input_type(**data["data"])
await pipeline.push(obj)
await pipeline.flush()
if output_fd:
output_fd.close()
logger.info(f"Output written to {args.output}")
if __name__ == "__main__":
import argparse
import asyncio
import sys
parser = argparse.ArgumentParser(description="Run a pipeline of processor")
parser.add_argument("--input", "-i", help="Input file (jsonl)")
parser.add_argument("--input-processor", "-f", help="Name of the processor to keep")
parser.add_argument("--output", "-o", help="Output file (jsonl)")
parser.add_argument("--pipeline", "-p", help="Pipeline description (yaml)")
parser.add_argument("--processor", help="Processor to run")
args = parser.parse_args()
if args.output and args.output == args.input:
parser.error("Input and output cannot be the same")
sys.exit(1)
if args.processor and args.pipeline:
parser.error("--processor and --pipeline are mutually exclusive")
sys.exit(1)
if not args.processor and not args.pipeline:
parser.error("You need to specify either --processor or --pipeline")
sys.exit(1)
if args.processor:
func = run_single_processor(args)
# elif args.pipeline:
# func = run_pipeline(args)
asyncio.run(func)

View File

@@ -0,0 +1,46 @@
import pytest
@pytest.mark.asyncio
async def test_processor_broadcast():
from reflector.processors.base import Processor, BroadcastProcessor, Pipeline
class TestProcessor(Processor):
INPUT_TYPE = str
OUTPUT_TYPE = str
def __init__(self, name, **kwargs):
super().__init__(**kwargs)
self.name = name
async def _push(self, data):
data = data + f":{self.name}"
await self.emit(data)
processors = [
TestProcessor("A"),
BroadcastProcessor(
processors=[
TestProcessor("B"),
TestProcessor("C"),
],
),
]
events = []
async def on_event(event):
events.append(event)
pipeline = Pipeline(*processors)
pipeline.on(on_event)
await pipeline.push("test")
await pipeline.flush()
assert len(events) == 3
assert events[0].processor == "A"
assert events[0].data == "test:A"
assert events[1].processor == "B"
assert events[1].data == "test:A:B"
assert events[2].processor == "C"
assert events[2].data == "test:A:C"

View File

@@ -24,15 +24,12 @@ async def test_basic_process(event_loop):
LLM.register("test", LLMTest) LLM.register("test", LLMTest)
# event callback # event callback
marks = { marks = {}
"transcript": 0,
"topic": 0,
"summary": 0,
}
async def event_callback(event, data): async def event_callback(event):
print(f"{event}: {data}") if event.processor not in marks:
marks[event] += 1 marks[event.processor] = 0
marks[event.processor] += 1
# invoke the process and capture events # invoke the process and capture events
path = Path(__file__).parent / "records" / "test_mathieu_hello.wav" path = Path(__file__).parent / "records" / "test_mathieu_hello.wav"
@@ -40,6 +37,6 @@ async def test_basic_process(event_loop):
print(marks) print(marks)
# validate the events # validate the events
assert marks["transcript"] == 5 assert marks["TranscriptLinerProcessor"] == 5
assert marks["topic"] == 1 assert marks["TranscriptTopicDetectorProcessor"] == 1
assert marks["summary"] == 1 assert marks["TranscriptFinalSummaryProcessor"] == 1

View File

@@ -0,0 +1,31 @@
"use client";
import React, { createContext, useContext, useState } from "react";
interface ErrorContextProps {
error: Error | null;
setError: React.Dispatch<React.SetStateAction<Error | null>>;
}
const ErrorContext = createContext<ErrorContextProps | undefined>(undefined);
export const useError = () => {
const context = useContext(ErrorContext);
if (!context) {
throw new Error("useError must be used within an ErrorProvider");
}
return context;
};
interface ErrorProviderProps {
children: React.ReactNode;
}
export const ErrorProvider: React.FC<ErrorProviderProps> = ({ children }) => {
const [error, setError] = useState<Error | null>(null);
return (
<ErrorContext.Provider value={{ error, setError }}>
{children}
</ErrorContext.Provider>
);
};

View File

@@ -0,0 +1,34 @@
"use client";
import { useError } from "./errorContext";
import { useEffect, useState } from "react";
import * as Sentry from "@sentry/react";
const ErrorMessage: React.FC = () => {
const { error, setError } = useError();
const [isVisible, setIsVisible] = useState<boolean>(false);
useEffect(() => {
if (error) {
setIsVisible(true);
Sentry.captureException(error);
console.error("Error", error.message, error);
}
}, [error]);
if (!isVisible || !error) return null;
return (
<div
onClick={() => {
setIsVisible(false);
setError(null);
}}
className="max-w-xs z-50 fixed top-16 right-10 bg-red-100 border border-red-400 text-red-700 px-4 py-3 rounded transition-opacity duration-300 ease-out opacity-100 hover:opacity-75 cursor-pointer transform hover:scale-105"
role="alert"
>
<span className="block sm:inline">{error?.message}</span>
</div>
);
};
export default ErrorMessage;

View File

@@ -3,6 +3,8 @@ import { Roboto } from "next/font/google";
import { Metadata } from "next"; import { Metadata } from "next";
import FiefWrapper from "./(auth)/fiefWrapper"; import FiefWrapper from "./(auth)/fiefWrapper";
import UserInfo from "./(auth)/userInfo"; import UserInfo from "./(auth)/userInfo";
import { ErrorProvider } from "./(errors)/errorContext";
import ErrorMessage from "./(errors)/errorMessage";
const roboto = Roboto({ subsets: ["latin"], weight: "400" }); const roboto = Roboto({ subsets: ["latin"], weight: "400" });
@@ -55,19 +57,24 @@ export default function RootLayout({ children }) {
<html lang="en"> <html lang="en">
<body className={roboto.className + " flex flex-col min-h-screen"}> <body className={roboto.className + " flex flex-col min-h-screen"}>
<FiefWrapper> <FiefWrapper>
<div id="container"> <ErrorProvider>
<div className="flex flex-col items-center h-[100svh] bg-gradient-to-r from-[#8ec5fc30] to-[#e0c3fc42]"> <ErrorMessage />
<UserInfo /> <div id="container">
<div className="flex flex-col items-center h-[100svh] bg-gradient-to-r from-[#8ec5fc30] to-[#e0c3fc42]">
<UserInfo />
<div className="h-[13svh] flex flex-col justify-center items-center"> <div className="h-[13svh] flex flex-col justify-center items-center">
<h1 className="text-5xl font-bold text-blue-500">Reflector</h1> <h1 className="text-5xl font-bold text-blue-500">
<p className="text-gray-500"> Reflector
Capture The Signal, Not The Noise </h1>
</p> <p className="text-gray-500">
Capture The Signal, Not The Noise
</p>
</div>
{children}
</div> </div>
{children}
</div> </div>
</div> </ErrorProvider>
</FiefWrapper> </FiefWrapper>
</body> </body>
</html> </html>

View File

@@ -1,19 +1,18 @@
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { DefaultApi, V1TranscriptsCreateRequest } from "../api/apis/DefaultApi"; import { DefaultApi, V1TranscriptsCreateRequest } from "../api/apis/DefaultApi";
import { GetTranscript } from "../api"; import { GetTranscript } from "../api";
import getApi from "../lib/getApi"; import { useError } from "../(errors)/errorContext";
type UseTranscript = { type UseTranscript = {
response: GetTranscript | null; response: GetTranscript | null;
loading: boolean; loading: boolean;
error: string | null;
createTranscript: () => void; createTranscript: () => void;
}; };
const useTranscript = (api: DefaultApi): UseTranscript => { const useTranscript = (api: DefaultApi): UseTranscript => {
const [response, setResponse] = useState<GetTranscript | null>(null); const [response, setResponse] = useState<GetTranscript | null>(null);
const [loading, setLoading] = useState<boolean>(false); const [loading, setLoading] = useState<boolean>(false);
const [error, setError] = useState<string | null>(null); const { setError } = useError();
const createTranscript = () => { const createTranscript = () => {
setLoading(true); setLoading(true);
@@ -37,10 +36,7 @@ const useTranscript = (api: DefaultApi): UseTranscript => {
console.debug("New transcript created:", result); console.debug("New transcript created:", result);
}) })
.catch((err) => { .catch((err) => {
const errorString = err.response || err.message || "Unknown error"; setError(err);
setError(errorString);
setLoading(false);
console.error("Error creating transcript:", errorString);
}); });
}; };
@@ -48,7 +44,7 @@ const useTranscript = (api: DefaultApi): UseTranscript => {
createTranscript(); createTranscript();
}, []); }, []);
return { response, loading, error, createTranscript }; return { response, loading, createTranscript };
}; };
export default useTranscript; export default useTranscript;

View File

@@ -4,7 +4,7 @@ import {
DefaultApi, DefaultApi,
V1TranscriptRecordWebrtcRequest, V1TranscriptRecordWebrtcRequest,
} from "../api/apis/DefaultApi"; } from "../api/apis/DefaultApi";
import { Configuration } from "../api/runtime"; import { useError } from "../(errors)/errorContext";
const useWebRTC = ( const useWebRTC = (
stream: MediaStream | null, stream: MediaStream | null,
@@ -12,13 +12,25 @@ const useWebRTC = (
api: DefaultApi, api: DefaultApi,
): Peer => { ): Peer => {
const [peer, setPeer] = useState<Peer | null>(null); const [peer, setPeer] = useState<Peer | null>(null);
const { setError } = useError();
useEffect(() => { useEffect(() => {
if (!stream || !transcriptId) { if (!stream || !transcriptId) {
return; return;
} }
let p: Peer = new Peer({ initiator: true, stream: stream }); let p: Peer;
try {
p = new Peer({ initiator: true, stream: stream });
} catch (error) {
setError(error);
return;
}
p.on("error", (err) => {
setError(new Error(`WebRTC error: ${err}`));
});
p.on("signal", (data: any) => { p.on("signal", (data: any) => {
if ("sdp" in data) { if ("sdp" in data) {
@@ -33,10 +45,14 @@ const useWebRTC = (
api api
.v1TranscriptRecordWebrtc(requestParameters) .v1TranscriptRecordWebrtc(requestParameters)
.then((answer) => { .then((answer) => {
p.signal(answer); try {
p.signal(answer);
} catch (error) {
setError(error);
}
}) })
.catch((err) => { .catch((error) => {
console.error("WebRTC signaling error:", err); setError(error);
}); });
} }
}); });

View File

@@ -1,5 +1,6 @@
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { Topic, FinalSummary, Status } from "./webSocketTypes"; import { Topic, FinalSummary, Status } from "./webSocketTypes";
import { useError } from "../(errors)/errorContext";
type UseWebSockets = { type UseWebSockets = {
transcriptText: string; transcriptText: string;
@@ -15,6 +16,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
summary: "", summary: "",
}); });
const [status, setStatus] = useState<Status>({ value: "disconnected" }); const [status, setStatus] = useState<Status>({ value: "disconnected" });
const { setError } = useError();
useEffect(() => { useEffect(() => {
document.onkeyup = (e) => { document.onkeyup = (e) => {
@@ -77,41 +79,53 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
ws.onmessage = (event) => { ws.onmessage = (event) => {
const message = JSON.parse(event.data); const message = JSON.parse(event.data);
switch (message.event) { try {
case "TRANSCRIPT": switch (message.event) {
if (message.data.text) { case "TRANSCRIPT":
setTranscriptText((message.data.text ?? "").trim()); if (message.data.text) {
console.debug("TRANSCRIPT event:", message.data); setTranscriptText((message.data.text ?? "").trim());
} console.debug("TRANSCRIPT event:", message.data);
break; }
break;
case "TOPIC": case "TOPIC":
setTopics((prevTopics) => [...prevTopics, message.data]); setTopics((prevTopics) => [...prevTopics, message.data]);
console.debug("TOPIC event:", message.data); console.debug("TOPIC event:", message.data);
break; break;
case "FINAL_SUMMARY": case "FINAL_SUMMARY":
if (message.data) { if (message.data) {
setFinalSummary(message.data); setFinalSummary(message.data);
console.debug("FINAL_SUMMARY event:", message.data); console.debug("FINAL_SUMMARY event:", message.data);
} }
break; break;
case "STATUS": case "STATUS":
setStatus(message.data); setStatus(message.data);
break; break;
default: default:
console.error("Unknown event:", message.event); setError(
new Error(`Received unknown WebSocket event: ${message.event}`),
);
}
} catch (error) {
setError(error);
} }
}; };
ws.onerror = (error) => { ws.onerror = (error) => {
console.error("WebSocket error:", error); console.error("WebSocket error:", error);
setError(new Error("A WebSocket error occurred."));
}; };
ws.onclose = () => { ws.onclose = (event) => {
console.debug("WebSocket connection closed"); console.debug("WebSocket connection closed");
if (event.code !== 1000) {
setError(
new Error(`WebSocket closed unexpectedly with code: ${event.code}`),
);
}
}; };
return () => { return () => {