mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 20:29:06 +00:00
server: full diarization processor implementation based on gokul app
This commit is contained in:
@@ -64,6 +64,9 @@ app.include_router(transcripts_router, prefix="/v1")
|
|||||||
app.include_router(user_router, prefix="/v1")
|
app.include_router(user_router, prefix="/v1")
|
||||||
add_pagination(app)
|
add_pagination(app)
|
||||||
|
|
||||||
|
# prepare celery
|
||||||
|
from reflector.worker import app as celery_app # noqa
|
||||||
|
|
||||||
|
|
||||||
# simpler openapi id
|
# simpler openapi id
|
||||||
def use_route_names_as_operation_ids(app: FastAPI) -> None:
|
def use_route_names_as_operation_ids(app: FastAPI) -> None:
|
||||||
|
|||||||
@@ -13,10 +13,12 @@ It is directly linked to our data model.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import timedelta
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from reflector.app import app
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
Transcript,
|
Transcript,
|
||||||
TranscriptFinalLongSummary,
|
TranscriptFinalLongSummary,
|
||||||
@@ -29,7 +31,7 @@ from reflector.db.transcripts import (
|
|||||||
from reflector.pipelines.runner import PipelineRunner
|
from reflector.pipelines.runner import PipelineRunner
|
||||||
from reflector.processors import (
|
from reflector.processors import (
|
||||||
AudioChunkerProcessor,
|
AudioChunkerProcessor,
|
||||||
AudioDiarizationProcessor,
|
AudioDiarizationAutoProcessor,
|
||||||
AudioFileWriterProcessor,
|
AudioFileWriterProcessor,
|
||||||
AudioMergeProcessor,
|
AudioMergeProcessor,
|
||||||
AudioTranscriptAutoProcessor,
|
AudioTranscriptAutoProcessor,
|
||||||
@@ -45,6 +47,7 @@ from reflector.processors import (
|
|||||||
from reflector.processors.types import AudioDiarizationInput
|
from reflector.processors.types import AudioDiarizationInput
|
||||||
from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
|
from reflector.processors.types import TitleSummary as TitleSummaryProcessorType
|
||||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||||
|
from reflector.settings import settings
|
||||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||||
|
|
||||||
|
|
||||||
@@ -174,7 +177,7 @@ class PipelineMainBase(PipelineRunner):
|
|||||||
async with self.transaction():
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript()
|
transcript = await self.get_transcript()
|
||||||
if not transcript.title:
|
if not transcript.title:
|
||||||
transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"title": final_title.title,
|
"title": final_title.title,
|
||||||
@@ -238,19 +241,13 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
AudioFileWriterProcessor(path=transcript.audio_mp3_filename),
|
||||||
AudioChunkerProcessor(),
|
AudioChunkerProcessor(),
|
||||||
AudioMergeProcessor(),
|
AudioMergeProcessor(),
|
||||||
AudioTranscriptAutoProcessor.as_threaded(),
|
AudioTranscriptAutoProcessor.get_instance().as_threaded(),
|
||||||
TranscriptLinerProcessor(),
|
TranscriptLinerProcessor(),
|
||||||
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
|
||||||
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
|
||||||
BroadcastProcessor(
|
BroadcastProcessor(
|
||||||
processors=[
|
processors=[
|
||||||
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
|
||||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
|
||||||
callback=self.on_long_summary
|
|
||||||
),
|
|
||||||
TranscriptFinalShortSummaryProcessor.as_threaded(
|
|
||||||
callback=self.on_short_summary
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
@@ -277,7 +274,7 @@ class PipelineMainDiarization(PipelineMainBase):
|
|||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
self.prepare()
|
self.prepare()
|
||||||
processors = [
|
processors = [
|
||||||
AudioDiarizationProcessor(),
|
AudioDiarizationAutoProcessor.get_instance(callback=self.on_topic),
|
||||||
BroadcastProcessor(
|
BroadcastProcessor(
|
||||||
processors=[
|
processors=[
|
||||||
TranscriptFinalLongSummaryProcessor.as_threaded(
|
TranscriptFinalLongSummaryProcessor.as_threaded(
|
||||||
@@ -307,8 +304,19 @@ class PipelineMainDiarization(PipelineMainBase):
|
|||||||
for topic in transcript.topics
|
for topic in transcript.topics
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# we need to create an url to be used for diarization
|
||||||
|
# we can't use the audio_mp3_filename because it's not accessible
|
||||||
|
# from the diarization processor
|
||||||
|
from reflector.views.transcripts import create_access_token
|
||||||
|
|
||||||
|
token = create_access_token(
|
||||||
|
{"sub": transcript.user_id},
|
||||||
|
expires_delta=timedelta(minutes=15),
|
||||||
|
)
|
||||||
|
path = app.url_path_for("transcript_get_audio_mp3", transcript_id=transcript.id)
|
||||||
|
url = f"{settings.BASE_URL}{path}?token={token}"
|
||||||
audio_diarization_input = AudioDiarizationInput(
|
audio_diarization_input = AudioDiarizationInput(
|
||||||
audio_filename=transcript.audio_mp3_filename,
|
audio_url=url,
|
||||||
topics=topics,
|
topics=topics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,11 @@ class PipelineRunner(BaseModel):
|
|||||||
"""
|
"""
|
||||||
Start the pipeline synchronously (for non-asyncio apps)
|
Start the pipeline synchronously (for non-asyncio apps)
|
||||||
"""
|
"""
|
||||||
asyncio.run(self.run())
|
loop = asyncio.get_event_loop()
|
||||||
|
if not loop:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
loop.run_until_complete(self.run())
|
||||||
|
|
||||||
def push(self, data):
|
def push(self, data):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
from .audio_chunker import AudioChunkerProcessor # noqa: F401
|
||||||
from .audio_diarization import AudioDiarizationProcessor # noqa: F401
|
from .audio_diarization_auto import AudioDiarizationAutoProcessor # noqa: F401
|
||||||
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
from .audio_file_writer import AudioFileWriterProcessor # noqa: F401
|
||||||
from .audio_merge import AudioMergeProcessor # noqa: F401
|
from .audio_merge import AudioMergeProcessor # noqa: F401
|
||||||
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
from .audio_transcript import AudioTranscriptProcessor # noqa: F401
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
from reflector.processors.base import Processor
|
|
||||||
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
|
||||||
|
|
||||||
|
|
||||||
class AudioDiarizationProcessor(Processor):
|
|
||||||
INPUT_TYPE = AudioDiarizationInput
|
|
||||||
OUTPUT_TYPE = TitleSummary
|
|
||||||
|
|
||||||
async def _push(self, data: AudioDiarizationInput):
|
|
||||||
# Gather diarization data
|
|
||||||
diarization = [
|
|
||||||
{"start": 0.0, "stop": 4.9, "speaker": 2},
|
|
||||||
{"start": 5.6, "stop": 6.7, "speaker": 2},
|
|
||||||
{"start": 7.3, "stop": 8.9, "speaker": 2},
|
|
||||||
{"start": 7.3, "stop": 7.9, "speaker": 0},
|
|
||||||
{"start": 9.4, "stop": 11.2, "speaker": 2},
|
|
||||||
{"start": 9.7, "stop": 10.0, "speaker": 0},
|
|
||||||
{"start": 10.0, "stop": 10.1, "speaker": 0},
|
|
||||||
{"start": 11.7, "stop": 16.1, "speaker": 2},
|
|
||||||
{"start": 11.8, "stop": 12.1, "speaker": 1},
|
|
||||||
{"start": 16.4, "stop": 21.0, "speaker": 2},
|
|
||||||
{"start": 21.1, "stop": 22.6, "speaker": 2},
|
|
||||||
{"start": 24.7, "stop": 31.9, "speaker": 2},
|
|
||||||
{"start": 32.0, "stop": 32.8, "speaker": 1},
|
|
||||||
{"start": 33.4, "stop": 37.8, "speaker": 2},
|
|
||||||
{"start": 37.9, "stop": 40.3, "speaker": 0},
|
|
||||||
{"start": 39.2, "stop": 40.4, "speaker": 2},
|
|
||||||
{"start": 40.7, "stop": 41.4, "speaker": 0},
|
|
||||||
{"start": 41.6, "stop": 45.7, "speaker": 2},
|
|
||||||
{"start": 46.4, "stop": 53.1, "speaker": 2},
|
|
||||||
{"start": 53.6, "stop": 56.5, "speaker": 2},
|
|
||||||
{"start": 54.9, "stop": 75.4, "speaker": 1},
|
|
||||||
{"start": 57.3, "stop": 58.0, "speaker": 2},
|
|
||||||
{"start": 65.7, "stop": 66.0, "speaker": 2},
|
|
||||||
{"start": 75.8, "stop": 78.8, "speaker": 1},
|
|
||||||
{"start": 79.0, "stop": 82.6, "speaker": 1},
|
|
||||||
{"start": 83.2, "stop": 83.3, "speaker": 1},
|
|
||||||
{"start": 84.5, "stop": 94.3, "speaker": 1},
|
|
||||||
{"start": 95.1, "stop": 100.7, "speaker": 1},
|
|
||||||
{"start": 100.7, "stop": 102.0, "speaker": 0},
|
|
||||||
{"start": 100.7, "stop": 101.8, "speaker": 1},
|
|
||||||
{"start": 102.0, "stop": 103.0, "speaker": 1},
|
|
||||||
{"start": 103.0, "stop": 103.7, "speaker": 0},
|
|
||||||
{"start": 103.7, "stop": 103.8, "speaker": 1},
|
|
||||||
{"start": 103.8, "stop": 113.9, "speaker": 0},
|
|
||||||
{"start": 114.7, "stop": 117.0, "speaker": 0},
|
|
||||||
{"start": 117.0, "stop": 117.4, "speaker": 1},
|
|
||||||
]
|
|
||||||
|
|
||||||
# now reapply speaker to topics (if any)
|
|
||||||
# topics is a list[BaseModel] with an attribute words
|
|
||||||
# words is a list[BaseModel] with text, start and speaker attribute
|
|
||||||
|
|
||||||
print("IN DIARIZATION PROCESSOR", data)
|
|
||||||
|
|
||||||
# mutate in place
|
|
||||||
for topic in data.topics:
|
|
||||||
for word in topic.transcript.words:
|
|
||||||
for d in diarization:
|
|
||||||
if d["start"] <= word.start <= d["stop"]:
|
|
||||||
word.speaker = d["speaker"]
|
|
||||||
|
|
||||||
# emit them
|
|
||||||
for topic in data.topics:
|
|
||||||
await self.emit(topic)
|
|
||||||
34
server/reflector/processors/audio_diarization_auto.py
Normal file
34
server/reflector/processors/audio_diarization_auto.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
import importlib
|
||||||
|
|
||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDiarizationAutoProcessor(Processor):
|
||||||
|
_registry = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(cls, name, kclass):
|
||||||
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_instance(cls, name: str | None = None, **kwargs):
|
||||||
|
if name is None:
|
||||||
|
name = settings.DIARIZATION_BACKEND
|
||||||
|
|
||||||
|
if name not in cls._registry:
|
||||||
|
module_name = f"reflector.processors.audio_diarization_{name}"
|
||||||
|
importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# gather specific configuration for the processor
|
||||||
|
# search `DIARIZATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
|
||||||
|
config = {}
|
||||||
|
name_upper = name.upper()
|
||||||
|
settings_prefix = "DIARIZATION_"
|
||||||
|
config_prefix = f"{settings_prefix}{name_upper}_"
|
||||||
|
for key, value in settings:
|
||||||
|
if key.startswith(config_prefix):
|
||||||
|
config_name = key[len(settings_prefix) :].lower()
|
||||||
|
config[config_name] = value
|
||||||
|
|
||||||
|
return cls._registry[name](**config | kwargs)
|
||||||
28
server/reflector/processors/audio_diarization_base.py
Normal file
28
server/reflector/processors/audio_diarization_base.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from reflector.processors.base import Processor
|
||||||
|
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDiarizationBaseProcessor(Processor):
|
||||||
|
INPUT_TYPE = AudioDiarizationInput
|
||||||
|
OUTPUT_TYPE = TitleSummary
|
||||||
|
|
||||||
|
async def _push(self, data: AudioDiarizationInput):
|
||||||
|
diarization = await self._diarize(data)
|
||||||
|
|
||||||
|
# now reapply speaker to topics (if any)
|
||||||
|
# topics is a list[BaseModel] with an attribute words
|
||||||
|
# words is a list[BaseModel] with text, start and speaker attribute
|
||||||
|
|
||||||
|
# mutate in place
|
||||||
|
for topic in data.topics:
|
||||||
|
for word in topic.transcript.words:
|
||||||
|
for d in diarization:
|
||||||
|
if d["start"] <= word.start <= d["end"]:
|
||||||
|
word.speaker = d["speaker"]
|
||||||
|
|
||||||
|
# emit them
|
||||||
|
for topic in data.topics:
|
||||||
|
await self.emit(topic)
|
||||||
|
|
||||||
|
async def _diarize(self, data: AudioDiarizationInput):
|
||||||
|
raise NotImplementedError
|
||||||
36
server/reflector/processors/audio_diarization_modal.py
Normal file
36
server/reflector/processors/audio_diarization_modal.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
import httpx
|
||||||
|
from reflector.processors.audio_diarization_auto import AudioDiarizationAutoProcessor
|
||||||
|
from reflector.processors.audio_diarization_base import AudioDiarizationBaseProcessor
|
||||||
|
from reflector.processors.types import AudioDiarizationInput, TitleSummary
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class AudioDiarizationModalProcessor(AudioDiarizationBaseProcessor):
|
||||||
|
INPUT_TYPE = AudioDiarizationInput
|
||||||
|
OUTPUT_TYPE = TitleSummary
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
|
||||||
|
self.headers = {
|
||||||
|
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _diarize(self, data: AudioDiarizationInput):
|
||||||
|
# Gather diarization data
|
||||||
|
params = {
|
||||||
|
"audio_file_url": data.audio_url,
|
||||||
|
"timestamp": 0,
|
||||||
|
}
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.diarization_url,
|
||||||
|
headers=self.headers,
|
||||||
|
params=params,
|
||||||
|
timeout=None,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()["text"]
|
||||||
|
|
||||||
|
|
||||||
|
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
import importlib
|
import importlib
|
||||||
|
|
||||||
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
from reflector.processors.audio_transcript import AudioTranscriptProcessor
|
||||||
from reflector.processors.base import Pipeline, Processor
|
|
||||||
from reflector.processors.types import AudioFile
|
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
|
||||||
@@ -14,7 +12,9 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
|||||||
cls._registry[name] = kclass
|
cls._registry[name] = kclass
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls, name):
|
def get_instance(cls, name: str | None = None, **kwargs):
|
||||||
|
if name is None:
|
||||||
|
name = settings.TRANSCRIPT_BACKEND
|
||||||
if name not in cls._registry:
|
if name not in cls._registry:
|
||||||
module_name = f"reflector.processors.audio_transcript_{name}"
|
module_name = f"reflector.processors.audio_transcript_{name}"
|
||||||
importlib.import_module(module_name)
|
importlib.import_module(module_name)
|
||||||
@@ -30,30 +30,4 @@ class AudioTranscriptAutoProcessor(AudioTranscriptProcessor):
|
|||||||
config_name = key[len(settings_prefix) :].lower()
|
config_name = key[len(settings_prefix) :].lower()
|
||||||
config[config_name] = value
|
config[config_name] = value
|
||||||
|
|
||||||
return cls._registry[name](**config)
|
return cls._registry[name](**config | kwargs)
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
self.processor = self.get_instance(settings.TRANSCRIPT_BACKEND)
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def set_pipeline(self, pipeline: Pipeline):
|
|
||||||
super().set_pipeline(pipeline)
|
|
||||||
self.processor.set_pipeline(pipeline)
|
|
||||||
|
|
||||||
def connect(self, processor: Processor):
|
|
||||||
self.processor.connect(processor)
|
|
||||||
|
|
||||||
def disconnect(self, processor: Processor):
|
|
||||||
self.processor.disconnect(processor)
|
|
||||||
|
|
||||||
def on(self, callback):
|
|
||||||
self.processor.on(callback)
|
|
||||||
|
|
||||||
def off(self, callback):
|
|
||||||
self.processor.off(callback)
|
|
||||||
|
|
||||||
async def _push(self, data: AudioFile):
|
|
||||||
return await self.processor._push(data)
|
|
||||||
|
|
||||||
async def _flush(self):
|
|
||||||
return await self.processor._flush()
|
|
||||||
|
|||||||
@@ -385,5 +385,5 @@ class TranslationLanguages(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class AudioDiarizationInput(BaseModel):
|
class AudioDiarizationInput(BaseModel):
|
||||||
audio_filename: Path
|
audio_url: str
|
||||||
topics: list[TitleSummary]
|
topics: list[TitleSummary]
|
||||||
|
|||||||
@@ -89,6 +89,10 @@ class Settings(BaseSettings):
|
|||||||
# LLM Modal configuration
|
# LLM Modal configuration
|
||||||
LLM_MODAL_API_KEY: str | None = None
|
LLM_MODAL_API_KEY: str | None = None
|
||||||
|
|
||||||
|
# Diarization
|
||||||
|
DIARIZATION_BACKEND: str = "modal"
|
||||||
|
DIARIZATION_URL: str | None = None
|
||||||
|
|
||||||
# Sentry
|
# Sentry
|
||||||
SENTRY_DSN: str | None = None
|
SENTRY_DSN: str | None = None
|
||||||
|
|
||||||
@@ -121,5 +125,11 @@ class Settings(BaseSettings):
|
|||||||
REDIS_HOST: str = "localhost"
|
REDIS_HOST: str = "localhost"
|
||||||
REDIS_PORT: int = 6379
|
REDIS_PORT: int = 6379
|
||||||
|
|
||||||
|
# Secret key
|
||||||
|
SECRET_KEY: str = "changeme-f02f86fd8b3e4fd892c6043e5a298e21"
|
||||||
|
|
||||||
|
# Current hosting/domain
|
||||||
|
BASE_URL: str = "http://localhost:1250"
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
14
server/reflector/tools/start_post_main_live_pipeline.py
Normal file
14
server/reflector/tools/start_post_main_live_pipeline.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
from reflector.app import celery_app # noqa
|
||||||
|
from reflector.pipelines.main_live_pipeline import task_pipeline_main_post
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("transcript_id", type=str)
|
||||||
|
parser.add_argument("--delay", action="store_true")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.delay:
|
||||||
|
task_pipeline_main_post.delay(args.transcript_id)
|
||||||
|
else:
|
||||||
|
task_pipeline_main_post(args.transcript_id)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime, timedelta
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
@@ -9,8 +9,10 @@ from fastapi import (
|
|||||||
Request,
|
Request,
|
||||||
WebSocket,
|
WebSocket,
|
||||||
WebSocketDisconnect,
|
WebSocketDisconnect,
|
||||||
|
status,
|
||||||
)
|
)
|
||||||
from fastapi_pagination import Page, paginate
|
from fastapi_pagination import Page, paginate
|
||||||
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
AudioWaveform,
|
AudioWaveform,
|
||||||
@@ -27,6 +29,18 @@ from .rtc_offer import RtcOffer, rtc_offer_base
|
|||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
ALGORITHM = "HS256"
|
||||||
|
DOWNLOAD_EXPIRE_MINUTES = 60
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(data: dict, expires_delta: timedelta):
|
||||||
|
to_encode = data.copy()
|
||||||
|
expire = datetime.utcnow() + expires_delta
|
||||||
|
to_encode.update({"exp": expire})
|
||||||
|
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
# Transcripts list
|
# Transcripts list
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
@@ -198,8 +212,21 @@ async def transcript_get_audio_mp3(
|
|||||||
request: Request,
|
request: Request,
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
token: str | None = None,
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
if not user_id and token:
|
||||||
|
unauthorized_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Invalid or expired token",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||||
|
user_id: str = payload.get("sub")
|
||||||
|
except jwt.JWTError:
|
||||||
|
raise unauthorized_exception
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from reflector.settings import settings
|
|||||||
app = Celery(__name__)
|
app = Celery(__name__)
|
||||||
app.conf.broker_url = settings.CELERY_BROKER_URL
|
app.conf.broker_url = settings.CELERY_BROKER_URL
|
||||||
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
|
||||||
|
app.conf.broker_connection_retry_on_startup = True
|
||||||
app.autodiscover_tasks(
|
app.autodiscover_tasks(
|
||||||
[
|
[
|
||||||
"reflector.pipelines.main_live_pipeline",
|
"reflector.pipelines.main_live_pipeline",
|
||||||
|
|||||||
@@ -102,6 +102,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
print("Test websocket: DISCONNECTED")
|
print("Test websocket: DISCONNECTED")
|
||||||
|
|
||||||
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||||
|
print("Test websocket: TASK CREATED", websocket_task)
|
||||||
|
|
||||||
# create stream client
|
# create stream client
|
||||||
import argparse
|
import argparse
|
||||||
@@ -243,6 +244,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
|
|||||||
print("Test websocket: DISCONNECTED")
|
print("Test websocket: DISCONNECTED")
|
||||||
|
|
||||||
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||||
|
print("Test websocket: TASK CREATED", websocket_task)
|
||||||
|
|
||||||
# create stream client
|
# create stream client
|
||||||
import argparse
|
import argparse
|
||||||
|
|||||||
Reference in New Issue
Block a user