server: full diarization processor implementation based on gokul app

This commit is contained in:
2023-10-27 20:00:07 +02:00
committed by Mathieu Virbel
parent 07c4d080c2
commit d8a842f099
15 changed files with 186 additions and 110 deletions

View File

@@ -64,6 +64,9 @@ app.include_router(transcripts_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")
add_pagination(app)
# prepare celery
from reflector.worker import app as celery_app # noqa
# simpler openapi id
def use_route_names_as_operation_ids(app: FastAPI) -> None:

View File

@@ -13,10 +13,12 @@ It is directly linked to our data model.
import asyncio
from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from celery import shared_task
from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import (
Transcript,
TranscriptFinalLongSummary,
@@ -29,7 +31,7 @@ from reflector.db.transcripts import (
from reflector.pipelines.runner import PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
AudioDiarizationProcessor,
AudioDiarizationAutoProcessor,
AudioFileWriterProcessor,
AudioMergeProcessor,
AudioTranscriptAutoProcessor,
@@ -45,6 +47,7 @@ from reflector.processors import (
from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager
@@ -174,7 +177,7 @@ class PipelineMainBase(PipelineRunner):
async with self.transaction():
transcript = await self.get_transcript()
if not transcript.title:
transcripts_controller.update(
await transcripts_controller.update(
transcript,
{
"title": final_title.title,
@@ -238,19 +241,13 @@ class PipelineMainLive(PipelineMainBase):
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
AudioChunkerProcessor(),
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
AudioTranscriptAutoProcessor.get_instance().as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
TranscriptFinalLongSummaryProcessor.as_threaded(
callback=self.on_long_summary
),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
),
]
@@ -277,7 +274,7 @@ class PipelineMainDiarization(PipelineMainBase):
# add a customised logger to the context
self.prepare()
processors = [
AudioDiarizationProcessor(),
AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic),
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -307,8 +304,19 @@ class PipelineMainDiarization(PipelineMainBase):
for topic in transcript.topics
]
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from reflector.views.transcripts import create_access_token
token = create_access_token(
{"sub": transcript.user_id},
expires_delta=timedelta(minutes=15),
)
path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id)
url = f"{settings.BASE_URL}{path}?token={token}"
audio_diarization_input = AudioDiarizationInput(
audio_filename=transcript.audio_mp3_filename,
audio_url=url,
topics=topics,
)

View File

@@ -55,7 +55,11 @@ class PipelineRunner(BaseModel):
"""
Start the pipeline synchronously (for non-asyncio apps)
"""
asyncio.run(self.run())
loop = asyncio.get_event_loop()
if not loop:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self.run())
def push(self, data):
"""

View File

@@ -1,5 +1,5 @@
from .audio_chunker import AudioChunkerProcessor # noqa: F401
from .audio_diarization import AudioDiarizationProcessor # noqa: F401
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
from .audio_merge import AudioMergeProcessor # noqa: F401
from .audio_transcript import AudioTranscriptProcessor # noqa: F401

View File

@@ -1,65 +0,0 @@
from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
class AudioDiarizationProcessor(Processor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
async def _push(self, data: AudioDiarizationInput):
# Gather diarization data
diarization = [
{"start": 0.0, "stop": 4.9, "speaker": 2},
{"start": 5.6, "stop": 6.7, "speaker": 2},
{"start": 7.3, "stop": 8.9, "speaker": 2},
{"start": 7.3, "stop": 7.9, "speaker": 0},
{"start": 9.4, "stop": 11.2, "speaker": 2},
{"start": 9.7, "stop": 10.0, "speaker": 0},
{"start": 10.0, "stop": 10.1, "speaker": 0},
{"start": 11.7, "stop": 16.1, "speaker": 2},
{"start": 11.8, "stop": 12.1, "speaker": 1},
{"start": 16.4, "stop": 21.0, "speaker": 2},
{"start": 21.1, "stop": 22.6, "speaker": 2},
{"start": 24.7, "stop": 31.9, "speaker": 2},
{"start": 32.0, "stop": 32.8, "speaker": 1},
{"start": 33.4, "stop": 37.8, "speaker": 2},
{"start": 37.9, "stop": 40.3, "speaker": 0},
{"start": 39.2, "stop": 40.4, "speaker": 2},
{"start": 40.7, "stop": 41.4, "speaker": 0},
{"start": 41.6, "stop": 45.7, "speaker": 2},
{"start": 46.4, "stop": 53.1, "speaker": 2},
{"start": 53.6, "stop": 56.5, "speaker": 2},
{"start": 54.9, "stop": 75.4, "speaker": 1},
{"start": 57.3, "stop": 58.0, "speaker": 2},
{"start": 65.7, "stop": 66.0, "speaker": 2},
{"start": 75.8, "stop": 78.8, "speaker": 1},
{"start": 79.0, "stop": 82.6, "speaker": 1},
{"start": 83.2, "stop": 83.3, "speaker": 1},
{"start": 84.5, "stop": 94.3, "speaker": 1},
{"start": 95.1, "stop": 100.7, "speaker": 1},
{"start": 100.7, "stop": 102.0, "speaker": 0},
{"start": 100.7, "stop": 101.8, "speaker": 1},
{"start": 102.0, "stop": 103.0, "speaker": 1},
{"start": 103.0, "stop": 103.7, "speaker": 0},
{"start": 103.7, "stop": 103.8, "speaker": 1},
{"start": 103.8, "stop": 113.9, "speaker": 0},
{"start": 114.7, "stop": 117.0, "speaker": 0},
{"start": 117.0, "stop": 117.4, "speaker": 1},
]
# now reapply speaker to topics (if any)
# topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute
print("IN DIARIZATION PROCESSOR", data)
# mutate in place
for topic in data.topics:
for word in topic.transcript.words:
for d in diarization:
if d["start"] <= word.start <= d["stop"]:
word.speaker = d["speaker"]
# emit them
for topic in data.topics:
await self.emit(topic)

View File

@@ -0,0 +1,34 @@
import importlib
from reflector.processors.base import Processor
from reflector.settings import settings
class AudioDiarizationAutoProcessor(Processor):
_registry = {}
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
@classmethod
def get_instance(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.DIARIZATION_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.audio_diarization_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
settings_prefix = "DIARIZATION_"
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config | kwargs)

View File

@@ -0,0 +1,28 @@
from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
class AudioDiarizationBaseProcessor(Processor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
async def _push(self, data: AudioDiarizationInput):
diarization = await self._diarize(data)
# now reapply speaker to topics (if any)
# topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute
# mutate in place
for topic in data.topics:
for word in topic.transcript.words:
for d in diarization:
if d["start"] <= word.start <= d["end"]:
word.speaker = d["speaker"]
# emit them
for topic in data.topics:
await self.emit(topic)
async def _diarize(self, data: AudioDiarizationInput):
raise NotImplementedError

View File

@@ -0,0 +1,36 @@
import httpx
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor
from reflector.processors.types import AudioDiarizationInput, TitleSummary
from reflector.settings import settings
class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
self.headers = {
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
}
async def _diarize(self, data: AudioDiarizationInput):
# Gather diarization data
params = {
"audio_file_url": data.audio_url,
"timestamp": 0,
}
async with httpx.AsyncClient() as client:
response = await client.post(
self.diarization_url,
headers=self.headers,
params=params,
timeout=None,
)
response.raise_for_status()
return response.json()["text"]
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)

View File

@@ -1,8 +1,6 @@
import importlib
from reflector.processors.audio_transcript import AudioTranscriptProcessor
from reflector.processors.base import Pipeline, Processor
from reflector.processors.types import AudioFile
from reflector.settings import settings
@@ -14,7 +12,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
cls._registry[name] = kclass
@classmethod
def get_instance(cls, name):
def get_instance(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.TRANSCRIPT_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.audio_transcript_{name}"
importlib.import_module(module_name)
@@ -30,30 +30,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config)
def __init__(self, **kwargs):
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
super().__init__(**kwargs)
def set_pipeline(self, pipeline: Pipeline):
super().set_pipeline(pipeline)
self.processor.set_pipeline(pipeline)
def connect(self, processor: Processor):
self.processor.connect(processor)
def disconnect(self, processor: Processor):
self.processor.disconnect(processor)
def on(self, callback):
self.processor.on(callback)
def off(self, callback):
self.processor.off(callback)
async def _push(self, data: AudioFile):
return await self.processor._push(data)
async def _flush(self):
return await self.processor._flush()
return cls._registry[name](**config | kwargs)

View File

@@ -385,5 +385,5 @@ class TranslationLanguages(BaseModel):
class AudioDiarizationInput(BaseModel):
audio_filename: Path
audio_url: str
topics: list[TitleSummary]

View File

@@ -89,6 +89,10 @@ class Settings(BaseSettings):
# LLM Modal configuration
LLM_MODAL_API_KEY: str | None = None
# Diarization
DIARIZATION_BACKEND: str = "modal"
DIARIZATION_URL: str | None = None
# Sentry
SENTRY_DSN: str | None = None
@@ -121,5 +125,11 @@ class Settings(BaseSettings):
REDIS_HOST: str = "localhost"
REDIS_PORT: int = 6379
# Secret key
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
# Current hosting/domain
BASE_URL: str = "http://localhost:1250"
settings = Settings()

View File

@@ -0,0 +1,14 @@
import argparse
from reflector.app import celery_app # noqa
from reflector.pipelines.main_live_pipeline import task_pipeline_main_post
parser = argparse.ArgumentParser()
parser.add_argument("transcript_id", type=str)
parser.add_argument("--delay", action="store_true")
args = parser.parse_args()
if args.delay:
task_pipeline_main_post.delay(args.transcript_id)
else:
task_pipeline_main_post(args.transcript_id)

View File

@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timedelta
from typing import Annotated, Optional
import reflector.auth as auth
@@ -9,8 +9,10 @@ from fastapi import (
Request,
WebSocket,
WebSocketDisconnect,
status,
)
from fastapi_pagination import Page, paginate
from jose import jwt
from pydantic import BaseModel, Field
from reflector.db.transcripts import (
AudioWaveform,
@@ -27,6 +29,18 @@ from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter()
ALGORITHM = "HS256"
DOWNLOAD_EXPIRE_MINUTES = 60
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ==============================================================
# Transcripts list
# ==============================================================
@@ -198,8 +212,21 @@ async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
token: str | None = None,
):
user_id = user["sub"] if user else None
if not user_id and token:
unauthorized_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")

View File

@@ -4,6 +4,7 @@ from reflector.settings import settings
app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
app.conf.broker_connection_retry_on_startup = True
app.autodiscover_tasks(
[
"reflector.pipelines.main_live_pipeline",

View File

@@ -102,6 +102,7 @@ async def test_transcript_rtc_and_websocket(
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
print("Test websocket: TASK CREATED", websocket_task)
# create stream client
import argparse
@@ -243,6 +244,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
print("Test websocket: TASK CREATED", websocket_task)
# create stream client
import argparse