Merge branch 'main' into post-to-zulip

This commit is contained in:
Koper
2023-12-04 21:06:10 +07:00
53 changed files with 2479 additions and 692 deletions

View File

@@ -23,7 +23,7 @@ It also uses https://github.com/fief-dev for authentication, and Vercel for depl
- [OpenAPI Code Generation](#openapi-code-generation) - [OpenAPI Code Generation](#openapi-code-generation)
- [Back-End](#back-end) - [Back-End](#back-end)
- [Installation](#installation-1) - [Installation](#installation-1)
- [Start the project](#start-the-project) - [Start the API/Backend](#start-the-apibackend)
- [Using docker](#using-docker) - [Using docker](#using-docker)
- [Using local GPT4All](#using-local-gpt4all) - [Using local GPT4All](#using-local-gpt4all)
- [Using local files](#using-local-files) - [Using local files](#using-local-files)
@@ -133,15 +133,15 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
``` ```
### Start the project ### Start the API/Backend
Use: Start the API server:
```bash ```bash
poetry run python3 -m reflector.app poetry run python3 -m reflector.app
``` ```
And start the background worker Start the background worker:
```bash ```bash
celery -A reflector.worker.app worker --loglevel=info celery -A reflector.worker.app worker --loglevel=info
@@ -153,6 +153,12 @@ Redis:
TODO TODO
``` ```
For crontab (only healthcheck for now), start the celery beat (you don't need it on your local dev environment):
```bash
celery -A reflector.worker.app beat
```
#### Using docker #### Using docker
Use: Use:

View File

@@ -0,0 +1,188 @@
"""
Reflector GPU backend - diarizer
===================================
"""
import os
import modal.gpu
from modal import Image, Secret, Stub, asgi_app, method
from pydantic import BaseModel
PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.0"
MODEL_DIR = "/root/diarization_models"
stub = Stub(name="reflector-diarizer")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR)
print("LLM cache moved")
def download_pyannote_audio():
from pyannote.audio import Pipeline
Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
cache_dir=MODEL_DIR,
use_auth_token="***REMOVED***"
)
diarizer_image = (
Image.debian_slim(python_version="3.10.8")
.pip_install(
"pyannote.audio",
"requests",
"onnx",
"torchaudio",
"onnxruntime-gpu",
"torch==2.0.0",
"transformers==4.34.0",
"sentencepiece",
"protobuf",
"numpy",
"huggingface_hub",
"hf-transfer"
)
.run_function(migrate_cache_llm)
.run_function(download_pyannote_audio)
.env(
{
"LD_LIBRARY_PATH": (
"/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:"
"/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/"
)
}
)
)
@stub.cls(
gpu=modal.gpu.A100(memory=40),
timeout=60 * 30,
container_idle_timeout=60,
allow_concurrent_inputs=1,
image=diarizer_image,
)
class Diarizer:
def __enter__(self):
import torch
from pyannote.audio import Pipeline
self.use_gpu = torch.cuda.is_available()
self.device = "cuda" if self.use_gpu else "cpu"
self.diarization_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.0",
cache_dir=MODEL_DIR
)
self.diarization_pipeline.to(torch.device(self.device))
@method()
def diarize(
self,
audio_data: str,
audio_suffix: str,
timestamp: float
):
import tempfile
import torchaudio
with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp:
fp.write(audio_data)
print("Diarizing audio")
waveform, sample_rate = torchaudio.load(fp.name)
diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate})
words = []
for diarization_segment, _, speaker in diarization.itertracks(yield_label=True):
words.append(
{
"start": round(timestamp + diarization_segment.start, 3),
"end": round(timestamp + diarization_segment.end, 3),
"speaker": int(speaker[-2:])
}
)
print("Diarization complete")
return {
"diarization": words
}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@stub.function(
timeout=60 * 10,
container_idle_timeout=60 * 3,
allow_concurrent_inputs=40,
secrets=[
Secret.from_name("reflector-gpu"),
],
image=diarizer_image
)
@asgi_app()
def web():
import requests
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
diarizerstub = Diarizer()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
def validate_audio_file(audio_file_url: str):
# Check if the audio file exists
response = requests.head(audio_file_url, allow_redirects=True)
if response.status_code == 404:
raise HTTPException(
status_code=response.status_code,
detail="The audio file does not exist."
)
class DiarizationResponse(BaseModel):
result: dict
@app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)])
def diarize(
audio_file_url: str,
timestamp: float = 0.0
) -> HTTPException | DiarizationResponse:
# Currently the uploaded files are in mp3 format
audio_suffix = "mp3"
print("Downloading audio file")
response = requests.get(audio_file_url, allow_redirects=True)
print("Audio file downloaded successfully")
func = diarizerstub.diarize.spawn(
audio_data=response.content,
audio_suffix=audio_suffix,
timestamp=timestamp
)
result = func.get()
return result
return app

View File

@@ -0,0 +1,33 @@
"""add share_mode
Revision ID: 0fea6d96b096
Revises: f819277e5169
Create Date: 2023-11-07 11:12:21.614198
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "0fea6d96b096"
down_revision: Union[str, None] = "f819277e5169"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"transcript",
sa.Column("share_mode", sa.String(), server_default="private", nullable=False),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "share_mode")
# ### end Alembic commands ###

View File

@@ -0,0 +1,30 @@
"""participants
Revision ID: 125031f7cb78
Revises: 0fea6d96b096
Create Date: 2023-11-30 15:56:03.341466
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '125031f7cb78'
down_revision: Union[str, None] = '0fea6d96b096'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('transcript', sa.Column('participants', sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('transcript', 'participants')
# ### end Alembic commands ###

View File

@@ -0,0 +1,35 @@
"""audio_location
Revision ID: f819277e5169
Revises: 4814901632bc
Create Date: 2023-11-16 10:29:09.351664
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = "f819277e5169"
down_revision: Union[str, None] = "4814901632bc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"transcript",
sa.Column(
"audio_location", sa.String(), server_default="local", nullable=False
),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "audio_location")
# ### end Alembic commands ###

View File

@@ -13,6 +13,12 @@ from reflector.metrics import metrics_init
from reflector.settings import settings from reflector.settings import settings
from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.rtc_offer import router as rtc_offer_router
from reflector.views.transcripts import router as transcripts_router from reflector.views.transcripts import router as transcripts_router
from reflector.views.transcripts_audio import router as transcripts_audio_router
from reflector.views.transcripts_participants import (
router as transcripts_participants_router,
)
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
from reflector.views.user import router as user_router from reflector.views.user import router as user_router
try: try:
@@ -60,6 +66,10 @@ metrics_init(app, instrumentator)
# register views # register views
app.include_router(rtc_offer_router) app.include_router(rtc_offer_router)
app.include_router(transcripts_router, prefix="/v1") app.include_router(transcripts_router, prefix="/v1")
app.include_router(transcripts_audio_router, prefix="/v1")
app.include_router(transcripts_participants_router, prefix="/v1")
app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(user_router, prefix="/v1") app.include_router(user_router, prefix="/v1")
add_pagination(app) add_pagination(app)

View File

@@ -2,15 +2,16 @@ import json
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any, Literal
from uuid import uuid4 from uuid import uuid4
import sqlalchemy import sqlalchemy
from pydantic import BaseModel, Field from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field
from reflector.db import database, metadata from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.audio_waveform import get_audio_waveform from reflector.storage import Storage
transcripts = sqlalchemy.Table( transcripts = sqlalchemy.Table(
"transcript", "transcript",
@@ -26,22 +27,42 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True),
sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("participants", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
sqlalchemy.Column(
"audio_location",
sqlalchemy.String,
nullable=False,
server_default="local",
),
# with user attached, optional # with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String), sqlalchemy.Column("user_id", sqlalchemy.String),
sqlalchemy.Column(
"share_mode",
sqlalchemy.String,
nullable=False,
server_default="private",
),
) )
def generate_uuid4(): def generate_uuid4() -> str:
return str(uuid4()) return str(uuid4())
def generate_transcript_name(): def generate_transcript_name() -> str:
now = datetime.utcnow() now = datetime.utcnow()
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
def get_storage() -> Storage:
return Storage.get_instance(
name=settings.TRANSCRIPT_STORAGE_BACKEND,
settings_prefix="TRANSCRIPT_STORAGE_",
)
class AudioWaveform(BaseModel): class AudioWaveform(BaseModel):
data: list[float] data: list[float]
@@ -79,11 +100,26 @@ class TranscriptFinalTitle(BaseModel):
title: str title: str
class TranscriptDuration(BaseModel):
duration: float
class TranscriptWaveform(BaseModel):
waveform: list[float]
class TranscriptEvent(BaseModel): class TranscriptEvent(BaseModel):
event: str event: str
data: dict data: dict
class TranscriptParticipant(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4)
speaker: int | None
name: str
class Transcript(BaseModel): class Transcript(BaseModel):
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None user_id: str | None = None
@@ -97,8 +133,11 @@ class Transcript(BaseModel):
long_summary: str | None = None long_summary: str | None = None
topics: list[TranscriptTopic] = [] topics: list[TranscriptTopic] = []
events: list[TranscriptEvent] = [] events: list[TranscriptEvent] = []
participants: list[TranscriptParticipant] | None = []
source_language: str = "en" source_language: str = "en"
target_language: str = "en" target_language: str = "en"
share_mode: Literal["private", "semi-private", "public"] = "private"
audio_location: str = "local"
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump()) ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -112,27 +151,33 @@ class Transcript(BaseModel):
else: else:
self.topics.append(topic) self.topics.append(topic)
def upsert_participant(self, participant: TranscriptParticipant):
index = next(
(i for i, p in enumerate(self.participants) if p.id == participant.id),
None,
)
if index is not None:
self.participants[index] = participant
else:
self.participants.append(participant)
return participant
def delete_participant(self, participant_id: str):
index = next(
(i for i, p in enumerate(self.participants) if p.id == participant_id),
None,
)
if index is not None:
del self.participants[index]
def events_dump(self, mode="json"): def events_dump(self, mode="json"):
return [event.model_dump(mode=mode) for event in self.events] return [event.model_dump(mode=mode) for event in self.events]
def topics_dump(self, mode="json"): def topics_dump(self, mode="json"):
return [topic.model_dump(mode=mode) for topic in self.topics] return [topic.model_dump(mode=mode) for topic in self.topics]
def convert_audio_to_waveform(self, segments_count=256): def participants_dump(self, mode="json"):
fn = self.audio_waveform_filename return [participant.model_dump(mode=mode) for participant in self.participants]
if fn.exists():
return
waveform = get_audio_waveform(
path=self.audio_mp3_filename, segments_count=segments_count
)
try:
with open(fn, "w") as fd:
json.dump(waveform, fd)
except Exception:
# remove file if anything happen during the write
fn.unlink(missing_ok=True)
raise
return waveform
def unlink(self): def unlink(self):
self.data_path.unlink(missing_ok=True) self.data_path.unlink(missing_ok=True)
@@ -141,6 +186,10 @@ class Transcript(BaseModel):
def data_path(self): def data_path(self):
return Path(settings.DATA_DIR) / self.id return Path(settings.DATA_DIR) / self.id
@property
def audio_wav_filename(self):
return self.data_path / "audio.wav"
@property @property
def audio_mp3_filename(self): def audio_mp3_filename(self):
return self.data_path / "audio.mp3" return self.data_path / "audio.mp3"
@@ -149,6 +198,10 @@ class Transcript(BaseModel):
def audio_waveform_filename(self): def audio_waveform_filename(self):
return self.data_path / "audio.json" return self.data_path / "audio.json"
@property
def storage_audio_path(self):
return f"{self.id}/audio.mp3"
@property @property
def audio_waveform(self): def audio_waveform(self):
try: try:
@@ -161,6 +214,40 @@ class Transcript(BaseModel):
return AudioWaveform(data=data) return AudioWaveform(data=data)
async def get_audio_url(self) -> str:
if self.audio_location == "local":
return self._generate_local_audio_link()
elif self.audio_location == "storage":
return await self._generate_storage_audio_link()
raise Exception(f"Unknown audio location {self.audio_location}")
async def _generate_storage_audio_link(self) -> str:
return await get_storage().get_file_url(self.storage_audio_path)
def _generate_local_audio_link(self) -> str:
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from datetime import timedelta
from reflector.app import app
from reflector.views.transcripts import create_access_token
path = app.url_path_for(
"transcript_get_audio_mp3",
transcript_id=self.id,
)
url = f"{settings.BASE_URL}{path}"
if self.user_id:
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": self.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
return url
class TranscriptController: class TranscriptController:
async def get_all( async def get_all(
@@ -169,6 +256,7 @@ class TranscriptController:
order_by: str | None = None, order_by: str | None = None,
filter_empty: bool | None = False, filter_empty: bool | None = False,
filter_recording: bool | None = False, filter_recording: bool | None = False,
return_query: bool = False,
) -> list[Transcript]: ) -> list[Transcript]:
""" """
Get all transcripts Get all transcripts
@@ -195,6 +283,9 @@ class TranscriptController:
if filter_recording: if filter_recording:
query = query.filter(transcripts.c.status != "recording") query = query.filter(transcripts.c.status != "recording")
if return_query:
return query
results = await database.fetch_all(query) results = await database.fetch_all(query)
return results return results
@@ -210,6 +301,47 @@ class TranscriptController:
return None return None
return Transcript(**result) return Transcript(**result)
async def get_by_id_for_http(
self,
transcript_id: str,
user_id: str | None,
) -> Transcript:
"""
Get a transcript by ID for HTTP request.
If not found, it will raise a 404 error.
If the user is not allowed to access the transcript, it will raise a 403 error.
This method checks the share mode of the transcript and the user_id
to determine if the user can access the transcript.
"""
query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await database.fetch_one(query)
if not result:
raise HTTPException(status_code=404, detail="Transcript not found")
# if the transcript is anonymous, share mode is not checked
transcript = Transcript(**result)
if transcript.user_id is None:
return transcript
if transcript.share_mode == "private":
# in private mode, only the owner can access the transcript
if transcript.user_id == user_id:
return transcript
elif transcript.share_mode == "semi-private":
# in semi-private mode, only the owner and the users with the link
# can access the transcript
if user_id is not None:
return transcript
elif transcript.share_mode == "public":
# in public mode, everyone can access the transcript
return transcript
raise HTTPException(status_code=403, detail="Transcript access denied")
async def add( async def add(
self, self,
name: str, name: str,
@@ -292,5 +424,45 @@ class TranscriptController:
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await self.update(transcript, {"topics": transcript.topics_dump()}) await self.update(transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, transcript: Transcript):
"""
Move mp3 file to storage
"""
# store the audio on external storage
await get_storage().put_file(
transcript.storage_audio_path,
transcript.audio_mp3_filename.read_bytes(),
)
# indicate on the transcript that the audio is now on storage
await self.update(transcript, {"audio_location": "storage"})
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
async def upsert_participant(
self,
transcript: Transcript,
participant: TranscriptParticipant,
) -> TranscriptParticipant:
"""
Add/update a participant to a transcript
"""
result = transcript.upsert_participant(participant)
await self.update(transcript, {"participants": transcript.participants_dump()})
return result
async def delete_participant(
self,
transcript: Transcript,
participant_id: str,
):
"""
Delete a participant from a transcript
"""
transcript.delete_participant(participant_id)
await self.update(transcript, {"participants": transcript.participants_dump()})
transcripts_controller = TranscriptController() transcripts_controller = TranscriptController()

View File

@@ -12,20 +12,20 @@ It is directly linked to our data model.
""" """
import asyncio import asyncio
import functools
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import timedelta
from pathlib import Path
from celery import shared_task from celery import chord, group, shared_task
from pydantic import BaseModel from pydantic import BaseModel
from reflector.app import app
from reflector.db.transcripts import ( from reflector.db.transcripts import (
Transcript, Transcript,
TranscriptDuration,
TranscriptFinalLongSummary, TranscriptFinalLongSummary,
TranscriptFinalShortSummary, TranscriptFinalShortSummary,
TranscriptFinalTitle, TranscriptFinalTitle,
TranscriptText, TranscriptText,
TranscriptTopic, TranscriptTopic,
TranscriptWaveform,
transcripts_controller, transcripts_controller,
) )
from reflector.logger import logger from reflector.logger import logger
@@ -45,6 +45,7 @@ from reflector.processors import (
TranscriptTopicDetectorProcessor, TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor, TranscriptTranslatorProcessor,
) )
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import AudioDiarizationInput from reflector.processors.types import AudioDiarizationInput
from reflector.processors.types import ( from reflector.processors.types import (
TitleSummaryWithId as TitleSummaryWithIdProcessorType, TitleSummaryWithId as TitleSummaryWithIdProcessorType,
@@ -52,6 +53,22 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings from reflector.settings import settings
from reflector.ws_manager import WebsocketManager, get_ws_manager from reflector.ws_manager import WebsocketManager, get_ws_manager
from structlog import BoundLogger as Logger
def asynctask(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
coro = f(*args, **kwargs)
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(coro)
return asyncio.run(coro)
return wrapper
def broadcast_to_sockets(func): def broadcast_to_sockets(func):
@@ -72,6 +89,26 @@ def broadcast_to_sockets(func):
return wrapper return wrapper
def get_transcript(func):
"""
Decorator to fetch the transcript from the database from the first argument
"""
async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
if not transcript:
raise Exception("Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id)
try:
return await func(transcript=transcript, logger=tlogger, **kwargs)
except Exception as exc:
tlogger.error("Pipeline error", exc_info=exc)
raise
return wrapper
class StrValue(BaseModel): class StrValue(BaseModel):
value: str value: str
@@ -96,6 +133,19 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found") raise Exception("Transcript not found")
return result return result
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
return [
TitleSummaryWithIdProcessorType(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
]
@asynccontextmanager @asynccontextmanager
async def transaction(self): async def transaction(self):
async with self._lock: async with self._lock:
@@ -113,7 +163,7 @@ class PipelineMainBase(PipelineRunner):
"flush": "processing", "flush": "processing",
"error": "error", "error": "error",
} }
elif isinstance(self, PipelineMainDiarization): elif isinstance(self, PipelineMainFinalSummaries):
status_mapping = { status_mapping = {
"push": "processing", "push": "processing",
"flush": "processing", "flush": "processing",
@@ -121,7 +171,8 @@ class PipelineMainBase(PipelineRunner):
"ended": "ended", "ended": "ended",
} }
else: else:
raise Exception(f"Runner {self.__class__} is missing status mapping") # intermediate pipeline don't update status
return
# mutate to model status # mutate to model status
status = status_mapping.get(status) status = status_mapping.get(status)
@@ -230,21 +281,39 @@ class PipelineMainBase(PipelineRunner):
data=final_short_summary, data=final_short_summary,
) )
async def on_duration(self, duration: float): @broadcast_to_sockets
async def on_duration(self, data):
async with self.transaction(): async with self.transaction():
duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript() transcript = await self.get_transcript()
await transcripts_controller.update( await transcripts_controller.update(
transcript, transcript,
{ {
"duration": duration, "duration": duration.duration,
}, },
) )
return await transcripts_controller.append_event(
transcript=transcript, event="DURATION", data=duration
)
@broadcast_to_sockets
async def on_waveform(self, data):
async with self.transaction():
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform
)
class PipelineMainLive(PipelineMainBase): class PipelineMainLive(PipelineMainBase):
audio_filename: Path | None = None """
source_language: str = "en" Main pipeline for live streaming, attach to RTC connection
target_language: str = "en" Any long post process should be done in the post pipeline
"""
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
@@ -254,7 +323,7 @@ class PipelineMainLive(PipelineMainBase):
processors = [ processors = [
AudioFileWriterProcessor( AudioFileWriterProcessor(
path=transcript.audio_mp3_filename, path=transcript.audio_wav_filename,
on_duration=self.on_duration, on_duration=self.on_duration,
), ),
AudioChunkerProcessor(), AudioChunkerProcessor(),
@@ -263,17 +332,13 @@ class PipelineMainLive(PipelineMainBase):
TranscriptLinerProcessor(), TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
] ]
pipeline = Pipeline(*processors) pipeline = Pipeline(*processors)
pipeline.options = self pipeline.options = self
pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:source_language", transcript.source_language)
pipeline.set_pref("audio:target_language", transcript.target_language) pipeline.set_pref("audio:target_language", transcript.target_language)
pipeline.logger.bind(transcript_id=transcript.id) pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info( pipeline.logger.info("Pipeline main live created")
"Pipeline main live created",
transcript_id=self.transcript_id,
)
return pipeline return pipeline
@@ -281,26 +346,106 @@ class PipelineMainLive(PipelineMainBase):
# when the pipeline ends, connect to the post pipeline # when the pipeline ends, connect to the post pipeline
logger.info("Pipeline main live ended", transcript_id=self.transcript_id) logger.info("Pipeline main live ended", transcript_id=self.transcript_id)
logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id)
task_pipeline_main_post.delay(transcript_id=self.transcript_id) pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase): class PipelineMainDiarization(PipelineMainBase):
""" """
Diarization is a long time process, so we do it in a separate pipeline Diarize the audio and update topics
When done, adjust the short and final summary
""" """
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
# add a customised logger to the context # add a customised logger to the context
self.prepare() self.prepare()
processors = [] pipeline = Pipeline(
if settings.DIARIZATION_ENABLED: AudioDiarizationAutoProcessor(callback=self.on_topic),
processors += [ )
AudioDiarizationAutoProcessor(callback=self.on_topic), pipeline.options = self
]
processors += [ # now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
transcript = await self.get_transcript()
# diarization works only if the file is uploaded to an external storage
if transcript.audio_location == "local":
pipeline.logger.info("Audio is local, skipping diarization")
return
topics = self.get_transcript_topics(transcript)
audio_url = await transcript.get_audio_url()
audio_diarization_input = AudioDiarizationInput(
audio_url=audio_url,
topics=topics,
)
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info("Diarization pipeline created")
self.push(audio_diarization_input)
self.flush()
return pipeline
class PipelineMainFromTopics(PipelineMainBase):
"""
Pseudo class for generating a pipeline from topics
"""
def get_processors(self) -> list:
raise NotImplementedError
async def create(self) -> Pipeline:
self.prepare()
# get transcript
self._transcript = transcript = await self.get_transcript()
# create pipeline
processors = self.get_processors()
pipeline = Pipeline(*processors)
pipeline.options = self
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
# push topics
topics = self.get_transcript_topics(transcript)
for topic in topics:
self.push(topic)
self.flush()
return pipeline
class PipelineMainTitleAndShortSummary(PipelineMainFromTopics):
"""
Generate title from the topics
"""
def get_processors(self) -> list:
return [
BroadcastProcessor(
processors=[
TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title),
TranscriptFinalShortSummaryProcessor.as_threaded(
callback=self.on_short_summary
),
]
)
]
class PipelineMainFinalSummaries(PipelineMainFromTopics):
"""
Generate summaries from the topics
"""
def get_processors(self) -> list:
return [
BroadcastProcessor( BroadcastProcessor(
processors=[ processors=[
TranscriptFinalLongSummaryProcessor.as_threaded( TranscriptFinalLongSummaryProcessor.as_threaded(
@@ -312,65 +457,164 @@ class PipelineMainDiarization(PipelineMainBase):
] ]
), ),
] ]
pipeline = Pipeline(*processors)
pipeline.options = self
# now let's start the pipeline by pushing information to the
# first processor diarization processor class PipelineMainWaveform(PipelineMainFromTopics):
# XXX translation is lost when converting our data model to the processor model """
transcript = await self.get_transcript() Generate waveform
topics = [ """
TitleSummaryWithIdProcessorType(
id=topic.id, def get_processors(self) -> list:
title=topic.title, return [
summary=topic.summary, AudioWaveformProcessor.as_threaded(
timestamp=topic.timestamp, audio_path=self._transcript.audio_wav_filename,
duration=topic.duration, waveform_path=self._transcript.audio_waveform_filename,
transcript=TranscriptProcessorType(words=topic.words), on_waveform=self.on_waveform,
) ),
for topic in transcript.topics
] ]
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from reflector.views.transcripts import create_access_token
path = app.url_path_for( @get_transcript
"transcript_get_audio_mp3", async def pipeline_waveform(transcript: Transcript, logger: Logger):
transcript_id=transcript.id, logger.info("Starting waveform")
) runner = PipelineMainWaveform(transcript_id=transcript.id)
url = f"{settings.BASE_URL}{path}" await runner.run()
if transcript.user_id: logger.info("Waveform done")
# we pass token only if the user_id is set
# otherwise, the audio is public
token = create_access_token(
{"sub": transcript.user_id},
expires_delta=timedelta(minutes=15),
)
url += f"?token={token}"
audio_diarization_input = AudioDiarizationInput(
audio_url=url,
topics=topics,
)
# as tempting to use pipeline.push, prefer to use the runner
# to let the start just do one job.
pipeline.logger.bind(transcript_id=transcript.id)
pipeline.logger.info(
"Pipeline main post created", transcript_id=self.transcript_id
)
self.push(audio_diarization_input)
self.flush()
return pipeline @get_transcript
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting convert to mp3")
# If the audio wav is not available, just skip
wav_filename = transcript.audio_wav_filename
if not wav_filename.exists():
logger.warning("Wav file not found, may be already converted")
return
# Convert to mp3
mp3_filename = transcript.audio_mp3_filename
import av
with av.open(wav_filename.as_posix()) as in_container:
in_stream = in_container.streams.audio[0]
with av.open(mp3_filename.as_posix(), "w") as out_container:
out_stream = out_container.add_stream("mp3")
for frame in in_container.decode(in_stream):
for packet in out_stream.encode(frame):
out_container.mux(packet)
# Delete the wav file
transcript.audio_wav_filename.unlink(missing_ok=True)
logger.info("Convert to mp3 done")
@get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
if not settings.TRANSCRIPT_STORAGE_BACKEND:
logger.info("No storage backend configured, skipping mp3 upload")
return
logger.info("Starting upload mp3")
# If the audio mp3 is not available, just skip
mp3_filename = transcript.audio_mp3_filename
if not mp3_filename.exists():
logger.warning("Mp3 file not found, may be already uploaded")
return
# Upload to external storage and delete the file
await transcripts_controller.move_mp3_to_storage(transcript)
logger.info("Upload mp3 done")
@get_transcript
async def pipeline_diarization(transcript: Transcript, logger: Logger):
logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id)
await runner.run()
logger.info("Diarization done")
@get_transcript
async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger):
logger.info("Starting title and short summary")
runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id)
await runner.run()
logger.info("Title and short summary done")
@get_transcript
async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Starting summaries")
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
await runner.run()
logger.info("Summaries done")
# ===================================================================
# Celery tasks that can be called from the API
# ===================================================================
@shared_task @shared_task
def task_pipeline_main_post(transcript_id: str): @asynctask
logger.info( async def task_pipeline_waveform(*, transcript_id: str):
"Starting main post pipeline", await pipeline_waveform(transcript_id=transcript_id)
transcript_id=transcript_id,
@shared_task
@asynctask
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
await pipeline_convert_to_mp3(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_upload_mp3(*, transcript_id: str):
await pipeline_upload_mp3(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_diarization(*, transcript_id: str):
await pipeline_diarization(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_title_and_short_summary(*, transcript_id: str):
await pipeline_title_and_short_summary(transcript_id=transcript_id)
@shared_task
@asynctask
async def task_pipeline_final_summaries(*, transcript_id: str):
await pipeline_summaries(transcript_id=transcript_id)
def pipeline_post(*, transcript_id: str):
"""
Run the post pipeline
"""
chain_mp3_and_diarize = (
task_pipeline_waveform.si(transcript_id=transcript_id)
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id)
) )
runner = PipelineMainDiarization(transcript_id=transcript_id) chain_title_preview = task_pipeline_title_and_short_summary.si(
runner.start_sync() transcript_id=transcript_id
)
chain_final_summaries = task_pipeline_final_summaries.si(
transcript_id=transcript_id
)
chain = chord(
group(chain_mp3_and_diarize, chain_title_preview),
chain_final_summaries,
)
chain.delay()

View File

@@ -106,6 +106,14 @@ class PipelineRunner(BaseModel):
if not self.pipeline: if not self.pipeline:
self.pipeline = await self.create() self.pipeline = await self.create()
if not self.pipeline:
# no pipeline created in create, just finish it then.
await self._set_status("ended")
self._ev_done.set()
if self.on_ended:
await self.on_ended()
return
# start the loop # start the loop
await self._set_status("started") await self._set_status("started")
while not self._ev_done.is_set(): while not self._ev_done.is_set():
@@ -119,8 +127,7 @@ class PipelineRunner(BaseModel):
self._logger.exception("Runner error") self._logger.exception("Runner error")
await self._set_status("error") await self._set_status("error")
self._ev_done.set() self._ev_done.set()
if self.on_ended: raise
await self.on_ended()
async def cmd_push(self, data): async def cmd_push(self, data):
if self._is_first_push: if self._is_first_push:

View File

@@ -1,5 +1,5 @@
from reflector.processors.base import Processor from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary from reflector.processors.types import AudioDiarizationInput, TitleSummary, Word
class AudioDiarizationProcessor(Processor): class AudioDiarizationProcessor(Processor):
@@ -19,12 +19,12 @@ class AudioDiarizationProcessor(Processor):
# topics is a list[BaseModel] with an attribute words # topics is a list[BaseModel] with an attribute words
# words is a list[BaseModel] with text, start and speaker attribute # words is a list[BaseModel] with text, start and speaker attribute
# mutate in place # create a view of words based on topics
for topic in data.topics: # the current algorithm is using words index, we cannot use a generator
for word in topic.transcript.words: words = list(self.iter_words_from_topics(data.topics))
for d in diarization:
if d["start"] <= word.start <= d["end"]: # assign speaker to words (mutate the words list)
word.speaker = d["speaker"] self.assign_speaker(words, diarization)
# emit them # emit them
for topic in data.topics: for topic in data.topics:
@@ -32,3 +32,150 @@ class AudioDiarizationProcessor(Processor):
async def _diarize(self, data: AudioDiarizationInput): async def _diarize(self, data: AudioDiarizationInput):
raise NotImplementedError raise NotImplementedError
def assign_speaker(self, words: list[Word], diarization: list[dict]):
self._diarization_remove_overlap(diarization)
self._diarization_remove_segment_without_words(words, diarization)
self._diarization_merge_same_speaker(words, diarization)
self._diarization_assign_speaker(words, diarization)
def iter_words_from_topics(self, topics: TitleSummary):
for topic in topics:
for word in topic.transcript.words:
yield word
def is_word_continuation(self, word_prev, word):
"""
Return True if the word is a continuation of the previous word
by checking if the previous word is ending with a punctuation
or if the current word is starting with a capital letter
"""
# is word_prev ending with a punctuation ?
if word_prev.text and word_prev.text[-1] in ".?!":
return False
elif word.text and word.text[0].isupper():
return False
return True
def _diarization_remove_overlap(self, diarization: list[dict]):
"""
Remove overlap in diarization results
When using a diarization algorithm, it's possible to have overlapping segments
This function remove the overlap by keeping the longest segment
Warning: this function mutate the diarization list
"""
# remove overlap by keeping the longest segment
diarization_idx = 0
while diarization_idx < len(diarization) - 1:
d = diarization[diarization_idx]
dnext = diarization[diarization_idx + 1]
if d["end"] > dnext["start"]:
# remove the shortest segment
if d["end"] - d["start"] > dnext["end"] - dnext["start"]:
# remove next segment
diarization.pop(diarization_idx + 1)
else:
# remove current segment
diarization.pop(diarization_idx)
else:
diarization_idx += 1
def _diarization_remove_segment_without_words(
self, words: list[Word], diarization: list[dict]
):
"""
Remove diarization segments without words
Warning: this function mutate the diarization list
"""
# count the number of words for each diarization segment
diarization_count = []
for d in diarization:
start = d["start"]
end = d["end"]
count = 0
for word in words:
if start <= word.start < end:
count += 1
elif start < word.end <= end:
count += 1
diarization_count.append(count)
# remove diarization segments with no words
diarization_idx = 0
while diarization_idx < len(diarization):
if diarization_count[diarization_idx] == 0:
diarization.pop(diarization_idx)
diarization_count.pop(diarization_idx)
else:
diarization_idx += 1
def _diarization_merge_same_speaker(
self, words: list[Word], diarization: list[dict]
):
"""
Merge diarization contigous segments with the same speaker
Warning: this function mutate the diarization list
"""
# merge segment with same speaker
diarization_idx = 0
while diarization_idx < len(diarization) - 1:
d = diarization[diarization_idx]
dnext = diarization[diarization_idx + 1]
if d["speaker"] == dnext["speaker"]:
diarization[diarization_idx]["end"] = dnext["end"]
diarization.pop(diarization_idx + 1)
else:
diarization_idx += 1
def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]):
"""
Assign speaker to words based on diarization
Warning: this function mutate the words list
"""
word_idx = 0
last_speaker = None
for d in diarization:
start = d["start"]
end = d["end"]
speaker = d["speaker"]
# diarization may start after the first set of words
# in this case, we assign the last speaker
for word in words[word_idx:]:
if word.start < start:
# speaker change, but what make sense for assigning the word ?
# If it's a new sentence, assign with the new speaker
# If it's a continuation, assign with the last speaker
is_continuation = False
if word_idx > 0 and word_idx < len(words) - 1:
is_continuation = self.is_word_continuation(
*words[word_idx - 1 : word_idx + 1]
)
if is_continuation:
word.speaker = last_speaker
else:
word.speaker = speaker
last_speaker = speaker
word_idx += 1
else:
break
# now continue to assign speaker until the word starts after the end
for word in words[word_idx:]:
if start <= word.start < end:
last_speaker = speaker
word.speaker = speaker
word_idx += 1
elif word.start > end:
break
# no more diarization available,
# assign last speaker to all words without speaker
for word in words[word_idx:]:
word.speaker = last_speaker

View File

@@ -31,7 +31,7 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
follow_redirects=True, follow_redirects=True,
) )
response.raise_for_status() response.raise_for_status()
return response.json()["text"] return response.json()["diarization"]
AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor) AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor)

View File

@@ -0,0 +1,36 @@
import json
from pathlib import Path
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary
from reflector.utils.audio_waveform import get_audio_waveform
class AudioWaveformProcessor(Processor):
"""
Write the waveform for the final audio
"""
INPUT_TYPE = TitleSummary
def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs):
super().__init__(**kwargs)
if isinstance(audio_path, str):
audio_path = Path(audio_path)
if audio_path.suffix not in (".mp3", ".wav"):
raise ValueError("Only mp3 and wav files are supported")
self.audio_path = audio_path
self.waveform_path = waveform_path
async def _flush(self):
self.waveform_path.parent.mkdir(parents=True, exist_ok=True)
self.logger.info("Waveform Processing Started")
waveform = get_audio_waveform(path=self.audio_path, segments_count=255)
with open(self.waveform_path, "w") as fd:
json.dump(waveform, fd)
self.logger.info("Waveform Processing Finished")
await self.emit(waveform, name="waveform")
async def _push(_self, _data):
return

View File

@@ -54,7 +54,7 @@ class Settings(BaseSettings):
TRANSCRIPT_MODAL_API_KEY: str | None = None TRANSCRIPT_MODAL_API_KEY: str | None = None
# Audio transcription storage # Audio transcription storage
TRANSCRIPT_STORAGE_BACKEND: str = "aws" TRANSCRIPT_STORAGE_BACKEND: str | None = None
# Storage configuration for AWS # Storage configuration for AWS
TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket" TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket"
@@ -62,9 +62,6 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Transcript MP3 storage
TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws"
# LLM # LLM
# available backend: openai, modal, oobabooga # available backend: openai, modal, oobabooga
LLM_BACKEND: str = "oobabooga" LLM_BACKEND: str = "oobabooga"
@@ -131,5 +128,8 @@ class Settings(BaseSettings):
# Profiling # Profiling
PROFILING: bool = False PROFILING: bool = False
# Healthcheck
HEALTHCHECK_URL: str | None = None
settings = Settings() settings = Settings()

View File

@@ -1,6 +1,7 @@
import importlib
from pydantic import BaseModel from pydantic import BaseModel
from reflector.settings import settings from reflector.settings import settings
import importlib
class FileResult(BaseModel): class FileResult(BaseModel):
@@ -17,7 +18,7 @@ class Storage:
cls._registry[name] = kclass cls._registry[name] = kclass
@classmethod @classmethod
def get_instance(cls, name, settings_prefix=""): def get_instance(cls, name: str, settings_prefix: str = ""):
if name not in cls._registry: if name not in cls._registry:
module_name = f"reflector.storage.storage_{name}" module_name = f"reflector.storage.storage_{name}"
importlib.import_module(module_name) importlib.import_module(module_name)
@@ -45,3 +46,9 @@ class Storage:
async def _delete_file(self, filename: str): async def _delete_file(self, filename: str):
raise NotImplementedError raise NotImplementedError
async def get_file_url(self, filename: str) -> str:
return await self._get_file_url(filename)
async def _get_file_url(self, filename: str) -> str:
raise NotImplementedError

View File

@@ -1,6 +1,6 @@
import aioboto3 import aioboto3
from reflector.storage.base import Storage, FileResult
from reflector.logger import logger from reflector.logger import logger
from reflector.storage.base import FileResult, Storage
class AwsStorage(Storage): class AwsStorage(Storage):
@@ -44,16 +44,18 @@ class AwsStorage(Storage):
Body=data, Body=data,
) )
async def _get_file_url(self, filename: str) -> FileResult:
bucket = self.aws_bucket_name
folder = self.aws_folder
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client:
presigned_url = await client.generate_presigned_url( presigned_url = await client.generate_presigned_url(
"get_object", "get_object",
Params={"Bucket": bucket, "Key": s3filename}, Params={"Bucket": bucket, "Key": s3filename},
ExpiresIn=3600, ExpiresIn=3600,
) )
return FileResult( return presigned_url
filename=filename,
url=presigned_url,
)
async def _delete_file(self, filename: str): async def _delete_file(self, filename: str):
bucket = self.aws_bucket_name bucket = self.aws_bucket_name

View File

@@ -1,31 +1,19 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Annotated, Optional from typing import Annotated, Literal, Optional
import reflector.auth as auth import reflector.auth as auth
from fastapi import ( from fastapi import APIRouter, Depends, HTTPException
APIRouter, from fastapi_pagination import Page
Depends, from fastapi_pagination.ext.databases import paginate
HTTPException,
Request,
WebSocket,
WebSocketDisconnect,
status,
)
from fastapi_pagination import Page, paginate
from jose import jwt from jose import jwt
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from reflector.db.transcripts import ( from reflector.db.transcripts import (
AudioWaveform, TranscriptParticipant,
TranscriptTopic, TranscriptTopic,
transcripts_controller, transcripts_controller,
) )
from reflector.processors.types import Transcript as ProcessorTranscript from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings from reflector.settings import settings
from reflector.ws_manager import get_ws_manager
from starlette.concurrency import run_in_threadpool
from ._range_requests_response import range_requests_response
from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter() router = APIRouter()
@@ -48,6 +36,7 @@ def create_access_token(data: dict, expires_delta: timedelta):
class GetTranscript(BaseModel): class GetTranscript(BaseModel):
id: str id: str
user_id: str | None
name: str name: str
status: str status: str
locked: bool locked: bool
@@ -56,8 +45,10 @@ class GetTranscript(BaseModel):
short_summary: str | None short_summary: str | None
long_summary: str | None long_summary: str | None
created_at: datetime created_at: datetime
share_mode: str = Field("private")
source_language: str | None source_language: str | None
target_language: str | None target_language: str | None
participants: list[TranscriptParticipant] | None
class CreateTranscript(BaseModel): class CreateTranscript(BaseModel):
@@ -72,6 +63,8 @@ class UpdateTranscript(BaseModel):
title: Optional[str] = Field(None) title: Optional[str] = Field(None)
short_summary: Optional[str] = Field(None) short_summary: Optional[str] = Field(None)
long_summary: Optional[str] = Field(None) long_summary: Optional[str] = Field(None)
share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None)
participants: Optional[list[TranscriptParticipant]] = Field(None)
class DeletionStatus(BaseModel): class DeletionStatus(BaseModel):
@@ -82,12 +75,19 @@ class DeletionStatus(BaseModel):
async def transcripts_list( async def transcripts_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
): ):
from reflector.db import database
if not user and not settings.PUBLIC_MODE: if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated") raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
return paginate( return await paginate(
await transcripts_controller.get_all(user_id=user_id, order_by="-created_at") database,
await transcripts_controller.get_all(
user_id=user_id,
order_by="-created_at",
return_query=True,
),
) )
@@ -165,10 +165,9 @@ async def transcript_get(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) return await transcripts_controller.get_by_id_for_http(
if not transcript: transcript_id, user_id=user_id
raise HTTPException(status_code=404, detail="Transcript not found") )
return transcript
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript) @router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
@@ -181,17 +180,7 @@ async def transcript_update(
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
values = {} values = info.dict(exclude_unset=True)
if info.name is not None:
values["name"] = info.name
if info.locked is not None:
values["locked"] = info.locked
if info.long_summary is not None:
values["long_summary"] = info.long_summary
if info.short_summary is not None:
values["short_summary"] = info.short_summary
if info.title is not None:
values["title"] = info.title
await transcripts_controller.update(transcript, values) await transcripts_controller.update(transcript, values)
return transcript return transcript
@@ -209,63 +198,6 @@ async def transcript_delete(
return DeletionStatus(status="ok") return DeletionStatus(status="ok")
@router.get("/transcripts/{transcript_id}/audio/mp3")
@router.head("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
token: str | None = None,
):
user_id = user["sub"] if user else None
if not user_id and token:
unauthorized_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
truncated_id = str(transcript.id).split("-")[0]
filename = f"recording_{truncated_id}.mp3"
return range_requests_response(
request,
transcript.audio_mp3_filename,
content_type="audio/mpeg",
content_disposition=f"attachment; filename={filename}",
)
@router.get("/transcripts/{transcript_id}/audio/waveform")
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform)
return transcript.audio_waveform
@router.get( @router.get(
"/transcripts/{transcript_id}/topics", "/transcripts/{transcript_id}/topics",
response_model=list[GetTranscriptTopic], response_model=list[GetTranscriptTopic],
@@ -275,92 +207,11 @@ async def transcript_get_topics(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) transcript = await transcripts_controller.get_by_id_for_http(
if not transcript: transcript_id, user_id=user_id
raise HTTPException(status_code=404, detail="Transcript not found") )
# convert to GetTranscriptTopic # convert to GetTranscriptTopic
return [ return [
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
] ]
# ==============================================================
# Websocket
# ==============================================================
@router.get("/transcripts/{transcript_id}/events")
async def transcript_get_websocket_events(transcript_id: str):
pass
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
transcript_id: str,
websocket: WebSocket,
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
# user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# connect to websocket manager
# use ts:transcript_id as room id
room_id = f"ts:{transcript_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(room_id, websocket)
try:
# on first connection, send all events only to the current user
for event in transcript.events:
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name in ("TRANSCRIPT", "STATUS"):
continue
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
# we do not have command system now,
while True:
await websocket.receive()
except (RuntimeError, WebSocketDisconnect):
await ws_manager.remove_user_from_room(room_id, websocket)
# ==============================================================
# Web RTC
# ==============================================================
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str,
params: RtcOffer,
request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# create a pipeline runner
from reflector.pipelines.main_live_pipeline import PipelineMainLive
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
# FIXME do not allow multiple recording at the same time
return await rtc_offer_base(
params,
request,
pipeline_runner=pipeline_runner,
)

View File

@@ -0,0 +1,109 @@
"""
Transcripts audio related endpoints
===================================
"""
from typing import Annotated, Optional
import httpx
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from jose import jwt
from reflector.db.transcripts import AudioWaveform, transcripts_controller
from reflector.settings import settings
from reflector.views.transcripts import ALGORITHM
from ._range_requests_response import range_requests_response
router = APIRouter()
@router.get("/transcripts/{transcript_id}/audio/mp3")
@router.head("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
token: str | None = None,
):
user_id = user["sub"] if user else None
if not user_id and token:
unauthorized_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=500, detail="Audio not found")
truncated_id = str(transcript.id).split("-")[0]
filename = f"recording_{truncated_id}.mp3"
return range_requests_response(
request,
transcript.audio_mp3_filename,
content_type="audio/mpeg",
content_disposition=f"attachment; filename={filename}",
)
@router.get("/transcripts/{transcript_id}/audio/waveform")
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript.audio_waveform_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return transcript.audio_waveform

View File

@@ -0,0 +1,142 @@
"""
Transcript participants API endpoints
=====================================
"""
from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
from reflector.views.types import DeletionStatus
router = APIRouter()
class Participant(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
speaker: int | None
name: str
class CreateParticipant(BaseModel):
speaker: Optional[int] = Field(None)
name: str
class UpdateParticipant(BaseModel):
speaker: Optional[int] = Field(None)
name: Optional[str] = Field(None)
@router.get("/transcripts/{transcript_id}/participants")
async def transcript_get_participants(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> list[Participant]:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
return [
Participant.model_validate(participant)
for participant in transcript.participants
]
@router.post("/transcripts/{transcript_id}/participants")
async def transcript_add_participant(
transcript_id: str,
participant: CreateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# ensure the speaker is unique
for p in transcript.participants:
if p.speaker == participant.speaker:
raise HTTPException(
status_code=400,
detail="Speaker already assigned",
)
obj = await transcripts_controller.upsert_participant(
transcript, TranscriptParticipant(**participant.dict())
)
return Participant.model_validate(obj)
@router.get("/transcripts/{transcript_id}/participants/{participant_id}")
async def transcript_get_participant(
transcript_id: str,
participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
for p in transcript.participants:
if p.id == participant_id:
return Participant.model_validate(p)
raise HTTPException(status_code=404, detail="Participant not found")
@router.patch("/transcripts/{transcript_id}/participants/{participant_id}")
async def transcript_update_participant(
transcript_id: str,
participant_id: str,
participant: UpdateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# ensure the speaker is unique
for p in transcript.participants:
if p.speaker == participant.speaker and p.id != participant_id:
raise HTTPException(
status_code=400,
detail="Speaker already assigned",
)
# find the participant
obj = None
for p in transcript.participants:
if p.id == participant_id:
obj = p
break
if not obj:
raise HTTPException(status_code=404, detail="Participant not found")
# update participant but just the fields that are set
fields = participant.dict(exclude_unset=True)
obj = obj.copy(update=fields)
await transcripts_controller.upsert_participant(transcript, obj)
return Participant.model_validate(obj)
@router.delete("/transcripts/{transcript_id}/participants/{participant_id}")
async def transcript_delete_participant(
transcript_id: str,
participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> DeletionStatus:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
await transcripts_controller.delete_participant(transcript, participant_id)
return DeletionStatus(status="ok")

View File

@@ -0,0 +1,37 @@
from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import APIRouter, Depends, HTTPException, Request
from reflector.db.transcripts import transcripts_controller
from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter()
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str,
params: RtcOffer,
request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# create a pipeline runner
from reflector.pipelines.main_live_pipeline import PipelineMainLive
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
# FIXME do not allow multiple recording at the same time
return await rtc_offer_base(
params,
request,
pipeline_runner=pipeline_runner,
)

View File

@@ -0,0 +1,53 @@
"""
Transcripts websocket API
=========================
"""
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
from reflector.db.transcripts import transcripts_controller
from reflector.ws_manager import get_ws_manager
router = APIRouter()
@router.get("/transcripts/{transcript_id}/events")
async def transcript_get_websocket_events(transcript_id: str):
pass
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
transcript_id: str,
websocket: WebSocket,
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
# user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# connect to websocket manager
# use ts:transcript_id as room id
room_id = f"ts:{transcript_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(room_id, websocket)
try:
# on first connection, send all events only to the current user
for event in transcript.events:
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name in ("TRANSCRIPT", "STATUS"):
continue
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
# we do not have command system now,
while True:
await websocket.receive()
except (RuntimeError, WebSocketDisconnect):
await ws_manager.remove_user_from_room(room_id, websocket)

View File

@@ -0,0 +1,5 @@
from pydantic import BaseModel
class DeletionStatus(BaseModel):
status: str

View File

@@ -1,6 +1,8 @@
import structlog
from celery import Celery from celery import Celery
from reflector.settings import settings from reflector.settings import settings
logger = structlog.get_logger(__name__)
app = Celery(__name__) app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND app.conf.result_backend = settings.CELERY_RESULT_BACKEND
@@ -8,5 +10,18 @@ app.conf.broker_connection_retry_on_startup = True
app.autodiscover_tasks( app.autodiscover_tasks(
[ [
"reflector.pipelines.main_live_pipeline", "reflector.pipelines.main_live_pipeline",
"reflector.worker.healthcheck",
] ]
) )
# crontab
app.conf.beat_schedule = {}
if settings.HEALTHCHECK_URL:
app.conf.beat_schedule["healthcheck_ping"] = {
"task": "reflector.worker.healthcheck.healthcheck_ping",
"schedule": 60.0 * 10,
}
logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
else:
logger.warning("Healthcheck disabled, no url configured")

View File

@@ -0,0 +1,18 @@
import httpx
import structlog
from celery import shared_task
from reflector.settings import settings
logger = structlog.get_logger(__name__)
@shared_task
def healthcheck_ping():
url = settings.HEALTHCHECK_URL
if not url:
return
try:
print("pinging healthcheck url", url)
httpx.get(url, timeout=10)
except Exception as e:
logger.error("healthcheck_ping", error=str(e))

View File

@@ -9,6 +9,8 @@ if [ "${ENTRYPOINT}" = "server" ]; then
python -m reflector.app python -m reflector.app
elif [ "${ENTRYPOINT}" = "worker" ]; then elif [ "${ENTRYPOINT}" = "worker" ]; then
celery -A reflector.worker.app worker --loglevel=info celery -A reflector.worker.app worker --loglevel=info
elif [ "${ENTRYPOINT}" = "beat" ]; then
celery -A reflector.worker.app beat --loglevel=info
else else
echo "Unknown command" echo "Unknown command"
fi fi

View File

@@ -1,4 +1,5 @@
from unittest.mock import patch from unittest.mock import patch
from tempfile import NamedTemporaryFile
import pytest import pytest
@@ -7,7 +8,6 @@ import pytest
@pytest.mark.asyncio @pytest.mark.asyncio
async def setup_database(): async def setup_database():
from reflector.settings import settings from reflector.settings import settings
from tempfile import NamedTemporaryFile
with NamedTemporaryFile() as f: with NamedTemporaryFile() as f:
settings.DATABASE_URL = f"sqlite:///{f.name}" settings.DATABASE_URL = f"sqlite:///{f.name}"
@@ -103,6 +103,25 @@ async def dummy_llm():
yield yield
@pytest.fixture
async def dummy_storage():
from reflector.storage.base import Storage
class DummyStorage(Storage):
async def _put_file(self, *args, **kwargs):
pass
async def _delete_file(self, *args, **kwargs):
pass
async def _get_file_url(self, *args, **kwargs):
return "http://fake_server/audio.mp3"
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
mock_storage.return_value = DummyStorage()
yield
@pytest.fixture @pytest.fixture
def nltk(): def nltk():
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk: with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
@@ -133,4 +152,17 @@ def celery_enable_logging():
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def celery_config(): def celery_config():
return {"broker_url": "memory://", "result_backend": "rpc"} with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
}
@pytest.fixture(scope="session")
def fake_mp3_upload():
with patch(
"reflector.db.transcripts.TranscriptController.move_mp3_to_storage"
) as mock_move:
mock_move.return_value = True
yield

View File

@@ -0,0 +1,140 @@
import pytest
from unittest import mock
@pytest.mark.parametrize(
"name,diarization,expected",
[
[
"no overlap",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 1.0, "end": 2.0, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"same speaker",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 1.0, "end": 2.0, "speaker": "A"},
],
["A", "A", "A", "A"],
],
[
# first segment is removed because it overlap
# with the second segment, and it is smaller
"overlap at 0.5s",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 0.5, "end": 2.0, "speaker": "B"},
],
["B", "B", "B", "B"],
],
[
"junk segment at 0.5s for 0.2s",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 0.5, "end": 0.7, "speaker": "B"},
{"start": 1, "end": 2.0, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"start without diarization",
[
{"start": 0.5, "end": 1.0, "speaker": "A"},
{"start": 1.0, "end": 2.0, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"end missing diarization",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 1.0, "end": 1.5, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"continuation of next speaker",
[
{"start": 0.0, "end": 0.9, "speaker": "A"},
{"start": 1.5, "end": 2.0, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"continuation of previous speaker",
[
{"start": 0.0, "end": 0.5, "speaker": "A"},
{"start": 1.0, "end": 2.0, "speaker": "B"},
],
["A", "A", "B", "B"],
],
[
"segment without words",
[
{"start": 0.0, "end": 1.0, "speaker": "A"},
{"start": 1.0, "end": 2.0, "speaker": "B"},
{"start": 2.0, "end": 3.0, "speaker": "X"},
],
["A", "A", "B", "B"],
],
],
)
@pytest.mark.asyncio
async def test_processors_audio_diarization(event_loop, name, diarization, expected):
from reflector.processors.audio_diarization import AudioDiarizationProcessor
from reflector.processors.types import (
TitleSummaryWithId,
Transcript,
Word,
AudioDiarizationInput,
)
# create fake topic
topics = [
TitleSummaryWithId(
id="1",
title="Title1",
summary="Summary1",
timestamp=0.0,
duration=1.0,
transcript=Transcript(
words=[
Word(text="Word1", start=0.0, end=0.5),
Word(text="word2.", start=0.5, end=1.0),
]
),
),
TitleSummaryWithId(
id="2",
title="Title2",
summary="Summary2",
timestamp=0.0,
duration=1.0,
transcript=Transcript(
words=[
Word(text="Word3", start=1.0, end=1.5),
Word(text="word4.", start=1.5, end=2.0),
]
),
),
]
diarizer = AudioDiarizationProcessor()
with mock.patch.object(diarizer, "_diarize") as mock_diarize:
mock_diarize.return_value = diarization
data = AudioDiarizationInput(
audio_url="https://example.com/audio.mp3",
topics=topics,
)
await diarizer._push(data)
# check that the speaker has been assigned to the words
assert topics[0].transcript.words[0].speaker == expected[0]
assert topics[0].transcript.words[1].speaker == expected[1]
assert topics[1].transcript.words[0].speaker == expected[2]
assert topics[1].transcript.words[1].speaker == expected[3]

View File

@@ -118,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek(
assert response.status_code == 206 assert response.status_code == 206
assert response.headers["content-type"] == content_type assert response.headers["content-type"] == content_type
assert response.headers["content-range"].startswith("bytes 100-") assert response.headers["content-range"].startswith("bytes 100-")
@pytest.mark.asyncio
async def test_transcript_audio_download_waveform(fake_transcript):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform")
assert response.status_code == 200
assert response.headers["content-type"] == "application/json"
assert isinstance(response.json()["data"], list)
assert len(response.json()["data"]) >= 255

View File

@@ -0,0 +1,164 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_participants():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
# create a participant
transcript_id = response.json()["id"]
response = await ac.post(
f"/transcripts/{transcript_id}/participants", json={"name": "test"}
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] is None
assert response.json()["name"] == "test"
# create another one with a speaker
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] == 1
assert response.json()["name"] == "test2"
# get all participants via transcript
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 2
# get participants via participants endpoint
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
assert len(response.json()) == 2
@pytest.mark.asyncio
async def test_transcript_participants_same_speaker():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# create another one with the same speaker
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_transcript_participants_update_name():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# update the participant
participant_id = response.json()["id"]
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant_id}",
json={"name": "test2"},
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated
response = await ac.get(
f"/transcripts/{transcript_id}/participants/{participant_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated in transcript
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 1
assert response.json()["participants"][0]["name"] == "test2"
@pytest.mark.asyncio
async def test_transcript_participants_update_speaker():
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
# create another participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 2},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# update the participant, refused as speaker is already taken
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 400
# delete the participant 1
response = await ac.delete(
f"/transcripts/{transcript_id}/participants/{participant1_id}"
)
assert response.status_code == 200
# update the participant 2 again, should be accepted now
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 200
# ensure participant2 name is still there
response = await ac.get(
f"/transcripts/{transcript_id}/participants/{participant2_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
assert response.json()["speaker"] == 1

View File

@@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket(
dummy_transcript, dummy_transcript,
dummy_processors, dummy_processors,
dummy_diarization, dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing, ensure_casing,
appserver, appserver,
sentence_tokenize, sentence_tokenize,
@@ -182,6 +184,16 @@ async def test_transcript_rtc_and_websocket(
ev = events[eventnames.index("FINAL_TITLE")] ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE" assert ev["data"]["title"] == "LLM TITLE"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("WAVEFORM")]
assert isinstance(ev["data"]["waveform"], list)
assert len(ev["data"]["waveform"]) >= 250
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
assert waveform_resp.status_code == 200
assert waveform_resp.headers["content-type"] == "application/json"
assert isinstance(waveform_resp.json()["data"], list)
assert len(waveform_resp.json()["data"]) >= 250
# check status order # check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]
assert statuses.index("recording") < statuses.index("processing") assert statuses.index("recording") < statuses.index("processing")
@@ -193,11 +205,12 @@ async def test_transcript_rtc_and_websocket(
# check on the latest response that the audio duration is > 0 # check on the latest response that the audio duration is > 0
assert resp.json()["duration"] > 0 assert resp.json()["duration"] > 0
assert "DURATION" in eventnames
# check that audio/mp3 is available # check that audio/mp3 is available
resp = await ac.get(f"/transcripts/{tid}/audio/mp3") audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
assert resp.status_code == 200 assert audio_resp.status_code == 200
assert resp.headers["Content-Type"] == "audio/mpeg" assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("celery_session_app") @pytest.mark.usefixtures("celery_session_app")
@@ -209,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_transcript, dummy_transcript,
dummy_processors, dummy_processors,
dummy_diarization, dummy_diarization,
dummy_storage,
fake_mp3_upload,
ensure_casing, ensure_casing,
appserver, appserver,
sentence_tokenize, sentence_tokenize,

View File

@@ -1,11 +1,18 @@
"use client"; "use client";
import { FiefAuthProvider } from "@fief/fief/nextjs/react"; import { FiefAuthProvider } from "@fief/fief/nextjs/react";
import { createContext } from "react";
export default function FiefWrapper({ children }) { export const CookieContext = createContext<{ hasAuthCookie: boolean }>({
hasAuthCookie: false,
});
export default function FiefWrapper({ children, hasAuthCookie }) {
return ( return (
<FiefAuthProvider currentUserPath="/api/current-user"> <CookieContext.Provider value={{ hasAuthCookie }}>
{children} <FiefAuthProvider currentUserPath="/api/current-user">
</FiefAuthProvider> {children}
</FiefAuthProvider>
</CookieContext.Provider>
); );
} }

View File

@@ -11,6 +11,9 @@ import About from "../(aboutAndPrivacy)/about";
import Privacy from "../(aboutAndPrivacy)/privacy"; import Privacy from "../(aboutAndPrivacy)/privacy";
import { DomainContextProvider } from "./domainContext"; import { DomainContextProvider } from "./domainContext";
import { getConfig } from "../lib/edgeConfig"; import { getConfig } from "../lib/edgeConfig";
import { ErrorBoundary } from "@sentry/nextjs";
import { cookies } from "next/dist/client/components/headers";
import { SESSION_COOKIE_NAME } from "../lib/fief";
const poppins = Poppins({ subsets: ["latin"], weight: ["200", "400", "600"] }); const poppins = Poppins({ subsets: ["latin"], weight: ["200", "400", "600"] });
@@ -70,86 +73,89 @@ type LayoutProps = {
export default async function RootLayout({ children, params }: LayoutProps) { export default async function RootLayout({ children, params }: LayoutProps) {
const config = await getConfig(params.domain); const config = await getConfig(params.domain);
const { requireLogin, privacy, browse } = config.features; const { requireLogin, privacy, browse } = config.features;
const hasAuthCookie = !!cookies().get(SESSION_COOKIE_NAME);
return ( return (
<html lang="en"> <html lang="en">
<body className={poppins.className + " h-screen relative"}> <body className={poppins.className + " h-screen relative"}>
<FiefWrapper> <FiefWrapper hasAuthCookie={hasAuthCookie}>
<DomainContextProvider config={config}> <DomainContextProvider config={config}>
<ErrorProvider> <ErrorBoundary fallback={<p>"something went really wrong"</p>}>
<ErrorMessage /> <ErrorProvider>
<div <ErrorMessage />
id="container" <div
className="items-center h-[100svh] w-[100svw] p-2 md:p-4 grid grid-rows-layout gap-2 md:gap-4" id="container"
> className="items-center h-[100svh] w-[100svw] p-2 md:p-4 grid grid-rows-layout gap-2 md:gap-4"
<header className="flex justify-between items-center w-full"> >
{/* Logo on the left */} <header className="flex justify-between items-center w-full">
<Link {/* Logo on the left */}
href="/"
className="flex outline-blue-300 md:outline-none focus-visible:underline underline-offset-2 decoration-[.5px] decoration-gray-500"
>
<Image
src="/reach.png"
width={16}
height={16}
className="h-10 w-auto"
alt="Reflector"
/>
<div className="hidden flex-col ml-2 md:block">
<h1 className="text-[38px] font-bold tracking-wide leading-tight">
Reflector
</h1>
<p className="text-gray-500 text-xs tracking-tighter">
Capture the signal, not the noise
</p>
</div>
</Link>
<div>
{/* Text link on the right */}
<Link <Link
href="/transcripts/new" href="/"
className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2" className="flex outline-blue-300 md:outline-none focus-visible:underline underline-offset-2 decoration-[.5px] decoration-gray-500"
> >
Create <Image
src="/reach.png"
width={16}
height={16}
className="h-10 w-auto"
alt="Reflector"
/>
<div className="hidden flex-col ml-2 md:block">
<h1 className="text-[38px] font-bold tracking-wide leading-tight">
Reflector
</h1>
<p className="text-gray-500 text-xs tracking-tighter">
Capture the signal, not the noise
</p>
</div>
</Link> </Link>
{browse ? ( <div>
<> {/* Text link on the right */}
&nbsp;·&nbsp; <Link
<Link href="/transcripts/new"
href="/browse" className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2"
className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2" >
prefetch={false} Create
> </Link>
Browse {browse ? (
</Link> <>
</> &nbsp;·&nbsp;
) : ( <Link
<></> href="/browse"
)} className="hover:underline focus-within:underline underline-offset-2 decoration-[.5px] font-light px-2"
&nbsp;·&nbsp; prefetch={false}
<About buttonText="About" /> >
{privacy ? ( Browse
<> </Link>
&nbsp;·&nbsp; </>
<Privacy buttonText="Privacy" /> ) : (
</> <></>
) : ( )}
<></> &nbsp;·&nbsp;
)} <About buttonText="About" />
{requireLogin ? ( {privacy ? (
<> <>
&nbsp;·&nbsp; &nbsp;·&nbsp;
<UserInfo /> <Privacy buttonText="Privacy" />
</> </>
) : ( ) : (
<></> <></>
)} )}
</div> {requireLogin ? (
</header> <>
&nbsp;·&nbsp;
<UserInfo />
</>
) : (
<></>
)}
</div>
</header>
{children} {children}
</div> </div>
</ErrorProvider> </ErrorProvider>
</ErrorBoundary>
</DomainContextProvider> </DomainContextProvider>
</FiefWrapper> </FiefWrapper>
</body> </body>

View File

@@ -5,15 +5,19 @@ import useTopics from "../useTopics";
import useWaveform from "../useWaveform"; import useWaveform from "../useWaveform";
import useMp3 from "../useMp3"; import useMp3 from "../useMp3";
import { TopicList } from "../topicList"; import { TopicList } from "../topicList";
import Recorder from "../recorder";
import { Topic } from "../webSocketTypes"; import { Topic } from "../webSocketTypes";
import React, { useState } from "react"; import React, { useEffect, useState } from "react";
import "../../../styles/button.css"; import "../../../styles/button.css";
import FinalSummary from "../finalSummary"; import FinalSummary from "../finalSummary";
import ShareLink from "../shareLink"; import ShareLink from "../shareLink";
import QRCode from "react-qr-code"; import QRCode from "react-qr-code";
import TranscriptTitle from "../transcriptTitle"; import TranscriptTitle from "../transcriptTitle";
import ShareModal from "./shareModal"; import ShareModal from "./shareModal";
import Player from "../player";
import WaveformLoading from "../waveformLoading";
import { useRouter } from "next/navigation";
import { faSpinner } from "@fortawesome/free-solid-svg-icons";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
type TranscriptDetails = { type TranscriptDetails = {
params: { params: {
@@ -21,26 +25,28 @@ type TranscriptDetails = {
}; };
}; };
const protectedPath = true;
export default function TranscriptDetails(details: TranscriptDetails) { export default function TranscriptDetails(details: TranscriptDetails) {
const transcriptId = details.params.transcriptId; const transcriptId = details.params.transcriptId;
const router = useRouter();
const transcript = useTranscript(protectedPath, transcriptId); const transcript = useTranscript(transcriptId);
const topics = useTopics(protectedPath, transcriptId); const topics = useTopics(transcriptId);
const waveform = useWaveform(protectedPath, transcriptId); const waveform = useWaveform(transcriptId);
const useActiveTopic = useState<Topic | null>(null); const useActiveTopic = useState<Topic | null>(null);
const mp3 = useMp3(protectedPath, transcriptId); const mp3 = useMp3(transcriptId);
const [showModal, setShowModal] = useState(false); const [showModal, setShowModal] = useState(false);
if (transcript?.error /** || topics?.error || waveform?.error **/) { useEffect(() => {
return ( const statusToRedirect = ["idle", "recording", "processing"];
<Modal if (statusToRedirect.includes(transcript.response?.status)) {
title="Transcription Not Found" const newUrl = "/transcripts/" + details.params.transcriptId + "/record";
text="A trascription with this ID does not exist." // Shallow redirection does not work on NextJS 13
/> // https://github.com/vercel/next.js/discussions/48110
); // https://github.com/vercel/next.js/discussions/49540
} router.push(newUrl, undefined);
// history.replaceState({}, "", newUrl);
}
}, [transcript.response?.status]);
const fullTranscript = const fullTranscript =
topics.topics topics.topics
@@ -90,79 +96,102 @@ export default function TranscriptDetails(details: TranscriptDetails) {
**Next Meeting:** **Next Meeting:**
Scheduled for December 5, 2023, to review progress and finalize the new product launch details. Scheduled for December 5, 2023, to review progress and finalize the new product launch details.
`; `;
}
return ( if (transcript.error || topics?.error) {
<> return (
{!transcriptId || transcript?.loading || topics?.loading ? ( <Modal
<Modal title="Loading" text={"Loading transcript..."} /> title="Transcription Not Found"
) : ( text="A trascription with this ID does not exist."
<> />
<ShareModal );
transcript={transcript.response} }
topics={topics ? topics.topics : null}
show={showModal} if (!transcriptId || transcript?.loading || topics?.loading) {
setShow={(v) => setShowModal(v)} return <Modal title="Loading" text={"Loading transcript..."} />;
title={transcript?.response?.title} }
summary={transcript?.response?.longSummary}
date={transcript?.response?.createdAt} return (
url={window.location.href} <>
/> <ShareModal
<div className="flex flex-col"> transcript={transcript.response}
{transcript?.response?.title && ( topics={topics ? topics.topics : null}
<TranscriptTitle show={showModal}
protectedPath={protectedPath} setShow={(v) => setShowModal(v)}
title={transcript.response.title} title={transcript?.response?.title}
transcriptId={transcript.response.id} summary={transcript?.response?.longSummary}
/> date={transcript?.response?.createdAt}
)} url={window.location.href}
{!waveform?.loading && ( />
<Recorder <div className="flex flex-col">
topics={topics?.topics || []} {transcript?.response?.title && (
useActiveTopic={useActiveTopic} <TranscriptTitle
waveform={waveform?.waveform} title={transcript.response.title}
isPastMeeting={true} transcriptId={transcript.response.id}
transcriptId={transcript?.response?.id} />
media={mp3?.media} )}
mediaDuration={transcript?.response?.duration} {waveform.waveform && mp3.media ? (
/> <Player
)}
</div>
<div className="grid grid-cols-1 lg:grid-cols-2 grid-rows-2 lg:grid-rows-1 gap-2 lg:gap-4 h-full">
<TopicList
topics={topics?.topics || []} topics={topics?.topics || []}
useActiveTopic={useActiveTopic} useActiveTopic={useActiveTopic}
autoscroll={false} waveform={waveform.waveform.data}
media={mp3.media}
mediaDuration={transcript.response.duration}
/> />
<div className="w-full h-full grid grid-rows-layout-one grid-cols-1 gap-2 lg:gap-4"> ) : waveform.error ? (
<section className=" bg-blue-400/20 rounded-lg md:rounded-xl p-2 md:px-4 h-full"> <div>"error loading this recording"</div>
{transcript?.response?.longSummary && ( ) : (
<FinalSummary <WaveformLoading />
protectedPath={protectedPath} )}
fullTranscript={fullTranscript} </div>
summary={transcript?.response?.longSummary} <div className="grid grid-cols-1 lg:grid-cols-2 grid-rows-2 lg:grid-rows-1 gap-2 lg:gap-4 h-full">
transcriptId={transcript?.response?.id} <TopicList
openZulipModal={() => setShowModal(true)} topics={topics.topics || []}
/> useActiveTopic={useActiveTopic}
)} autoscroll={false}
</section> />
<section className="flex items-center"> <div className="w-full h-full grid grid-rows-layout-one grid-cols-1 gap-2 lg:gap-4">
<div className="mr-4 hidden md:block h-auto"> <section className=" bg-blue-400/20 rounded-lg md:rounded-xl p-2 md:px-4 h-full">
<QRCode {transcript.response.longSummary ? (
value={`${location.origin}/transcripts/${details.params.transcriptId}`} <FinalSummary
level="L" fullTranscript={fullTranscript}
size={98} summary={transcript.response.longSummary}
/> transcriptId={transcript.response.id}
openZulipModal={() => setShowModal(true)}
/>
) : (
<div className="flex flex-col h-full justify-center content-center">
{transcript.response.status == "processing" ? (
<p>Loading Transcript</p>
) : (
<p>
There was an error generating the final summary, please
come back later
</p>
)}
</div> </div>
<div className="flex-grow max-w-full"> )}
<ShareLink /> </section>
</div>
</section> <section className="flex items-center">
</div> <div className="mr-4 hidden md:block h-auto">
<QRCode
value={`${location.origin}/transcripts/${details.params.transcriptId}`}
level="L"
size={98}
/>
</div>
<div className="flex-grow max-w-full">
<ShareLink
transcriptId={transcript?.response?.id}
userId={transcript?.response?.userId}
shareMode={transcript?.response?.shareMode}
/>
</div>
</section>
</div> </div>
</> </div>
)} </>
</> );
); }
} }

View File

@@ -8,12 +8,15 @@ import { useWebSockets } from "../../useWebSockets";
import useAudioDevice from "../../useAudioDevice"; import useAudioDevice from "../../useAudioDevice";
import "../../../../styles/button.css"; import "../../../../styles/button.css";
import { Topic } from "../../webSocketTypes"; import { Topic } from "../../webSocketTypes";
import getApi from "../../../../lib/getApi";
import LiveTrancription from "../../liveTranscription"; import LiveTrancription from "../../liveTranscription";
import DisconnectedIndicator from "../../disconnectedIndicator"; import DisconnectedIndicator from "../../disconnectedIndicator";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { faGear } from "@fortawesome/free-solid-svg-icons"; import { faGear } from "@fortawesome/free-solid-svg-icons";
import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock"; import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock";
import { useRouter } from "next/navigation";
import Player from "../../player";
import useMp3 from "../../useMp3";
import WaveformLoading from "../../waveformLoading";
type TranscriptDetails = { type TranscriptDetails = {
params: { params: {
@@ -36,14 +39,18 @@ const TranscriptRecord = (details: TranscriptDetails) => {
} }
}, []); }, []);
const transcript = useTranscript(true, details.params.transcriptId); const transcript = useTranscript(details.params.transcriptId);
const webRTC = useWebRTC(stream, details.params.transcriptId, true); const webRTC = useWebRTC(stream, details.params.transcriptId);
const webSockets = useWebSockets(details.params.transcriptId); const webSockets = useWebSockets(details.params.transcriptId);
const { audioDevices, getAudioStream } = useAudioDevice(); const { audioDevices, getAudioStream } = useAudioDevice();
const [hasRecorded, setHasRecorded] = useState(false); const [recordedTime, setRecordedTime] = useState(0);
const [startTime, setStartTime] = useState(0);
const [transcriptStarted, setTranscriptStarted] = useState(false); const [transcriptStarted, setTranscriptStarted] = useState(false);
let mp3 = useMp3(details.params.transcriptId, true);
const router = useRouter();
useEffect(() => { useEffect(() => {
if (!transcriptStarted && webSockets.transcriptText.length !== 0) if (!transcriptStarted && webSockets.transcriptText.length !== 0)
@@ -51,15 +58,27 @@ const TranscriptRecord = (details: TranscriptDetails) => {
}, [webSockets.transcriptText]); }, [webSockets.transcriptText]);
useEffect(() => { useEffect(() => {
if (transcript?.response?.longSummary) { const statusToRedirect = ["ended", "error"];
const newUrl = `/transcripts/${transcript.response.id}`;
//TODO if has no topic and is error, get back to new
if (
statusToRedirect.includes(transcript.response?.status) ||
statusToRedirect.includes(webSockets.status.value)
) {
const newUrl = "/transcripts/" + details.params.transcriptId;
// Shallow redirection does not work on NextJS 13 // Shallow redirection does not work on NextJS 13
// https://github.com/vercel/next.js/discussions/48110 // https://github.com/vercel/next.js/discussions/48110
// https://github.com/vercel/next.js/discussions/49540 // https://github.com/vercel/next.js/discussions/49540
// router.push(newUrl, undefined, { shallow: true }); router.replace(newUrl);
history.replaceState({}, "", newUrl); // history.replaceState({}, "", newUrl);
} // history.replaceState({}, "", newUrl);
}, [webSockets.status.value, transcript.response?.status]);
useEffect(() => {
if (webSockets.duration) {
mp3.getNow();
} }
}); }, [webSockets.duration]);
useEffect(() => { useEffect(() => {
lockWakeState(); lockWakeState();
@@ -70,19 +89,31 @@ const TranscriptRecord = (details: TranscriptDetails) => {
return ( return (
<> <>
<Recorder {webSockets.waveform && webSockets.duration && mp3?.media ? (
setStream={setStream} <Player
onStop={() => { topics={webSockets.topics || []}
setStream(null); useActiveTopic={useActiveTopic}
setHasRecorded(true); waveform={webSockets.waveform}
webRTC?.send(JSON.stringify({ cmd: "STOP" })); media={mp3.media}
}} mediaDuration={webSockets.duration}
topics={webSockets.topics} />
getAudioStream={getAudioStream} ) : recordedTime ? (
useActiveTopic={useActiveTopic} <WaveformLoading />
isPastMeeting={false} ) : (
audioDevices={audioDevices} <Recorder
/> setStream={setStream}
onStop={() => {
setStream(null);
setRecordedTime(Date.now() - startTime);
webRTC?.send(JSON.stringify({ cmd: "STOP" }));
}}
onRecord={() => {
setStartTime(Date.now());
}}
getAudioStream={getAudioStream}
audioDevices={audioDevices}
/>
)}
<div className="grid grid-cols-1 lg:grid-cols-2 grid-rows-mobile-inner lg:grid-rows-1 gap-2 lg:gap-4 h-full"> <div className="grid grid-cols-1 lg:grid-cols-2 grid-rows-mobile-inner lg:grid-rows-1 gap-2 lg:gap-4 h-full">
<TopicList <TopicList
@@ -94,7 +125,7 @@ const TranscriptRecord = (details: TranscriptDetails) => {
<section <section
className={`w-full h-full bg-blue-400/20 rounded-lg md:rounded-xl p-2 md:px-4`} className={`w-full h-full bg-blue-400/20 rounded-lg md:rounded-xl p-2 md:px-4`}
> >
{!hasRecorded ? ( {!recordedTime ? (
<> <>
{transcriptStarted && ( {transcriptStarted && (
<h2 className="md:text-lg font-bold">Transcription</h2> <h2 className="md:text-lg font-bold">Transcription</h2>
@@ -128,6 +159,7 @@ const TranscriptRecord = (details: TranscriptDetails) => {
couple of minutes. Please do not navigate away from the page couple of minutes. Please do not navigate away from the page
during this time. during this time.
</p> </p>
{/* NTH If login required remove last sentence */}
</div> </div>
)} )}
</section> </section>

View File

@@ -19,7 +19,7 @@ const useCreateTranscript = (): CreateTranscript => {
const [loading, setLoading] = useState<boolean>(false); const [loading, setLoading] = useState<boolean>(false);
const [error, setErrorState] = useState<Error | null>(null); const [error, setErrorState] = useState<Error | null>(null);
const { setError } = useError(); const { setError } = useError();
const api = getApi(true); const api = getApi();
const create = (params: V1TranscriptsCreateRequest["createTranscript"]) => { const create = (params: V1TranscriptsCreateRequest["createTranscript"]) => {
if (loading || !api) return; if (loading || !api) return;

View File

@@ -5,7 +5,6 @@ import "../../styles/markdown.css";
import getApi from "../../lib/getApi"; import getApi from "../../lib/getApi";
type FinalSummaryProps = { type FinalSummaryProps = {
protectedPath: boolean;
summary: string; summary: string;
fullTranscript: string; fullTranscript: string;
transcriptId: string; transcriptId: string;
@@ -19,7 +18,7 @@ export default function FinalSummary(props: FinalSummaryProps) {
const [isEditMode, setIsEditMode] = useState(false); const [isEditMode, setIsEditMode] = useState(false);
const [preEditSummary, setPreEditSummary] = useState(props.summary); const [preEditSummary, setPreEditSummary] = useState(props.summary);
const [editedSummary, setEditedSummary] = useState(props.summary); const [editedSummary, setEditedSummary] = useState(props.summary);
const api = getApi(props.protectedPath); const api = getApi();
const updateSummary = async (newSummary: string, transcriptId: string) => { const updateSummary = async (newSummary: string, transcriptId: string) => {
if (!api) return; if (!api) return;
@@ -88,7 +87,7 @@ export default function FinalSummary(props: FinalSummaryProps) {
<div <div
className={ className={
(isEditMode ? "overflow-y-none" : "overflow-y-auto") + (isEditMode ? "overflow-y-none" : "overflow-y-auto") +
" h-auto max-h-full flex flex-col h-full" " max-h-full flex flex-col h-full"
} }
> >
<div className="flex flex-row flex-wrap-reverse justify-between items-center"> <div className="flex flex-row flex-wrap-reverse justify-between items-center">

View File

@@ -0,0 +1,166 @@
import React, { useRef, useEffect, useState } from "react";
import WaveSurfer from "wavesurfer.js";
import CustomRegionsPlugin from "../../lib/custom-plugins/regions";
import { formatTime } from "../../lib/time";
import { Topic } from "./webSocketTypes";
import { AudioWaveform } from "../../api";
import { waveSurferStyles } from "../../styles/recorder";
type PlayerProps = {
topics: Topic[];
useActiveTopic: [
Topic | null,
React.Dispatch<React.SetStateAction<Topic | null>>,
];
waveform: AudioWaveform["data"];
media: HTMLMediaElement;
mediaDuration: number;
};
export default function Player(props: PlayerProps) {
const waveformRef = useRef<HTMLDivElement>(null);
const [wavesurfer, setWavesurfer] = useState<WaveSurfer | null>(null);
const [isPlaying, setIsPlaying] = useState<boolean>(false);
const [currentTime, setCurrentTime] = useState<number>(0);
const [waveRegions, setWaveRegions] = useState<CustomRegionsPlugin | null>(
null,
);
const [activeTopic, setActiveTopic] = props.useActiveTopic;
const topicsRef = useRef(props.topics);
// Waveform setup
useEffect(() => {
if (waveformRef.current) {
// XXX duration is required to prevent recomputing peaks from audio
// However, the current waveform returns only the peaks, and no duration
// And the backend does not save duration properly.
// So at the moment, we deduct the duration from the topics.
// This is not ideal, but it works for now.
const _wavesurfer = WaveSurfer.create({
container: waveformRef.current,
peaks: props.waveform,
hideScrollbar: true,
autoCenter: true,
barWidth: 2,
height: "auto",
duration: props.mediaDuration,
...waveSurferStyles.player,
});
// styling
const wsWrapper = _wavesurfer.getWrapper();
wsWrapper.style.cursor = waveSurferStyles.playerStyle.cursor;
wsWrapper.style.backgroundColor =
waveSurferStyles.playerStyle.backgroundColor;
wsWrapper.style.borderRadius = waveSurferStyles.playerStyle.borderRadius;
_wavesurfer.on("play", () => {
setIsPlaying(true);
});
_wavesurfer.on("pause", () => {
setIsPlaying(false);
});
_wavesurfer.on("timeupdate", setCurrentTime);
setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create()));
_wavesurfer.toggleInteraction(true);
_wavesurfer.setMediaElement(props.media);
setWavesurfer(_wavesurfer);
return () => {
_wavesurfer.destroy();
setIsPlaying(false);
setCurrentTime(0);
};
}
}, []);
useEffect(() => {
if (!wavesurfer) return;
if (!props.media) return;
wavesurfer.setMediaElement(props.media);
}, [props.media, wavesurfer]);
useEffect(() => {
topicsRef.current = props.topics;
renderMarkers();
}, [props.topics, waveRegions]);
const renderMarkers = () => {
if (!waveRegions) return;
waveRegions.clearRegions();
for (let topic of topicsRef.current) {
const content = document.createElement("div");
content.setAttribute("style", waveSurferStyles.marker);
content.onmouseover = () => {
content.style.backgroundColor =
waveSurferStyles.markerHover.backgroundColor;
content.style.zIndex = "999";
content.style.width = "300px";
};
content.onmouseout = () => {
content.setAttribute("style", waveSurferStyles.marker);
};
content.textContent = topic.title;
const region = waveRegions.addRegion({
start: topic.timestamp,
content,
color: "f00",
drag: false,
});
region.on("click", (e) => {
e.stopPropagation();
setActiveTopic(topic);
wavesurfer?.setTime(region.start);
});
}
};
useEffect(() => {
if (activeTopic) {
wavesurfer?.setTime(activeTopic.timestamp);
}
}, [activeTopic]);
const handlePlayClick = () => {
wavesurfer?.playPause();
};
const timeLabel = () => {
if (props.mediaDuration)
return `${formatTime(currentTime)}/${formatTime(props.mediaDuration)}`;
return "";
};
return (
<div className="flex items-center w-full relative">
<div className="flex-grow items-end relative">
<div
ref={waveformRef}
className="flex-grow rounded-lg md:rounded-xl h-20"
></div>
<div className="absolute right-2 bottom-0">{timeLabel()}</div>
</div>
<button
className={`${
isPlaying
? "bg-orange-400 hover:bg-orange-500 focus-visible:bg-orange-500"
: "bg-green-400 hover:bg-green-500 focus-visible:bg-green-500"
} text-white ml-2 md:ml:4 md:h-[78px] md:min-w-[100px] text-lg`}
id="play-btn"
onClick={handlePlayClick}
>
{isPlaying ? "Pause" : "Play"}
</button>
</div>
);
}

View File

@@ -6,31 +6,19 @@ import CustomRegionsPlugin from "../../lib/custom-plugins/regions";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { faMicrophone } from "@fortawesome/free-solid-svg-icons"; import { faMicrophone } from "@fortawesome/free-solid-svg-icons";
import { faDownload } from "@fortawesome/free-solid-svg-icons";
import { formatTime } from "../../lib/time"; import { formatTime } from "../../lib/time";
import { Topic } from "./webSocketTypes";
import { AudioWaveform } from "../../api";
import AudioInputsDropdown from "./audioInputsDropdown"; import AudioInputsDropdown from "./audioInputsDropdown";
import { Option } from "react-dropdown"; import { Option } from "react-dropdown";
import { waveSurferStyles } from "../../styles/recorder"; import { waveSurferStyles } from "../../styles/recorder";
import { useError } from "../../(errors)/errorContext"; import { useError } from "../../(errors)/errorContext";
type RecorderProps = { type RecorderProps = {
setStream?: React.Dispatch<React.SetStateAction<MediaStream | null>>; setStream: React.Dispatch<React.SetStateAction<MediaStream | null>>;
onStop?: () => void; onStop: () => void;
topics: Topic[]; onRecord?: () => void;
getAudioStream?: (deviceId) => Promise<MediaStream | null>; getAudioStream: (deviceId) => Promise<MediaStream | null>;
audioDevices?: Option[]; audioDevices: Option[];
useActiveTopic: [
Topic | null,
React.Dispatch<React.SetStateAction<Topic | null>>,
];
waveform?: AudioWaveform | null;
isPastMeeting: boolean;
transcriptId?: string | null;
media?: HTMLMediaElement | null;
mediaDuration?: number | null;
}; };
export default function Recorder(props: RecorderProps) { export default function Recorder(props: RecorderProps) {
@@ -38,7 +26,7 @@ export default function Recorder(props: RecorderProps) {
const [wavesurfer, setWavesurfer] = useState<WaveSurfer | null>(null); const [wavesurfer, setWavesurfer] = useState<WaveSurfer | null>(null);
const [record, setRecord] = useState<RecordPlugin | null>(null); const [record, setRecord] = useState<RecordPlugin | null>(null);
const [isRecording, setIsRecording] = useState<boolean>(false); const [isRecording, setIsRecording] = useState<boolean>(false);
const [hasRecorded, setHasRecorded] = useState<boolean>(props.isPastMeeting); const [hasRecorded, setHasRecorded] = useState<boolean>(false);
const [isPlaying, setIsPlaying] = useState<boolean>(false); const [isPlaying, setIsPlaying] = useState<boolean>(false);
const [currentTime, setCurrentTime] = useState<number>(0); const [currentTime, setCurrentTime] = useState<number>(0);
const [timeInterval, setTimeInterval] = useState<number | null>(null); const [timeInterval, setTimeInterval] = useState<number | null>(null);
@@ -48,8 +36,6 @@ export default function Recorder(props: RecorderProps) {
); );
const [deviceId, setDeviceId] = useState<string | null>(null); const [deviceId, setDeviceId] = useState<string | null>(null);
const [recordStarted, setRecordStarted] = useState(false); const [recordStarted, setRecordStarted] = useState(false);
const [activeTopic, setActiveTopic] = props.useActiveTopic;
const topicsRef = useRef(props.topics);
const [showDevices, setShowDevices] = useState(false); const [showDevices, setShowDevices] = useState(false);
const { setError } = useError(); const { setError } = useError();
@@ -73,8 +59,6 @@ export default function Recorder(props: RecorderProps) {
if (!record.isRecording()) return; if (!record.isRecording()) return;
handleRecClick(); handleRecClick();
break; break;
case "^":
throw new Error("Unhandled Exception thrown by '^' shortcut");
case "(": case "(":
location.href = "/login"; location.href = "/login";
break; break;
@@ -104,27 +88,18 @@ export default function Recorder(props: RecorderProps) {
// Waveform setup // Waveform setup
useEffect(() => { useEffect(() => {
if (waveformRef.current) { if (waveformRef.current) {
// XXX duration is required to prevent recomputing peaks from audio
// However, the current waveform returns only the peaks, and no duration
// And the backend does not save duration properly.
// So at the moment, we deduct the duration from the topics.
// This is not ideal, but it works for now.
const _wavesurfer = WaveSurfer.create({ const _wavesurfer = WaveSurfer.create({
container: waveformRef.current, container: waveformRef.current,
peaks: props.waveform?.data,
hideScrollbar: true, hideScrollbar: true,
autoCenter: true, autoCenter: true,
barWidth: 2, barWidth: 2,
height: "auto", height: "auto",
duration: props.mediaDuration || 1,
...waveSurferStyles.player, ...waveSurferStyles.player,
}); });
if (!props.transcriptId) { const _wshack: any = _wavesurfer;
const _wshack: any = _wavesurfer; _wshack.renderer.renderSingleCanvas = () => {};
_wshack.renderer.renderSingleCanvas = () => {};
}
// styling // styling
const wsWrapper = _wavesurfer.getWrapper(); const wsWrapper = _wavesurfer.getWrapper();
@@ -144,12 +119,6 @@ export default function Recorder(props: RecorderProps) {
setRecord(_wavesurfer.registerPlugin(RecordPlugin.create())); setRecord(_wavesurfer.registerPlugin(RecordPlugin.create()));
setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create())); setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create()));
if (props.isPastMeeting) _wavesurfer.toggleInteraction(true);
if (props.media) {
_wavesurfer.setMediaElement(props.media);
}
setWavesurfer(_wavesurfer); setWavesurfer(_wavesurfer);
return () => { return () => {
@@ -161,58 +130,6 @@ export default function Recorder(props: RecorderProps) {
} }
}, []); }, []);
useEffect(() => {
if (!wavesurfer) return;
if (!props.media) return;
wavesurfer.setMediaElement(props.media);
}, [props.media, wavesurfer]);
useEffect(() => {
topicsRef.current = props.topics;
if (!isRecording) renderMarkers();
}, [props.topics, waveRegions]);
const renderMarkers = () => {
if (!waveRegions) return;
waveRegions.clearRegions();
for (let topic of topicsRef.current) {
const content = document.createElement("div");
content.setAttribute("style", waveSurferStyles.marker);
content.onmouseover = () => {
content.style.backgroundColor =
waveSurferStyles.markerHover.backgroundColor;
content.style.zIndex = "999";
content.style.width = "300px";
};
content.onmouseout = () => {
content.setAttribute("style", waveSurferStyles.marker);
};
content.textContent = topic.title;
const region = waveRegions.addRegion({
start: topic.timestamp,
content,
color: "f00",
drag: false,
});
region.on("click", (e) => {
e.stopPropagation();
setActiveTopic(topic);
wavesurfer?.setTime(region.start);
});
}
};
useEffect(() => {
if (!record) return;
return record.on("stopRecording", () => {
renderMarkers();
});
}, [record]);
useEffect(() => { useEffect(() => {
if (isRecording) { if (isRecording) {
const interval = window.setInterval(() => { const interval = window.setInterval(() => {
@@ -229,12 +146,6 @@ export default function Recorder(props: RecorderProps) {
} }
}, [isRecording]); }, [isRecording]);
useEffect(() => {
if (activeTopic) {
wavesurfer?.setTime(activeTopic.timestamp);
}
}, [activeTopic]);
const handleRecClick = async () => { const handleRecClick = async () => {
if (!record) return console.log("no record"); if (!record) return console.log("no record");
@@ -249,10 +160,10 @@ export default function Recorder(props: RecorderProps) {
setScreenMediaStream(null); setScreenMediaStream(null);
setDestinationStream(null); setDestinationStream(null);
} else { } else {
if (props.onRecord) props.onRecord();
const stream = await getCurrentStream(); const stream = await getCurrentStream();
if (props.setStream) props.setStream(stream); if (props.setStream) props.setStream(stream);
waveRegions?.clearRegions();
if (stream) { if (stream) {
await record.startRecording(stream); await record.startRecording(stream);
setIsRecording(true); setIsRecording(true);
@@ -320,7 +231,6 @@ export default function Recorder(props: RecorderProps) {
if (!record) return; if (!record) return;
if (!destinationStream) return; if (!destinationStream) return;
if (props.setStream) props.setStream(destinationStream); if (props.setStream) props.setStream(destinationStream);
waveRegions?.clearRegions();
if (destinationStream) { if (destinationStream) {
record.startRecording(destinationStream); record.startRecording(destinationStream);
setIsRecording(true); setIsRecording(true);
@@ -379,23 +289,9 @@ export default function Recorder(props: RecorderProps) {
} text-white ml-2 md:ml:4 md:h-[78px] md:min-w-[100px] text-lg`} } text-white ml-2 md:ml:4 md:h-[78px] md:min-w-[100px] text-lg`}
id="play-btn" id="play-btn"
onClick={handlePlayClick} onClick={handlePlayClick}
disabled={isRecording}
> >
{isPlaying ? "Pause" : "Play"} {isPlaying ? "Pause" : "Play"}
</button> </button>
{props.transcriptId && (
<a
title="Download recording"
className="text-center cursor-pointer text-blue-400 hover:text-blue-700 ml-2 md:ml:4 p-2 rounded-lg outline-blue-400"
download={`recording-${
props.transcriptId?.split("-")[0] || "0000"
}`}
href={`${process.env.NEXT_PUBLIC_API_URL}/v1/transcripts/${props.transcriptId}/audio/mp3`}
>
<FontAwesomeIcon icon={faDownload} className="h-5 w-auto" />
</a>
)}
</> </>
)} )}
{!hasRecorded && ( {!hasRecorded && (

View File

@@ -1,15 +1,39 @@
import React, { useState, useRef, useEffect, use } from "react"; import React, { useState, useRef, useEffect, use } from "react";
import { featureEnabled } from "../domainContext"; import { featureEnabled } from "../domainContext";
import getApi from "../../lib/getApi";
import { useFiefUserinfo } from "@fief/fief/nextjs/react";
import SelectSearch from "react-select-search";
import "react-select-search/style.css";
import "../../styles/button.css";
import "../../styles/form.scss";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
import { faSpinner } from "@fortawesome/free-solid-svg-icons";
const ShareLink = () => { type ShareLinkProps = {
transcriptId: string;
userId: string | null;
shareMode: string;
};
const ShareLink = (props: ShareLinkProps) => {
const [isCopied, setIsCopied] = useState(false); const [isCopied, setIsCopied] = useState(false);
const inputRef = useRef<HTMLInputElement>(null); const inputRef = useRef<HTMLInputElement>(null);
const [currentUrl, setCurrentUrl] = useState<string>(""); const [currentUrl, setCurrentUrl] = useState<string>("");
const requireLogin = featureEnabled("requireLogin");
const [isOwner, setIsOwner] = useState(false);
const [shareMode, setShareMode] = useState(props.shareMode);
const [shareLoading, setShareLoading] = useState(false);
const userinfo = useFiefUserinfo();
const api = getApi();
useEffect(() => { useEffect(() => {
setCurrentUrl(window.location.href); setCurrentUrl(window.location.href);
}, []); }, []);
useEffect(() => {
setIsOwner(!!(requireLogin && userinfo?.sub === props.userId));
}, [userinfo, props.userId]);
const handleCopyClick = () => { const handleCopyClick = () => {
if (inputRef.current) { if (inputRef.current) {
let text_to_copy = inputRef.current.value; let text_to_copy = inputRef.current.value;
@@ -23,6 +47,18 @@ const ShareLink = () => {
} }
}; };
const updateShareMode = async (selectedShareMode: string) => {
if (!api) return;
setShareLoading(true);
const updatedTranscript = await api.v1TranscriptUpdate({
transcriptId: props.transcriptId,
updateTranscript: {
shareMode: selectedShareMode,
},
});
setShareMode(updatedTranscript.shareMode);
setShareLoading(false);
};
const privacyEnabled = featureEnabled("privacy"); const privacyEnabled = featureEnabled("privacy");
return ( return (
@@ -30,17 +66,60 @@ const ShareLink = () => {
className="p-2 md:p-4 rounded" className="p-2 md:p-4 rounded"
style={{ background: "rgba(96, 165, 250, 0.2)" }} style={{ background: "rgba(96, 165, 250, 0.2)" }}
> >
{privacyEnabled ? ( {requireLogin && (
<p className="text-sm mb-2"> <div className="text-sm mb-2">
You can share this link with others. Anyone with the link will have {shareMode === "private" && (
access to the page, including the full audio recording, for the next 7 <p>This transcript is private and can only be accessed by you.</p>
days. )}
</p> {shareMode === "semi-private" && (
) : ( <p>
<p className="text-sm mb-2"> This transcript is secure. Only authenticated users can access it.
You can share this link with others. Anyone with the link will have </p>
access to the page, including the full audio recording. )}
</p> {shareMode === "public" && (
<p>This transcript is public. Everyone can access it.</p>
)}
{isOwner && api && (
<div className="relative">
<SelectSearch
className="select-search--top select-search"
options={[
{ name: "Private", value: "private" },
{ name: "Secure", value: "semi-private" },
{ name: "Public", value: "public" },
]}
value={shareMode}
onChange={updateShareMode}
closeOnSelect={true}
/>
{shareLoading && (
<div className="h-4 w-4 absolute top-1/3 right-3 z-10">
<FontAwesomeIcon
icon={faSpinner}
className="animate-spin-slow text-gray-600 flex-grow rounded-lg md:rounded-xl h-4 w-4"
/>
</div>
)}
</div>
)}
</div>
)}
{!requireLogin && (
<>
{privacyEnabled ? (
<p className="text-sm mb-2">
Share this link to grant others access to this page. The link
includes the full audio recording and is valid for the next 7
days.
</p>
) : (
<p className="text-sm mb-2">
Share this link to allow others to view this page and listen to
the full audio recording.
</p>
)}
</>
)} )}
<div className="flex items-center"> <div className="flex items-center">
<input <input

View File

@@ -2,7 +2,6 @@ import { useState } from "react";
import getApi from "../../lib/getApi"; import getApi from "../../lib/getApi";
type TranscriptTitle = { type TranscriptTitle = {
protectedPath: boolean;
title: string; title: string;
transcriptId: string; transcriptId: string;
}; };
@@ -11,7 +10,7 @@ const TranscriptTitle = (props: TranscriptTitle) => {
const [displayedTitle, setDisplayedTitle] = useState(props.title); const [displayedTitle, setDisplayedTitle] = useState(props.title);
const [preEditTitle, setPreEditTitle] = useState(props.title); const [preEditTitle, setPreEditTitle] = useState(props.title);
const [isEditing, setIsEditing] = useState(false); const [isEditing, setIsEditing] = useState(false);
const api = getApi(props.protectedPath); const api = getApi();
const updateTitle = async (newTitle: string, transcriptId: string) => { const updateTitle = async (newTitle: string, transcriptId: string) => {
if (!api) return; if (!api) return;

View File

@@ -1,49 +1,48 @@
import { useContext, useEffect, useState } from "react"; import { useContext, useEffect, useState } from "react";
import { useError } from "../../(errors)/errorContext";
import { DomainContext } from "../domainContext"; import { DomainContext } from "../domainContext";
import getApi from "../../lib/getApi"; import getApi from "../../lib/getApi";
import { useFiefAccessTokenInfo } from "@fief/fief/build/esm/nextjs/react"; import { useFiefAccessTokenInfo } from "@fief/fief/build/esm/nextjs/react";
import { shouldShowError } from "../../lib/errorUtils";
type Mp3Response = { export type Mp3Response = {
url: string | null;
media: HTMLMediaElement | null; media: HTMLMediaElement | null;
loading: boolean; loading: boolean;
error: Error | null; getNow: () => void;
}; };
const useMp3 = (protectedPath: boolean, id: string): Mp3Response => { const useMp3 = (id: string, waiting?: boolean): Mp3Response => {
const [url, setUrl] = useState<string | null>(null);
const [media, setMedia] = useState<HTMLMediaElement | null>(null); const [media, setMedia] = useState<HTMLMediaElement | null>(null);
const [later, setLater] = useState(waiting);
const [loading, setLoading] = useState<boolean>(false); const [loading, setLoading] = useState<boolean>(false);
const [error, setErrorState] = useState<Error | null>(null); const api = getApi();
const { setError } = useError();
const api = getApi(protectedPath);
const { api_url } = useContext(DomainContext); const { api_url } = useContext(DomainContext);
const accessTokenInfo = useFiefAccessTokenInfo(); const accessTokenInfo = useFiefAccessTokenInfo();
const [serviceWorkerReady, setServiceWorkerReady] = useState(false); const [serviceWorker, setServiceWorker] =
useState<ServiceWorkerRegistration | null>(null);
useEffect(() => { useEffect(() => {
if ("serviceWorker" in navigator) { if ("serviceWorker" in navigator) {
navigator.serviceWorker.register("/service-worker.js").then(() => { navigator.serviceWorker.register("/service-worker.js").then((worker) => {
setServiceWorkerReady(true); setServiceWorker(worker);
}); });
} }
return () => {
serviceWorker?.unregister();
};
}, []); }, []);
useEffect(() => { useEffect(() => {
if (!navigator.serviceWorker) return; if (!navigator.serviceWorker) return;
if (!navigator.serviceWorker.controller) return; if (!navigator.serviceWorker.controller) return;
if (!serviceWorkerReady) return; if (!serviceWorker) return;
// Send the token to the service worker // Send the token to the service worker
navigator.serviceWorker.controller.postMessage({ navigator.serviceWorker.controller.postMessage({
type: "SET_AUTH_TOKEN", type: "SET_AUTH_TOKEN",
token: accessTokenInfo?.access_token, token: accessTokenInfo?.access_token,
}); });
}, [navigator.serviceWorker, serviceWorkerReady, accessTokenInfo]); }, [navigator.serviceWorker, !serviceWorker, accessTokenInfo]);
const getMp3 = (id: string) => { useEffect(() => {
if (!id || !api) return; if (!id || !api || later) return;
// createa a audio element and set the source // createa a audio element and set the source
setLoading(true); setLoading(true);
@@ -53,13 +52,13 @@ const useMp3 = (protectedPath: boolean, id: string): Mp3Response => {
audioElement.preload = "auto"; audioElement.preload = "auto";
setMedia(audioElement); setMedia(audioElement);
setLoading(false); setLoading(false);
}, [id, api, later]);
const getNow = () => {
setLater(false);
}; };
useEffect(() => { return { media, loading, getNow };
getMp3(id);
}, [id, api]);
return { url, media, loading, error };
}; };
export default useMp3; export default useMp3;

View File

@@ -14,12 +14,12 @@ type TranscriptTopics = {
error: Error | null; error: Error | null;
}; };
const useTopics = (protectedPath, id: string): TranscriptTopics => { const useTopics = (id: string): TranscriptTopics => {
const [topics, setTopics] = useState<Topic[] | null>(null); const [topics, setTopics] = useState<Topic[] | null>(null);
const [loading, setLoading] = useState<boolean>(false); const [loading, setLoading] = useState<boolean>(false);
const [error, setErrorState] = useState<Error | null>(null); const [error, setErrorState] = useState<Error | null>(null);
const { setError } = useError(); const { setError } = useError();
const api = getApi(protectedPath); const api = getApi();
useEffect(() => { useEffect(() => {
if (!id || !api) return; if (!id || !api) return;

View File

@@ -5,21 +5,32 @@ import { useError } from "../../(errors)/errorContext";
import getApi from "../../lib/getApi"; import getApi from "../../lib/getApi";
import { shouldShowError } from "../../lib/errorUtils"; import { shouldShowError } from "../../lib/errorUtils";
type Transcript = { type ErrorTranscript = {
response: GetTranscript | null; error: Error;
loading: boolean; loading: false;
error: Error | null; response: any;
};
type LoadingTranscript = {
response: any;
loading: true;
error: false;
};
type SuccessTranscript = {
response: GetTranscript;
loading: false;
error: null;
}; };
const useTranscript = ( const useTranscript = (
protectedPath: boolean,
id: string | null, id: string | null,
): Transcript => { ): ErrorTranscript | LoadingTranscript | SuccessTranscript => {
const [response, setResponse] = useState<GetTranscript | null>(null); const [response, setResponse] = useState<GetTranscript | null>(null);
const [loading, setLoading] = useState<boolean>(true); const [loading, setLoading] = useState<boolean>(true);
const [error, setErrorState] = useState<Error | null>(null); const [error, setErrorState] = useState<Error | null>(null);
const { setError } = useError(); const { setError } = useError();
const api = getApi(protectedPath); const api = getApi();
useEffect(() => { useEffect(() => {
if (!id || !api) return; if (!id || !api) return;
@@ -46,7 +57,10 @@ const useTranscript = (
}); });
}, [id, !api]); }, [id, !api]);
return { response, loading, error }; return { response, loading, error } as
| ErrorTranscript
| LoadingTranscript
| SuccessTranscript;
}; };
export default useTranscript; export default useTranscript;

View File

@@ -15,7 +15,7 @@ const useTranscriptList = (page: number): TranscriptList => {
const [loading, setLoading] = useState<boolean>(true); const [loading, setLoading] = useState<boolean>(true);
const [error, setErrorState] = useState<Error | null>(null); const [error, setErrorState] = useState<Error | null>(null);
const { setError } = useError(); const { setError } = useError();
const api = getApi(true); const api = getApi();
useEffect(() => { useEffect(() => {
if (!api) return; if (!api) return;

View File

@@ -1,8 +1,5 @@
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { import { V1TranscriptGetAudioWaveformRequest } from "../../api/apis/DefaultApi";
DefaultApi,
V1TranscriptGetAudioWaveformRequest,
} from "../../api/apis/DefaultApi";
import { AudioWaveform } from "../../api"; import { AudioWaveform } from "../../api";
import { useError } from "../../(errors)/errorContext"; import { useError } from "../../(errors)/errorContext";
import getApi from "../../lib/getApi"; import getApi from "../../lib/getApi";
@@ -14,12 +11,12 @@ type AudioWaveFormResponse = {
error: Error | null; error: Error | null;
}; };
const useWaveform = (protectedPath, id: string): AudioWaveFormResponse => { const useWaveform = (id: string): AudioWaveFormResponse => {
const [waveform, setWaveform] = useState<AudioWaveform | null>(null); const [waveform, setWaveform] = useState<AudioWaveform | null>(null);
const [loading, setLoading] = useState<boolean>(true); const [loading, setLoading] = useState<boolean>(true);
const [error, setErrorState] = useState<Error | null>(null); const [error, setErrorState] = useState<Error | null>(null);
const { setError } = useError(); const { setError } = useError();
const api = getApi(protectedPath); const api = getApi();
useEffect(() => { useEffect(() => {
if (!id || !api) return; if (!id || !api) return;

View File

@@ -10,11 +10,10 @@ import getApi from "../../lib/getApi";
const useWebRTC = ( const useWebRTC = (
stream: MediaStream | null, stream: MediaStream | null,
transcriptId: string | null, transcriptId: string | null,
protectedPath,
): Peer => { ): Peer => {
const [peer, setPeer] = useState<Peer | null>(null); const [peer, setPeer] = useState<Peer | null>(null);
const { setError } = useError(); const { setError } = useError();
const api = getApi(protectedPath); const api = getApi();
useEffect(() => { useEffect(() => {
if (!stream || !transcriptId) { if (!stream || !transcriptId) {

View File

@@ -1,30 +1,35 @@
import { useContext, useEffect, useState } from "react"; import { useContext, useEffect, useState } from "react";
import { Topic, FinalSummary, Status } from "./webSocketTypes"; import { Topic, FinalSummary, Status } from "./webSocketTypes";
import { useError } from "../../(errors)/errorContext"; import { useError } from "../../(errors)/errorContext";
import { useRouter } from "next/navigation";
import { DomainContext } from "../domainContext"; import { DomainContext } from "../domainContext";
import { AudioWaveform } from "../../api";
type UseWebSockets = { export type UseWebSockets = {
transcriptText: string; transcriptText: string;
translateText: string; translateText: string;
title: string;
topics: Topic[]; topics: Topic[];
finalSummary: FinalSummary; finalSummary: FinalSummary;
status: Status; status: Status;
waveform: AudioWaveform["data"] | null;
duration: number | null;
}; };
export const useWebSockets = (transcriptId: string | null): UseWebSockets => { export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
const [transcriptText, setTranscriptText] = useState<string>(""); const [transcriptText, setTranscriptText] = useState<string>("");
const [translateText, setTranslateText] = useState<string>(""); const [translateText, setTranslateText] = useState<string>("");
const [title, setTitle] = useState<string>("");
const [textQueue, setTextQueue] = useState<string[]>([]); const [textQueue, setTextQueue] = useState<string[]>([]);
const [translationQueue, setTranslationQueue] = useState<string[]>([]); const [translationQueue, setTranslationQueue] = useState<string[]>([]);
const [isProcessing, setIsProcessing] = useState(false); const [isProcessing, setIsProcessing] = useState(false);
const [topics, setTopics] = useState<Topic[]>([]); const [topics, setTopics] = useState<Topic[]>([]);
const [waveform, setWaveForm] = useState<AudioWaveform | null>(null);
const [duration, setDuration] = useState<number | null>(null);
const [finalSummary, setFinalSummary] = useState<FinalSummary>({ const [finalSummary, setFinalSummary] = useState<FinalSummary>({
summary: "", summary: "",
}); });
const [status, setStatus] = useState<Status>({ value: "initial" }); const [status, setStatus] = useState<Status>({ value: "initial" });
const { setError } = useError(); const { setError } = useError();
const router = useRouter();
const { websocket_url } = useContext(DomainContext); const { websocket_url } = useContext(DomainContext);
@@ -294,7 +299,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
if (!transcriptId) return; if (!transcriptId) return;
const url = `${websocket_url}/v1/transcripts/${transcriptId}/events`; const url = `${websocket_url}/v1/transcripts/${transcriptId}/events`;
const ws = new WebSocket(url); let ws = new WebSocket(url);
ws.onopen = () => { ws.onopen = () => {
console.debug("WebSocket connection opened"); console.debug("WebSocket connection opened");
@@ -343,24 +348,39 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
case "FINAL_TITLE": case "FINAL_TITLE":
console.debug("FINAL_TITLE event:", message.data); console.debug("FINAL_TITLE event:", message.data);
if (message.data) {
setTitle(message.data.title);
}
break;
case "WAVEFORM":
console.debug(
"WAVEFORM event length:",
message.data.waveform.length,
);
if (message.data) {
setWaveForm(message.data.waveform);
}
break;
case "DURATION":
console.debug("DURATION event:", message.data);
if (message.data) {
setDuration(message.data.duration);
}
break; break;
case "STATUS": case "STATUS":
console.log("STATUS event:", message.data); console.log("STATUS event:", message.data);
if (message.data.value === "ended") {
const newUrl = "/transcripts/" + transcriptId;
router.push(newUrl);
console.debug("FINAL_LONG_SUMMARY event:", message.data);
}
if (message.data.value === "error") { if (message.data.value === "error") {
const newUrl = "/transcripts/" + transcriptId;
router.push(newUrl);
setError( setError(
Error("Websocket error status"), Error("Websocket error status"),
"There was an error processing this meeting.", "There was an error processing this meeting.",
); );
} }
setStatus(message.data); setStatus(message.data);
if (message.data.value === "ended") {
ws.close();
}
break; break;
default: default:
@@ -382,13 +402,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
console.debug("WebSocket connection closed"); console.debug("WebSocket connection closed");
switch (event.code) { switch (event.code) {
case 1000: // Normal Closure: case 1000: // Normal Closure:
case 1001: // Going Away: case 1005: // Closure by client FF
case 1005:
break;
default: default:
setError( setError(
new Error(`WebSocket closed unexpectedly with code: ${event.code}`), new Error(`WebSocket closed unexpectedly with code: ${event.code}`),
"Disconnected",
); );
console.log(
"Socket is closed. Reconnect will be attempted in 1 second.",
event.reason,
);
setTimeout(function () {
ws = new WebSocket(url);
}, 1000);
} }
}; };
@@ -397,5 +423,14 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
}; };
}, [transcriptId]); }, [transcriptId]);
return { transcriptText, translateText, topics, finalSummary, status }; return {
transcriptText,
translateText,
topics,
finalSummary,
title,
status,
waveform,
duration,
};
}; };

View File

@@ -0,0 +1,11 @@
import { faSpinner } from "@fortawesome/free-solid-svg-icons";
import { FontAwesomeIcon } from "@fortawesome/react-fontawesome";
export default () => (
<div className="flex flex-grow items-center justify-center h-20">
<FontAwesomeIcon
icon={faSpinner}
className="animate-spin-slow text-gray-600 flex-grow rounded-lg md:rounded-xl h-10 w-10"
/>
</div>
);

View File

@@ -25,6 +25,12 @@ export interface GetTranscript {
* @memberof GetTranscript * @memberof GetTranscript
*/ */
id: any | null; id: any | null;
/**
*
* @type {any}
* @memberof GetTranscript
*/
userId: any | null;
/** /**
* *
* @type {any} * @type {any}
@@ -73,6 +79,12 @@ export interface GetTranscript {
* @memberof GetTranscript * @memberof GetTranscript
*/ */
createdAt: any | null; createdAt: any | null;
/**
*
* @type {any}
* @memberof GetTranscript
*/
shareMode?: any | null;
/** /**
* *
* @type {any} * @type {any}
@@ -93,6 +105,7 @@ export interface GetTranscript {
export function instanceOfGetTranscript(value: object): boolean { export function instanceOfGetTranscript(value: object): boolean {
let isInstance = true; let isInstance = true;
isInstance = isInstance && "id" in value; isInstance = isInstance && "id" in value;
isInstance = isInstance && "userId" in value;
isInstance = isInstance && "name" in value; isInstance = isInstance && "name" in value;
isInstance = isInstance && "status" in value; isInstance = isInstance && "status" in value;
isInstance = isInstance && "locked" in value; isInstance = isInstance && "locked" in value;
@@ -120,6 +133,7 @@ export function GetTranscriptFromJSONTyped(
} }
return { return {
id: json["id"], id: json["id"],
userId: json["user_id"],
name: json["name"], name: json["name"],
status: json["status"], status: json["status"],
locked: json["locked"], locked: json["locked"],
@@ -128,6 +142,7 @@ export function GetTranscriptFromJSONTyped(
shortSummary: json["short_summary"], shortSummary: json["short_summary"],
longSummary: json["long_summary"], longSummary: json["long_summary"],
createdAt: json["created_at"], createdAt: json["created_at"],
shareMode: !exists(json, "share_mode") ? undefined : json["share_mode"],
sourceLanguage: json["source_language"], sourceLanguage: json["source_language"],
targetLanguage: json["target_language"], targetLanguage: json["target_language"],
}; };
@@ -142,6 +157,7 @@ export function GetTranscriptToJSON(value?: GetTranscript | null): any {
} }
return { return {
id: value.id, id: value.id,
user_id: value.userId,
name: value.name, name: value.name,
status: value.status, status: value.status,
locked: value.locked, locked: value.locked,
@@ -150,6 +166,7 @@ export function GetTranscriptToJSON(value?: GetTranscript | null): any {
short_summary: value.shortSummary, short_summary: value.shortSummary,
long_summary: value.longSummary, long_summary: value.longSummary,
created_at: value.createdAt, created_at: value.createdAt,
share_mode: value.shareMode,
source_language: value.sourceLanguage, source_language: value.sourceLanguage,
target_language: value.targetLanguage, target_language: value.targetLanguage,
}; };

View File

@@ -49,6 +49,12 @@ export interface UpdateTranscript {
* @memberof UpdateTranscript * @memberof UpdateTranscript
*/ */
longSummary?: any | null; longSummary?: any | null;
/**
*
* @type {any}
* @memberof UpdateTranscript
*/
shareMode?: any | null;
} }
/** /**
@@ -81,6 +87,7 @@ export function UpdateTranscriptFromJSONTyped(
longSummary: !exists(json, "long_summary") longSummary: !exists(json, "long_summary")
? undefined ? undefined
: json["long_summary"], : json["long_summary"],
shareMode: !exists(json, "share_mode") ? undefined : json["share_mode"],
}; };
} }
@@ -97,5 +104,6 @@ export function UpdateTranscriptToJSON(value?: UpdateTranscript | null): any {
title: value.title, title: value.title,
short_summary: value.shortSummary, short_summary: value.shortSummary,
long_summary: value.longSummary, long_summary: value.longSummary,
share_mode: value.shareMode,
}; };
} }

View File

@@ -1,5 +1,8 @@
function shouldShowError(error: Error | null | undefined) { function shouldShowError(error: Error | null | undefined) {
if (error?.name == "ResponseError" && error["response"].status == 404) if (
error?.name == "ResponseError" &&
(error["response"].status == 404 || error["response"].status == 403)
)
return false; return false;
if (error?.name == "FetchError") return false; if (error?.name == "FetchError") return false;
return true; return true;

View File

@@ -66,10 +66,6 @@ export const getFiefAuthMiddleware = async (url) => {
matcher: "/transcripts", matcher: "/transcripts",
parameters: {}, parameters: {},
}, },
{
matcher: "/transcripts/((?!new).*)",
parameters: {},
},
{ {
matcher: "/browse", matcher: "/browse",
parameters: {}, parameters: {},

View File

@@ -4,17 +4,19 @@ import { DefaultApi } from "../api/apis/DefaultApi";
import { useFiefAccessTokenInfo } from "@fief/fief/nextjs/react"; import { useFiefAccessTokenInfo } from "@fief/fief/nextjs/react";
import { useContext, useEffect, useState } from "react"; import { useContext, useEffect, useState } from "react";
import { DomainContext, featureEnabled } from "../[domain]/domainContext"; import { DomainContext, featureEnabled } from "../[domain]/domainContext";
import { CookieContext } from "../(auth)/fiefWrapper";
export default function getApi(protectedPath: boolean): DefaultApi | undefined { export default function getApi(): DefaultApi | undefined {
const accessTokenInfo = useFiefAccessTokenInfo(); const accessTokenInfo = useFiefAccessTokenInfo();
const api_url = useContext(DomainContext).api_url; const api_url = useContext(DomainContext).api_url;
const requireLogin = featureEnabled("requireLogin"); const requireLogin = featureEnabled("requireLogin");
const [api, setApi] = useState<DefaultApi>(); const [api, setApi] = useState<DefaultApi>();
const { hasAuthCookie } = useContext(CookieContext);
if (!api_url) throw new Error("no API URL"); if (!api_url) throw new Error("no API URL");
useEffect(() => { useEffect(() => {
if (protectedPath && requireLogin && !accessTokenInfo) { if (hasAuthCookie && requireLogin && !accessTokenInfo) {
return; return;
} }
@@ -25,7 +27,7 @@ export default function getApi(protectedPath: boolean): DefaultApi | undefined {
: undefined, : undefined,
}); });
setApi(new DefaultApi(apiConfiguration)); setApi(new DefaultApi(apiConfiguration));
}, [!accessTokenInfo, protectedPath]); }, [!accessTokenInfo, hasAuthCookie]);
return api; return api;
} }

View File

@@ -35,3 +35,8 @@ body.is-light-mode .input-container {
max-width: 100%; max-width: 100%;
width: auto; width: auto;
} }
body .select-search-container .select-search--top.select-search-select {
top: auto;
bottom: 46px;
}