diff --git a/README.md b/README.md
index 627de235..b18264c0 100644
--- a/README.md
+++ b/README.md
@@ -23,7 +23,7 @@ It also uses https://github.com/fief-dev for authentication, and Vercel for depl
- [OpenAPI Code Generation](#openapi-code-generation)
- [Back-End](#back-end)
- [Installation](#installation-1)
- - [Start the project](#start-the-project)
+ - [Start the API/Backend](#start-the-apibackend)
- [Using docker](#using-docker)
- [Using local GPT4All](#using-local-gpt4all)
- [Using local files](#using-local-files)
@@ -133,15 +133,15 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
```
-### Start the project
+### Start the API/Backend
-Use:
+Start the API server:
```bash
poetry run python3 -m reflector.app
```
-And start the background worker
+Start the background worker:
```bash
celery -A reflector.worker.app worker --loglevel=info
@@ -153,6 +153,12 @@ Redis:
TODO
```
+For crontab (only healthcheck for now), start the celery beat (you don't need it on your local dev environment):
+
+```bash
+celery -A reflector.worker.app beat
+```
+
#### Using docker
Use:
diff --git a/server/gpu/modal/reflector_diarizer.py b/server/gpu/modal/reflector_diarizer.py
new file mode 100644
index 00000000..b1989a11
--- /dev/null
+++ b/server/gpu/modal/reflector_diarizer.py
@@ -0,0 +1,188 @@
+"""
+Reflector GPU backend - diarizer
+===================================
+"""
+
+import os
+
+import modal.gpu
+from modal import Image, Secret, Stub, asgi_app, method
+from pydantic import BaseModel
+
+PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.0"
+MODEL_DIR = "/root/diarization_models"
+
+stub = Stub(name="reflector-diarizer")
+
+
+def migrate_cache_llm():
+ """
+ XXX The cache for model files in Transformers v4.22.0 has been updated.
+ Migrating your old cache. This is a one-time only operation. You can
+ interrupt this and resume the migration later on by calling
+ `transformers.utils.move_cache()`.
+ """
+ from transformers.utils.hub import move_cache
+
+ print("Moving LLM cache")
+ move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
+ print("LLM cache moved")
+
+
+def download_pyannote_audio():
+ from pyannote.audio import Pipeline
+ Pipeline.from_pretrained(
+ "pyannote/speaker-diarization-3.0",
+ cache_dir=MODEL_DIR,
+ use_auth_token="***REMOVED***"
+ )
+
+
+diarizer_image = (
+ Image.debian_slim(python_version="3.10.8")
+ .pip_install(
+ "pyannote.audio",
+ "requests",
+ "onnx",
+ "torchaudio",
+ "onnxruntime-gpu",
+ "torch==2.0.0",
+ "transformers==4.34.0",
+ "sentencepiece",
+ "protobuf",
+ "numpy",
+ "huggingface_hub",
+ "hf-transfer"
+ )
+ .run_function(migrate_cache_llm)
+ .run_function(download_pyannote_audio)
+ .env(
+ {
+ "LD_LIBRARY_PATH": (
+ "/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
+ "/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
+ )
+ }
+ )
+)
+
+
+@stub.cls(
+ gpu=modal.gpu.A100(memory=40),
+ timeout=60 * 30,
+ container_idle_timeout=60,
+ allow_concurrent_inputs=1,
+ image=diarizer_image,
+)
+class Diarizer:
+ def __enter__(self):
+ import torch
+ from pyannote.audio import Pipeline
+
+ self.use_gpu = torch.cuda.is_available()
+ self.device = "cuda" if self.use_gpu else "cpu"
+ self.diarization_pipeline = Pipeline.from_pretrained(
+ "pyannote/speaker-diarization-3.0",
+ cache_dir=MODEL_DIR
+ )
+ self.diarization_pipeline.to(torch.device(self.device))
+
+ @method()
+ def diarize(
+ self,
+ audio_data: str,
+ audio_suffix: str,
+ timestamp: float
+ ):
+ import tempfile
+
+ import torchaudio
+
+ with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
+ fp.write(audio_data)
+
+ print("Diarizing audio")
+ waveform, sample_rate = torchaudio.load(fp.name)
+ diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate})
+
+ words = []
+ for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
+ words.append(
+ {
+ "start": round(timestamp + diarization_segment.start, 3),
+ "end": round(timestamp + diarization_segment.end, 3),
+ "speaker": int(speaker[-2:])
+ }
+ )
+ print("Diarization complete")
+ return {
+ "diarization": words
+ }
+
+# -------------------------------------------------------------------
+# Web API
+# -------------------------------------------------------------------
+
+
+@stub.function(
+ timeout=60 * 10,
+ container_idle_timeout=60 * 3,
+ allow_concurrent_inputs=40,
+ secrets=[
+ Secret.from_name("reflector-gpu"),
+ ],
+ image=diarizer_image
+)
+@asgi_app()
+def web():
+ import requests
+ from fastapi import Depends, FastAPI, HTTPException, status
+ from fastapi.security import OAuth2PasswordBearer
+
+ diarizerstub = Diarizer()
+
+ app = FastAPI()
+
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+ def apikey_auth(apikey: str = Depends(oauth2_scheme)):
+ if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid API key",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+ def validate_audio_file(audio_file_url: str):
+ # Check if the audio file exists
+ response = requests.head(audio_file_url, allow_redirects=True)
+ if response.status_code == 404:
+ raise HTTPException(
+ status_code=response.status_code,
+ detail="The audio file does not exist."
+ )
+
+ class DiarizationResponse(BaseModel):
+ result: dict
+
+ @app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)])
+ def diarize(
+ audio_file_url: str,
+ timestamp: float = 0.0
+ ) -> HTTPException | DiarizationResponse:
+ # Currently the uploaded files are in mp3 format
+ audio_suffix = "mp3"
+
+ print("Downloading audio file")
+ response = requests.get(audio_file_url, allow_redirects=True)
+ print("Audio file downloaded successfully")
+
+ func = diarizerstub.diarize.spawn(
+ audio_data=response.content,
+ audio_suffix=audio_suffix,
+ timestamp=timestamp
+ )
+ result = func.get()
+ return result
+
+ return app
diff --git a/server/migrations/versions/0fea6d96b096_add_share_mode.py b/server/migrations/versions/0fea6d96b096_add_share_mode.py
new file mode 100644
index 00000000..48746c3b
--- /dev/null
+++ b/server/migrations/versions/0fea6d96b096_add_share_mode.py
@@ -0,0 +1,33 @@
+"""add share_mode
+
+Revision ID: 0fea6d96b096
+Revises: f819277e5169
+Create Date: 2023-11-07 11:12:21.614198
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "0fea6d96b096"
+down_revision: Union[str, None] = "f819277e5169"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column(
+ "transcript",
+ sa.Column("share_mode", sa.String(), server_default="private", nullable=False),
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column("transcript", "share_mode")
+ # ### end Alembic commands ###
diff --git a/server/migrations/versions/125031f7cb78_participants.py b/server/migrations/versions/125031f7cb78_participants.py
new file mode 100644
index 00000000..c345b083
--- /dev/null
+++ b/server/migrations/versions/125031f7cb78_participants.py
@@ -0,0 +1,30 @@
+"""participants
+
+Revision ID: 125031f7cb78
+Revises: 0fea6d96b096
+Create Date: 2023-11-30 15:56:03.341466
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = '125031f7cb78'
+down_revision: Union[str, None] = '0fea6d96b096'
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column('transcript', sa.Column('participants', sa.JSON(), nullable=True))
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column('transcript', 'participants')
+ # ### end Alembic commands ###
diff --git a/server/migrations/versions/f819277e5169_audio_location.py b/server/migrations/versions/f819277e5169_audio_location.py
new file mode 100644
index 00000000..061abec4
--- /dev/null
+++ b/server/migrations/versions/f819277e5169_audio_location.py
@@ -0,0 +1,35 @@
+"""audio_location
+
+Revision ID: f819277e5169
+Revises: 4814901632bc
+Create Date: 2023-11-16 10:29:09.351664
+
+"""
+from typing import Sequence, Union
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision: str = "f819277e5169"
+down_revision: Union[str, None] = "4814901632bc"
+branch_labels: Union[str, Sequence[str], None] = None
+depends_on: Union[str, Sequence[str], None] = None
+
+
+def upgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.add_column(
+ "transcript",
+ sa.Column(
+ "audio_location", sa.String(), server_default="local", nullable=False
+ ),
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade() -> None:
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_column("transcript", "audio_location")
+ # ### end Alembic commands ###
diff --git a/server/reflector/app.py b/server/reflector/app.py
index 5bfffeca..8f45efd5 100644
--- a/server/reflector/app.py
+++ b/server/reflector/app.py
@@ -13,6 +13,12 @@ from reflector.metrics import metrics_init
from reflector.settings import settings
from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router
+from reflector.views.transcripts_audio import router as transcripts_audio_router
+from reflector.views.transcripts_participants import (
+ router as transcripts_participants_router,
+)
+from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
+from reflector.views.transcripts_websocket import router as transcripts_websocket_router
from reflector.views.user import router as user_router
try:
@@ -60,6 +66,10 @@ metrics_init(app, instrumentator)
# register views
app.include_router(rtc_offer_router)
app.include_router(transcripts_router, prefix="/v1")
+app.include_router(transcripts_audio_router, prefix="/v1")
+app.include_router(transcripts_participants_router, prefix="/v1")
+app.include_router(transcripts_websocket_router, prefix="/v1")
+app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")
add_pagination(app)
diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py
index 6ac2e32a..970393d5 100644
--- a/server/reflector/db/transcripts.py
+++ b/server/reflector/db/transcripts.py
@@ -2,15 +2,16 @@ import json
from contextlib import asynccontextmanager
from datetime import datetime
from pathlib import Path
-from typing import Any
+from typing import Any, Literal
from uuid import uuid4
import sqlalchemy
-from pydantic import BaseModel, Field
+from fastapi import HTTPException
+from pydantic import BaseModel, ConfigDict, Field
from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
-from reflector.utils.audio_waveform import get_audio_waveform
+from reflector.storage import Storage
transcripts = sqlalchemy.Table(
"transcript",
@@ -26,22 +27,42 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
+ sqlalchemy.Column("participants", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
+ sqlalchemy.Column(
+ "audio_location",
+ sqlalchemy.String,
+ nullable=False,
+ server_default="local",
+ ),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
+ sqlalchemy.Column(
+ "share_mode",
+ sqlalchemy.String,
+ nullable=False,
+ server_default="private",
+ ),
)
-def generate_uuid4():
+def generate_uuid4() -> str:
return str(uuid4())
-def generate_transcript_name():
+def generate_transcript_name() -> str:
now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
+def get_storage() -> Storage:
+ return Storage.get_instance(
+ name=settings.TRANSCRIPT_STORAGE_BACKEND,
+ settings_prefix="TRANSCRIPT_STORAGE_",
+ )
+
+
class AudioWaveform(BaseModel):
data: list[float]
@@ -79,11 +100,26 @@ class TranscriptFinalTitle(BaseModel):
title: str
+class TranscriptDuration(BaseModel):
+ duration: float
+
+
+class TranscriptWaveform(BaseModel):
+ waveform: list[float]
+
+
class TranscriptEvent(BaseModel):
event: str
data: dict
+class TranscriptParticipant(BaseModel):
+ model_config = ConfigDict(from_attributes=True)
+ id: str = Field(default_factory=generate_uuid4)
+ speaker: int | None
+ name: str
+
+
class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
@@ -97,8 +133,11 @@ class Transcript(BaseModel):
long_summary: str | None = None
topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = []
+ participants: list[TranscriptParticipant] | None = []
source_language: str = "en"
target_language: str = "en"
+ share_mode: Literal["private", "semi-private", "public"] = "private"
+ audio_location: str = "local"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -112,27 +151,33 @@ class Transcript(BaseModel):
else:
self.topics.append(topic)
+ def upsert_participant(self, participant: TranscriptParticipant):
+ index = next(
+ (i for i, p in enumerate(self.participants) if p.id == participant.id),
+ None,
+ )
+ if index is not None:
+ self.participants[index] = participant
+ else:
+ self.participants.append(participant)
+ return participant
+
+ def delete_participant(self, participant_id: str):
+ index = next(
+ (i for i, p in enumerate(self.participants) if p.id == participant_id),
+ None,
+ )
+ if index is not None:
+ del self.participants[index]
+
def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics]
- def convert_audio_to_waveform(self, segments_count=256):
- fn = self.audio_waveform_filename
- if fn.exists():
- return
- waveform = get_audio_waveform(
- path=self.audio_mp3_filename, segments_count=segments_count
- )
- try:
- with open(fn, "w") as fd:
- json.dump(waveform, fd)
- except Exception:
- # remove file if anything happen during the write
- fn.unlink(missing_ok=True)
- raise
- return waveform
+ def participants_dump(self, mode="json"):
+ return [participant.model_dump(mode=mode) for participant in self.participants]
def unlink(self):
self.data_path.unlink(missing_ok=True)
@@ -141,6 +186,10 @@ class Transcript(BaseModel):
def data_path(self):
return Path(settings.DATA_DIR) / self.id
+ @property
+ def audio_wav_filename(self):
+ return self.data_path / "audio.wav"
+
@property
def audio_mp3_filename(self):
return self.data_path / "audio.mp3"
@@ -149,6 +198,10 @@ class Transcript(BaseModel):
def audio_waveform_filename(self):
return self.data_path / "audio.json"
+ @property
+ def storage_audio_path(self):
+ return f"{self.id}/audio.mp3"
+
@property
def audio_waveform(self):
try:
@@ -161,6 +214,40 @@ class Transcript(BaseModel):
return AudioWaveform(data=data)
+ async def get_audio_url(self) -> str:
+ if self.audio_location == "local":
+ return self._generate_local_audio_link()
+ elif self.audio_location == "storage":
+ return await self._generate_storage_audio_link()
+ raise Exception(f"Unknown audio location {self.audio_location}")
+
+ async def _generate_storage_audio_link(self) -> str:
+ return await get_storage().get_file_url(self.storage_audio_path)
+
+ def _generate_local_audio_link(self) -> str:
+ # we need to create an url to be used for diarization
+ # we can't use the audio_mp3_filename because it's not accessible
+ # from the diarization processor
+ from datetime import timedelta
+
+ from reflector.app import app
+ from reflector.views.transcripts import create_access_token
+
+ path = app.url_path_for(
+ "transcript_get_audio_mp3",
+ transcript_id=self.id,
+ )
+ url = f"{settings.BASE_URL}{path}"
+ if self.user_id:
+ # we pass token only if the user_id is set
+ # otherwise, the audio is public
+ token = create_access_token(
+ {"sub": self.user_id},
+ expires_delta=timedelta(minutes=15),
+ )
+ url += f"?token={token}"
+ return url
+
class TranscriptController:
async def get_all(
@@ -169,6 +256,7 @@ class TranscriptController:
order_by: str | None = None,
filter_empty: bool | None = False,
filter_recording: bool | None = False,
+ return_query: bool = False,
) -> list[Transcript]:
"""
Get all transcripts
@@ -195,6 +283,9 @@ class TranscriptController:
if filter_recording:
query = query.filter(transcripts.c.status != "recording")
+ if return_query:
+ return query
+
results = await database.fetch_all(query)
return results
@@ -210,6 +301,47 @@ class TranscriptController:
return None
return Transcript(**result)
+ async def get_by_id_for_http(
+ self,
+ transcript_id: str,
+ user_id: str | None,
+ ) -> Transcript:
+ """
+ Get a transcript by ID for HTTP request.
+
+ If not found, it will raise a 404 error.
+ If the user is not allowed to access the transcript, it will raise a 403 error.
+
+ This method checks the share mode of the transcript and the user_id
+ to determine if the user can access the transcript.
+ """
+ query = transcripts.select().where(transcripts.c.id == transcript_id)
+ result = await database.fetch_one(query)
+ if not result:
+ raise HTTPException(status_code=404, detail="Transcript not found")
+
+ # if the transcript is anonymous, share mode is not checked
+ transcript = Transcript(**result)
+ if transcript.user_id is None:
+ return transcript
+
+ if transcript.share_mode == "private":
+ # in private mode, only the owner can access the transcript
+ if transcript.user_id == user_id:
+ return transcript
+
+ elif transcript.share_mode == "semi-private":
+ # in semi-private mode, only the owner and the users with the link
+ # can access the transcript
+ if user_id is not None:
+ return transcript
+
+ elif transcript.share_mode == "public":
+ # in public mode, everyone can access the transcript
+ return transcript
+
+ raise HTTPException(status_code=403, detail="Transcript access denied")
+
async def add(
self,
name: str,
@@ -292,5 +424,45 @@ class TranscriptController:
transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()})
+ async def move_mp3_to_storage(self, transcript: Transcript):
+ """
+ Move mp3 file to storage
+ """
+
+ # store the audio on external storage
+ await get_storage().put_file(
+ transcript.storage_audio_path,
+ transcript.audio_mp3_filename.read_bytes(),
+ )
+
+ # indicate on the transcript that the audio is now on storage
+ await self.update(transcript, {"audio_location": "storage"})
+
+ # unlink the local file
+ transcript.audio_mp3_filename.unlink(missing_ok=True)
+
+ async def upsert_participant(
+ self,
+ transcript: Transcript,
+ participant: TranscriptParticipant,
+ ) -> TranscriptParticipant:
+ """
+ Add/update a participant to a transcript
+ """
+ result = transcript.upsert_participant(participant)
+ await self.update(transcript, {"participants": transcript.participants_dump()})
+ return result
+
+ async def delete_participant(
+ self,
+ transcript: Transcript,
+ participant_id: str,
+ ):
+ """
+ Delete a participant from a transcript
+ """
+ transcript.delete_participant(participant_id)
+ await self.update(transcript, {"participants": transcript.participants_dump()})
+
transcripts_controller = TranscriptController()
diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py
index 316ecbcc..b182f421 100644
--- a/server/reflector/pipelines/main_live_pipeline.py
+++ b/server/reflector/pipelines/main_live_pipeline.py
@@ -12,20 +12,20 @@ It is directly linked to our data model.
"""
import asyncio
+import functools
from contextlib import asynccontextmanager
-from datetime import timedelta
-from pathlib import Path
-from celery import shared_task
+from celery import chord, group, shared_task
from pydantic import BaseModel
-from reflector.app import app
from reflector.db.transcripts import (
Transcript,
+ TranscriptDuration,
TranscriptFinalLongSummary,
TranscriptFinalShortSummary,
TranscriptFinalTitle,
TranscriptText,
TranscriptTopic,
+ TranscriptWaveform,
transcripts_controller,
)
from reflector.logger import logger
@@ -45,6 +45,7 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
)
+from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import (
TitleSummaryWithId as TitleSummaryWithIdProcessorType,
@@ -52,6 +53,22 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager
+from structlog import BoundLogger as Logger
+
+
+def asynctask(f):
+ @functools.wraps(f)
+ def wrapper(*args, **kwargs):
+ coro = f(*args, **kwargs)
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = None
+ if loop and loop.is_running():
+ return loop.run_until_complete(coro)
+ return asyncio.run(coro)
+
+ return wrapper
def broadcast_to_sockets(func):
@@ -72,6 +89,26 @@ def broadcast_to_sockets(func):
return wrapper
+def get_transcript(func):
+ """
+ Decorator to fetch the transcript from the database from the first argument
+ """
+
+ async def wrapper(**kwargs):
+ transcript_id = kwargs.pop("transcript_id")
+ transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
+ if not transcript:
+ raise Exception("Transcript {transcript_id} not found")
+ tlogger = logger.bind(transcript_id=transcript.id)
+ try:
+ return await func(transcript=transcript, logger=tlogger, **kwargs)
+ except Exception as exc:
+ tlogger.error("Pipeline error", exc_info=exc)
+ raise
+
+ return wrapper
+
+
class StrValue(BaseModel):
value: str
@@ -96,6 +133,19 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found")
return result
+ def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
+ return [
+ TitleSummaryWithIdProcessorType(
+ id=topic.id,
+ title=topic.title,
+ summary=topic.summary,
+ timestamp=topic.timestamp,
+ duration=topic.duration,
+ transcript=TranscriptProcessorType(words=topic.words),
+ )
+ for topic in transcript.topics
+ ]
+
@asynccontextmanager
async def transaction(self):
async with self._lock:
@@ -113,7 +163,7 @@ class PipelineMainBase(PipelineRunner):
"flush": "processing",
"error": "error",
}
- elif isinstance(self, PipelineMainDiarization):
+ elif isinstance(self, PipelineMainFinalSummaries):
status_mapping = {
"push": "processing",
"flush": "processing",
@@ -121,7 +171,8 @@ class PipelineMainBase(PipelineRunner):
"ended": "ended",
}
else:
- raise Exception(f"Runner {self.__class__} is missing status mapping")
+ # intermediate pipeline don't update status
+ return
# mutate to model status
status = status_mapping.get(status)
@@ -230,21 +281,39 @@ class PipelineMainBase(PipelineRunner):
data=final_short_summary,
)
- async def on_duration(self, duration: float):
+ @broadcast_to_sockets
+ async def on_duration(self, data):
async with self.transaction():
+ duration = TranscriptDuration(duration=data)
+
transcript = await self.get_transcript()
await transcripts_controller.update(
transcript,
{
- "duration": duration,
+ "duration": duration.duration,
},
)
+ return await transcripts_controller.append_event(
+ transcript=transcript, event="DURATION", data=duration
+ )
+
+ @broadcast_to_sockets
+ async def on_waveform(self, data):
+ async with self.transaction():
+ waveform = TranscriptWaveform(waveform=data)
+
+ transcript = await self.get_transcript()
+
+ return await transcripts_controller.append_event(
+ transcript=transcript, event="WAVEFORM", data=waveform
+ )
class PipelineMainLive(PipelineMainBase):
- audio_filename: Path | None = None
- source_language: str = "en"
- target_language: str = "en"
+ """
+ Main pipeline for live streaming, attach to RTC connection
+ Any long post process should be done in the post pipeline
+ """
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
@@ -254,7 +323,7 @@ class PipelineMainLive(PipelineMainBase):
processors = [
AudioFileWriterProcessor(
- path=transcript.audio_mp3_filename,
+ path=transcript.audio_wav_filename,
on_duration=self.on_duration,
),
AudioChunkerProcessor(),
@@ -263,17 +332,13 @@ class PipelineMainLive(PipelineMainBase):
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
- TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
]
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id)
- pipeline.logger.info(
- "Pipeline main live created",
- transcript_id=self.transcript_id,
- )
+ pipeline.logger.info("Pipeline main live created")
return pipeline
@@ -281,26 +346,106 @@ class PipelineMainLive(PipelineMainBase):
# when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
- task_pipeline_main_post.delay(transcript_id=self.transcript_id)
+ pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase):
"""
- Diarization is a long time process, so we do it in a separate pipeline
- When done, adjust the short and final summary
+ Diarize the audio and update topics
"""
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
self.prepare()
- processors = []
- if settings.DIARIZATION_ENABLED:
- processors += [
- AudioDiarizationAutoProcessor(callback=self.on_topic),
- ]
+ pipeline = Pipeline(
+ AudioDiarizationAutoProcessor(callback=self.on_topic),
+ )
+ pipeline.options = self
- processors += [
+ # now let's start the pipeline by pushing information to the
+ # first processor diarization processor
+ # XXX translation is lost when converting our data model to the processor model
+ transcript = await self.get_transcript()
+
+ # diarization works only if the file is uploaded to an external storage
+ if transcript.audio_location == "local":
+ pipeline.logger.info("Audio is local, skipping diarization")
+ return
+
+ topics = self.get_transcript_topics(transcript)
+ audio_url = await transcript.get_audio_url()
+ audio_diarization_input = AudioDiarizationInput(
+ audio_url=audio_url,
+ topics=topics,
+ )
+
+ # as tempting to use pipeline.push, prefer to use the runner
+ # to let the start just do one job.
+ pipeline.logger.bind(transcript_id=transcript.id)
+ pipeline.logger.info("Diarization pipeline created")
+ self.push(audio_diarization_input)
+ self.flush()
+
+ return pipeline
+
+
+class PipelineMainFromTopics(PipelineMainBase):
+ """
+ Pseudo class for generating a pipeline from topics
+ """
+
+ def get_processors(self) -> list:
+ raise NotImplementedError
+
+ async def create(self) -> Pipeline:
+ self.prepare()
+
+ # get transcript
+ self._transcript = transcript = await self.get_transcript()
+
+ # create pipeline
+ processors = self.get_processors()
+ pipeline = Pipeline(*processors)
+ pipeline.options = self
+ pipeline.logger.bind(transcript_id=transcript.id)
+ pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
+
+ # push topics
+ topics = self.get_transcript_topics(transcript)
+ for topic in topics:
+ self.push(topic)
+
+ self.flush()
+
+ return pipeline
+
+
+class PipelineMainTitleAndShortSummary(PipelineMainFromTopics):
+ """
+ Generate title from the topics
+ """
+
+ def get_processors(self) -> list:
+ return [
+ BroadcastProcessor(
+ processors=[
+ TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
+ TranscriptFinalShortSummaryProcessor.as_threaded(
+ callback=self.on_short_summary
+ ),
+ ]
+ )
+ ]
+
+
+class PipelineMainFinalSummaries(PipelineMainFromTopics):
+ """
+ Generate summaries from the topics
+ """
+
+ def get_processors(self) -> list:
+ return [
BroadcastProcessor(
processors=[
TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -312,65 +457,164 @@ class PipelineMainDiarization(PipelineMainBase):
]
),
]
- pipeline = Pipeline(*processors)
- pipeline.options = self
- # now let's start the pipeline by pushing information to the
- # first processor diarization processor
- # XXX translation is lost when converting our data model to the processor model
- transcript = await self.get_transcript()
- topics = [
- TitleSummaryWithIdProcessorType(
- id=topic.id,
- title=topic.title,
- summary=topic.summary,
- timestamp=topic.timestamp,
- duration=topic.duration,
- transcript=TranscriptProcessorType(words=topic.words),
- )
- for topic in transcript.topics
+
+class PipelineMainWaveform(PipelineMainFromTopics):
+ """
+ Generate waveform
+ """
+
+ def get_processors(self) -> list:
+ return [
+ AudioWaveformProcessor.as_threaded(
+ audio_path=self._transcript.audio_wav_filename,
+ waveform_path=self._transcript.audio_waveform_filename,
+ on_waveform=self.on_waveform,
+ ),
]
- # we need to create an url to be used for diarization
- # we can't use the audio_mp3_filename because it's not accessible
- # from the diarization processor
- from reflector.views.transcripts import create_access_token
- path = app.url_path_for(
- "transcript_get_audio_mp3",
- transcript_id=transcript.id,
- )
- url = f"{settings.BASE_URL}{path}"
- if transcript.user_id:
- # we pass token only if the user_id is set
- # otherwise, the audio is public
- token = create_access_token(
- {"sub": transcript.user_id},
- expires_delta=timedelta(minutes=15),
- )
- url += f"?token={token}"
- audio_diarization_input = AudioDiarizationInput(
- audio_url=url,
- topics=topics,
- )
+@get_transcript
+async def pipeline_waveform(transcript: Transcript, logger: Logger):
+ logger.info("Starting waveform")
+ runner = PipelineMainWaveform(transcript_id=transcript.id)
+ await runner.run()
+ logger.info("Waveform done")
- # as tempting to use pipeline.push, prefer to use the runner
- # to let the start just do one job.
- pipeline.logger.bind(transcript_id=transcript.id)
- pipeline.logger.info(
- "Pipeline main post created", transcript_id=self.transcript_id
- )
- self.push(audio_diarization_input)
- self.flush()
- return pipeline
+@get_transcript
+async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
+ logger.info("Starting convert to mp3")
+
+ # If the audio wav is not available, just skip
+ wav_filename = transcript.audio_wav_filename
+ if not wav_filename.exists():
+ logger.warning("Wav file not found, may be already converted")
+ return
+
+ # Convert to mp3
+ mp3_filename = transcript.audio_mp3_filename
+
+ import av
+
+ with av.open(wav_filename.as_posix()) as in_container:
+ in_stream = in_container.streams.audio[0]
+ with av.open(mp3_filename.as_posix(), "w") as out_container:
+ out_stream = out_container.add_stream("mp3")
+ for frame in in_container.decode(in_stream):
+ for packet in out_stream.encode(frame):
+ out_container.mux(packet)
+
+ # Delete the wav file
+ transcript.audio_wav_filename.unlink(missing_ok=True)
+
+ logger.info("Convert to mp3 done")
+
+
+@get_transcript
+async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
+ if not settings.TRANSCRIPT_STORAGE_BACKEND:
+ logger.info("No storage backend configured, skipping mp3 upload")
+ return
+
+ logger.info("Starting upload mp3")
+
+ # If the audio mp3 is not available, just skip
+ mp3_filename = transcript.audio_mp3_filename
+ if not mp3_filename.exists():
+ logger.warning("Mp3 file not found, may be already uploaded")
+ return
+
+ # Upload to external storage and delete the file
+ await transcripts_controller.move_mp3_to_storage(transcript)
+
+ logger.info("Upload mp3 done")
+
+
+@get_transcript
+async def pipeline_diarization(transcript: Transcript, logger: Logger):
+ logger.info("Starting diarization")
+ runner = PipelineMainDiarization(transcript_id=transcript.id)
+ await runner.run()
+ logger.info("Diarization done")
+
+
+@get_transcript
+async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger):
+ logger.info("Starting title and short summary")
+ runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id)
+ await runner.run()
+ logger.info("Title and short summary done")
+
+
+@get_transcript
+async def pipeline_summaries(transcript: Transcript, logger: Logger):
+ logger.info("Starting summaries")
+ runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
+ await runner.run()
+ logger.info("Summaries done")
+
+
+# ===================================================================
+# Celery tasks that can be called from the API
+# ===================================================================
@shared_task
-def task_pipeline_main_post(transcript_id: str):
- logger.info(
- "Starting main post pipeline",
- transcript_id=transcript_id,
+@asynctask
+async def task_pipeline_waveform(*, transcript_id: str):
+ await pipeline_waveform(transcript_id=transcript_id)
+
+
+@shared_task
+@asynctask
+async def task_pipeline_convert_to_mp3(*, transcript_id: str):
+ await pipeline_convert_to_mp3(transcript_id=transcript_id)
+
+
+@shared_task
+@asynctask
+async def task_pipeline_upload_mp3(*, transcript_id: str):
+ await pipeline_upload_mp3(transcript_id=transcript_id)
+
+
+@shared_task
+@asynctask
+async def task_pipeline_diarization(*, transcript_id: str):
+ await pipeline_diarization(transcript_id=transcript_id)
+
+
+@shared_task
+@asynctask
+async def task_pipeline_title_and_short_summary(*, transcript_id: str):
+ await pipeline_title_and_short_summary(transcript_id=transcript_id)
+
+
+@shared_task
+@asynctask
+async def task_pipeline_final_summaries(*, transcript_id: str):
+ await pipeline_summaries(transcript_id=transcript_id)
+
+
+def pipeline_post(*, transcript_id: str):
+ """
+ Run the post pipeline
+ """
+ chain_mp3_and_diarize = (
+ task_pipeline_waveform.si(transcript_id=transcript_id)
+ | task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
+ | task_pipeline_upload_mp3.si(transcript_id=transcript_id)
+ | task_pipeline_diarization.si(transcript_id=transcript_id)
)
- runner = PipelineMainDiarization(transcript_id=transcript_id)
- runner.start_sync()
+ chain_title_preview = task_pipeline_title_and_short_summary.si(
+ transcript_id=transcript_id
+ )
+ chain_final_summaries = task_pipeline_final_summaries.si(
+ transcript_id=transcript_id
+ )
+
+ chain = chord(
+ group(chain_mp3_and_diarize, chain_title_preview),
+ chain_final_summaries,
+ )
+ chain.delay()
diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py
index a1e137a7..708a4265 100644
--- a/server/reflector/pipelines/runner.py
+++ b/server/reflector/pipelines/runner.py
@@ -106,6 +106,14 @@ class PipelineRunner(BaseModel):
if not self.pipeline:
self.pipeline = await self.create()
+ if not self.pipeline:
+ # no pipeline created in create, just finish it then.
+ await self._set_status("ended")
+ self._ev_done.set()
+ if self.on_ended:
+ await self.on_ended()
+ return
+
# start the loop
await self._set_status("started")
while not self._ev_done.is_set():
@@ -119,8 +127,7 @@ class PipelineRunner(BaseModel):
self._logger.exception("Runner error")
await self._set_status("error")
self._ev_done.set()
- if self.on_ended:
- await self.on_ended()
+ raise
async def cmd_push(self, data):
if self._is_first_push:
diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py
index 82c6a553..69eab5b7 100644
--- a/server/reflector/processors/audio_diarization.py
+++ b/server/reflector/processors/audio_diarization.py
@@ -1,5 +1,5 @@
from reflector.processors.base import Processor
-from reflector.processors.types import AudioDiarizationInput, TitleSummary
+from reflector.processors.types import AudioDiarizationInput, TitleSummary, Word
class AudioDiarizationProcessor(Processor):
@@ -19,12 +19,12 @@ class AudioDiarizationProcessor(Processor):
# topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute
- # mutate in place
- for topic in data.topics:
- for word in topic.transcript.words:
- for d in diarization:
- if d["start"] <= word.start <= d["end"]:
- word.speaker = d["speaker"]
+ # create a view of words based on topics
+ # the current algorithm is using words index, we cannot use a generator
+ words = list(self.iter_words_from_topics(data.topics))
+
+ # assign speaker to words (mutate the words list)
+ self.assign_speaker(words, diarization)
# emit them
for topic in data.topics:
@@ -32,3 +32,150 @@ class AudioDiarizationProcessor(Processor):
async def _diarize(self, data: AudioDiarizationInput):
raise NotImplementedError
+
+ def assign_speaker(self, words: list[Word], diarization: list[dict]):
+ self._diarization_remove_overlap(diarization)
+ self._diarization_remove_segment_without_words(words, diarization)
+ self._diarization_merge_same_speaker(words, diarization)
+ self._diarization_assign_speaker(words, diarization)
+
+ def iter_words_from_topics(self, topics: TitleSummary):
+ for topic in topics:
+ for word in topic.transcript.words:
+ yield word
+
+ def is_word_continuation(self, word_prev, word):
+ """
+ Return True if the word is a continuation of the previous word
+ by checking if the previous word is ending with a punctuation
+ or if the current word is starting with a capital letter
+ """
+ # is word_prev ending with a punctuation ?
+ if word_prev.text and word_prev.text[-1] in ".?!":
+ return False
+ elif word.text and word.text[0].isupper():
+ return False
+ return True
+
+ def _diarization_remove_overlap(self, diarization: list[dict]):
+ """
+ Remove overlap in diarization results
+
+ When using a diarization algorithm, it's possible to have overlapping segments
+ This function remove the overlap by keeping the longest segment
+
+ Warning: this function mutate the diarization list
+ """
+ # remove overlap by keeping the longest segment
+ diarization_idx = 0
+ while diarization_idx < len(diarization) - 1:
+ d = diarization[diarization_idx]
+ dnext = diarization[diarization_idx + 1]
+ if d["end"] > dnext["start"]:
+ # remove the shortest segment
+ if d["end"] - d["start"] > dnext["end"] - dnext["start"]:
+ # remove next segment
+ diarization.pop(diarization_idx + 1)
+ else:
+ # remove current segment
+ diarization.pop(diarization_idx)
+ else:
+ diarization_idx += 1
+
+ def _diarization_remove_segment_without_words(
+ self, words: list[Word], diarization: list[dict]
+ ):
+ """
+ Remove diarization segments without words
+
+ Warning: this function mutate the diarization list
+ """
+ # count the number of words for each diarization segment
+ diarization_count = []
+ for d in diarization:
+ start = d["start"]
+ end = d["end"]
+ count = 0
+ for word in words:
+ if start <= word.start < end:
+ count += 1
+ elif start < word.end <= end:
+ count += 1
+ diarization_count.append(count)
+
+ # remove diarization segments with no words
+ diarization_idx = 0
+ while diarization_idx < len(diarization):
+ if diarization_count[diarization_idx] == 0:
+ diarization.pop(diarization_idx)
+ diarization_count.pop(diarization_idx)
+ else:
+ diarization_idx += 1
+
+ def _diarization_merge_same_speaker(
+ self, words: list[Word], diarization: list[dict]
+ ):
+ """
+ Merge diarization contigous segments with the same speaker
+
+ Warning: this function mutate the diarization list
+ """
+ # merge segment with same speaker
+ diarization_idx = 0
+ while diarization_idx < len(diarization) - 1:
+ d = diarization[diarization_idx]
+ dnext = diarization[diarization_idx + 1]
+ if d["speaker"] == dnext["speaker"]:
+ diarization[diarization_idx]["end"] = dnext["end"]
+ diarization.pop(diarization_idx + 1)
+ else:
+ diarization_idx += 1
+
+ def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
+ """
+ Assign speaker to words based on diarization
+
+ Warning: this function mutate the words list
+ """
+
+ word_idx = 0
+ last_speaker = None
+ for d in diarization:
+ start = d["start"]
+ end = d["end"]
+ speaker = d["speaker"]
+
+ # diarization may start after the first set of words
+ # in this case, we assign the last speaker
+ for word in words[word_idx:]:
+ if word.start < start:
+ # speaker change, but what make sense for assigning the word ?
+ # If it's a new sentence, assign with the new speaker
+ # If it's a continuation, assign with the last speaker
+ is_continuation = False
+ if word_idx > 0 and word_idx < len(words) - 1:
+ is_continuation = self.is_word_continuation(
+ *words[word_idx - 1 : word_idx + 1]
+ )
+ if is_continuation:
+ word.speaker = last_speaker
+ else:
+ word.speaker = speaker
+ last_speaker = speaker
+ word_idx += 1
+ else:
+ break
+
+ # now continue to assign speaker until the word starts after the end
+ for word in words[word_idx:]:
+ if start <= word.start < end:
+ last_speaker = speaker
+ word.speaker = speaker
+ word_idx += 1
+ elif word.start > end:
+ break
+
+ # no more diarization available,
+ # assign last speaker to all words without speaker
+ for word in words[word_idx:]:
+ word.speaker = last_speaker
diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py
index 53de2501..511b7f70 100644
--- a/server/reflector/processors/audio_diarization_modal.py
+++ b/server/reflector/processors/audio_diarization_modal.py
@@ -31,7 +31,7 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
follow_redirects=True,
)
response.raise_for_status()
- return response.json()["text"]
+ return response.json()["diarization"]
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)
diff --git a/server/reflector/processors/audio_waveform_processor.py b/server/reflector/processors/audio_waveform_processor.py
new file mode 100644
index 00000000..f1a24ffd
--- /dev/null
+++ b/server/reflector/processors/audio_waveform_processor.py
@@ -0,0 +1,36 @@
+import json
+from pathlib import Path
+
+from reflector.processors.base import Processor
+from reflector.processors.types import TitleSummary
+from reflector.utils.audio_waveform import get_audio_waveform
+
+
+class AudioWaveformProcessor(Processor):
+ """
+ Write the waveform for the final audio
+ """
+
+ INPUT_TYPE = TitleSummary
+
+ def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
+ super().__init__(**kwargs)
+ if isinstance(audio_path, str):
+ audio_path = Path(audio_path)
+ if audio_path.suffix not in (".mp3", ".wav"):
+ raise ValueError("Only mp3 and wav files are supported")
+ self.audio_path = audio_path
+ self.waveform_path = waveform_path
+
+ async def _flush(self):
+ self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
+ self.logger.info("Waveform Processing Started")
+ waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
+
+ with open(self.waveform_path, "w") as fd:
+ json.dump(waveform, fd)
+ self.logger.info("Waveform Processing Finished")
+ await self.emit(waveform, name="waveform")
+
+ async def _push(_self, _data):
+ return
diff --git a/server/reflector/settings.py b/server/reflector/settings.py
index 65412310..d0ddc91a 100644
--- a/server/reflector/settings.py
+++ b/server/reflector/settings.py
@@ -54,7 +54,7 @@ class Settings(BaseSettings):
TRANSCRIPT_MODAL_API_KEY: str | None = None
# Audio transcription storage
- TRANSCRIPT_STORAGE_BACKEND: str = "aws"
+ TRANSCRIPT_STORAGE_BACKEND: str | None = None
# Storage configuration for AWS
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
@@ -62,9 +62,6 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
- # Transcript MP3 storage
- TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
-
# LLM
# available backend: openai, modal, oobabooga
LLM_BACKEND: str = "oobabooga"
@@ -131,5 +128,8 @@ class Settings(BaseSettings):
# Profiling
PROFILING: bool = False
+ # Healthcheck
+ HEALTHCHECK_URL: str | None = None
+
settings = Settings()
diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py
index 5cdafdbf..a457ddf8 100644
--- a/server/reflector/storage/base.py
+++ b/server/reflector/storage/base.py
@@ -1,6 +1,7 @@
+import importlib
+
from pydantic import BaseModel
from reflector.settings import settings
-import importlib
class FileResult(BaseModel):
@@ -17,7 +18,7 @@ class Storage:
cls._registry[name] = kclass
@classmethod
- def get_instance(cls, name, settings_prefix=""):
+ def get_instance(cls, name: str, settings_prefix: str = ""):
if name not in cls._registry:
module_name = f"reflector.storage.storage_{name}"
importlib.import_module(module_name)
@@ -45,3 +46,9 @@ class Storage:
async def _delete_file(self, filename: str):
raise NotImplementedError
+
+ async def get_file_url(self, filename: str) -> str:
+ return await self._get_file_url(filename)
+
+ async def _get_file_url(self, filename: str) -> str:
+ raise NotImplementedError
diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py
index 09a9c383..d2313293 100644
--- a/server/reflector/storage/storage_aws.py
+++ b/server/reflector/storage/storage_aws.py
@@ -1,6 +1,6 @@
import aioboto3
-from reflector.storage.base import Storage, FileResult
from reflector.logger import logger
+from reflector.storage.base import FileResult, Storage
class AwsStorage(Storage):
@@ -44,16 +44,18 @@ class AwsStorage(Storage):
Body=data,
)
+ async def _get_file_url(self, filename: str) -> FileResult:
+ bucket = self.aws_bucket_name
+ folder = self.aws_folder
+ s3filename = f"{folder}/{filename}" if folder else filename
+ async with self.session.client("s3") as client:
presigned_url = await client.generate_presigned_url(
"get_object",
Params={"Bucket": bucket, "Key": s3filename},
ExpiresIn=3600,
)
- return FileResult(
- filename=filename,
- url=presigned_url,
- )
+ return presigned_url
async def _delete_file(self, filename: str):
bucket = self.aws_bucket_name
diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py
index 5de9ced3..9e62192b 100644
--- a/server/reflector/views/transcripts.py
+++ b/server/reflector/views/transcripts.py
@@ -1,31 +1,19 @@
from datetime import datetime, timedelta
-from typing import Annotated, Optional
+from typing import Annotated, Literal, Optional
import reflector.auth as auth
-from fastapi import (
- APIRouter,
- Depends,
- HTTPException,
- Request,
- WebSocket,
- WebSocketDisconnect,
- status,
-)
-from fastapi_pagination import Page, paginate
+from fastapi import APIRouter, Depends, HTTPException
+from fastapi_pagination import Page
+from fastapi_pagination.ext.databases import paginate
from jose import jwt
from pydantic import BaseModel, Field
from reflector.db.transcripts import (
- AudioWaveform,
+ TranscriptParticipant,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings
-from reflector.ws_manager import get_ws_manager
-from starlette.concurrency import run_in_threadpool
-
-from ._range_requests_response import range_requests_response
-from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter()
@@ -48,6 +36,7 @@ def create_access_token(data: dict, expires_delta: timedelta):
class GetTranscript(BaseModel):
id: str
+ user_id: str | None
name: str
status: str
locked: bool
@@ -56,8 +45,10 @@ class GetTranscript(BaseModel):
short_summary: str | None
long_summary: str | None
created_at: datetime
+ share_mode: str = Field("private")
source_language: str | None
target_language: str | None
+ participants: list[TranscriptParticipant] | None
class CreateTranscript(BaseModel):
@@ -72,6 +63,8 @@ class UpdateTranscript(BaseModel):
title: Optional[str] = Field(None)
short_summary: Optional[str] = Field(None)
long_summary: Optional[str] = Field(None)
+ share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None)
+ participants: Optional[list[TranscriptParticipant]] = Field(None)
class DeletionStatus(BaseModel):
@@ -82,12 +75,19 @@ class DeletionStatus(BaseModel):
async def transcripts_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
+ from reflector.db import database
+
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
- return paginate(
- await transcripts_controller.get_all(user_id=user_id, order_by="-created_at")
+ return await paginate(
+ database,
+ await transcripts_controller.get_all(
+ user_id=user_id,
+ order_by="-created_at",
+ return_query=True,
+ ),
)
@@ -165,10 +165,9 @@ async def transcript_get(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
- return transcript
+ return await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
@@ -181,17 +180,7 @@ async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
- values = {}
- if info.name is not None:
- values["name"] = info.name
- if info.locked is not None:
- values["locked"] = info.locked
- if info.long_summary is not None:
- values["long_summary"] = info.long_summary
- if info.short_summary is not None:
- values["short_summary"] = info.short_summary
- if info.title is not None:
- values["title"] = info.title
+ values = info.dict(exclude_unset=True)
await transcripts_controller.update(transcript, values)
return transcript
@@ -209,63 +198,6 @@ async def transcript_delete(
return DeletionStatus(status="ok")
-@router.get("/transcripts/{transcript_id}/audio/mp3")
-@router.head("/transcripts/{transcript_id}/audio/mp3")
-async def transcript_get_audio_mp3(
- request: Request,
- transcript_id: str,
- user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
- token: str | None = None,
-):
- user_id = user["sub"] if user else None
- if not user_id and token:
- unauthorized_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Invalid or expired token",
- headers={"WWW-Authenticate": "Bearer"},
- )
- try:
- payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
- user_id: str = payload.get("sub")
- except jwt.JWTError:
- raise unauthorized_exception
-
- transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
-
- if not transcript.audio_mp3_filename.exists():
- raise HTTPException(status_code=404, detail="Audio not found")
-
- truncated_id = str(transcript.id).split("-")[0]
- filename = f"recording_{truncated_id}.mp3"
-
- return range_requests_response(
- request,
- transcript.audio_mp3_filename,
- content_type="audio/mpeg",
- content_disposition=f"attachment; filename={filename}",
- )
-
-
-@router.get("/transcripts/{transcript_id}/audio/waveform")
-async def transcript_get_audio_waveform(
- transcript_id: str,
- user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
-) -> AudioWaveform:
- user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
-
- if not transcript.audio_mp3_filename.exists():
- raise HTTPException(status_code=404, detail="Audio not found")
-
- await run_in_threadpool(transcript.convert_audio_to_waveform)
-
- return transcript.audio_waveform
-
-
@router.get(
"/transcripts/{transcript_id}/topics",
response_model=list[GetTranscriptTopic],
@@ -275,92 +207,11 @@ async def transcript_get_topics(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
# convert to GetTranscriptTopic
return [
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
]
-
-
-# ==============================================================
-# Websocket
-# ==============================================================
-
-
-@router.get("/transcripts/{transcript_id}/events")
-async def transcript_get_websocket_events(transcript_id: str):
- pass
-
-
-@router.websocket("/transcripts/{transcript_id}/events")
-async def transcript_events_websocket(
- transcript_id: str,
- websocket: WebSocket,
- # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
-):
- # user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id(transcript_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
-
- # connect to websocket manager
- # use ts:transcript_id as room id
- room_id = f"ts:{transcript_id}"
- ws_manager = get_ws_manager()
- await ws_manager.add_user_to_room(room_id, websocket)
-
- try:
- # on first connection, send all events only to the current user
- for event in transcript.events:
- # for now, do not send TRANSCRIPT or STATUS options - theses are live event
- # not necessary to be sent to the client; but keep the rest
- name = event.event
- if name in ("TRANSCRIPT", "STATUS"):
- continue
- await websocket.send_json(event.model_dump(mode="json"))
-
- # XXX if transcript is final (locked=True and status=ended)
- # XXX send a final event to the client and close the connection
-
- # endless loop to wait for new events
- # we do not have command system now,
- while True:
- await websocket.receive()
- except (RuntimeError, WebSocketDisconnect):
- await ws_manager.remove_user_from_room(room_id, websocket)
-
-
-# ==============================================================
-# Web RTC
-# ==============================================================
-
-
-@router.post("/transcripts/{transcript_id}/record/webrtc")
-async def transcript_record_webrtc(
- transcript_id: str,
- params: RtcOffer,
- request: Request,
- user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
-):
- user_id = user["sub"] if user else None
- transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
- if not transcript:
- raise HTTPException(status_code=404, detail="Transcript not found")
-
- if transcript.locked:
- raise HTTPException(status_code=400, detail="Transcript is locked")
-
- # create a pipeline runner
- from reflector.pipelines.main_live_pipeline import PipelineMainLive
-
- pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
-
- # FIXME do not allow multiple recording at the same time
- return await rtc_offer_base(
- params,
- request,
- pipeline_runner=pipeline_runner,
- )
diff --git a/server/reflector/views/transcripts_audio.py b/server/reflector/views/transcripts_audio.py
new file mode 100644
index 00000000..a174d992
--- /dev/null
+++ b/server/reflector/views/transcripts_audio.py
@@ -0,0 +1,109 @@
+"""
+Transcripts audio related endpoints
+===================================
+
+"""
+from typing import Annotated, Optional
+
+import httpx
+import reflector.auth as auth
+from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
+from jose import jwt
+from reflector.db.transcripts import AudioWaveform, transcripts_controller
+from reflector.settings import settings
+from reflector.views.transcripts import ALGORITHM
+
+from ._range_requests_response import range_requests_response
+
+router = APIRouter()
+
+
+@router.get("/transcripts/{transcript_id}/audio/mp3")
+@router.head("/transcripts/{transcript_id}/audio/mp3")
+async def transcript_get_audio_mp3(
+ request: Request,
+ transcript_id: str,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+ token: str | None = None,
+):
+ user_id = user["sub"] if user else None
+ if not user_id and token:
+ unauthorized_exception = HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid or expired token",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ try:
+ payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
+ user_id: str = payload.get("sub")
+ except jwt.JWTError:
+ raise unauthorized_exception
+
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ if transcript.audio_location == "storage":
+ # proxy S3 file, to prevent issue with CORS
+ url = await transcript.get_audio_url()
+ headers = {}
+
+ copy_headers = ["range", "accept-encoding"]
+ for header in copy_headers:
+ if header in request.headers:
+ headers[header] = request.headers[header]
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.request(request.method, url, headers=headers)
+ return Response(
+ content=resp.content,
+ status_code=resp.status_code,
+ headers=resp.headers,
+ )
+
+ if transcript.audio_location == "storage":
+ # proxy S3 file, to prevent issue with CORS
+ url = await transcript.get_audio_url()
+ headers = {}
+
+ copy_headers = ["range", "accept-encoding"]
+ for header in copy_headers:
+ if header in request.headers:
+ headers[header] = request.headers[header]
+
+ async with httpx.AsyncClient() as client:
+ resp = await client.request(request.method, url, headers=headers)
+ return Response(
+ content=resp.content,
+ status_code=resp.status_code,
+ headers=resp.headers,
+ )
+
+ if not transcript.audio_mp3_filename.exists():
+ raise HTTPException(status_code=500, detail="Audio not found")
+
+ truncated_id = str(transcript.id).split("-")[0]
+ filename = f"recording_{truncated_id}.mp3"
+
+ return range_requests_response(
+ request,
+ transcript.audio_mp3_filename,
+ content_type="audio/mpeg",
+ content_disposition=f"attachment; filename={filename}",
+ )
+
+
+@router.get("/transcripts/{transcript_id}/audio/waveform")
+async def transcript_get_audio_waveform(
+ transcript_id: str,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> AudioWaveform:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ if not transcript.audio_waveform_filename.exists():
+ raise HTTPException(status_code=404, detail="Audio not found")
+
+ return transcript.audio_waveform
diff --git a/server/reflector/views/transcripts_participants.py b/server/reflector/views/transcripts_participants.py
new file mode 100644
index 00000000..318d6018
--- /dev/null
+++ b/server/reflector/views/transcripts_participants.py
@@ -0,0 +1,142 @@
+"""
+Transcript participants API endpoints
+=====================================
+
+"""
+from typing import Annotated, Optional
+
+import reflector.auth as auth
+from fastapi import APIRouter, Depends, HTTPException
+from pydantic import BaseModel, ConfigDict, Field
+from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
+from reflector.views.types import DeletionStatus
+
+router = APIRouter()
+
+
+class Participant(BaseModel):
+ model_config = ConfigDict(from_attributes=True)
+ id: str
+ speaker: int | None
+ name: str
+
+
+class CreateParticipant(BaseModel):
+ speaker: Optional[int] = Field(None)
+ name: str
+
+
+class UpdateParticipant(BaseModel):
+ speaker: Optional[int] = Field(None)
+ name: Optional[str] = Field(None)
+
+
+@router.get("/transcripts/{transcript_id}/participants")
+async def transcript_get_participants(
+ transcript_id: str,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> list[Participant]:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ return [
+ Participant.model_validate(participant)
+ for participant in transcript.participants
+ ]
+
+
+@router.post("/transcripts/{transcript_id}/participants")
+async def transcript_add_participant(
+ transcript_id: str,
+ participant: CreateParticipant,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> Participant:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ # ensure the speaker is unique
+ for p in transcript.participants:
+ if p.speaker == participant.speaker:
+ raise HTTPException(
+ status_code=400,
+ detail="Speaker already assigned",
+ )
+
+ obj = await transcripts_controller.upsert_participant(
+ transcript, TranscriptParticipant(**participant.dict())
+ )
+ return Participant.model_validate(obj)
+
+
+@router.get("/transcripts/{transcript_id}/participants/{participant_id}")
+async def transcript_get_participant(
+ transcript_id: str,
+ participant_id: str,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> Participant:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ for p in transcript.participants:
+ if p.id == participant_id:
+ return Participant.model_validate(p)
+
+ raise HTTPException(status_code=404, detail="Participant not found")
+
+
+@router.patch("/transcripts/{transcript_id}/participants/{participant_id}")
+async def transcript_update_participant(
+ transcript_id: str,
+ participant_id: str,
+ participant: UpdateParticipant,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> Participant:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ # ensure the speaker is unique
+ for p in transcript.participants:
+ if p.speaker == participant.speaker and p.id != participant_id:
+ raise HTTPException(
+ status_code=400,
+ detail="Speaker already assigned",
+ )
+
+ # find the participant
+ obj = None
+ for p in transcript.participants:
+ if p.id == participant_id:
+ obj = p
+ break
+
+ if not obj:
+ raise HTTPException(status_code=404, detail="Participant not found")
+
+ # update participant but just the fields that are set
+ fields = participant.dict(exclude_unset=True)
+ obj = obj.copy(update=fields)
+
+ await transcripts_controller.upsert_participant(transcript, obj)
+ return Participant.model_validate(obj)
+
+
+@router.delete("/transcripts/{transcript_id}/participants/{participant_id}")
+async def transcript_delete_participant(
+ transcript_id: str,
+ participant_id: str,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+) -> DeletionStatus:
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+ await transcripts_controller.delete_participant(transcript, participant_id)
+ return DeletionStatus(status="ok")
diff --git a/server/reflector/views/transcripts_webrtc.py b/server/reflector/views/transcripts_webrtc.py
new file mode 100644
index 00000000..af451411
--- /dev/null
+++ b/server/reflector/views/transcripts_webrtc.py
@@ -0,0 +1,37 @@
+from typing import Annotated, Optional
+
+import reflector.auth as auth
+from fastapi import APIRouter, Depends, HTTPException, Request
+from reflector.db.transcripts import transcripts_controller
+
+from .rtc_offer import RtcOffer, rtc_offer_base
+
+router = APIRouter()
+
+
+@router.post("/transcripts/{transcript_id}/record/webrtc")
+async def transcript_record_webrtc(
+ transcript_id: str,
+ params: RtcOffer,
+ request: Request,
+ user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+):
+ user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id_for_http(
+ transcript_id, user_id=user_id
+ )
+
+ if transcript.locked:
+ raise HTTPException(status_code=400, detail="Transcript is locked")
+
+ # create a pipeline runner
+ from reflector.pipelines.main_live_pipeline import PipelineMainLive
+
+ pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
+
+ # FIXME do not allow multiple recording at the same time
+ return await rtc_offer_base(
+ params,
+ request,
+ pipeline_runner=pipeline_runner,
+ )
diff --git a/server/reflector/views/transcripts_websocket.py b/server/reflector/views/transcripts_websocket.py
new file mode 100644
index 00000000..65571aab
--- /dev/null
+++ b/server/reflector/views/transcripts_websocket.py
@@ -0,0 +1,53 @@
+"""
+Transcripts websocket API
+=========================
+
+"""
+from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
+from reflector.db.transcripts import transcripts_controller
+from reflector.ws_manager import get_ws_manager
+
+router = APIRouter()
+
+
+@router.get("/transcripts/{transcript_id}/events")
+async def transcript_get_websocket_events(transcript_id: str):
+ pass
+
+
+@router.websocket("/transcripts/{transcript_id}/events")
+async def transcript_events_websocket(
+ transcript_id: str,
+ websocket: WebSocket,
+ # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
+):
+ # user_id = user["sub"] if user else None
+ transcript = await transcripts_controller.get_by_id(transcript_id)
+ if not transcript:
+ raise HTTPException(status_code=404, detail="Transcript not found")
+
+ # connect to websocket manager
+ # use ts:transcript_id as room id
+ room_id = f"ts:{transcript_id}"
+ ws_manager = get_ws_manager()
+ await ws_manager.add_user_to_room(room_id, websocket)
+
+ try:
+ # on first connection, send all events only to the current user
+ for event in transcript.events:
+ # for now, do not send TRANSCRIPT or STATUS options - theses are live event
+ # not necessary to be sent to the client; but keep the rest
+ name = event.event
+ if name in ("TRANSCRIPT", "STATUS"):
+ continue
+ await websocket.send_json(event.model_dump(mode="json"))
+
+ # XXX if transcript is final (locked=True and status=ended)
+ # XXX send a final event to the client and close the connection
+
+ # endless loop to wait for new events
+ # we do not have command system now,
+ while True:
+ await websocket.receive()
+ except (RuntimeError, WebSocketDisconnect):
+ await ws_manager.remove_user_from_room(room_id, websocket)
diff --git a/server/reflector/views/types.py b/server/reflector/views/types.py
new file mode 100644
index 00000000..70361131
--- /dev/null
+++ b/server/reflector/views/types.py
@@ -0,0 +1,5 @@
+from pydantic import BaseModel
+
+
+class DeletionStatus(BaseModel):
+ status: str
diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py
index e1000364..689623ce 100644
--- a/server/reflector/worker/app.py
+++ b/server/reflector/worker/app.py
@@ -1,6 +1,8 @@
+import structlog
from celery import Celery
from reflector.settings import settings
+logger = structlog.get_logger(__name__)
app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
@@ -8,5 +10,18 @@ app.conf.broker_connection_retry_on_startup = True
app.autodiscover_tasks(
[
"reflector.pipelines.main_live_pipeline",
+ "reflector.worker.healthcheck",
]
)
+
+# crontab
+app.conf.beat_schedule = {}
+
+if settings.HEALTHCHECK_URL:
+ app.conf.beat_schedule["healthcheck_ping"] = {
+ "task": "reflector.worker.healthcheck.healthcheck_ping",
+ "schedule": 60.0 * 10,
+ }
+ logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
+else:
+ logger.warning("Healthcheck disabled, no url configured")
diff --git a/server/reflector/worker/healthcheck.py b/server/reflector/worker/healthcheck.py
new file mode 100644
index 00000000..e4ce6bc3
--- /dev/null
+++ b/server/reflector/worker/healthcheck.py
@@ -0,0 +1,18 @@
+import httpx
+import structlog
+from celery import shared_task
+from reflector.settings import settings
+
+logger = structlog.get_logger(__name__)
+
+
+@shared_task
+def healthcheck_ping():
+ url = settings.HEALTHCHECK_URL
+ if not url:
+ return
+ try:
+ print("pinging healthcheck url", url)
+ httpx.get(url, timeout=10)
+ except Exception as e:
+ logger.error("healthcheck_ping", error=str(e))
diff --git a/server/runserver.sh b/server/runserver.sh
index b0c3f138..31cce123 100755
--- a/server/runserver.sh
+++ b/server/runserver.sh
@@ -9,6 +9,8 @@ if [ "${ENTRYPOINT}" = "server" ]; then
python -m reflector.app
elif [ "${ENTRYPOINT}" = "worker" ]; then
celery -A reflector.worker.app worker --loglevel=info
+elif [ "${ENTRYPOINT}" = "beat" ]; then
+ celery -A reflector.worker.app beat --loglevel=info
else
echo "Unknown command"
fi
diff --git a/server/tests/conftest.py b/server/tests/conftest.py
index aafca9fd..532ebff9 100644
--- a/server/tests/conftest.py
+++ b/server/tests/conftest.py
@@ -1,4 +1,5 @@
from unittest.mock import patch
+from tempfile import NamedTemporaryFile
import pytest
@@ -7,7 +8,6 @@ import pytest
@pytest.mark.asyncio
async def setup_database():
from reflector.settings import settings
- from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as f:
settings.DATABASE_URL = f"sqlite:///{f.name}"
@@ -103,6 +103,25 @@ async def dummy_llm():
yield
+@pytest.fixture
+async def dummy_storage():
+ from reflector.storage.base import Storage
+
+ class DummyStorage(Storage):
+ async def _put_file(self, *args, **kwargs):
+ pass
+
+ async def _delete_file(self, *args, **kwargs):
+ pass
+
+ async def _get_file_url(self, *args, **kwargs):
+ return "http://fake_server/audio.mp3"
+
+ with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
+ mock_storage.return_value = DummyStorage()
+ yield
+
+
@pytest.fixture
def nltk():
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
@@ -133,4 +152,17 @@ def celery_enable_logging():
@pytest.fixture(scope="session")
def celery_config():
- return {"broker_url": "memory://", "result_backend": "rpc"}
+ with NamedTemporaryFile() as f:
+ yield {
+ "broker_url": "memory://",
+ "result_backend": f"db+sqlite:///{f.name}",
+ }
+
+
+@pytest.fixture(scope="session")
+def fake_mp3_upload():
+ with patch(
+ "reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
+ ) as mock_move:
+ mock_move.return_value = True
+ yield
diff --git a/server/tests/test_processor_audio_diarization.py b/server/tests/test_processor_audio_diarization.py
new file mode 100644
index 00000000..00935a49
--- /dev/null
+++ b/server/tests/test_processor_audio_diarization.py
@@ -0,0 +1,140 @@
+import pytest
+from unittest import mock
+
+
+@pytest.mark.parametrize(
+ "name,diarization,expected",
+ [
+ [
+ "no overlap",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 1.0, "end": 2.0, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "same speaker",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 1.0, "end": 2.0, "speaker": "A"},
+ ],
+ ["A", "A", "A", "A"],
+ ],
+ [
+ # first segment is removed because it overlap
+ # with the second segment, and it is smaller
+ "overlap at 0.5s",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 0.5, "end": 2.0, "speaker": "B"},
+ ],
+ ["B", "B", "B", "B"],
+ ],
+ [
+ "junk segment at 0.5s for 0.2s",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 0.5, "end": 0.7, "speaker": "B"},
+ {"start": 1, "end": 2.0, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "start without diarization",
+ [
+ {"start": 0.5, "end": 1.0, "speaker": "A"},
+ {"start": 1.0, "end": 2.0, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "end missing diarization",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 1.0, "end": 1.5, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "continuation of next speaker",
+ [
+ {"start": 0.0, "end": 0.9, "speaker": "A"},
+ {"start": 1.5, "end": 2.0, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "continuation of previous speaker",
+ [
+ {"start": 0.0, "end": 0.5, "speaker": "A"},
+ {"start": 1.0, "end": 2.0, "speaker": "B"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ [
+ "segment without words",
+ [
+ {"start": 0.0, "end": 1.0, "speaker": "A"},
+ {"start": 1.0, "end": 2.0, "speaker": "B"},
+ {"start": 2.0, "end": 3.0, "speaker": "X"},
+ ],
+ ["A", "A", "B", "B"],
+ ],
+ ],
+)
+@pytest.mark.asyncio
+async def test_processors_audio_diarization(event_loop, name, diarization, expected):
+ from reflector.processors.audio_diarization import AudioDiarizationProcessor
+ from reflector.processors.types import (
+ TitleSummaryWithId,
+ Transcript,
+ Word,
+ AudioDiarizationInput,
+ )
+
+ # create fake topic
+ topics = [
+ TitleSummaryWithId(
+ id="1",
+ title="Title1",
+ summary="Summary1",
+ timestamp=0.0,
+ duration=1.0,
+ transcript=Transcript(
+ words=[
+ Word(text="Word1", start=0.0, end=0.5),
+ Word(text="word2.", start=0.5, end=1.0),
+ ]
+ ),
+ ),
+ TitleSummaryWithId(
+ id="2",
+ title="Title2",
+ summary="Summary2",
+ timestamp=0.0,
+ duration=1.0,
+ transcript=Transcript(
+ words=[
+ Word(text="Word3", start=1.0, end=1.5),
+ Word(text="word4.", start=1.5, end=2.0),
+ ]
+ ),
+ ),
+ ]
+
+ diarizer = AudioDiarizationProcessor()
+ with mock.patch.object(diarizer, "_diarize") as mock_diarize:
+ mock_diarize.return_value = diarization
+
+ data = AudioDiarizationInput(
+ audio_url="https://example.com/audio.mp3",
+ topics=topics,
+ )
+ await diarizer._push(data)
+
+ # check that the speaker has been assigned to the words
+ assert topics[0].transcript.words[0].speaker == expected[0]
+ assert topics[0].transcript.words[1].speaker == expected[1]
+ assert topics[1].transcript.words[0].speaker == expected[2]
+ assert topics[1].transcript.words[1].speaker == expected[3]
diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py
index 69ae5f65..28f83fff 100644
--- a/server/tests/test_transcripts_audio_download.py
+++ b/server/tests/test_transcripts_audio_download.py
@@ -118,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek(
assert response.status_code == 206
assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 100-")
-
-
-@pytest.mark.asyncio
-async def test_transcript_audio_download_waveform(fake_transcript):
- from reflector.app import app
-
- ac = AsyncClient(app=app, base_url="http://test/v1")
- response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
- assert response.status_code == 200
- assert response.headers["content-type"] == "application/json"
- assert isinstance(response.json()["data"], list)
- assert len(response.json()["data"]) >= 255
diff --git a/server/tests/test_transcripts_participants.py b/server/tests/test_transcripts_participants.py
new file mode 100644
index 00000000..b55b16a8
--- /dev/null
+++ b/server/tests/test_transcripts_participants.py
@@ -0,0 +1,164 @@
+import pytest
+from httpx import AsyncClient
+
+
+@pytest.mark.asyncio
+async def test_transcript_participants():
+ from reflector.app import app
+
+ async with AsyncClient(app=app, base_url="http://test/v1") as ac:
+ response = await ac.post("/transcripts", json={"name": "test"})
+ assert response.status_code == 200
+ assert response.json()["participants"] == []
+
+ # create a participant
+ transcript_id = response.json()["id"]
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants", json={"name": "test"}
+ )
+ assert response.status_code == 200
+ assert response.json()["id"] is not None
+ assert response.json()["speaker"] is None
+ assert response.json()["name"] == "test"
+
+ # create another one with a speaker
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test2", "speaker": 1},
+ )
+ assert response.status_code == 200
+ assert response.json()["id"] is not None
+ assert response.json()["speaker"] == 1
+ assert response.json()["name"] == "test2"
+
+ # get all participants via transcript
+ response = await ac.get(f"/transcripts/{transcript_id}")
+ assert response.status_code == 200
+ assert len(response.json()["participants"]) == 2
+
+ # get participants via participants endpoint
+ response = await ac.get(f"/transcripts/{transcript_id}/participants")
+ assert response.status_code == 200
+ assert len(response.json()) == 2
+
+
+@pytest.mark.asyncio
+async def test_transcript_participants_same_speaker():
+ from reflector.app import app
+
+ async with AsyncClient(app=app, base_url="http://test/v1") as ac:
+ response = await ac.post("/transcripts", json={"name": "test"})
+ assert response.status_code == 200
+ assert response.json()["participants"] == []
+ transcript_id = response.json()["id"]
+
+ # create a participant
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test", "speaker": 1},
+ )
+ assert response.status_code == 200
+ assert response.json()["speaker"] == 1
+
+ # create another one with the same speaker
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test2", "speaker": 1},
+ )
+ assert response.status_code == 400
+
+
+@pytest.mark.asyncio
+async def test_transcript_participants_update_name():
+ from reflector.app import app
+
+ async with AsyncClient(app=app, base_url="http://test/v1") as ac:
+ response = await ac.post("/transcripts", json={"name": "test"})
+ assert response.status_code == 200
+ assert response.json()["participants"] == []
+ transcript_id = response.json()["id"]
+
+ # create a participant
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test", "speaker": 1},
+ )
+ assert response.status_code == 200
+ assert response.json()["speaker"] == 1
+
+ # update the participant
+ participant_id = response.json()["id"]
+ response = await ac.patch(
+ f"/transcripts/{transcript_id}/participants/{participant_id}",
+ json={"name": "test2"},
+ )
+ assert response.status_code == 200
+ assert response.json()["name"] == "test2"
+
+ # verify the participant was updated
+ response = await ac.get(
+ f"/transcripts/{transcript_id}/participants/{participant_id}"
+ )
+ assert response.status_code == 200
+ assert response.json()["name"] == "test2"
+
+ # verify the participant was updated in transcript
+ response = await ac.get(f"/transcripts/{transcript_id}")
+ assert response.status_code == 200
+ assert len(response.json()["participants"]) == 1
+ assert response.json()["participants"][0]["name"] == "test2"
+
+
+@pytest.mark.asyncio
+async def test_transcript_participants_update_speaker():
+ from reflector.app import app
+
+ async with AsyncClient(app=app, base_url="http://test/v1") as ac:
+ response = await ac.post("/transcripts", json={"name": "test"})
+ assert response.status_code == 200
+ assert response.json()["participants"] == []
+ transcript_id = response.json()["id"]
+
+ # create a participant
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test", "speaker": 1},
+ )
+ assert response.status_code == 200
+ participant1_id = response.json()["id"]
+
+ # create another participant
+ response = await ac.post(
+ f"/transcripts/{transcript_id}/participants",
+ json={"name": "test2", "speaker": 2},
+ )
+ assert response.status_code == 200
+ participant2_id = response.json()["id"]
+
+ # update the participant, refused as speaker is already taken
+ response = await ac.patch(
+ f"/transcripts/{transcript_id}/participants/{participant2_id}",
+ json={"speaker": 1},
+ )
+ assert response.status_code == 400
+
+ # delete the participant 1
+ response = await ac.delete(
+ f"/transcripts/{transcript_id}/participants/{participant1_id}"
+ )
+ assert response.status_code == 200
+
+ # update the participant 2 again, should be accepted now
+ response = await ac.patch(
+ f"/transcripts/{transcript_id}/participants/{participant2_id}",
+ json={"speaker": 1},
+ )
+ assert response.status_code == 200
+
+ # ensure participant2 name is still there
+ response = await ac.get(
+ f"/transcripts/{transcript_id}/participants/{participant2_id}"
+ )
+ assert response.status_code == 200
+ assert response.json()["name"] == "test2"
+ assert response.json()["speaker"] == 1
diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py
index cf2ea304..8502a0d9 100644
--- a/server/tests/test_transcripts_rtc_ws.py
+++ b/server/tests/test_transcripts_rtc_ws.py
@@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket(
dummy_transcript,
dummy_processors,
dummy_diarization,
+ dummy_storage,
+ fake_mp3_upload,
ensure_casing,
appserver,
sentence_tokenize,
@@ -182,6 +184,16 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
+ assert "WAVEFORM" in eventnames
+ ev = events[eventnames.index("WAVEFORM")]
+ assert isinstance(ev["data"]["waveform"], list)
+ assert len(ev["data"]["waveform"]) >= 250
+ waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
+ assert waveform_resp.status_code == 200
+ assert waveform_resp.headers["content-type"] == "application/json"
+ assert isinstance(waveform_resp.json()["data"], list)
+ assert len(waveform_resp.json()["data"]) >= 250
+
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses.index("recording") < statuses.index("processing")
@@ -193,11 +205,12 @@ async def test_transcript_rtc_and_websocket(
# check on the latest response that the audio duration is > 0
assert resp.json()["duration"] > 0
+ assert "DURATION" in eventnames
# check that audio/mp3 is available
- resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
- assert resp.status_code == 200
- assert resp.headers["Content-Type"] == "audio/mpeg"
+ audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
+ assert audio_resp.status_code == 200
+ assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("celery_session_app")
@@ -209,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_transcript,
dummy_processors,
dummy_diarization,
+ dummy_storage,
+ fake_mp3_upload,
ensure_casing,
appserver,
sentence_tokenize,
diff --git a/www/app/(auth)/fiefWrapper.tsx b/www/app/(auth)/fiefWrapper.tsx
index 187fef7c..bb38f5ee 100644
--- a/www/app/(auth)/fiefWrapper.tsx
+++ b/www/app/(auth)/fiefWrapper.tsx
@@ -1,11 +1,18 @@
"use client";
import { FiefAuthProvider } from "@fief/fief/nextjs/react";
+import { createContext } from "react";
-export default function FiefWrapper({ children }) {
+export const CookieContext = createContext<{ hasAuthCookie: boolean }>({
+ hasAuthCookie: false,
+});
+
+export default function FiefWrapper({ children, hasAuthCookie }) {
return (
-
- Capture the signal, not the noise -
-+ Capture the signal, not the noise +
+Loading Transcript
+ ) : ( ++ There was an error generating the final summary, please + come back later +
+ )}- You can share this link with others. Anyone with the link will have - access to the page, including the full audio recording, for the next 7 - days. -
- ) : ( -- You can share this link with others. Anyone with the link will have - access to the page, including the full audio recording. -
+ {requireLogin && ( +This transcript is private and can only be accessed by you.
+ )} + {shareMode === "semi-private" && ( ++ This transcript is secure. Only authenticated users can access it. +
+ )} + {shareMode === "public" && ( +This transcript is public. Everyone can access it.
+ )} + + {isOwner && api && ( ++ Share this link to grant others access to this page. The link + includes the full audio recording and is valid for the next 7 + days. +
+ ) : ( ++ Share this link to allow others to view this page and listen to + the full audio recording. +
+ )} + > )}