mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: full diarization processor implementation based on gokul app
This commit is contained in:
@@ -64,6 +64,9 @@ app.include_router(transcripts_router, prefix="/v1")
|
||||
app.include_router(user_router, prefix="/v1")
|
||||
add_pagination(app)
|
||||
|
||||
# prepare celery
|
||||
from reflector.worker import app as celery_app # noqa
|
||||
|
||||
|
||||
# simpler openapi id
|
||||
def use_route_names_as_operation_ids(app: FastAPI) -> None:
|
||||
|
||||
@@ -13,10 +13,12 @@ It is directly linked to our data model.
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
|
||||
from celery import shared_task
|
||||
from pydantic import BaseModel
|
||||
from reflector.app import app
|
||||
from reflector.db.transcripts import (
|
||||
Transcript,
|
||||
TranscriptFinalLongSummary,
|
||||
@@ -29,7 +31,7 @@ from reflector.db.transcripts import (
|
||||
from reflector.pipelines.runner import PipelineRunner
|
||||
from reflector.processors import (
|
||||
AudioChunkerProcessor,
|
||||
AudioDiarizationProcessor,
|
||||
AudioDiarizationAutoProcessor,
|
||||
AudioFileWriterProcessor,
|
||||
AudioMergeProcessor,
|
||||
AudioTranscriptAutoProcessor,
|
||||
@@ -45,6 +47,7 @@ from reflector.processors import (
|
||||
from reflector.processors.types import AudioDiarizationInput
|
||||
from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
|
||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||
from reflector.settings import settings
|
||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||
|
||||
|
||||
@@ -174,7 +177,7 @@ class PipelineMainBase(PipelineRunner):
|
||||
async with self.transaction():
|
||||
transcript = await self.get_transcript()
|
||||
if not transcript.title:
|
||||
transcripts_controller.update(
|
||||
await transcripts_controller.update(
|
||||
transcript,
|
||||
{
|
||||
"title": final_title.title,
|
||||
@@ -238,19 +241,13 @@ class PipelineMainLive(PipelineMainBase):
|
||||
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
||||
AudioChunkerProcessor(),
|
||||
AudioMergeProcessor(),
|
||||
AudioTranscriptAutoProcessor.as_threaded(),
|
||||
AudioTranscriptAutoProcessor.get_instance().as_threaded(),
|
||||
TranscriptLinerProcessor(),
|
||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
callback=self.on_long_summary
|
||||
),
|
||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
||||
callback=self.on_short_summary
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
@@ -277,7 +274,7 @@ class PipelineMainDiarization(PipelineMainBase):
|
||||
# add a customised logger to the context
|
||||
self.prepare()
|
||||
processors = [
|
||||
AudioDiarizationProcessor(),
|
||||
AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic),
|
||||
BroadcastProcessor(
|
||||
processors=[
|
||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||
@@ -307,8 +304,19 @@ class PipelineMainDiarization(PipelineMainBase):
|
||||
for topic in transcript.topics
|
||||
]
|
||||
|
||||
# we need to create an url to be used for diarization
|
||||
# we can't use the audio_mp3_filename because it's not accessible
|
||||
# from the diarization processor
|
||||
from reflector.views.transcripts import create_access_token
|
||||
|
||||
token = create_access_token(
|
||||
{"sub": transcript.user_id},
|
||||
expires_delta=timedelta(minutes=15),
|
||||
)
|
||||
path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id)
|
||||
url = f"{settings.BASE_URL}{path}?token={token}"
|
||||
audio_diarization_input = AudioDiarizationInput(
|
||||
audio_filename=transcript.audio_mp3_filename,
|
||||
audio_url=url,
|
||||
topics=topics,
|
||||
)
|
||||
|
||||
|
||||
@@ -55,7 +55,11 @@ class PipelineRunner(BaseModel):
|
||||
"""
|
||||
Start the pipeline synchronously (for non-asyncio apps)
|
||||
"""
|
||||
asyncio.run(self.run())
|
||||
loop = asyncio.get_event_loop()
|
||||
if not loop:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
loop.run_until_complete(self.run())
|
||||
|
||||
def push(self, data):
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||
from .audio_diarization import AudioDiarizationProcessor # noqa: F401
|
||||
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||
|
||||
|
||||
class AudioDiarizationProcessor(Processor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
async def _push(self, data: AudioDiarizationInput):
|
||||
# Gather diarization data
|
||||
diarization = [
|
||||
{"start": 0.0, "stop": 4.9, "speaker": 2},
|
||||
{"start": 5.6, "stop": 6.7, "speaker": 2},
|
||||
{"start": 7.3, "stop": 8.9, "speaker": 2},
|
||||
{"start": 7.3, "stop": 7.9, "speaker": 0},
|
||||
{"start": 9.4, "stop": 11.2, "speaker": 2},
|
||||
{"start": 9.7, "stop": 10.0, "speaker": 0},
|
||||
{"start": 10.0, "stop": 10.1, "speaker": 0},
|
||||
{"start": 11.7, "stop": 16.1, "speaker": 2},
|
||||
{"start": 11.8, "stop": 12.1, "speaker": 1},
|
||||
{"start": 16.4, "stop": 21.0, "speaker": 2},
|
||||
{"start": 21.1, "stop": 22.6, "speaker": 2},
|
||||
{"start": 24.7, "stop": 31.9, "speaker": 2},
|
||||
{"start": 32.0, "stop": 32.8, "speaker": 1},
|
||||
{"start": 33.4, "stop": 37.8, "speaker": 2},
|
||||
{"start": 37.9, "stop": 40.3, "speaker": 0},
|
||||
{"start": 39.2, "stop": 40.4, "speaker": 2},
|
||||
{"start": 40.7, "stop": 41.4, "speaker": 0},
|
||||
{"start": 41.6, "stop": 45.7, "speaker": 2},
|
||||
{"start": 46.4, "stop": 53.1, "speaker": 2},
|
||||
{"start": 53.6, "stop": 56.5, "speaker": 2},
|
||||
{"start": 54.9, "stop": 75.4, "speaker": 1},
|
||||
{"start": 57.3, "stop": 58.0, "speaker": 2},
|
||||
{"start": 65.7, "stop": 66.0, "speaker": 2},
|
||||
{"start": 75.8, "stop": 78.8, "speaker": 1},
|
||||
{"start": 79.0, "stop": 82.6, "speaker": 1},
|
||||
{"start": 83.2, "stop": 83.3, "speaker": 1},
|
||||
{"start": 84.5, "stop": 94.3, "speaker": 1},
|
||||
{"start": 95.1, "stop": 100.7, "speaker": 1},
|
||||
{"start": 100.7, "stop": 102.0, "speaker": 0},
|
||||
{"start": 100.7, "stop": 101.8, "speaker": 1},
|
||||
{"start": 102.0, "stop": 103.0, "speaker": 1},
|
||||
{"start": 103.0, "stop": 103.7, "speaker": 0},
|
||||
{"start": 103.7, "stop": 103.8, "speaker": 1},
|
||||
{"start": 103.8, "stop": 113.9, "speaker": 0},
|
||||
{"start": 114.7, "stop": 117.0, "speaker": 0},
|
||||
{"start": 117.0, "stop": 117.4, "speaker": 1},
|
||||
]
|
||||
|
||||
# now reapply speaker to topics (if any)
|
||||
# topics is a list[BaseModel] with an attribute words
|
||||
# words is a list[BaseModel] with text, start and speaker attribute
|
||||
|
||||
print("IN DIARIZATION PROCESSOR", data)
|
||||
|
||||
# mutate in place
|
||||
for topic in data.topics:
|
||||
for word in topic.transcript.words:
|
||||
for d in diarization:
|
||||
if d["start"] <= word.start <= d["stop"]:
|
||||
word.speaker = d["speaker"]
|
||||
|
||||
# emit them
|
||||
for topic in data.topics:
|
||||
await self.emit(topic)
|
||||
34
server/reflector/processors/audio_diarization_auto.py
Normal file
34
server/reflector/processors/audio_diarization_auto.py
Normal file
@@ -0,0 +1,34 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioDiarizationAutoProcessor(Processor):
|
||||
_registry = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, kclass):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.DIARIZATION_BACKEND
|
||||
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_diarization_{name}"
|
||||
importlib.import_module(module_name)
|
||||
|
||||
# gather specific configuration for the processor
|
||||
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||
config = {}
|
||||
name_upper = name.upper()
|
||||
settings_prefix = "DIARIZATION_"
|
||||
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||
for key, value in settings:
|
||||
if key.startswith(config_prefix):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config | kwargs)
|
||||
28
server/reflector/processors/audio_diarization_base.py
Normal file
28
server/reflector/processors/audio_diarization_base.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from reflector.processors.base import Processor
|
||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||
|
||||
|
||||
class AudioDiarizationBaseProcessor(Processor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
async def _push(self, data: AudioDiarizationInput):
|
||||
diarization = await self._diarize(data)
|
||||
|
||||
# now reapply speaker to topics (if any)
|
||||
# topics is a list[BaseModel] with an attribute words
|
||||
# words is a list[BaseModel] with text, start and speaker attribute
|
||||
|
||||
# mutate in place
|
||||
for topic in data.topics:
|
||||
for word in topic.transcript.words:
|
||||
for d in diarization:
|
||||
if d["start"] <= word.start <= d["end"]:
|
||||
word.speaker = d["speaker"]
|
||||
|
||||
# emit them
|
||||
for topic in data.topics:
|
||||
await self.emit(topic)
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
raise NotImplementedError
|
||||
36
server/reflector/processors/audio_diarization_modal.py
Normal file
36
server/reflector/processors/audio_diarization_modal.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import httpx
|
||||
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||
from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor
|
||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor):
|
||||
INPUT_TYPE = AudioDiarizationInput
|
||||
OUTPUT_TYPE = TitleSummary
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
||||
}
|
||||
|
||||
async def _diarize(self, data: AudioDiarizationInput):
|
||||
# Gather diarization data
|
||||
params = {
|
||||
"audio_file_url": data.audio_url,
|
||||
"timestamp": 0,
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.diarization_url,
|
||||
headers=self.headers,
|
||||
params=params,
|
||||
timeout=None,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()["text"]
|
||||
|
||||
|
||||
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)
|
||||
@@ -1,8 +1,6 @@
|
||||
import importlib
|
||||
|
||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||
from reflector.processors.base import Pipeline, Processor
|
||||
from reflector.processors.types import AudioFile
|
||||
from reflector.settings import settings
|
||||
|
||||
|
||||
@@ -14,7 +12,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||
cls._registry[name] = kclass
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, name):
|
||||
def get_instance(cls, name: str | None = None, **kwargs):
|
||||
if name is None:
|
||||
name = settings.TRANSCRIPT_BACKEND
|
||||
if name not in cls._registry:
|
||||
module_name = f"reflector.processors.audio_transcript_{name}"
|
||||
importlib.import_module(module_name)
|
||||
@@ -30,30 +30,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
||||
config_name = key[len(settings_prefix) :].lower()
|
||||
config[config_name] = value
|
||||
|
||||
return cls._registry[name](**config)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def set_pipeline(self, pipeline: Pipeline):
|
||||
super().set_pipeline(pipeline)
|
||||
self.processor.set_pipeline(pipeline)
|
||||
|
||||
def connect(self, processor: Processor):
|
||||
self.processor.connect(processor)
|
||||
|
||||
def disconnect(self, processor: Processor):
|
||||
self.processor.disconnect(processor)
|
||||
|
||||
def on(self, callback):
|
||||
self.processor.on(callback)
|
||||
|
||||
def off(self, callback):
|
||||
self.processor.off(callback)
|
||||
|
||||
async def _push(self, data: AudioFile):
|
||||
return await self.processor._push(data)
|
||||
|
||||
async def _flush(self):
|
||||
return await self.processor._flush()
|
||||
return cls._registry[name](**config | kwargs)
|
||||
|
||||
@@ -385,5 +385,5 @@ class TranslationLanguages(BaseModel):
|
||||
|
||||
|
||||
class AudioDiarizationInput(BaseModel):
|
||||
audio_filename: Path
|
||||
audio_url: str
|
||||
topics: list[TitleSummary]
|
||||
|
||||
@@ -89,6 +89,10 @@ class Settings(BaseSettings):
|
||||
# LLM Modal configuration
|
||||
LLM_MODAL_API_KEY: str | None = None
|
||||
|
||||
# Diarization
|
||||
DIARIZATION_BACKEND: str = "modal"
|
||||
DIARIZATION_URL: str | None = None
|
||||
|
||||
# Sentry
|
||||
SENTRY_DSN: str | None = None
|
||||
|
||||
@@ -121,5 +125,11 @@ class Settings(BaseSettings):
|
||||
REDIS_HOST: str = "localhost"
|
||||
REDIS_PORT: int = 6379
|
||||
|
||||
# Secret key
|
||||
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
|
||||
|
||||
# Current hosting/domain
|
||||
BASE_URL: str = "http://localhost:1250"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
14
server/reflector/tools/start_post_main_live_pipeline.py
Normal file
14
server/reflector/tools/start_post_main_live_pipeline.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import argparse
|
||||
|
||||
from reflector.app import celery_app # noqa
|
||||
from reflector.pipelines.main_live_pipeline import task_pipeline_main_post
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("transcript_id", type=str)
|
||||
parser.add_argument("--delay", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.delay:
|
||||
task_pipeline_main_post.delay(args.transcript_id)
|
||||
else:
|
||||
task_pipeline_main_post(args.transcript_id)
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import reflector.auth as auth
|
||||
@@ -9,8 +9,10 @@ from fastapi import (
|
||||
Request,
|
||||
WebSocket,
|
||||
WebSocketDisconnect,
|
||||
status,
|
||||
)
|
||||
from fastapi_pagination import Page, paginate
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
from reflector.db.transcripts import (
|
||||
AudioWaveform,
|
||||
@@ -27,6 +29,18 @@ from .rtc_offer import RtcOffer, rtc_offer_base
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ALGORITHM = "HS256"
|
||||
DOWNLOAD_EXPIRE_MINUTES = 60
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta):
|
||||
to_encode = data.copy()
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
# ==============================================================
|
||||
# Transcripts list
|
||||
# ==============================================================
|
||||
@@ -198,8 +212,21 @@ async def transcript_get_audio_mp3(
|
||||
request: Request,
|
||||
transcript_id: str,
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
token: str | None = None,
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
if not user_id and token:
|
||||
unauthorized_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
user_id: str = payload.get("sub")
|
||||
except jwt.JWTError:
|
||||
raise unauthorized_exception
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
@@ -4,6 +4,7 @@ from reflector.settings import settings
|
||||
app = Celery(__name__)
|
||||
app.conf.broker_url = settings.CELERY_BROKER_URL
|
||||
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||
app.conf.broker_connection_retry_on_startup = True
|
||||
app.autodiscover_tasks(
|
||||
[
|
||||
"reflector.pipelines.main_live_pipeline",
|
||||
|
||||
@@ -102,6 +102,7 @@ async def test_transcript_rtc_and_websocket(
|
||||
print("Test websocket: DISCONNECTED")
|
||||
|
||||
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||
print("Test websocket: TASK CREATED", websocket_task)
|
||||
|
||||
# create stream client
|
||||
import argparse
|
||||
@@ -243,6 +244,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
||||
print("Test websocket: DISCONNECTED")
|
||||
|
||||
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||
print("Test websocket: TASK CREATED", websocket_task)
|
||||
|
||||
# create stream client
|
||||
import argparse
|
||||
|
||||
Reference in New Issue
Block a user