server: full diarization processor implementation based on gokul app

This commit is contained in:
2023-10-27 20:00:07 +02:00
committed by Mathieu Virbel
parent 07c4d080c2
commit d8a842f099
15 changed files with 186 additions and 110 deletions

View File

@@ -64,6 +64,9 @@ app.include_router(transcripts_router, prefix="/v1")
app.include_router(user_router, prefix="/v1") app.include_router(user_router, prefix="/v1")
add_pagination(app) add_pagination(app)
# prepare celery
from reflector.worker import app as celery_app # noqa
# simpler openapi id # simpler openapi id
def use_route_names_as_operation_ids(app: FastAPI) -> None: def use_route_names_as_operation_ids(app: FastAPI) -> None:

View File

@@ -13,10 +13,12 @@ It is directly linked to our data model.
import asyncio import asyncio
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path from pathlib import Path
from celery import shared_task from celery import shared_task
from pydantic import BaseModel from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import ( from reflector.db.transcripts import (
Transcript, Transcript,
TranscriptFinalLongSummary, TranscriptFinalLongSummary,
@@ -29,7 +31,7 @@ from reflector.db.transcripts import (
from reflector.pipelines.runner import PipelineRunner from reflector.pipelines.runner import PipelineRunner
from reflector.processors import ( from reflector.processors import (
AudioChunkerProcessor, AudioChunkerProcessor,
AudioDiarizationProcessor, AudioDiarizationAutoProcessor,
AudioFileWriterProcessor, AudioFileWriterProcessor,
AudioMergeProcessor, AudioMergeProcessor,
AudioTranscriptAutoProcessor, AudioTranscriptAutoProcessor,
@@ -45,6 +47,7 @@ from reflector.processors import (
from reflector.processors.types import AudioDiarizationInput from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import TitleSummary as TitleSummaryProcessorType from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager from reflector.ws_manager import WebsocketManager, get_ws_manager
@@ -174,7 +177,7 @@ class PipelineMainBase(PipelineRunner):
async with self.transaction(): async with self.transaction():
transcript = await self.get_transcript() transcript = await self.get_transcript()
if not transcript.title: if not transcript.title:
transcripts_controller.update( await transcripts_controller.update(
transcript, transcript,
{ {
"title": final_title.title, "title": final_title.title,
@@ -238,19 +241,13 @@ class PipelineMainLive(PipelineMainBase):
AudioFileWriterProcessor(path=transcript.audio_mp3_filename), AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioChunkerProcessor(), AudioChunkerProcessor(),
AudioMergeProcessor(), AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(), AudioTranscriptAutoProcessor.get_instance().as_threaded(),
TranscriptLinerProcessor(), TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
BroadcastProcessor( BroadcastProcessor(
processors=[ processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
] ]
), ),
] ]
@@ -277,7 +274,7 @@ class PipelineMainDiarization(PipelineMainBase):
# add a customised logger to the context # add a customised logger to the context
self.prepare() self.prepare()
processors = [ processors = [
AudioDiarizationProcessor(), AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic),
BroadcastProcessor( BroadcastProcessor(
processors=[ processors=[
TranscriptFinalLongSummaryProcessor.as_threaded( TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -307,8 +304,19 @@ class PipelineMainDiarization(PipelineMainBase):
for topic in transcript.topics for topic in transcript.topics
] ]
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from reflector.views.transcripts import create_access_token
token = create_access_token(
{"sub": transcript.user_id},
expires_delta=timedelta(minutes=15),
)
path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id)
url = f"{settings.BASE_URL}{path}?token={token}"
audio_diarization_input = AudioDiarizationInput( audio_diarization_input = AudioDiarizationInput(
audio_filename=transcript.audio_mp3_filename, audio_url=url,
topics=topics, topics=topics,
) )

View File

@@ -55,7 +55,11 @@ class PipelineRunner(BaseModel):
""" """
Start the pipeline synchronously (for non-asyncio apps) Start the pipeline synchronously (for non-asyncio apps)
""" """
asyncio.run(self.run()) loop = asyncio.get_event_loop()
if not loop:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.run())
def push(self, data): def push(self, data):
""" """

View File

@@ -1,5 +1,5 @@
from .audio_chunker import AudioChunkerProcessor # noqa: F401 from .audio_chunker import AudioChunkerProcessor # noqa: F401
from .audio_diarization import AudioDiarizationProcessor # noqa: F401 from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401 from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401 from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401 from .audio_transcript import AudioTranscriptProcessor # noqa: F401

View File

@@ -1,65 +0,0 @@
from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
class AudioDiarizationProcessor(Processor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
async def _push(self, data: AudioDiarizationInput):
# Gather diarization data
diarization = [
{"start": 0.0, "stop": 4.9, "speaker": 2},
{"start": 5.6, "stop": 6.7, "speaker": 2},
{"start": 7.3, "stop": 8.9, "speaker": 2},
{"start": 7.3, "stop": 7.9, "speaker": 0},
{"start": 9.4, "stop": 11.2, "speaker": 2},
{"start": 9.7, "stop": 10.0, "speaker": 0},
{"start": 10.0, "stop": 10.1, "speaker": 0},
{"start": 11.7, "stop": 16.1, "speaker": 2},
{"start": 11.8, "stop": 12.1, "speaker": 1},
{"start": 16.4, "stop": 21.0, "speaker": 2},
{"start": 21.1, "stop": 22.6, "speaker": 2},
{"start": 24.7, "stop": 31.9, "speaker": 2},
{"start": 32.0, "stop": 32.8, "speaker": 1},
{"start": 33.4, "stop": 37.8, "speaker": 2},
{"start": 37.9, "stop": 40.3, "speaker": 0},
{"start": 39.2, "stop": 40.4, "speaker": 2},
{"start": 40.7, "stop": 41.4, "speaker": 0},
{"start": 41.6, "stop": 45.7, "speaker": 2},
{"start": 46.4, "stop": 53.1, "speaker": 2},
{"start": 53.6, "stop": 56.5, "speaker": 2},
{"start": 54.9, "stop": 75.4, "speaker": 1},
{"start": 57.3, "stop": 58.0, "speaker": 2},
{"start": 65.7, "stop": 66.0, "speaker": 2},
{"start": 75.8, "stop": 78.8, "speaker": 1},
{"start": 79.0, "stop": 82.6, "speaker": 1},
{"start": 83.2, "stop": 83.3, "speaker": 1},
{"start": 84.5, "stop": 94.3, "speaker": 1},
{"start": 95.1, "stop": 100.7, "speaker": 1},
{"start": 100.7, "stop": 102.0, "speaker": 0},
{"start": 100.7, "stop": 101.8, "speaker": 1},
{"start": 102.0, "stop": 103.0, "speaker": 1},
{"start": 103.0, "stop": 103.7, "speaker": 0},
{"start": 103.7, "stop": 103.8, "speaker": 1},
{"start": 103.8, "stop": 113.9, "speaker": 0},
{"start": 114.7, "stop": 117.0, "speaker": 0},
{"start": 117.0, "stop": 117.4, "speaker": 1},
]
# now reapply speaker to topics (if any)
# topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute
print("IN DIARIZATION PROCESSOR", data)
# mutate in place
for topic in data.topics:
for word in topic.transcript.words:
for d in diarization:
if d["start"] <= word.start <= d["stop"]:
word.speaker = d["speaker"]
# emit them
for topic in data.topics:
await self.emit(topic)

View File

@@ -0,0 +1,34 @@
import importlib
from reflector.processors.base import Processor
from reflector.settings import settings
class AudioDiarizationAutoProcessor(Processor):
_registry = {}
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
@classmethod
def get_instance(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.DIARIZATION_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.audio_diarization_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
settings_prefix = "DIARIZATION_"
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config | kwargs)

View File

@@ -0,0 +1,28 @@
from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
class AudioDiarizationBaseProcessor(Processor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
async def _push(self, data: AudioDiarizationInput):
diarization = await self._diarize(data)
# now reapply speaker to topics (if any)
# topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute
# mutate in place
for topic in data.topics:
for word in topic.transcript.words:
for d in diarization:
if d["start"] <= word.start <= d["end"]:
word.speaker = d["speaker"]
# emit them
for topic in data.topics:
await self.emit(topic)
async def _diarize(self, data: AudioDiarizationInput):
raise NotImplementedError

View File

@@ -0,0 +1,36 @@
import httpx
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
from reflector.settings import settings
class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
self.headers = {
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
}
async def _diarize(self, data: AudioDiarizationInput):
# Gather diarization data
params = {
"audio_file_url": data.audio_url,
"timestamp": 0,
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.diarization_url,
headers=self.headers,
params=params,
timeout=None,
)
response.raise_for_status()
return response.json()["text"]
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)

View File

@@ -1,8 +1,6 @@
import importlib import importlib
from reflector.processors.audio_transcript import AudioTranscriptProcessor from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.base import Pipeline, Processor
from reflector.processors.types import AudioFile
from reflector.settings import settings from reflector.settings import settings
@@ -14,7 +12,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
cls._registry[name] = kclass cls._registry[name] = kclass
@classmethod @classmethod
def get_instance(cls, name): def get_instance(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.TRANSCRIPT_BACKEND
if name not in cls._registry: if name not in cls._registry:
module_name = f"reflector.processors.audio_transcript_{name}" module_name = f"reflector.processors.audio_transcript_{name}"
importlib.import_module(module_name) importlib.import_module(module_name)
@@ -30,30 +30,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
config_name = key[len(settings_prefix) :].lower() config_name = key[len(settings_prefix) :].lower()
config[config_name] = value config[config_name] = value
return cls._registry[name](**config) return cls._registry[name](**config | kwargs)
def __init__(self, **kwargs):
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
super().__init__(**kwargs)
def set_pipeline(self, pipeline: Pipeline):
super().set_pipeline(pipeline)
self.processor.set_pipeline(pipeline)
def connect(self, processor: Processor):
self.processor.connect(processor)
def disconnect(self, processor: Processor):
self.processor.disconnect(processor)
def on(self, callback):
self.processor.on(callback)
def off(self, callback):
self.processor.off(callback)
async def _push(self, data: AudioFile):
return await self.processor._push(data)
async def _flush(self):
return await self.processor._flush()

View File

@@ -385,5 +385,5 @@ class TranslationLanguages(BaseModel):
class AudioDiarizationInput(BaseModel): class AudioDiarizationInput(BaseModel):
audio_filename: Path audio_url: str
topics: list[TitleSummary] topics: list[TitleSummary]

View File

@@ -89,6 +89,10 @@ class Settings(BaseSettings):
# LLM Modal configuration # LLM Modal configuration
LLM_MODAL_API_KEY: str | None = None LLM_MODAL_API_KEY: str | None = None
# Diarization
DIARIZATION_BACKEND: str = "modal"
DIARIZATION_URL: str | None = None
# Sentry # Sentry
SENTRY_DSN: str | None = None SENTRY_DSN: str | None = None
@@ -121,5 +125,11 @@ class Settings(BaseSettings):
REDIS_HOST: str = "localhost" REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379 REDIS_PORT: int = 6379
# Secret key
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
# Current hosting/domain
BASE_URL: str = "http://localhost:1250"
settings = Settings() settings = Settings()

View File

@@ -0,0 +1,14 @@
import argparse
from reflector.app import celery_app # noqa
from reflector.pipelines.main_live_pipeline import task_pipeline_main_post
parser = argparse.ArgumentParser()
parser.add_argument("transcript_id", type=str)
parser.add_argument("--delay", action="store_true")
args = parser.parse_args()
if args.delay:
task_pipeline_main_post.delay(args.transcript_id)
else:
task_pipeline_main_post(args.transcript_id)

View File

@@ -1,4 +1,4 @@
from datetime import datetime from datetime import datetime, timedelta
from typing import Annotated, Optional from typing import Annotated, Optional
import reflector.auth as auth import reflector.auth as auth
@@ -9,8 +9,10 @@ from fastapi import (
Request, Request,
WebSocket, WebSocket,
WebSocketDisconnect, WebSocketDisconnect,
status,
) )
from fastapi_pagination import Page, paginate from fastapi_pagination import Page, paginate
from jose import jwt
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from reflector.db.transcripts import ( from reflector.db.transcripts import (
AudioWaveform, AudioWaveform,
@@ -27,6 +29,18 @@ from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter() router = APIRouter()
ALGORITHM = "HS256"
DOWNLOAD_EXPIRE_MINUTES = 60
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ============================================================== # ==============================================================
# Transcripts list # Transcripts list
# ============================================================== # ==============================================================
@@ -198,8 +212,21 @@ async def transcript_get_audio_mp3(
request: Request, request: Request,
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
token: str | None = None,
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
if not user_id and token:
unauthorized_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")

View File

@@ -4,6 +4,7 @@ from reflector.settings import settings
app = Celery(__name__) app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND app.conf.result_backend = settings.CELERY_RESULT_BACKEND
app.conf.broker_connection_retry_on_startup = True
app.autodiscover_tasks( app.autodiscover_tasks(
[ [
"reflector.pipelines.main_live_pipeline", "reflector.pipelines.main_live_pipeline",

View File

@@ -102,6 +102,7 @@ async def test_transcript_rtc_and_websocket(
print("Test websocket: DISCONNECTED") print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task()) websocket_task = asyncio.get_event_loop().create_task(websocket_task())
print("Test websocket: TASK CREATED", websocket_task)
# create stream client # create stream client
import argparse import argparse
@@ -243,6 +244,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
print("Test websocket: DISCONNECTED") print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task()) websocket_task = asyncio.get_event_loop().create_task(websocket_task())
print("Test websocket: TASK CREATED", websocket_task)
# create stream client # create stream client
import argparse import argparse