diff --git a/README.md b/README.md index 627de235..b18264c0 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,7 @@ It also uses https://github.com/fief-dev for authentication, and Vercel for depl - [OpenAPI Code Generation](#openapi-code-generation) - [Back-End](#back-end) - [Installation](#installation-1) - - [Start the project](#start-the-project) + - [Start the API/Backend](#start-the-apibackend) - [Using docker](#using-docker) - [Using local GPT4All](#using-local-gpt4all) - [Using local files](#using-local-files) @@ -133,15 +133,15 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run ``` -### Start the project +### Start the API/Backend -Use: +Start the API server: ```bash poetry run python3 -m reflector.app ``` -And start the background worker +Start the background worker: ```bash celery -A reflector.worker.app worker --loglevel=info @@ -153,6 +153,12 @@ Redis: TODO ``` +For crontab (only healthcheck for now), start the celery beat (you don't need it on your local dev environment): + +```bash +celery -A reflector.worker.app beat +``` + #### Using docker Use: diff --git a/server/gpu/modal/reflector_diarizer.py b/server/gpu/modal/reflector_diarizer.py new file mode 100644 index 00000000..b1989a11 --- /dev/null +++ b/server/gpu/modal/reflector_diarizer.py @@ -0,0 +1,188 @@ +""" +Reflector GPU backend - diarizer +=================================== +""" + +import os + +import modal.gpu +from modal import Image, Secret, Stub, asgi_app, method +from pydantic import BaseModel + +PYANNOTE_MODEL_NAME: str = "pyannote/speaker-diarization-3.0" +MODEL_DIR = "/root/diarization_models" + +stub = Stub(name="reflector-diarizer") + + +def migrate_cache_llm(): + """ + XXX The cache for model files in Transformers v4.22.0 has been updated. + Migrating your old cache. This is a one-time only operation. You can + interrupt this and resume the migration later on by calling + `transformers.utils.move_cache()`. + """ + from transformers.utils.hub import move_cache + + print("Moving LLM cache") + move_cache(cache_dir=MODEL_DIR, new_cache_dir=MODEL_DIR) + print("LLM cache moved") + + +def download_pyannote_audio(): + from pyannote.audio import Pipeline + Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.0", + cache_dir=MODEL_DIR, + use_auth_token="***REMOVED***" + ) + + +diarizer_image = ( + Image.debian_slim(python_version="3.10.8") + .pip_install( + "pyannote.audio", + "requests", + "onnx", + "torchaudio", + "onnxruntime-gpu", + "torch==2.0.0", + "transformers==4.34.0", + "sentencepiece", + "protobuf", + "numpy", + "huggingface_hub", + "hf-transfer" + ) + .run_function(migrate_cache_llm) + .run_function(download_pyannote_audio) + .env( + { + "LD_LIBRARY_PATH": ( + "/usr/local/lib/python3.10/site-packages/nvidia/cudnn/lib/:" + "/opt/conda/lib/python3.10/site-packages/nvidia/cublas/lib/" + ) + } + ) +) + + +@stub.cls( + gpu=modal.gpu.A100(memory=40), + timeout=60 * 30, + container_idle_timeout=60, + allow_concurrent_inputs=1, + image=diarizer_image, +) +class Diarizer: + def __enter__(self): + import torch + from pyannote.audio import Pipeline + + self.use_gpu = torch.cuda.is_available() + self.device = "cuda" if self.use_gpu else "cpu" + self.diarization_pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.0", + cache_dir=MODEL_DIR + ) + self.diarization_pipeline.to(torch.device(self.device)) + + @method() + def diarize( + self, + audio_data: str, + audio_suffix: str, + timestamp: float + ): + import tempfile + + import torchaudio + + with tempfile.NamedTemporaryFile("wb+", suffix=f".{audio_suffix}") as fp: + fp.write(audio_data) + + print("Diarizing audio") + waveform, sample_rate = torchaudio.load(fp.name) + diarization = self.diarization_pipeline({"waveform": waveform, "sample_rate": sample_rate}) + + words = [] + for diarization_segment, _, speaker in diarization.itertracks(yield_label=True): + words.append( + { + "start": round(timestamp + diarization_segment.start, 3), + "end": round(timestamp + diarization_segment.end, 3), + "speaker": int(speaker[-2:]) + } + ) + print("Diarization complete") + return { + "diarization": words + } + +# ------------------------------------------------------------------- +# Web API +# ------------------------------------------------------------------- + + +@stub.function( + timeout=60 * 10, + container_idle_timeout=60 * 3, + allow_concurrent_inputs=40, + secrets=[ + Secret.from_name("reflector-gpu"), + ], + image=diarizer_image +) +@asgi_app() +def web(): + import requests + from fastapi import Depends, FastAPI, HTTPException, status + from fastapi.security import OAuth2PasswordBearer + + diarizerstub = Diarizer() + + app = FastAPI() + + oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + + def apikey_auth(apikey: str = Depends(oauth2_scheme)): + if apikey != os.environ["REFLECTOR_GPU_APIKEY"]: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid API key", + headers={"WWW-Authenticate": "Bearer"}, + ) + + def validate_audio_file(audio_file_url: str): + # Check if the audio file exists + response = requests.head(audio_file_url, allow_redirects=True) + if response.status_code == 404: + raise HTTPException( + status_code=response.status_code, + detail="The audio file does not exist." + ) + + class DiarizationResponse(BaseModel): + result: dict + + @app.post("/diarize", dependencies=[Depends(apikey_auth), Depends(validate_audio_file)]) + def diarize( + audio_file_url: str, + timestamp: float = 0.0 + ) -> HTTPException | DiarizationResponse: + # Currently the uploaded files are in mp3 format + audio_suffix = "mp3" + + print("Downloading audio file") + response = requests.get(audio_file_url, allow_redirects=True) + print("Audio file downloaded successfully") + + func = diarizerstub.diarize.spawn( + audio_data=response.content, + audio_suffix=audio_suffix, + timestamp=timestamp + ) + result = func.get() + return result + + return app diff --git a/server/migrations/versions/0fea6d96b096_add_share_mode.py b/server/migrations/versions/0fea6d96b096_add_share_mode.py new file mode 100644 index 00000000..48746c3b --- /dev/null +++ b/server/migrations/versions/0fea6d96b096_add_share_mode.py @@ -0,0 +1,33 @@ +"""add share_mode + +Revision ID: 0fea6d96b096 +Revises: f819277e5169 +Create Date: 2023-11-07 11:12:21.614198 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "0fea6d96b096" +down_revision: Union[str, None] = "f819277e5169" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "transcript", + sa.Column("share_mode", sa.String(), server_default="private", nullable=False), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("transcript", "share_mode") + # ### end Alembic commands ### diff --git a/server/migrations/versions/125031f7cb78_participants.py b/server/migrations/versions/125031f7cb78_participants.py new file mode 100644 index 00000000..c345b083 --- /dev/null +++ b/server/migrations/versions/125031f7cb78_participants.py @@ -0,0 +1,30 @@ +"""participants + +Revision ID: 125031f7cb78 +Revises: 0fea6d96b096 +Create Date: 2023-11-30 15:56:03.341466 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '125031f7cb78' +down_revision: Union[str, None] = '0fea6d96b096' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('transcript', sa.Column('participants', sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('transcript', 'participants') + # ### end Alembic commands ### diff --git a/server/migrations/versions/f819277e5169_audio_location.py b/server/migrations/versions/f819277e5169_audio_location.py new file mode 100644 index 00000000..061abec4 --- /dev/null +++ b/server/migrations/versions/f819277e5169_audio_location.py @@ -0,0 +1,35 @@ +"""audio_location + +Revision ID: f819277e5169 +Revises: 4814901632bc +Create Date: 2023-11-16 10:29:09.351664 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = "f819277e5169" +down_revision: Union[str, None] = "4814901632bc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "transcript", + sa.Column( + "audio_location", sa.String(), server_default="local", nullable=False + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("transcript", "audio_location") + # ### end Alembic commands ### diff --git a/server/reflector/app.py b/server/reflector/app.py index 5bfffeca..8f45efd5 100644 --- a/server/reflector/app.py +++ b/server/reflector/app.py @@ -13,6 +13,12 @@ from reflector.metrics import metrics_init from reflector.settings import settings from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.transcripts import router as transcripts_router +from reflector.views.transcripts_audio import router as transcripts_audio_router +from reflector.views.transcripts_participants import ( + router as transcripts_participants_router, +) +from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router +from reflector.views.transcripts_websocket import router as transcripts_websocket_router from reflector.views.user import router as user_router try: @@ -60,6 +66,10 @@ metrics_init(app, instrumentator) # register views app.include_router(rtc_offer_router) app.include_router(transcripts_router, prefix="/v1") +app.include_router(transcripts_audio_router, prefix="/v1") +app.include_router(transcripts_participants_router, prefix="/v1") +app.include_router(transcripts_websocket_router, prefix="/v1") +app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(user_router, prefix="/v1") add_pagination(app) diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 6ac2e32a..970393d5 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -2,15 +2,16 @@ import json from contextlib import asynccontextmanager from datetime import datetime from pathlib import Path -from typing import Any +from typing import Any, Literal from uuid import uuid4 import sqlalchemy -from pydantic import BaseModel, Field +from fastapi import HTTPException +from pydantic import BaseModel, ConfigDict, Field from reflector.db import database, metadata from reflector.processors.types import Word as ProcessorWord from reflector.settings import settings -from reflector.utils.audio_waveform import get_audio_waveform +from reflector.storage import Storage transcripts = sqlalchemy.Table( "transcript", @@ -26,22 +27,42 @@ transcripts = sqlalchemy.Table( sqlalchemy.Column("long_summary", sqlalchemy.String, nullable=True), sqlalchemy.Column("topics", sqlalchemy.JSON), sqlalchemy.Column("events", sqlalchemy.JSON), + sqlalchemy.Column("participants", sqlalchemy.JSON), sqlalchemy.Column("source_language", sqlalchemy.String, nullable=True), sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True), + sqlalchemy.Column( + "audio_location", + sqlalchemy.String, + nullable=False, + server_default="local", + ), # with user attached, optional sqlalchemy.Column("user_id", sqlalchemy.String), + sqlalchemy.Column( + "share_mode", + sqlalchemy.String, + nullable=False, + server_default="private", + ), ) -def generate_uuid4(): +def generate_uuid4() -> str: return str(uuid4()) -def generate_transcript_name(): +def generate_transcript_name() -> str: now = datetime.utcnow() return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" +def get_storage() -> Storage: + return Storage.get_instance( + name=settings.TRANSCRIPT_STORAGE_BACKEND, + settings_prefix="TRANSCRIPT_STORAGE_", + ) + + class AudioWaveform(BaseModel): data: list[float] @@ -79,11 +100,26 @@ class TranscriptFinalTitle(BaseModel): title: str +class TranscriptDuration(BaseModel): + duration: float + + +class TranscriptWaveform(BaseModel): + waveform: list[float] + + class TranscriptEvent(BaseModel): event: str data: dict +class TranscriptParticipant(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str = Field(default_factory=generate_uuid4) + speaker: int | None + name: str + + class Transcript(BaseModel): id: str = Field(default_factory=generate_uuid4) user_id: str | None = None @@ -97,8 +133,11 @@ class Transcript(BaseModel): long_summary: str | None = None topics: list[TranscriptTopic] = [] events: list[TranscriptEvent] = [] + participants: list[TranscriptParticipant] | None = [] source_language: str = "en" target_language: str = "en" + share_mode: Literal["private", "semi-private", "public"] = "private" + audio_location: str = "local" def add_event(self, event: str, data: BaseModel) -> TranscriptEvent: ev = TranscriptEvent(event=event, data=data.model_dump()) @@ -112,27 +151,33 @@ class Transcript(BaseModel): else: self.topics.append(topic) + def upsert_participant(self, participant: TranscriptParticipant): + index = next( + (i for i, p in enumerate(self.participants) if p.id == participant.id), + None, + ) + if index is not None: + self.participants[index] = participant + else: + self.participants.append(participant) + return participant + + def delete_participant(self, participant_id: str): + index = next( + (i for i, p in enumerate(self.participants) if p.id == participant_id), + None, + ) + if index is not None: + del self.participants[index] + def events_dump(self, mode="json"): return [event.model_dump(mode=mode) for event in self.events] def topics_dump(self, mode="json"): return [topic.model_dump(mode=mode) for topic in self.topics] - def convert_audio_to_waveform(self, segments_count=256): - fn = self.audio_waveform_filename - if fn.exists(): - return - waveform = get_audio_waveform( - path=self.audio_mp3_filename, segments_count=segments_count - ) - try: - with open(fn, "w") as fd: - json.dump(waveform, fd) - except Exception: - # remove file if anything happen during the write - fn.unlink(missing_ok=True) - raise - return waveform + def participants_dump(self, mode="json"): + return [participant.model_dump(mode=mode) for participant in self.participants] def unlink(self): self.data_path.unlink(missing_ok=True) @@ -141,6 +186,10 @@ class Transcript(BaseModel): def data_path(self): return Path(settings.DATA_DIR) / self.id + @property + def audio_wav_filename(self): + return self.data_path / "audio.wav" + @property def audio_mp3_filename(self): return self.data_path / "audio.mp3" @@ -149,6 +198,10 @@ class Transcript(BaseModel): def audio_waveform_filename(self): return self.data_path / "audio.json" + @property + def storage_audio_path(self): + return f"{self.id}/audio.mp3" + @property def audio_waveform(self): try: @@ -161,6 +214,40 @@ class Transcript(BaseModel): return AudioWaveform(data=data) + async def get_audio_url(self) -> str: + if self.audio_location == "local": + return self._generate_local_audio_link() + elif self.audio_location == "storage": + return await self._generate_storage_audio_link() + raise Exception(f"Unknown audio location {self.audio_location}") + + async def _generate_storage_audio_link(self) -> str: + return await get_storage().get_file_url(self.storage_audio_path) + + def _generate_local_audio_link(self) -> str: + # we need to create an url to be used for diarization + # we can't use the audio_mp3_filename because it's not accessible + # from the diarization processor + from datetime import timedelta + + from reflector.app import app + from reflector.views.transcripts import create_access_token + + path = app.url_path_for( + "transcript_get_audio_mp3", + transcript_id=self.id, + ) + url = f"{settings.BASE_URL}{path}" + if self.user_id: + # we pass token only if the user_id is set + # otherwise, the audio is public + token = create_access_token( + {"sub": self.user_id}, + expires_delta=timedelta(minutes=15), + ) + url += f"?token={token}" + return url + class TranscriptController: async def get_all( @@ -169,6 +256,7 @@ class TranscriptController: order_by: str | None = None, filter_empty: bool | None = False, filter_recording: bool | None = False, + return_query: bool = False, ) -> list[Transcript]: """ Get all transcripts @@ -195,6 +283,9 @@ class TranscriptController: if filter_recording: query = query.filter(transcripts.c.status != "recording") + if return_query: + return query + results = await database.fetch_all(query) return results @@ -210,6 +301,47 @@ class TranscriptController: return None return Transcript(**result) + async def get_by_id_for_http( + self, + transcript_id: str, + user_id: str | None, + ) -> Transcript: + """ + Get a transcript by ID for HTTP request. + + If not found, it will raise a 404 error. + If the user is not allowed to access the transcript, it will raise a 403 error. + + This method checks the share mode of the transcript and the user_id + to determine if the user can access the transcript. + """ + query = transcripts.select().where(transcripts.c.id == transcript_id) + result = await database.fetch_one(query) + if not result: + raise HTTPException(status_code=404, detail="Transcript not found") + + # if the transcript is anonymous, share mode is not checked + transcript = Transcript(**result) + if transcript.user_id is None: + return transcript + + if transcript.share_mode == "private": + # in private mode, only the owner can access the transcript + if transcript.user_id == user_id: + return transcript + + elif transcript.share_mode == "semi-private": + # in semi-private mode, only the owner and the users with the link + # can access the transcript + if user_id is not None: + return transcript + + elif transcript.share_mode == "public": + # in public mode, everyone can access the transcript + return transcript + + raise HTTPException(status_code=403, detail="Transcript access denied") + async def add( self, name: str, @@ -292,5 +424,45 @@ class TranscriptController: transcript.upsert_topic(topic) await self.update(transcript, {"topics": transcript.topics_dump()}) + async def move_mp3_to_storage(self, transcript: Transcript): + """ + Move mp3 file to storage + """ + + # store the audio on external storage + await get_storage().put_file( + transcript.storage_audio_path, + transcript.audio_mp3_filename.read_bytes(), + ) + + # indicate on the transcript that the audio is now on storage + await self.update(transcript, {"audio_location": "storage"}) + + # unlink the local file + transcript.audio_mp3_filename.unlink(missing_ok=True) + + async def upsert_participant( + self, + transcript: Transcript, + participant: TranscriptParticipant, + ) -> TranscriptParticipant: + """ + Add/update a participant to a transcript + """ + result = transcript.upsert_participant(participant) + await self.update(transcript, {"participants": transcript.participants_dump()}) + return result + + async def delete_participant( + self, + transcript: Transcript, + participant_id: str, + ): + """ + Delete a participant from a transcript + """ + transcript.delete_participant(participant_id) + await self.update(transcript, {"participants": transcript.participants_dump()}) + transcripts_controller = TranscriptController() diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 316ecbcc..b182f421 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -12,20 +12,20 @@ It is directly linked to our data model. """ import asyncio +import functools from contextlib import asynccontextmanager -from datetime import timedelta -from pathlib import Path -from celery import shared_task +from celery import chord, group, shared_task from pydantic import BaseModel -from reflector.app import app from reflector.db.transcripts import ( Transcript, + TranscriptDuration, TranscriptFinalLongSummary, TranscriptFinalShortSummary, TranscriptFinalTitle, TranscriptText, TranscriptTopic, + TranscriptWaveform, transcripts_controller, ) from reflector.logger import logger @@ -45,6 +45,7 @@ from reflector.processors import ( TranscriptTopicDetectorProcessor, TranscriptTranslatorProcessor, ) +from reflector.processors.audio_waveform_processor import AudioWaveformProcessor from reflector.processors.types import AudioDiarizationInput from reflector.processors.types import ( TitleSummaryWithId as TitleSummaryWithIdProcessorType, @@ -52,6 +53,22 @@ from reflector.processors.types import ( from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.settings import settings from reflector.ws_manager import WebsocketManager, get_ws_manager +from structlog import BoundLogger as Logger + + +def asynctask(f): + @functools.wraps(f) + def wrapper(*args, **kwargs): + coro = f(*args, **kwargs) + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + if loop and loop.is_running(): + return loop.run_until_complete(coro) + return asyncio.run(coro) + + return wrapper def broadcast_to_sockets(func): @@ -72,6 +89,26 @@ def broadcast_to_sockets(func): return wrapper +def get_transcript(func): + """ + Decorator to fetch the transcript from the database from the first argument + """ + + async def wrapper(**kwargs): + transcript_id = kwargs.pop("transcript_id") + transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id) + if not transcript: + raise Exception("Transcript {transcript_id} not found") + tlogger = logger.bind(transcript_id=transcript.id) + try: + return await func(transcript=transcript, logger=tlogger, **kwargs) + except Exception as exc: + tlogger.error("Pipeline error", exc_info=exc) + raise + + return wrapper + + class StrValue(BaseModel): value: str @@ -96,6 +133,19 @@ class PipelineMainBase(PipelineRunner): raise Exception("Transcript not found") return result + def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]: + return [ + TitleSummaryWithIdProcessorType( + id=topic.id, + title=topic.title, + summary=topic.summary, + timestamp=topic.timestamp, + duration=topic.duration, + transcript=TranscriptProcessorType(words=topic.words), + ) + for topic in transcript.topics + ] + @asynccontextmanager async def transaction(self): async with self._lock: @@ -113,7 +163,7 @@ class PipelineMainBase(PipelineRunner): "flush": "processing", "error": "error", } - elif isinstance(self, PipelineMainDiarization): + elif isinstance(self, PipelineMainFinalSummaries): status_mapping = { "push": "processing", "flush": "processing", @@ -121,7 +171,8 @@ class PipelineMainBase(PipelineRunner): "ended": "ended", } else: - raise Exception(f"Runner {self.__class__} is missing status mapping") + # intermediate pipeline don't update status + return # mutate to model status status = status_mapping.get(status) @@ -230,21 +281,39 @@ class PipelineMainBase(PipelineRunner): data=final_short_summary, ) - async def on_duration(self, duration: float): + @broadcast_to_sockets + async def on_duration(self, data): async with self.transaction(): + duration = TranscriptDuration(duration=data) + transcript = await self.get_transcript() await transcripts_controller.update( transcript, { - "duration": duration, + "duration": duration.duration, }, ) + return await transcripts_controller.append_event( + transcript=transcript, event="DURATION", data=duration + ) + + @broadcast_to_sockets + async def on_waveform(self, data): + async with self.transaction(): + waveform = TranscriptWaveform(waveform=data) + + transcript = await self.get_transcript() + + return await transcripts_controller.append_event( + transcript=transcript, event="WAVEFORM", data=waveform + ) class PipelineMainLive(PipelineMainBase): - audio_filename: Path | None = None - source_language: str = "en" - target_language: str = "en" + """ + Main pipeline for live streaming, attach to RTC connection + Any long post process should be done in the post pipeline + """ async def create(self) -> Pipeline: # create a context for the whole rtc transaction @@ -254,7 +323,7 @@ class PipelineMainLive(PipelineMainBase): processors = [ AudioFileWriterProcessor( - path=transcript.audio_mp3_filename, + path=transcript.audio_wav_filename, on_duration=self.on_duration, ), AudioChunkerProcessor(), @@ -263,17 +332,13 @@ class PipelineMainLive(PipelineMainBase): TranscriptLinerProcessor(), TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript), TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic), - TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), ] pipeline = Pipeline(*processors) pipeline.options = self pipeline.set_pref("audio:source_language", transcript.source_language) pipeline.set_pref("audio:target_language", transcript.target_language) pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info( - "Pipeline main live created", - transcript_id=self.transcript_id, - ) + pipeline.logger.info("Pipeline main live created") return pipeline @@ -281,26 +346,106 @@ class PipelineMainLive(PipelineMainBase): # when the pipeline ends, connect to the post pipeline logger.info("Pipeline main live ended", transcript_id=self.transcript_id) logger.info("Scheduling pipeline main post", transcript_id=self.transcript_id) - task_pipeline_main_post.delay(transcript_id=self.transcript_id) + pipeline_post(transcript_id=self.transcript_id) class PipelineMainDiarization(PipelineMainBase): """ - Diarization is a long time process, so we do it in a separate pipeline - When done, adjust the short and final summary + Diarize the audio and update topics """ async def create(self) -> Pipeline: # create a context for the whole rtc transaction # add a customised logger to the context self.prepare() - processors = [] - if settings.DIARIZATION_ENABLED: - processors += [ - AudioDiarizationAutoProcessor(callback=self.on_topic), - ] + pipeline = Pipeline( + AudioDiarizationAutoProcessor(callback=self.on_topic), + ) + pipeline.options = self - processors += [ + # now let's start the pipeline by pushing information to the + # first processor diarization processor + # XXX translation is lost when converting our data model to the processor model + transcript = await self.get_transcript() + + # diarization works only if the file is uploaded to an external storage + if transcript.audio_location == "local": + pipeline.logger.info("Audio is local, skipping diarization") + return + + topics = self.get_transcript_topics(transcript) + audio_url = await transcript.get_audio_url() + audio_diarization_input = AudioDiarizationInput( + audio_url=audio_url, + topics=topics, + ) + + # as tempting to use pipeline.push, prefer to use the runner + # to let the start just do one job. + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info("Diarization pipeline created") + self.push(audio_diarization_input) + self.flush() + + return pipeline + + +class PipelineMainFromTopics(PipelineMainBase): + """ + Pseudo class for generating a pipeline from topics + """ + + def get_processors(self) -> list: + raise NotImplementedError + + async def create(self) -> Pipeline: + self.prepare() + + # get transcript + self._transcript = transcript = await self.get_transcript() + + # create pipeline + processors = self.get_processors() + pipeline = Pipeline(*processors) + pipeline.options = self + pipeline.logger.bind(transcript_id=transcript.id) + pipeline.logger.info(f"{self.__class__.__name__} pipeline created") + + # push topics + topics = self.get_transcript_topics(transcript) + for topic in topics: + self.push(topic) + + self.flush() + + return pipeline + + +class PipelineMainTitleAndShortSummary(PipelineMainFromTopics): + """ + Generate title from the topics + """ + + def get_processors(self) -> list: + return [ + BroadcastProcessor( + processors=[ + TranscriptFinalTitleProcessor.as_threaded(callback=self.on_title), + TranscriptFinalShortSummaryProcessor.as_threaded( + callback=self.on_short_summary + ), + ] + ) + ] + + +class PipelineMainFinalSummaries(PipelineMainFromTopics): + """ + Generate summaries from the topics + """ + + def get_processors(self) -> list: + return [ BroadcastProcessor( processors=[ TranscriptFinalLongSummaryProcessor.as_threaded( @@ -312,65 +457,164 @@ class PipelineMainDiarization(PipelineMainBase): ] ), ] - pipeline = Pipeline(*processors) - pipeline.options = self - # now let's start the pipeline by pushing information to the - # first processor diarization processor - # XXX translation is lost when converting our data model to the processor model - transcript = await self.get_transcript() - topics = [ - TitleSummaryWithIdProcessorType( - id=topic.id, - title=topic.title, - summary=topic.summary, - timestamp=topic.timestamp, - duration=topic.duration, - transcript=TranscriptProcessorType(words=topic.words), - ) - for topic in transcript.topics + +class PipelineMainWaveform(PipelineMainFromTopics): + """ + Generate waveform + """ + + def get_processors(self) -> list: + return [ + AudioWaveformProcessor.as_threaded( + audio_path=self._transcript.audio_wav_filename, + waveform_path=self._transcript.audio_waveform_filename, + on_waveform=self.on_waveform, + ), ] - # we need to create an url to be used for diarization - # we can't use the audio_mp3_filename because it's not accessible - # from the diarization processor - from reflector.views.transcripts import create_access_token - path = app.url_path_for( - "transcript_get_audio_mp3", - transcript_id=transcript.id, - ) - url = f"{settings.BASE_URL}{path}" - if transcript.user_id: - # we pass token only if the user_id is set - # otherwise, the audio is public - token = create_access_token( - {"sub": transcript.user_id}, - expires_delta=timedelta(minutes=15), - ) - url += f"?token={token}" - audio_diarization_input = AudioDiarizationInput( - audio_url=url, - topics=topics, - ) +@get_transcript +async def pipeline_waveform(transcript: Transcript, logger: Logger): + logger.info("Starting waveform") + runner = PipelineMainWaveform(transcript_id=transcript.id) + await runner.run() + logger.info("Waveform done") - # as tempting to use pipeline.push, prefer to use the runner - # to let the start just do one job. - pipeline.logger.bind(transcript_id=transcript.id) - pipeline.logger.info( - "Pipeline main post created", transcript_id=self.transcript_id - ) - self.push(audio_diarization_input) - self.flush() - return pipeline +@get_transcript +async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger): + logger.info("Starting convert to mp3") + + # If the audio wav is not available, just skip + wav_filename = transcript.audio_wav_filename + if not wav_filename.exists(): + logger.warning("Wav file not found, may be already converted") + return + + # Convert to mp3 + mp3_filename = transcript.audio_mp3_filename + + import av + + with av.open(wav_filename.as_posix()) as in_container: + in_stream = in_container.streams.audio[0] + with av.open(mp3_filename.as_posix(), "w") as out_container: + out_stream = out_container.add_stream("mp3") + for frame in in_container.decode(in_stream): + for packet in out_stream.encode(frame): + out_container.mux(packet) + + # Delete the wav file + transcript.audio_wav_filename.unlink(missing_ok=True) + + logger.info("Convert to mp3 done") + + +@get_transcript +async def pipeline_upload_mp3(transcript: Transcript, logger: Logger): + if not settings.TRANSCRIPT_STORAGE_BACKEND: + logger.info("No storage backend configured, skipping mp3 upload") + return + + logger.info("Starting upload mp3") + + # If the audio mp3 is not available, just skip + mp3_filename = transcript.audio_mp3_filename + if not mp3_filename.exists(): + logger.warning("Mp3 file not found, may be already uploaded") + return + + # Upload to external storage and delete the file + await transcripts_controller.move_mp3_to_storage(transcript) + + logger.info("Upload mp3 done") + + +@get_transcript +async def pipeline_diarization(transcript: Transcript, logger: Logger): + logger.info("Starting diarization") + runner = PipelineMainDiarization(transcript_id=transcript.id) + await runner.run() + logger.info("Diarization done") + + +@get_transcript +async def pipeline_title_and_short_summary(transcript: Transcript, logger: Logger): + logger.info("Starting title and short summary") + runner = PipelineMainTitleAndShortSummary(transcript_id=transcript.id) + await runner.run() + logger.info("Title and short summary done") + + +@get_transcript +async def pipeline_summaries(transcript: Transcript, logger: Logger): + logger.info("Starting summaries") + runner = PipelineMainFinalSummaries(transcript_id=transcript.id) + await runner.run() + logger.info("Summaries done") + + +# =================================================================== +# Celery tasks that can be called from the API +# =================================================================== @shared_task -def task_pipeline_main_post(transcript_id: str): - logger.info( - "Starting main post pipeline", - transcript_id=transcript_id, +@asynctask +async def task_pipeline_waveform(*, transcript_id: str): + await pipeline_waveform(transcript_id=transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_convert_to_mp3(*, transcript_id: str): + await pipeline_convert_to_mp3(transcript_id=transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_upload_mp3(*, transcript_id: str): + await pipeline_upload_mp3(transcript_id=transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_diarization(*, transcript_id: str): + await pipeline_diarization(transcript_id=transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_title_and_short_summary(*, transcript_id: str): + await pipeline_title_and_short_summary(transcript_id=transcript_id) + + +@shared_task +@asynctask +async def task_pipeline_final_summaries(*, transcript_id: str): + await pipeline_summaries(transcript_id=transcript_id) + + +def pipeline_post(*, transcript_id: str): + """ + Run the post pipeline + """ + chain_mp3_and_diarize = ( + task_pipeline_waveform.si(transcript_id=transcript_id) + | task_pipeline_convert_to_mp3.si(transcript_id=transcript_id) + | task_pipeline_upload_mp3.si(transcript_id=transcript_id) + | task_pipeline_diarization.si(transcript_id=transcript_id) ) - runner = PipelineMainDiarization(transcript_id=transcript_id) - runner.start_sync() + chain_title_preview = task_pipeline_title_and_short_summary.si( + transcript_id=transcript_id + ) + chain_final_summaries = task_pipeline_final_summaries.si( + transcript_id=transcript_id + ) + + chain = chord( + group(chain_mp3_and_diarize, chain_title_preview), + chain_final_summaries, + ) + chain.delay() diff --git a/server/reflector/pipelines/runner.py b/server/reflector/pipelines/runner.py index a1e137a7..708a4265 100644 --- a/server/reflector/pipelines/runner.py +++ b/server/reflector/pipelines/runner.py @@ -106,6 +106,14 @@ class PipelineRunner(BaseModel): if not self.pipeline: self.pipeline = await self.create() + if not self.pipeline: + # no pipeline created in create, just finish it then. + await self._set_status("ended") + self._ev_done.set() + if self.on_ended: + await self.on_ended() + return + # start the loop await self._set_status("started") while not self._ev_done.is_set(): @@ -119,8 +127,7 @@ class PipelineRunner(BaseModel): self._logger.exception("Runner error") await self._set_status("error") self._ev_done.set() - if self.on_ended: - await self.on_ended() + raise async def cmd_push(self, data): if self._is_first_push: diff --git a/server/reflector/processors/audio_diarization.py b/server/reflector/processors/audio_diarization.py index 82c6a553..69eab5b7 100644 --- a/server/reflector/processors/audio_diarization.py +++ b/server/reflector/processors/audio_diarization.py @@ -1,5 +1,5 @@ from reflector.processors.base import Processor -from reflector.processors.types import AudioDiarizationInput, TitleSummary +from reflector.processors.types import AudioDiarizationInput, TitleSummary, Word class AudioDiarizationProcessor(Processor): @@ -19,12 +19,12 @@ class AudioDiarizationProcessor(Processor): # topics is a list[BaseModel] with an attribute words # words is a list[BaseModel] with text, start and speaker attribute - # mutate in place - for topic in data.topics: - for word in topic.transcript.words: - for d in diarization: - if d["start"] <= word.start <= d["end"]: - word.speaker = d["speaker"] + # create a view of words based on topics + # the current algorithm is using words index, we cannot use a generator + words = list(self.iter_words_from_topics(data.topics)) + + # assign speaker to words (mutate the words list) + self.assign_speaker(words, diarization) # emit them for topic in data.topics: @@ -32,3 +32,150 @@ class AudioDiarizationProcessor(Processor): async def _diarize(self, data: AudioDiarizationInput): raise NotImplementedError + + def assign_speaker(self, words: list[Word], diarization: list[dict]): + self._diarization_remove_overlap(diarization) + self._diarization_remove_segment_without_words(words, diarization) + self._diarization_merge_same_speaker(words, diarization) + self._diarization_assign_speaker(words, diarization) + + def iter_words_from_topics(self, topics: TitleSummary): + for topic in topics: + for word in topic.transcript.words: + yield word + + def is_word_continuation(self, word_prev, word): + """ + Return True if the word is a continuation of the previous word + by checking if the previous word is ending with a punctuation + or if the current word is starting with a capital letter + """ + # is word_prev ending with a punctuation ? + if word_prev.text and word_prev.text[-1] in ".?!": + return False + elif word.text and word.text[0].isupper(): + return False + return True + + def _diarization_remove_overlap(self, diarization: list[dict]): + """ + Remove overlap in diarization results + + When using a diarization algorithm, it's possible to have overlapping segments + This function remove the overlap by keeping the longest segment + + Warning: this function mutate the diarization list + """ + # remove overlap by keeping the longest segment + diarization_idx = 0 + while diarization_idx < len(diarization) - 1: + d = diarization[diarization_idx] + dnext = diarization[diarization_idx + 1] + if d["end"] > dnext["start"]: + # remove the shortest segment + if d["end"] - d["start"] > dnext["end"] - dnext["start"]: + # remove next segment + diarization.pop(diarization_idx + 1) + else: + # remove current segment + diarization.pop(diarization_idx) + else: + diarization_idx += 1 + + def _diarization_remove_segment_without_words( + self, words: list[Word], diarization: list[dict] + ): + """ + Remove diarization segments without words + + Warning: this function mutate the diarization list + """ + # count the number of words for each diarization segment + diarization_count = [] + for d in diarization: + start = d["start"] + end = d["end"] + count = 0 + for word in words: + if start <= word.start < end: + count += 1 + elif start < word.end <= end: + count += 1 + diarization_count.append(count) + + # remove diarization segments with no words + diarization_idx = 0 + while diarization_idx < len(diarization): + if diarization_count[diarization_idx] == 0: + diarization.pop(diarization_idx) + diarization_count.pop(diarization_idx) + else: + diarization_idx += 1 + + def _diarization_merge_same_speaker( + self, words: list[Word], diarization: list[dict] + ): + """ + Merge diarization contigous segments with the same speaker + + Warning: this function mutate the diarization list + """ + # merge segment with same speaker + diarization_idx = 0 + while diarization_idx < len(diarization) - 1: + d = diarization[diarization_idx] + dnext = diarization[diarization_idx + 1] + if d["speaker"] == dnext["speaker"]: + diarization[diarization_idx]["end"] = dnext["end"] + diarization.pop(diarization_idx + 1) + else: + diarization_idx += 1 + + def _diarization_assign_speaker(self, words: list[Word], diarization: list[dict]): + """ + Assign speaker to words based on diarization + + Warning: this function mutate the words list + """ + + word_idx = 0 + last_speaker = None + for d in diarization: + start = d["start"] + end = d["end"] + speaker = d["speaker"] + + # diarization may start after the first set of words + # in this case, we assign the last speaker + for word in words[word_idx:]: + if word.start < start: + # speaker change, but what make sense for assigning the word ? + # If it's a new sentence, assign with the new speaker + # If it's a continuation, assign with the last speaker + is_continuation = False + if word_idx > 0 and word_idx < len(words) - 1: + is_continuation = self.is_word_continuation( + *words[word_idx - 1 : word_idx + 1] + ) + if is_continuation: + word.speaker = last_speaker + else: + word.speaker = speaker + last_speaker = speaker + word_idx += 1 + else: + break + + # now continue to assign speaker until the word starts after the end + for word in words[word_idx:]: + if start <= word.start < end: + last_speaker = speaker + word.speaker = speaker + word_idx += 1 + elif word.start > end: + break + + # no more diarization available, + # assign last speaker to all words without speaker + for word in words[word_idx:]: + word.speaker = last_speaker diff --git a/server/reflector/processors/audio_diarization_modal.py b/server/reflector/processors/audio_diarization_modal.py index 53de2501..511b7f70 100644 --- a/server/reflector/processors/audio_diarization_modal.py +++ b/server/reflector/processors/audio_diarization_modal.py @@ -31,7 +31,7 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor): follow_redirects=True, ) response.raise_for_status() - return response.json()["text"] + return response.json()["diarization"] AudioDiarizationAutoProcessor.register("modal", AudioDiarizationModalProcessor) diff --git a/server/reflector/processors/audio_waveform_processor.py b/server/reflector/processors/audio_waveform_processor.py new file mode 100644 index 00000000..f1a24ffd --- /dev/null +++ b/server/reflector/processors/audio_waveform_processor.py @@ -0,0 +1,36 @@ +import json +from pathlib import Path + +from reflector.processors.base import Processor +from reflector.processors.types import TitleSummary +from reflector.utils.audio_waveform import get_audio_waveform + + +class AudioWaveformProcessor(Processor): + """ + Write the waveform for the final audio + """ + + INPUT_TYPE = TitleSummary + + def __init__(self, audio_path: Path | str, waveform_path: str, **kwargs): + super().__init__(**kwargs) + if isinstance(audio_path, str): + audio_path = Path(audio_path) + if audio_path.suffix not in (".mp3", ".wav"): + raise ValueError("Only mp3 and wav files are supported") + self.audio_path = audio_path + self.waveform_path = waveform_path + + async def _flush(self): + self.waveform_path.parent.mkdir(parents=True, exist_ok=True) + self.logger.info("Waveform Processing Started") + waveform = get_audio_waveform(path=self.audio_path, segments_count=255) + + with open(self.waveform_path, "w") as fd: + json.dump(waveform, fd) + self.logger.info("Waveform Processing Finished") + await self.emit(waveform, name="waveform") + + async def _push(_self, _data): + return diff --git a/server/reflector/settings.py b/server/reflector/settings.py index 65412310..d0ddc91a 100644 --- a/server/reflector/settings.py +++ b/server/reflector/settings.py @@ -54,7 +54,7 @@ class Settings(BaseSettings): TRANSCRIPT_MODAL_API_KEY: str | None = None # Audio transcription storage - TRANSCRIPT_STORAGE_BACKEND: str = "aws" + TRANSCRIPT_STORAGE_BACKEND: str | None = None # Storage configuration for AWS TRANSCRIPT_STORAGE_AWS_BUCKET_NAME: str = "reflector-bucket" @@ -62,9 +62,6 @@ class Settings(BaseSettings): TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None - # Transcript MP3 storage - TRANSCRIPT_MP3_STORAGE_BACKEND: str = "aws" - # LLM # available backend: openai, modal, oobabooga LLM_BACKEND: str = "oobabooga" @@ -131,5 +128,8 @@ class Settings(BaseSettings): # Profiling PROFILING: bool = False + # Healthcheck + HEALTHCHECK_URL: str | None = None + settings = Settings() diff --git a/server/reflector/storage/base.py b/server/reflector/storage/base.py index 5cdafdbf..a457ddf8 100644 --- a/server/reflector/storage/base.py +++ b/server/reflector/storage/base.py @@ -1,6 +1,7 @@ +import importlib + from pydantic import BaseModel from reflector.settings import settings -import importlib class FileResult(BaseModel): @@ -17,7 +18,7 @@ class Storage: cls._registry[name] = kclass @classmethod - def get_instance(cls, name, settings_prefix=""): + def get_instance(cls, name: str, settings_prefix: str = ""): if name not in cls._registry: module_name = f"reflector.storage.storage_{name}" importlib.import_module(module_name) @@ -45,3 +46,9 @@ class Storage: async def _delete_file(self, filename: str): raise NotImplementedError + + async def get_file_url(self, filename: str) -> str: + return await self._get_file_url(filename) + + async def _get_file_url(self, filename: str) -> str: + raise NotImplementedError diff --git a/server/reflector/storage/storage_aws.py b/server/reflector/storage/storage_aws.py index 09a9c383..d2313293 100644 --- a/server/reflector/storage/storage_aws.py +++ b/server/reflector/storage/storage_aws.py @@ -1,6 +1,6 @@ import aioboto3 -from reflector.storage.base import Storage, FileResult from reflector.logger import logger +from reflector.storage.base import FileResult, Storage class AwsStorage(Storage): @@ -44,16 +44,18 @@ class AwsStorage(Storage): Body=data, ) + async def _get_file_url(self, filename: str) -> FileResult: + bucket = self.aws_bucket_name + folder = self.aws_folder + s3filename = f"{folder}/{filename}" if folder else filename + async with self.session.client("s3") as client: presigned_url = await client.generate_presigned_url( "get_object", Params={"Bucket": bucket, "Key": s3filename}, ExpiresIn=3600, ) - return FileResult( - filename=filename, - url=presigned_url, - ) + return presigned_url async def _delete_file(self, filename: str): bucket = self.aws_bucket_name diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 5de9ced3..9e62192b 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,31 +1,19 @@ from datetime import datetime, timedelta -from typing import Annotated, Optional +from typing import Annotated, Literal, Optional import reflector.auth as auth -from fastapi import ( - APIRouter, - Depends, - HTTPException, - Request, - WebSocket, - WebSocketDisconnect, - status, -) -from fastapi_pagination import Page, paginate +from fastapi import APIRouter, Depends, HTTPException +from fastapi_pagination import Page +from fastapi_pagination.ext.databases import paginate from jose import jwt from pydantic import BaseModel, Field from reflector.db.transcripts import ( - AudioWaveform, + TranscriptParticipant, TranscriptTopic, transcripts_controller, ) from reflector.processors.types import Transcript as ProcessorTranscript from reflector.settings import settings -from reflector.ws_manager import get_ws_manager -from starlette.concurrency import run_in_threadpool - -from ._range_requests_response import range_requests_response -from .rtc_offer import RtcOffer, rtc_offer_base router = APIRouter() @@ -48,6 +36,7 @@ def create_access_token(data: dict, expires_delta: timedelta): class GetTranscript(BaseModel): id: str + user_id: str | None name: str status: str locked: bool @@ -56,8 +45,10 @@ class GetTranscript(BaseModel): short_summary: str | None long_summary: str | None created_at: datetime + share_mode: str = Field("private") source_language: str | None target_language: str | None + participants: list[TranscriptParticipant] | None class CreateTranscript(BaseModel): @@ -72,6 +63,8 @@ class UpdateTranscript(BaseModel): title: Optional[str] = Field(None) short_summary: Optional[str] = Field(None) long_summary: Optional[str] = Field(None) + share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None) + participants: Optional[list[TranscriptParticipant]] = Field(None) class DeletionStatus(BaseModel): @@ -82,12 +75,19 @@ class DeletionStatus(BaseModel): async def transcripts_list( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): + from reflector.db import database + if not user and not settings.PUBLIC_MODE: raise HTTPException(status_code=401, detail="Not authenticated") user_id = user["sub"] if user else None - return paginate( - await transcripts_controller.get_all(user_id=user_id, order_by="-created_at") + return await paginate( + database, + await transcripts_controller.get_all( + user_id=user_id, + order_by="-created_at", + return_query=True, + ), ) @@ -165,10 +165,9 @@ async def transcript_get( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") - return transcript + return await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) @router.patch("/transcripts/{transcript_id}", response_model=GetTranscript) @@ -181,17 +180,7 @@ async def transcript_update( transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") - values = {} - if info.name is not None: - values["name"] = info.name - if info.locked is not None: - values["locked"] = info.locked - if info.long_summary is not None: - values["long_summary"] = info.long_summary - if info.short_summary is not None: - values["short_summary"] = info.short_summary - if info.title is not None: - values["title"] = info.title + values = info.dict(exclude_unset=True) await transcripts_controller.update(transcript, values) return transcript @@ -209,63 +198,6 @@ async def transcript_delete( return DeletionStatus(status="ok") -@router.get("/transcripts/{transcript_id}/audio/mp3") -@router.head("/transcripts/{transcript_id}/audio/mp3") -async def transcript_get_audio_mp3( - request: Request, - transcript_id: str, - user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], - token: str | None = None, -): - user_id = user["sub"] if user else None - if not user_id and token: - unauthorized_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Invalid or expired token", - headers={"WWW-Authenticate": "Bearer"}, - ) - try: - payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) - user_id: str = payload.get("sub") - except jwt.JWTError: - raise unauthorized_exception - - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") - - if not transcript.audio_mp3_filename.exists(): - raise HTTPException(status_code=404, detail="Audio not found") - - truncated_id = str(transcript.id).split("-")[0] - filename = f"recording_{truncated_id}.mp3" - - return range_requests_response( - request, - transcript.audio_mp3_filename, - content_type="audio/mpeg", - content_disposition=f"attachment; filename={filename}", - ) - - -@router.get("/transcripts/{transcript_id}/audio/waveform") -async def transcript_get_audio_waveform( - transcript_id: str, - user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -) -> AudioWaveform: - user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") - - if not transcript.audio_mp3_filename.exists(): - raise HTTPException(status_code=404, detail="Audio not found") - - await run_in_threadpool(transcript.convert_audio_to_waveform) - - return transcript.audio_waveform - - @router.get( "/transcripts/{transcript_id}/topics", response_model=list[GetTranscriptTopic], @@ -275,92 +207,11 @@ async def transcript_get_topics( user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) # convert to GetTranscriptTopic return [ GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics ] - - -# ============================================================== -# Websocket -# ============================================================== - - -@router.get("/transcripts/{transcript_id}/events") -async def transcript_get_websocket_events(transcript_id: str): - pass - - -@router.websocket("/transcripts/{transcript_id}/events") -async def transcript_events_websocket( - transcript_id: str, - websocket: WebSocket, - # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -): - # user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") - - # connect to websocket manager - # use ts:transcript_id as room id - room_id = f"ts:{transcript_id}" - ws_manager = get_ws_manager() - await ws_manager.add_user_to_room(room_id, websocket) - - try: - # on first connection, send all events only to the current user - for event in transcript.events: - # for now, do not send TRANSCRIPT or STATUS options - theses are live event - # not necessary to be sent to the client; but keep the rest - name = event.event - if name in ("TRANSCRIPT", "STATUS"): - continue - await websocket.send_json(event.model_dump(mode="json")) - - # XXX if transcript is final (locked=True and status=ended) - # XXX send a final event to the client and close the connection - - # endless loop to wait for new events - # we do not have command system now, - while True: - await websocket.receive() - except (RuntimeError, WebSocketDisconnect): - await ws_manager.remove_user_from_room(room_id, websocket) - - -# ============================================================== -# Web RTC -# ============================================================== - - -@router.post("/transcripts/{transcript_id}/record/webrtc") -async def transcript_record_webrtc( - transcript_id: str, - params: RtcOffer, - request: Request, - user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], -): - user_id = user["sub"] if user else None - transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) - if not transcript: - raise HTTPException(status_code=404, detail="Transcript not found") - - if transcript.locked: - raise HTTPException(status_code=400, detail="Transcript is locked") - - # create a pipeline runner - from reflector.pipelines.main_live_pipeline import PipelineMainLive - - pipeline_runner = PipelineMainLive(transcript_id=transcript_id) - - # FIXME do not allow multiple recording at the same time - return await rtc_offer_base( - params, - request, - pipeline_runner=pipeline_runner, - ) diff --git a/server/reflector/views/transcripts_audio.py b/server/reflector/views/transcripts_audio.py new file mode 100644 index 00000000..a174d992 --- /dev/null +++ b/server/reflector/views/transcripts_audio.py @@ -0,0 +1,109 @@ +""" +Transcripts audio related endpoints +=================================== + +""" +from typing import Annotated, Optional + +import httpx +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException, Request, Response, status +from jose import jwt +from reflector.db.transcripts import AudioWaveform, transcripts_controller +from reflector.settings import settings +from reflector.views.transcripts import ALGORITHM + +from ._range_requests_response import range_requests_response + +router = APIRouter() + + +@router.get("/transcripts/{transcript_id}/audio/mp3") +@router.head("/transcripts/{transcript_id}/audio/mp3") +async def transcript_get_audio_mp3( + request: Request, + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], + token: str | None = None, +): + user_id = user["sub"] if user else None + if not user_id and token: + unauthorized_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or expired token", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM]) + user_id: str = payload.get("sub") + except jwt.JWTError: + raise unauthorized_exception + + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if transcript.audio_location == "storage": + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} + + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) + + if transcript.audio_location == "storage": + # proxy S3 file, to prevent issue with CORS + url = await transcript.get_audio_url() + headers = {} + + copy_headers = ["range", "accept-encoding"] + for header in copy_headers: + if header in request.headers: + headers[header] = request.headers[header] + + async with httpx.AsyncClient() as client: + resp = await client.request(request.method, url, headers=headers) + return Response( + content=resp.content, + status_code=resp.status_code, + headers=resp.headers, + ) + + if not transcript.audio_mp3_filename.exists(): + raise HTTPException(status_code=500, detail="Audio not found") + + truncated_id = str(transcript.id).split("-")[0] + filename = f"recording_{truncated_id}.mp3" + + return range_requests_response( + request, + transcript.audio_mp3_filename, + content_type="audio/mpeg", + content_disposition=f"attachment; filename={filename}", + ) + + +@router.get("/transcripts/{transcript_id}/audio/waveform") +async def transcript_get_audio_waveform( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> AudioWaveform: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if not transcript.audio_waveform_filename.exists(): + raise HTTPException(status_code=404, detail="Audio not found") + + return transcript.audio_waveform diff --git a/server/reflector/views/transcripts_participants.py b/server/reflector/views/transcripts_participants.py new file mode 100644 index 00000000..318d6018 --- /dev/null +++ b/server/reflector/views/transcripts_participants.py @@ -0,0 +1,142 @@ +""" +Transcript participants API endpoints +===================================== + +""" +from typing import Annotated, Optional + +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel, ConfigDict, Field +from reflector.db.transcripts import TranscriptParticipant, transcripts_controller +from reflector.views.types import DeletionStatus + +router = APIRouter() + + +class Participant(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: str + speaker: int | None + name: str + + +class CreateParticipant(BaseModel): + speaker: Optional[int] = Field(None) + name: str + + +class UpdateParticipant(BaseModel): + speaker: Optional[int] = Field(None) + name: Optional[str] = Field(None) + + +@router.get("/transcripts/{transcript_id}/participants") +async def transcript_get_participants( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> list[Participant]: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + return [ + Participant.model_validate(participant) + for participant in transcript.participants + ] + + +@router.post("/transcripts/{transcript_id}/participants") +async def transcript_add_participant( + transcript_id: str, + participant: CreateParticipant, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> Participant: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + # ensure the speaker is unique + for p in transcript.participants: + if p.speaker == participant.speaker: + raise HTTPException( + status_code=400, + detail="Speaker already assigned", + ) + + obj = await transcripts_controller.upsert_participant( + transcript, TranscriptParticipant(**participant.dict()) + ) + return Participant.model_validate(obj) + + +@router.get("/transcripts/{transcript_id}/participants/{participant_id}") +async def transcript_get_participant( + transcript_id: str, + participant_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> Participant: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + for p in transcript.participants: + if p.id == participant_id: + return Participant.model_validate(p) + + raise HTTPException(status_code=404, detail="Participant not found") + + +@router.patch("/transcripts/{transcript_id}/participants/{participant_id}") +async def transcript_update_participant( + transcript_id: str, + participant_id: str, + participant: UpdateParticipant, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> Participant: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + # ensure the speaker is unique + for p in transcript.participants: + if p.speaker == participant.speaker and p.id != participant_id: + raise HTTPException( + status_code=400, + detail="Speaker already assigned", + ) + + # find the participant + obj = None + for p in transcript.participants: + if p.id == participant_id: + obj = p + break + + if not obj: + raise HTTPException(status_code=404, detail="Participant not found") + + # update participant but just the fields that are set + fields = participant.dict(exclude_unset=True) + obj = obj.copy(update=fields) + + await transcripts_controller.upsert_participant(transcript, obj) + return Participant.model_validate(obj) + + +@router.delete("/transcripts/{transcript_id}/participants/{participant_id}") +async def transcript_delete_participant( + transcript_id: str, + participant_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +) -> DeletionStatus: + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + await transcripts_controller.delete_participant(transcript, participant_id) + return DeletionStatus(status="ok") diff --git a/server/reflector/views/transcripts_webrtc.py b/server/reflector/views/transcripts_webrtc.py new file mode 100644 index 00000000..af451411 --- /dev/null +++ b/server/reflector/views/transcripts_webrtc.py @@ -0,0 +1,37 @@ +from typing import Annotated, Optional + +import reflector.auth as auth +from fastapi import APIRouter, Depends, HTTPException, Request +from reflector.db.transcripts import transcripts_controller + +from .rtc_offer import RtcOffer, rtc_offer_base + +router = APIRouter() + + +@router.post("/transcripts/{transcript_id}/record/webrtc") +async def transcript_record_webrtc( + transcript_id: str, + params: RtcOffer, + request: Request, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id_for_http( + transcript_id, user_id=user_id + ) + + if transcript.locked: + raise HTTPException(status_code=400, detail="Transcript is locked") + + # create a pipeline runner + from reflector.pipelines.main_live_pipeline import PipelineMainLive + + pipeline_runner = PipelineMainLive(transcript_id=transcript_id) + + # FIXME do not allow multiple recording at the same time + return await rtc_offer_base( + params, + request, + pipeline_runner=pipeline_runner, + ) diff --git a/server/reflector/views/transcripts_websocket.py b/server/reflector/views/transcripts_websocket.py new file mode 100644 index 00000000..65571aab --- /dev/null +++ b/server/reflector/views/transcripts_websocket.py @@ -0,0 +1,53 @@ +""" +Transcripts websocket API +========================= + +""" +from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect +from reflector.db.transcripts import transcripts_controller +from reflector.ws_manager import get_ws_manager + +router = APIRouter() + + +@router.get("/transcripts/{transcript_id}/events") +async def transcript_get_websocket_events(transcript_id: str): + pass + + +@router.websocket("/transcripts/{transcript_id}/events") +async def transcript_events_websocket( + transcript_id: str, + websocket: WebSocket, + # user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + # user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id) + if not transcript: + raise HTTPException(status_code=404, detail="Transcript not found") + + # connect to websocket manager + # use ts:transcript_id as room id + room_id = f"ts:{transcript_id}" + ws_manager = get_ws_manager() + await ws_manager.add_user_to_room(room_id, websocket) + + try: + # on first connection, send all events only to the current user + for event in transcript.events: + # for now, do not send TRANSCRIPT or STATUS options - theses are live event + # not necessary to be sent to the client; but keep the rest + name = event.event + if name in ("TRANSCRIPT", "STATUS"): + continue + await websocket.send_json(event.model_dump(mode="json")) + + # XXX if transcript is final (locked=True and status=ended) + # XXX send a final event to the client and close the connection + + # endless loop to wait for new events + # we do not have command system now, + while True: + await websocket.receive() + except (RuntimeError, WebSocketDisconnect): + await ws_manager.remove_user_from_room(room_id, websocket) diff --git a/server/reflector/views/types.py b/server/reflector/views/types.py new file mode 100644 index 00000000..70361131 --- /dev/null +++ b/server/reflector/views/types.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class DeletionStatus(BaseModel): + status: str diff --git a/server/reflector/worker/app.py b/server/reflector/worker/app.py index e1000364..689623ce 100644 --- a/server/reflector/worker/app.py +++ b/server/reflector/worker/app.py @@ -1,6 +1,8 @@ +import structlog from celery import Celery from reflector.settings import settings +logger = structlog.get_logger(__name__) app = Celery(__name__) app.conf.broker_url = settings.CELERY_BROKER_URL app.conf.result_backend = settings.CELERY_RESULT_BACKEND @@ -8,5 +10,18 @@ app.conf.broker_connection_retry_on_startup = True app.autodiscover_tasks( [ "reflector.pipelines.main_live_pipeline", + "reflector.worker.healthcheck", ] ) + +# crontab +app.conf.beat_schedule = {} + +if settings.HEALTHCHECK_URL: + app.conf.beat_schedule["healthcheck_ping"] = { + "task": "reflector.worker.healthcheck.healthcheck_ping", + "schedule": 60.0 * 10, + } + logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL) +else: + logger.warning("Healthcheck disabled, no url configured") diff --git a/server/reflector/worker/healthcheck.py b/server/reflector/worker/healthcheck.py new file mode 100644 index 00000000..e4ce6bc3 --- /dev/null +++ b/server/reflector/worker/healthcheck.py @@ -0,0 +1,18 @@ +import httpx +import structlog +from celery import shared_task +from reflector.settings import settings + +logger = structlog.get_logger(__name__) + + +@shared_task +def healthcheck_ping(): + url = settings.HEALTHCHECK_URL + if not url: + return + try: + print("pinging healthcheck url", url) + httpx.get(url, timeout=10) + except Exception as e: + logger.error("healthcheck_ping", error=str(e)) diff --git a/server/runserver.sh b/server/runserver.sh index b0c3f138..31cce123 100755 --- a/server/runserver.sh +++ b/server/runserver.sh @@ -9,6 +9,8 @@ if [ "${ENTRYPOINT}" = "server" ]; then python -m reflector.app elif [ "${ENTRYPOINT}" = "worker" ]; then celery -A reflector.worker.app worker --loglevel=info +elif [ "${ENTRYPOINT}" = "beat" ]; then + celery -A reflector.worker.app beat --loglevel=info else echo "Unknown command" fi diff --git a/server/tests/conftest.py b/server/tests/conftest.py index aafca9fd..532ebff9 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,4 +1,5 @@ from unittest.mock import patch +from tempfile import NamedTemporaryFile import pytest @@ -7,7 +8,6 @@ import pytest @pytest.mark.asyncio async def setup_database(): from reflector.settings import settings - from tempfile import NamedTemporaryFile with NamedTemporaryFile() as f: settings.DATABASE_URL = f"sqlite:///{f.name}" @@ -103,6 +103,25 @@ async def dummy_llm(): yield +@pytest.fixture +async def dummy_storage(): + from reflector.storage.base import Storage + + class DummyStorage(Storage): + async def _put_file(self, *args, **kwargs): + pass + + async def _delete_file(self, *args, **kwargs): + pass + + async def _get_file_url(self, *args, **kwargs): + return "http://fake_server/audio.mp3" + + with patch("reflector.storage.base.Storage.get_instance") as mock_storage: + mock_storage.return_value = DummyStorage() + yield + + @pytest.fixture def nltk(): with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk: @@ -133,4 +152,17 @@ def celery_enable_logging(): @pytest.fixture(scope="session") def celery_config(): - return {"broker_url": "memory://", "result_backend": "rpc"} + with NamedTemporaryFile() as f: + yield { + "broker_url": "memory://", + "result_backend": f"db+sqlite:///{f.name}", + } + + +@pytest.fixture(scope="session") +def fake_mp3_upload(): + with patch( + "reflector.db.transcripts.TranscriptController.move_mp3_to_storage" + ) as mock_move: + mock_move.return_value = True + yield diff --git a/server/tests/test_processor_audio_diarization.py b/server/tests/test_processor_audio_diarization.py new file mode 100644 index 00000000..00935a49 --- /dev/null +++ b/server/tests/test_processor_audio_diarization.py @@ -0,0 +1,140 @@ +import pytest +from unittest import mock + + +@pytest.mark.parametrize( + "name,diarization,expected", + [ + [ + "no overlap", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "same speaker", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "A"}, + ], + ["A", "A", "A", "A"], + ], + [ + # first segment is removed because it overlap + # with the second segment, and it is smaller + "overlap at 0.5s", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 0.5, "end": 2.0, "speaker": "B"}, + ], + ["B", "B", "B", "B"], + ], + [ + "junk segment at 0.5s for 0.2s", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 0.5, "end": 0.7, "speaker": "B"}, + {"start": 1, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "start without diarization", + [ + {"start": 0.5, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "end missing diarization", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 1.5, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "continuation of next speaker", + [ + {"start": 0.0, "end": 0.9, "speaker": "A"}, + {"start": 1.5, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "continuation of previous speaker", + [ + {"start": 0.0, "end": 0.5, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + ], + ["A", "A", "B", "B"], + ], + [ + "segment without words", + [ + {"start": 0.0, "end": 1.0, "speaker": "A"}, + {"start": 1.0, "end": 2.0, "speaker": "B"}, + {"start": 2.0, "end": 3.0, "speaker": "X"}, + ], + ["A", "A", "B", "B"], + ], + ], +) +@pytest.mark.asyncio +async def test_processors_audio_diarization(event_loop, name, diarization, expected): + from reflector.processors.audio_diarization import AudioDiarizationProcessor + from reflector.processors.types import ( + TitleSummaryWithId, + Transcript, + Word, + AudioDiarizationInput, + ) + + # create fake topic + topics = [ + TitleSummaryWithId( + id="1", + title="Title1", + summary="Summary1", + timestamp=0.0, + duration=1.0, + transcript=Transcript( + words=[ + Word(text="Word1", start=0.0, end=0.5), + Word(text="word2.", start=0.5, end=1.0), + ] + ), + ), + TitleSummaryWithId( + id="2", + title="Title2", + summary="Summary2", + timestamp=0.0, + duration=1.0, + transcript=Transcript( + words=[ + Word(text="Word3", start=1.0, end=1.5), + Word(text="word4.", start=1.5, end=2.0), + ] + ), + ), + ] + + diarizer = AudioDiarizationProcessor() + with mock.patch.object(diarizer, "_diarize") as mock_diarize: + mock_diarize.return_value = diarization + + data = AudioDiarizationInput( + audio_url="https://example.com/audio.mp3", + topics=topics, + ) + await diarizer._push(data) + + # check that the speaker has been assigned to the words + assert topics[0].transcript.words[0].speaker == expected[0] + assert topics[0].transcript.words[1].speaker == expected[1] + assert topics[1].transcript.words[0].speaker == expected[2] + assert topics[1].transcript.words[1].speaker == expected[3] diff --git a/server/tests/test_transcripts_audio_download.py b/server/tests/test_transcripts_audio_download.py index 69ae5f65..28f83fff 100644 --- a/server/tests/test_transcripts_audio_download.py +++ b/server/tests/test_transcripts_audio_download.py @@ -118,15 +118,3 @@ async def test_transcript_audio_download_range_with_seek( assert response.status_code == 206 assert response.headers["content-type"] == content_type assert response.headers["content-range"].startswith("bytes 100-") - - -@pytest.mark.asyncio -async def test_transcript_audio_download_waveform(fake_transcript): - from reflector.app import app - - ac = AsyncClient(app=app, base_url="http://test/v1") - response = await ac.get(f"/transcripts/{fake_transcript.id}/audio/waveform") - assert response.status_code == 200 - assert response.headers["content-type"] == "application/json" - assert isinstance(response.json()["data"], list) - assert len(response.json()["data"]) >= 255 diff --git a/server/tests/test_transcripts_participants.py b/server/tests/test_transcripts_participants.py new file mode 100644 index 00000000..b55b16a8 --- /dev/null +++ b/server/tests/test_transcripts_participants.py @@ -0,0 +1,164 @@ +import pytest +from httpx import AsyncClient + + +@pytest.mark.asyncio +async def test_transcript_participants(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + + # create a participant + transcript_id = response.json()["id"] + response = await ac.post( + f"/transcripts/{transcript_id}/participants", json={"name": "test"} + ) + assert response.status_code == 200 + assert response.json()["id"] is not None + assert response.json()["speaker"] is None + assert response.json()["name"] == "test" + + # create another one with a speaker + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test2", "speaker": 1}, + ) + assert response.status_code == 200 + assert response.json()["id"] is not None + assert response.json()["speaker"] == 1 + assert response.json()["name"] == "test2" + + # get all participants via transcript + response = await ac.get(f"/transcripts/{transcript_id}") + assert response.status_code == 200 + assert len(response.json()["participants"]) == 2 + + # get participants via participants endpoint + response = await ac.get(f"/transcripts/{transcript_id}/participants") + assert response.status_code == 200 + assert len(response.json()) == 2 + + +@pytest.mark.asyncio +async def test_transcript_participants_same_speaker(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + transcript_id = response.json()["id"] + + # create a participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test", "speaker": 1}, + ) + assert response.status_code == 200 + assert response.json()["speaker"] == 1 + + # create another one with the same speaker + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test2", "speaker": 1}, + ) + assert response.status_code == 400 + + +@pytest.mark.asyncio +async def test_transcript_participants_update_name(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + transcript_id = response.json()["id"] + + # create a participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test", "speaker": 1}, + ) + assert response.status_code == 200 + assert response.json()["speaker"] == 1 + + # update the participant + participant_id = response.json()["id"] + response = await ac.patch( + f"/transcripts/{transcript_id}/participants/{participant_id}", + json={"name": "test2"}, + ) + assert response.status_code == 200 + assert response.json()["name"] == "test2" + + # verify the participant was updated + response = await ac.get( + f"/transcripts/{transcript_id}/participants/{participant_id}" + ) + assert response.status_code == 200 + assert response.json()["name"] == "test2" + + # verify the participant was updated in transcript + response = await ac.get(f"/transcripts/{transcript_id}") + assert response.status_code == 200 + assert len(response.json()["participants"]) == 1 + assert response.json()["participants"][0]["name"] == "test2" + + +@pytest.mark.asyncio +async def test_transcript_participants_update_speaker(): + from reflector.app import app + + async with AsyncClient(app=app, base_url="http://test/v1") as ac: + response = await ac.post("/transcripts", json={"name": "test"}) + assert response.status_code == 200 + assert response.json()["participants"] == [] + transcript_id = response.json()["id"] + + # create a participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test", "speaker": 1}, + ) + assert response.status_code == 200 + participant1_id = response.json()["id"] + + # create another participant + response = await ac.post( + f"/transcripts/{transcript_id}/participants", + json={"name": "test2", "speaker": 2}, + ) + assert response.status_code == 200 + participant2_id = response.json()["id"] + + # update the participant, refused as speaker is already taken + response = await ac.patch( + f"/transcripts/{transcript_id}/participants/{participant2_id}", + json={"speaker": 1}, + ) + assert response.status_code == 400 + + # delete the participant 1 + response = await ac.delete( + f"/transcripts/{transcript_id}/participants/{participant1_id}" + ) + assert response.status_code == 200 + + # update the participant 2 again, should be accepted now + response = await ac.patch( + f"/transcripts/{transcript_id}/participants/{participant2_id}", + json={"speaker": 1}, + ) + assert response.status_code == 200 + + # ensure participant2 name is still there + response = await ac.get( + f"/transcripts/{transcript_id}/participants/{participant2_id}" + ) + assert response.status_code == 200 + assert response.json()["name"] == "test2" + assert response.json()["speaker"] == 1 diff --git a/server/tests/test_transcripts_rtc_ws.py b/server/tests/test_transcripts_rtc_ws.py index cf2ea304..8502a0d9 100644 --- a/server/tests/test_transcripts_rtc_ws.py +++ b/server/tests/test_transcripts_rtc_ws.py @@ -66,6 +66,8 @@ async def test_transcript_rtc_and_websocket( dummy_transcript, dummy_processors, dummy_diarization, + dummy_storage, + fake_mp3_upload, ensure_casing, appserver, sentence_tokenize, @@ -182,6 +184,16 @@ async def test_transcript_rtc_and_websocket( ev = events[eventnames.index("FINAL_TITLE")] assert ev["data"]["title"] == "LLM TITLE" + assert "WAVEFORM" in eventnames + ev = events[eventnames.index("WAVEFORM")] + assert isinstance(ev["data"]["waveform"], list) + assert len(ev["data"]["waveform"]) >= 250 + waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform") + assert waveform_resp.status_code == 200 + assert waveform_resp.headers["content-type"] == "application/json" + assert isinstance(waveform_resp.json()["data"], list) + assert len(waveform_resp.json()["data"]) >= 250 + # check status order statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"] assert statuses.index("recording") < statuses.index("processing") @@ -193,11 +205,12 @@ async def test_transcript_rtc_and_websocket( # check on the latest response that the audio duration is > 0 assert resp.json()["duration"] > 0 + assert "DURATION" in eventnames # check that audio/mp3 is available - resp = await ac.get(f"/transcripts/{tid}/audio/mp3") - assert resp.status_code == 200 - assert resp.headers["Content-Type"] == "audio/mpeg" + audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3") + assert audio_resp.status_code == 200 + assert audio_resp.headers["Content-Type"] == "audio/mpeg" @pytest.mark.usefixtures("celery_session_app") @@ -209,6 +222,8 @@ async def test_transcript_rtc_and_websocket_and_fr( dummy_transcript, dummy_processors, dummy_diarization, + dummy_storage, + fake_mp3_upload, ensure_casing, appserver, sentence_tokenize, diff --git a/www/app/(auth)/fiefWrapper.tsx b/www/app/(auth)/fiefWrapper.tsx index 187fef7c..bb38f5ee 100644 --- a/www/app/(auth)/fiefWrapper.tsx +++ b/www/app/(auth)/fiefWrapper.tsx @@ -1,11 +1,18 @@ "use client"; import { FiefAuthProvider } from "@fief/fief/nextjs/react"; +import { createContext } from "react"; -export default function FiefWrapper({ children }) { +export const CookieContext = createContext<{ hasAuthCookie: boolean }>({ + hasAuthCookie: false, +}); + +export default function FiefWrapper({ children, hasAuthCookie }) { return ( - - {children} - + + + {children} + + ); } diff --git a/www/app/[domain]/layout.tsx b/www/app/[domain]/layout.tsx index dbe5ed11..73cc4841 100644 --- a/www/app/[domain]/layout.tsx +++ b/www/app/[domain]/layout.tsx @@ -11,6 +11,9 @@ import About from "../(aboutAndPrivacy)/about"; import Privacy from "../(aboutAndPrivacy)/privacy"; import { DomainContextProvider } from "./domainContext"; import { getConfig } from "../lib/edgeConfig"; +import { ErrorBoundary } from "@sentry/nextjs"; +import { cookies } from "next/dist/client/components/headers"; +import { SESSION_COOKIE_NAME } from "../lib/fief"; const poppins = Poppins({ subsets: ["latin"], weight: ["200", "400", "600"] }); @@ -70,86 +73,89 @@ type LayoutProps = { export default async function RootLayout({ children, params }: LayoutProps) { const config = await getConfig(params.domain); const { requireLogin, privacy, browse } = config.features; + const hasAuthCookie = !!cookies().get(SESSION_COOKIE_NAME); return ( - + - - -
-
- {/* Logo on the left */} - - Reflector -
-

- Reflector -

-

- Capture the signal, not the noise -

-
- -
- {/* Text link on the right */} + "something went really wrong"

}> + + +
+
+ {/* Logo on the left */} - Create + Reflector +
+

+ Reflector +

+

+ Capture the signal, not the noise +

+
- {browse ? ( - <> -  ·  - - Browse - - - ) : ( - <> - )} -  ·  - - {privacy ? ( - <> -  ·  - - - ) : ( - <> - )} - {requireLogin ? ( - <> -  ·  - - - ) : ( - <> - )} -
-
+
+ {/* Text link on the right */} + + Create + + {browse ? ( + <> +  ·  + + Browse + + + ) : ( + <> + )} +  ·  + + {privacy ? ( + <> +  ·  + + + ) : ( + <> + )} + {requireLogin ? ( + <> +  ·  + + + ) : ( + <> + )} +
+ - {children} -
-
+ {children} + + +
diff --git a/www/app/[domain]/transcripts/[transcriptId]/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/page.tsx index 50d3002c..94ff61c3 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/page.tsx @@ -5,15 +5,19 @@ import useTopics from "../useTopics"; import useWaveform from "../useWaveform"; import useMp3 from "../useMp3"; import { TopicList } from "../topicList"; -import Recorder from "../recorder"; import { Topic } from "../webSocketTypes"; -import React, { useState } from "react"; +import React, { useEffect, useState } from "react"; import "../../../styles/button.css"; import FinalSummary from "../finalSummary"; import ShareLink from "../shareLink"; import QRCode from "react-qr-code"; import TranscriptTitle from "../transcriptTitle"; import ShareModal from "./shareModal"; +import Player from "../player"; +import WaveformLoading from "../waveformLoading"; +import { useRouter } from "next/navigation"; +import { faSpinner } from "@fortawesome/free-solid-svg-icons"; +import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; type TranscriptDetails = { params: { @@ -21,26 +25,28 @@ type TranscriptDetails = { }; }; -const protectedPath = true; - export default function TranscriptDetails(details: TranscriptDetails) { const transcriptId = details.params.transcriptId; + const router = useRouter(); - const transcript = useTranscript(protectedPath, transcriptId); - const topics = useTopics(protectedPath, transcriptId); - const waveform = useWaveform(protectedPath, transcriptId); + const transcript = useTranscript(transcriptId); + const topics = useTopics(transcriptId); + const waveform = useWaveform(transcriptId); const useActiveTopic = useState(null); - const mp3 = useMp3(protectedPath, transcriptId); + const mp3 = useMp3(transcriptId); const [showModal, setShowModal] = useState(false); - if (transcript?.error /** || topics?.error || waveform?.error **/) { - return ( - - ); - } + useEffect(() => { + const statusToRedirect = ["idle", "recording", "processing"]; + if (statusToRedirect.includes(transcript.response?.status)) { + const newUrl = "/transcripts/" + details.params.transcriptId + "/record"; + // Shallow redirection does not work on NextJS 13 + // https://github.com/vercel/next.js/discussions/48110 + // https://github.com/vercel/next.js/discussions/49540 + router.push(newUrl, undefined); + // history.replaceState({}, "", newUrl); + } + }, [transcript.response?.status]); const fullTranscript = topics.topics @@ -90,79 +96,102 @@ export default function TranscriptDetails(details: TranscriptDetails) { **Next Meeting:** Scheduled for December 5, 2023, to review progress and finalize the new product launch details. `; - } - return ( - <> - {!transcriptId || transcript?.loading || topics?.loading ? ( - - ) : ( - <> - setShowModal(v)} - title={transcript?.response?.title} - summary={transcript?.response?.longSummary} - date={transcript?.response?.createdAt} - url={window.location.href} - /> -
- {transcript?.response?.title && ( - - )} - {!waveform?.loading && ( - - )} -
-
- + ); + } + + if (!transcriptId || transcript?.loading || topics?.loading) { + return ; + } + + return ( + <> + setShowModal(v)} + title={transcript?.response?.title} + summary={transcript?.response?.longSummary} + date={transcript?.response?.createdAt} + url={window.location.href} + /> +
+ {transcript?.response?.title && ( + + )} + {waveform.waveform && mp3.media ? ( + -
-
- {transcript?.response?.longSummary && ( - setShowModal(true)} - /> - )} -
+ ) : waveform.error ? ( +
"error loading this recording"
+ ) : ( + + )} +
+
+ -
-
- +
+
+ {transcript.response.longSummary ? ( + setShowModal(true)} + /> + ) : ( +
+ {transcript.response.status == "processing" ? ( +

Loading Transcript

+ ) : ( +

+ There was an error generating the final summary, please + come back later +

+ )}
-
- -
-
-
+ )} +
+ +
+
+ +
+
+ +
+
- - )} - - ); +
+ + ); + } } diff --git a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx index 41a2d053..8615a4b1 100644 --- a/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx +++ b/www/app/[domain]/transcripts/[transcriptId]/record/page.tsx @@ -8,12 +8,15 @@ import { useWebSockets } from "../../useWebSockets"; import useAudioDevice from "../../useAudioDevice"; import "../../../../styles/button.css"; import { Topic } from "../../webSocketTypes"; -import getApi from "../../../../lib/getApi"; import LiveTrancription from "../../liveTranscription"; import DisconnectedIndicator from "../../disconnectedIndicator"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { faGear } from "@fortawesome/free-solid-svg-icons"; import { lockWakeState, releaseWakeState } from "../../../../lib/wakeLock"; +import { useRouter } from "next/navigation"; +import Player from "../../player"; +import useMp3 from "../../useMp3"; +import WaveformLoading from "../../waveformLoading"; type TranscriptDetails = { params: { @@ -36,14 +39,18 @@ const TranscriptRecord = (details: TranscriptDetails) => { } }, []); - const transcript = useTranscript(true, details.params.transcriptId); - const webRTC = useWebRTC(stream, details.params.transcriptId, true); + const transcript = useTranscript(details.params.transcriptId); + const webRTC = useWebRTC(stream, details.params.transcriptId); const webSockets = useWebSockets(details.params.transcriptId); const { audioDevices, getAudioStream } = useAudioDevice(); - const [hasRecorded, setHasRecorded] = useState(false); + const [recordedTime, setRecordedTime] = useState(0); + const [startTime, setStartTime] = useState(0); const [transcriptStarted, setTranscriptStarted] = useState(false); + let mp3 = useMp3(details.params.transcriptId, true); + + const router = useRouter(); useEffect(() => { if (!transcriptStarted && webSockets.transcriptText.length !== 0) @@ -51,15 +58,27 @@ const TranscriptRecord = (details: TranscriptDetails) => { }, [webSockets.transcriptText]); useEffect(() => { - if (transcript?.response?.longSummary) { - const newUrl = `/transcripts/${transcript.response.id}`; + const statusToRedirect = ["ended", "error"]; + + //TODO if has no topic and is error, get back to new + if ( + statusToRedirect.includes(transcript.response?.status) || + statusToRedirect.includes(webSockets.status.value) + ) { + const newUrl = "/transcripts/" + details.params.transcriptId; // Shallow redirection does not work on NextJS 13 // https://github.com/vercel/next.js/discussions/48110 // https://github.com/vercel/next.js/discussions/49540 - // router.push(newUrl, undefined, { shallow: true }); - history.replaceState({}, "", newUrl); + router.replace(newUrl); + // history.replaceState({}, "", newUrl); + } // history.replaceState({}, "", newUrl); + }, [webSockets.status.value, transcript.response?.status]); + + useEffect(() => { + if (webSockets.duration) { + mp3.getNow(); } - }); + }, [webSockets.duration]); useEffect(() => { lockWakeState(); @@ -70,19 +89,31 @@ const TranscriptRecord = (details: TranscriptDetails) => { return ( <> - { - setStream(null); - setHasRecorded(true); - webRTC?.send(JSON.stringify({ cmd: "STOP" })); - }} - topics={webSockets.topics} - getAudioStream={getAudioStream} - useActiveTopic={useActiveTopic} - isPastMeeting={false} - audioDevices={audioDevices} - /> + {webSockets.waveform && webSockets.duration && mp3?.media ? ( + + ) : recordedTime ? ( + + ) : ( + { + setStream(null); + setRecordedTime(Date.now() - startTime); + webRTC?.send(JSON.stringify({ cmd: "STOP" })); + }} + onRecord={() => { + setStartTime(Date.now()); + }} + getAudioStream={getAudioStream} + audioDevices={audioDevices} + /> + )}
{
- {!hasRecorded ? ( + {!recordedTime ? ( <> {transcriptStarted && (

Transcription

@@ -128,6 +159,7 @@ const TranscriptRecord = (details: TranscriptDetails) => { couple of minutes. Please do not navigate away from the page during this time.

+ {/* NTH If login required remove last sentence */}
)} diff --git a/www/app/[domain]/transcripts/createTranscript.ts b/www/app/[domain]/transcripts/createTranscript.ts index 0d96b8db..9ad1abe0 100644 --- a/www/app/[domain]/transcripts/createTranscript.ts +++ b/www/app/[domain]/transcripts/createTranscript.ts @@ -19,7 +19,7 @@ const useCreateTranscript = (): CreateTranscript => { const [loading, setLoading] = useState(false); const [error, setErrorState] = useState(null); const { setError } = useError(); - const api = getApi(true); + const api = getApi(); const create = (params: V1TranscriptsCreateRequest["createTranscript"]) => { if (loading || !api) return; diff --git a/www/app/[domain]/transcripts/finalSummary.tsx b/www/app/[domain]/transcripts/finalSummary.tsx index c79ea763..5ac1e21c 100644 --- a/www/app/[domain]/transcripts/finalSummary.tsx +++ b/www/app/[domain]/transcripts/finalSummary.tsx @@ -5,7 +5,6 @@ import "../../styles/markdown.css"; import getApi from "../../lib/getApi"; type FinalSummaryProps = { - protectedPath: boolean; summary: string; fullTranscript: string; transcriptId: string; @@ -19,7 +18,7 @@ export default function FinalSummary(props: FinalSummaryProps) { const [isEditMode, setIsEditMode] = useState(false); const [preEditSummary, setPreEditSummary] = useState(props.summary); const [editedSummary, setEditedSummary] = useState(props.summary); - const api = getApi(props.protectedPath); + const api = getApi(); const updateSummary = async (newSummary: string, transcriptId: string) => { if (!api) return; @@ -88,7 +87,7 @@ export default function FinalSummary(props: FinalSummaryProps) {
diff --git a/www/app/[domain]/transcripts/player.tsx b/www/app/[domain]/transcripts/player.tsx new file mode 100644 index 00000000..02151a68 --- /dev/null +++ b/www/app/[domain]/transcripts/player.tsx @@ -0,0 +1,166 @@ +import React, { useRef, useEffect, useState } from "react"; + +import WaveSurfer from "wavesurfer.js"; +import CustomRegionsPlugin from "../../lib/custom-plugins/regions"; + +import { formatTime } from "../../lib/time"; +import { Topic } from "./webSocketTypes"; +import { AudioWaveform } from "../../api"; +import { waveSurferStyles } from "../../styles/recorder"; + +type PlayerProps = { + topics: Topic[]; + useActiveTopic: [ + Topic | null, + React.Dispatch>, + ]; + waveform: AudioWaveform["data"]; + media: HTMLMediaElement; + mediaDuration: number; +}; + +export default function Player(props: PlayerProps) { + const waveformRef = useRef(null); + const [wavesurfer, setWavesurfer] = useState(null); + const [isPlaying, setIsPlaying] = useState(false); + const [currentTime, setCurrentTime] = useState(0); + const [waveRegions, setWaveRegions] = useState( + null, + ); + const [activeTopic, setActiveTopic] = props.useActiveTopic; + const topicsRef = useRef(props.topics); + // Waveform setup + useEffect(() => { + if (waveformRef.current) { + // XXX duration is required to prevent recomputing peaks from audio + // However, the current waveform returns only the peaks, and no duration + // And the backend does not save duration properly. + // So at the moment, we deduct the duration from the topics. + // This is not ideal, but it works for now. + const _wavesurfer = WaveSurfer.create({ + container: waveformRef.current, + peaks: props.waveform, + hideScrollbar: true, + autoCenter: true, + barWidth: 2, + height: "auto", + duration: props.mediaDuration, + + ...waveSurferStyles.player, + }); + + // styling + const wsWrapper = _wavesurfer.getWrapper(); + wsWrapper.style.cursor = waveSurferStyles.playerStyle.cursor; + wsWrapper.style.backgroundColor = + waveSurferStyles.playerStyle.backgroundColor; + wsWrapper.style.borderRadius = waveSurferStyles.playerStyle.borderRadius; + + _wavesurfer.on("play", () => { + setIsPlaying(true); + }); + _wavesurfer.on("pause", () => { + setIsPlaying(false); + }); + _wavesurfer.on("timeupdate", setCurrentTime); + + setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create())); + + _wavesurfer.toggleInteraction(true); + + _wavesurfer.setMediaElement(props.media); + + setWavesurfer(_wavesurfer); + + return () => { + _wavesurfer.destroy(); + setIsPlaying(false); + setCurrentTime(0); + }; + } + }, []); + + useEffect(() => { + if (!wavesurfer) return; + if (!props.media) return; + wavesurfer.setMediaElement(props.media); + }, [props.media, wavesurfer]); + + useEffect(() => { + topicsRef.current = props.topics; + renderMarkers(); + }, [props.topics, waveRegions]); + + const renderMarkers = () => { + if (!waveRegions) return; + + waveRegions.clearRegions(); + + for (let topic of topicsRef.current) { + const content = document.createElement("div"); + content.setAttribute("style", waveSurferStyles.marker); + content.onmouseover = () => { + content.style.backgroundColor = + waveSurferStyles.markerHover.backgroundColor; + content.style.zIndex = "999"; + content.style.width = "300px"; + }; + content.onmouseout = () => { + content.setAttribute("style", waveSurferStyles.marker); + }; + content.textContent = topic.title; + + const region = waveRegions.addRegion({ + start: topic.timestamp, + content, + color: "f00", + drag: false, + }); + region.on("click", (e) => { + e.stopPropagation(); + setActiveTopic(topic); + wavesurfer?.setTime(region.start); + }); + } + }; + + useEffect(() => { + if (activeTopic) { + wavesurfer?.setTime(activeTopic.timestamp); + } + }, [activeTopic]); + + const handlePlayClick = () => { + wavesurfer?.playPause(); + }; + + const timeLabel = () => { + if (props.mediaDuration) + return `${formatTime(currentTime)}/${formatTime(props.mediaDuration)}`; + return ""; + }; + + return ( +
+
+
+
{timeLabel()}
+
+ + +
+ ); +} diff --git a/www/app/[domain]/transcripts/recorder.tsx b/www/app/[domain]/transcripts/recorder.tsx index 8db32ff7..e7c016a7 100644 --- a/www/app/[domain]/transcripts/recorder.tsx +++ b/www/app/[domain]/transcripts/recorder.tsx @@ -6,31 +6,19 @@ import CustomRegionsPlugin from "../../lib/custom-plugins/regions"; import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import { faMicrophone } from "@fortawesome/free-solid-svg-icons"; -import { faDownload } from "@fortawesome/free-solid-svg-icons"; import { formatTime } from "../../lib/time"; -import { Topic } from "./webSocketTypes"; -import { AudioWaveform } from "../../api"; import AudioInputsDropdown from "./audioInputsDropdown"; import { Option } from "react-dropdown"; import { waveSurferStyles } from "../../styles/recorder"; import { useError } from "../../(errors)/errorContext"; type RecorderProps = { - setStream?: React.Dispatch>; - onStop?: () => void; - topics: Topic[]; - getAudioStream?: (deviceId) => Promise; - audioDevices?: Option[]; - useActiveTopic: [ - Topic | null, - React.Dispatch>, - ]; - waveform?: AudioWaveform | null; - isPastMeeting: boolean; - transcriptId?: string | null; - media?: HTMLMediaElement | null; - mediaDuration?: number | null; + setStream: React.Dispatch>; + onStop: () => void; + onRecord?: () => void; + getAudioStream: (deviceId) => Promise; + audioDevices: Option[]; }; export default function Recorder(props: RecorderProps) { @@ -38,7 +26,7 @@ export default function Recorder(props: RecorderProps) { const [wavesurfer, setWavesurfer] = useState(null); const [record, setRecord] = useState(null); const [isRecording, setIsRecording] = useState(false); - const [hasRecorded, setHasRecorded] = useState(props.isPastMeeting); + const [hasRecorded, setHasRecorded] = useState(false); const [isPlaying, setIsPlaying] = useState(false); const [currentTime, setCurrentTime] = useState(0); const [timeInterval, setTimeInterval] = useState(null); @@ -48,8 +36,6 @@ export default function Recorder(props: RecorderProps) { ); const [deviceId, setDeviceId] = useState(null); const [recordStarted, setRecordStarted] = useState(false); - const [activeTopic, setActiveTopic] = props.useActiveTopic; - const topicsRef = useRef(props.topics); const [showDevices, setShowDevices] = useState(false); const { setError } = useError(); @@ -73,8 +59,6 @@ export default function Recorder(props: RecorderProps) { if (!record.isRecording()) return; handleRecClick(); break; - case "^": - throw new Error("Unhandled Exception thrown by '^' shortcut"); case "(": location.href = "/login"; break; @@ -104,27 +88,18 @@ export default function Recorder(props: RecorderProps) { // Waveform setup useEffect(() => { if (waveformRef.current) { - // XXX duration is required to prevent recomputing peaks from audio - // However, the current waveform returns only the peaks, and no duration - // And the backend does not save duration properly. - // So at the moment, we deduct the duration from the topics. - // This is not ideal, but it works for now. const _wavesurfer = WaveSurfer.create({ container: waveformRef.current, - peaks: props.waveform?.data, hideScrollbar: true, autoCenter: true, barWidth: 2, height: "auto", - duration: props.mediaDuration || 1, ...waveSurferStyles.player, }); - if (!props.transcriptId) { - const _wshack: any = _wavesurfer; - _wshack.renderer.renderSingleCanvas = () => {}; - } + const _wshack: any = _wavesurfer; + _wshack.renderer.renderSingleCanvas = () => {}; // styling const wsWrapper = _wavesurfer.getWrapper(); @@ -144,12 +119,6 @@ export default function Recorder(props: RecorderProps) { setRecord(_wavesurfer.registerPlugin(RecordPlugin.create())); setWaveRegions(_wavesurfer.registerPlugin(CustomRegionsPlugin.create())); - if (props.isPastMeeting) _wavesurfer.toggleInteraction(true); - - if (props.media) { - _wavesurfer.setMediaElement(props.media); - } - setWavesurfer(_wavesurfer); return () => { @@ -161,58 +130,6 @@ export default function Recorder(props: RecorderProps) { } }, []); - useEffect(() => { - if (!wavesurfer) return; - if (!props.media) return; - wavesurfer.setMediaElement(props.media); - }, [props.media, wavesurfer]); - - useEffect(() => { - topicsRef.current = props.topics; - if (!isRecording) renderMarkers(); - }, [props.topics, waveRegions]); - - const renderMarkers = () => { - if (!waveRegions) return; - - waveRegions.clearRegions(); - - for (let topic of topicsRef.current) { - const content = document.createElement("div"); - content.setAttribute("style", waveSurferStyles.marker); - content.onmouseover = () => { - content.style.backgroundColor = - waveSurferStyles.markerHover.backgroundColor; - content.style.zIndex = "999"; - content.style.width = "300px"; - }; - content.onmouseout = () => { - content.setAttribute("style", waveSurferStyles.marker); - }; - content.textContent = topic.title; - - const region = waveRegions.addRegion({ - start: topic.timestamp, - content, - color: "f00", - drag: false, - }); - region.on("click", (e) => { - e.stopPropagation(); - setActiveTopic(topic); - wavesurfer?.setTime(region.start); - }); - } - }; - - useEffect(() => { - if (!record) return; - - return record.on("stopRecording", () => { - renderMarkers(); - }); - }, [record]); - useEffect(() => { if (isRecording) { const interval = window.setInterval(() => { @@ -229,12 +146,6 @@ export default function Recorder(props: RecorderProps) { } }, [isRecording]); - useEffect(() => { - if (activeTopic) { - wavesurfer?.setTime(activeTopic.timestamp); - } - }, [activeTopic]); - const handleRecClick = async () => { if (!record) return console.log("no record"); @@ -249,10 +160,10 @@ export default function Recorder(props: RecorderProps) { setScreenMediaStream(null); setDestinationStream(null); } else { + if (props.onRecord) props.onRecord(); const stream = await getCurrentStream(); if (props.setStream) props.setStream(stream); - waveRegions?.clearRegions(); if (stream) { await record.startRecording(stream); setIsRecording(true); @@ -320,7 +231,6 @@ export default function Recorder(props: RecorderProps) { if (!record) return; if (!destinationStream) return; if (props.setStream) props.setStream(destinationStream); - waveRegions?.clearRegions(); if (destinationStream) { record.startRecording(destinationStream); setIsRecording(true); @@ -379,23 +289,9 @@ export default function Recorder(props: RecorderProps) { } text-white ml-2 md:ml:4 md:h-[78px] md:min-w-[100px] text-lg`} id="play-btn" onClick={handlePlayClick} - disabled={isRecording} > {isPlaying ? "Pause" : "Play"} - - {props.transcriptId && ( - - - - )} )} {!hasRecorded && ( diff --git a/www/app/[domain]/transcripts/shareLink.tsx b/www/app/[domain]/transcripts/shareLink.tsx index 49163a5b..dd66d6cb 100644 --- a/www/app/[domain]/transcripts/shareLink.tsx +++ b/www/app/[domain]/transcripts/shareLink.tsx @@ -1,15 +1,39 @@ import React, { useState, useRef, useEffect, use } from "react"; import { featureEnabled } from "../domainContext"; +import getApi from "../../lib/getApi"; +import { useFiefUserinfo } from "@fief/fief/nextjs/react"; +import SelectSearch from "react-select-search"; +import "react-select-search/style.css"; +import "../../styles/button.css"; +import "../../styles/form.scss"; +import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; +import { faSpinner } from "@fortawesome/free-solid-svg-icons"; -const ShareLink = () => { +type ShareLinkProps = { + transcriptId: string; + userId: string | null; + shareMode: string; +}; + +const ShareLink = (props: ShareLinkProps) => { const [isCopied, setIsCopied] = useState(false); const inputRef = useRef(null); const [currentUrl, setCurrentUrl] = useState(""); + const requireLogin = featureEnabled("requireLogin"); + const [isOwner, setIsOwner] = useState(false); + const [shareMode, setShareMode] = useState(props.shareMode); + const [shareLoading, setShareLoading] = useState(false); + const userinfo = useFiefUserinfo(); + const api = getApi(); useEffect(() => { setCurrentUrl(window.location.href); }, []); + useEffect(() => { + setIsOwner(!!(requireLogin && userinfo?.sub === props.userId)); + }, [userinfo, props.userId]); + const handleCopyClick = () => { if (inputRef.current) { let text_to_copy = inputRef.current.value; @@ -23,6 +47,18 @@ const ShareLink = () => { } }; + const updateShareMode = async (selectedShareMode: string) => { + if (!api) return; + setShareLoading(true); + const updatedTranscript = await api.v1TranscriptUpdate({ + transcriptId: props.transcriptId, + updateTranscript: { + shareMode: selectedShareMode, + }, + }); + setShareMode(updatedTranscript.shareMode); + setShareLoading(false); + }; const privacyEnabled = featureEnabled("privacy"); return ( @@ -30,17 +66,60 @@ const ShareLink = () => { className="p-2 md:p-4 rounded" style={{ background: "rgba(96, 165, 250, 0.2)" }} > - {privacyEnabled ? ( -

- You can share this link with others. Anyone with the link will have - access to the page, including the full audio recording, for the next 7 - days. -

- ) : ( -

- You can share this link with others. Anyone with the link will have - access to the page, including the full audio recording. -

+ {requireLogin && ( +
+ {shareMode === "private" && ( +

This transcript is private and can only be accessed by you.

+ )} + {shareMode === "semi-private" && ( +

+ This transcript is secure. Only authenticated users can access it. +

+ )} + {shareMode === "public" && ( +

This transcript is public. Everyone can access it.

+ )} + + {isOwner && api && ( +
+ + {shareLoading && ( +
+ +
+ )} +
+ )} +
+ )} + {!requireLogin && ( + <> + {privacyEnabled ? ( +

+ Share this link to grant others access to this page. The link + includes the full audio recording and is valid for the next 7 + days. +

+ ) : ( +

+ Share this link to allow others to view this page and listen to + the full audio recording. +

+ )} + )}
{ const [displayedTitle, setDisplayedTitle] = useState(props.title); const [preEditTitle, setPreEditTitle] = useState(props.title); const [isEditing, setIsEditing] = useState(false); - const api = getApi(props.protectedPath); + const api = getApi(); const updateTitle = async (newTitle: string, transcriptId: string) => { if (!api) return; diff --git a/www/app/[domain]/transcripts/useMp3.ts b/www/app/[domain]/transcripts/useMp3.ts index 570a6a25..363a4190 100644 --- a/www/app/[domain]/transcripts/useMp3.ts +++ b/www/app/[domain]/transcripts/useMp3.ts @@ -1,49 +1,48 @@ import { useContext, useEffect, useState } from "react"; -import { useError } from "../../(errors)/errorContext"; import { DomainContext } from "../domainContext"; import getApi from "../../lib/getApi"; import { useFiefAccessTokenInfo } from "@fief/fief/build/esm/nextjs/react"; -import { shouldShowError } from "../../lib/errorUtils"; -type Mp3Response = { - url: string | null; +export type Mp3Response = { media: HTMLMediaElement | null; loading: boolean; - error: Error | null; + getNow: () => void; }; -const useMp3 = (protectedPath: boolean, id: string): Mp3Response => { - const [url, setUrl] = useState(null); +const useMp3 = (id: string, waiting?: boolean): Mp3Response => { const [media, setMedia] = useState(null); + const [later, setLater] = useState(waiting); const [loading, setLoading] = useState(false); - const [error, setErrorState] = useState(null); - const { setError } = useError(); - const api = getApi(protectedPath); + const api = getApi(); const { api_url } = useContext(DomainContext); const accessTokenInfo = useFiefAccessTokenInfo(); - const [serviceWorkerReady, setServiceWorkerReady] = useState(false); + const [serviceWorker, setServiceWorker] = + useState(null); useEffect(() => { if ("serviceWorker" in navigator) { - navigator.serviceWorker.register("/service-worker.js").then(() => { - setServiceWorkerReady(true); + navigator.serviceWorker.register("/service-worker.js").then((worker) => { + setServiceWorker(worker); }); } + return () => { + serviceWorker?.unregister(); + }; }, []); useEffect(() => { if (!navigator.serviceWorker) return; if (!navigator.serviceWorker.controller) return; - if (!serviceWorkerReady) return; + if (!serviceWorker) return; // Send the token to the service worker navigator.serviceWorker.controller.postMessage({ type: "SET_AUTH_TOKEN", token: accessTokenInfo?.access_token, }); - }, [navigator.serviceWorker, serviceWorkerReady, accessTokenInfo]); + }, [navigator.serviceWorker, !serviceWorker, accessTokenInfo]); - const getMp3 = (id: string) => { - if (!id || !api) return; + useEffect(() => { + if (!id || !api || later) return; // createa a audio element and set the source setLoading(true); @@ -53,13 +52,13 @@ const useMp3 = (protectedPath: boolean, id: string): Mp3Response => { audioElement.preload = "auto"; setMedia(audioElement); setLoading(false); + }, [id, api, later]); + + const getNow = () => { + setLater(false); }; - useEffect(() => { - getMp3(id); - }, [id, api]); - - return { url, media, loading, error }; + return { media, loading, getNow }; }; export default useMp3; diff --git a/www/app/[domain]/transcripts/useTopics.ts b/www/app/[domain]/transcripts/useTopics.ts index 01053019..de4097b3 100644 --- a/www/app/[domain]/transcripts/useTopics.ts +++ b/www/app/[domain]/transcripts/useTopics.ts @@ -14,12 +14,12 @@ type TranscriptTopics = { error: Error | null; }; -const useTopics = (protectedPath, id: string): TranscriptTopics => { +const useTopics = (id: string): TranscriptTopics => { const [topics, setTopics] = useState(null); const [loading, setLoading] = useState(false); const [error, setErrorState] = useState(null); const { setError } = useError(); - const api = getApi(protectedPath); + const api = getApi(); useEffect(() => { if (!id || !api) return; diff --git a/www/app/[domain]/transcripts/useTranscript.ts b/www/app/[domain]/transcripts/useTranscript.ts index af60cd3b..91700d7a 100644 --- a/www/app/[domain]/transcripts/useTranscript.ts +++ b/www/app/[domain]/transcripts/useTranscript.ts @@ -5,21 +5,32 @@ import { useError } from "../../(errors)/errorContext"; import getApi from "../../lib/getApi"; import { shouldShowError } from "../../lib/errorUtils"; -type Transcript = { - response: GetTranscript | null; - loading: boolean; - error: Error | null; +type ErrorTranscript = { + error: Error; + loading: false; + response: any; +}; + +type LoadingTranscript = { + response: any; + loading: true; + error: false; +}; + +type SuccessTranscript = { + response: GetTranscript; + loading: false; + error: null; }; const useTranscript = ( - protectedPath: boolean, id: string | null, -): Transcript => { +): ErrorTranscript | LoadingTranscript | SuccessTranscript => { const [response, setResponse] = useState(null); const [loading, setLoading] = useState(true); const [error, setErrorState] = useState(null); const { setError } = useError(); - const api = getApi(protectedPath); + const api = getApi(); useEffect(() => { if (!id || !api) return; @@ -46,7 +57,10 @@ const useTranscript = ( }); }, [id, !api]); - return { response, loading, error }; + return { response, loading, error } as + | ErrorTranscript + | LoadingTranscript + | SuccessTranscript; }; export default useTranscript; diff --git a/www/app/[domain]/transcripts/useTranscriptList.ts b/www/app/[domain]/transcripts/useTranscriptList.ts index cc8f4701..7b5abb37 100644 --- a/www/app/[domain]/transcripts/useTranscriptList.ts +++ b/www/app/[domain]/transcripts/useTranscriptList.ts @@ -15,7 +15,7 @@ const useTranscriptList = (page: number): TranscriptList => { const [loading, setLoading] = useState(true); const [error, setErrorState] = useState(null); const { setError } = useError(); - const api = getApi(true); + const api = getApi(); useEffect(() => { if (!api) return; diff --git a/www/app/[domain]/transcripts/useWaveform.ts b/www/app/[domain]/transcripts/useWaveform.ts index 4073b711..f80ad78c 100644 --- a/www/app/[domain]/transcripts/useWaveform.ts +++ b/www/app/[domain]/transcripts/useWaveform.ts @@ -1,8 +1,5 @@ import { useEffect, useState } from "react"; -import { - DefaultApi, - V1TranscriptGetAudioWaveformRequest, -} from "../../api/apis/DefaultApi"; +import { V1TranscriptGetAudioWaveformRequest } from "../../api/apis/DefaultApi"; import { AudioWaveform } from "../../api"; import { useError } from "../../(errors)/errorContext"; import getApi from "../../lib/getApi"; @@ -14,12 +11,12 @@ type AudioWaveFormResponse = { error: Error | null; }; -const useWaveform = (protectedPath, id: string): AudioWaveFormResponse => { +const useWaveform = (id: string): AudioWaveFormResponse => { const [waveform, setWaveform] = useState(null); const [loading, setLoading] = useState(true); const [error, setErrorState] = useState(null); const { setError } = useError(); - const api = getApi(protectedPath); + const api = getApi(); useEffect(() => { if (!id || !api) return; diff --git a/www/app/[domain]/transcripts/useWebRTC.ts b/www/app/[domain]/transcripts/useWebRTC.ts index f4421e4d..edd3bef0 100644 --- a/www/app/[domain]/transcripts/useWebRTC.ts +++ b/www/app/[domain]/transcripts/useWebRTC.ts @@ -10,11 +10,10 @@ import getApi from "../../lib/getApi"; const useWebRTC = ( stream: MediaStream | null, transcriptId: string | null, - protectedPath, ): Peer => { const [peer, setPeer] = useState(null); const { setError } = useError(); - const api = getApi(protectedPath); + const api = getApi(); useEffect(() => { if (!stream || !transcriptId) { diff --git a/www/app/[domain]/transcripts/useWebSockets.ts b/www/app/[domain]/transcripts/useWebSockets.ts index bcf6b163..1e59781c 100644 --- a/www/app/[domain]/transcripts/useWebSockets.ts +++ b/www/app/[domain]/transcripts/useWebSockets.ts @@ -1,30 +1,35 @@ import { useContext, useEffect, useState } from "react"; import { Topic, FinalSummary, Status } from "./webSocketTypes"; import { useError } from "../../(errors)/errorContext"; -import { useRouter } from "next/navigation"; import { DomainContext } from "../domainContext"; +import { AudioWaveform } from "../../api"; -type UseWebSockets = { +export type UseWebSockets = { transcriptText: string; translateText: string; + title: string; topics: Topic[]; finalSummary: FinalSummary; status: Status; + waveform: AudioWaveform["data"] | null; + duration: number | null; }; export const useWebSockets = (transcriptId: string | null): UseWebSockets => { const [transcriptText, setTranscriptText] = useState(""); const [translateText, setTranslateText] = useState(""); + const [title, setTitle] = useState(""); const [textQueue, setTextQueue] = useState([]); const [translationQueue, setTranslationQueue] = useState([]); const [isProcessing, setIsProcessing] = useState(false); const [topics, setTopics] = useState([]); + const [waveform, setWaveForm] = useState(null); + const [duration, setDuration] = useState(null); const [finalSummary, setFinalSummary] = useState({ summary: "", }); const [status, setStatus] = useState({ value: "initial" }); const { setError } = useError(); - const router = useRouter(); const { websocket_url } = useContext(DomainContext); @@ -294,7 +299,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { if (!transcriptId) return; const url = `${websocket_url}/v1/transcripts/${transcriptId}/events`; - const ws = new WebSocket(url); + let ws = new WebSocket(url); ws.onopen = () => { console.debug("WebSocket connection opened"); @@ -343,24 +348,39 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { case "FINAL_TITLE": console.debug("FINAL_TITLE event:", message.data); + if (message.data) { + setTitle(message.data.title); + } + break; + + case "WAVEFORM": + console.debug( + "WAVEFORM event length:", + message.data.waveform.length, + ); + if (message.data) { + setWaveForm(message.data.waveform); + } + break; + case "DURATION": + console.debug("DURATION event:", message.data); + if (message.data) { + setDuration(message.data.duration); + } break; case "STATUS": console.log("STATUS event:", message.data); - if (message.data.value === "ended") { - const newUrl = "/transcripts/" + transcriptId; - router.push(newUrl); - console.debug("FINAL_LONG_SUMMARY event:", message.data); - } if (message.data.value === "error") { - const newUrl = "/transcripts/" + transcriptId; - router.push(newUrl); setError( Error("Websocket error status"), "There was an error processing this meeting.", ); } setStatus(message.data); + if (message.data.value === "ended") { + ws.close(); + } break; default: @@ -382,13 +402,19 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { console.debug("WebSocket connection closed"); switch (event.code) { case 1000: // Normal Closure: - case 1001: // Going Away: - case 1005: - break; + case 1005: // Closure by client FF default: setError( new Error(`WebSocket closed unexpectedly with code: ${event.code}`), + "Disconnected", ); + console.log( + "Socket is closed. Reconnect will be attempted in 1 second.", + event.reason, + ); + setTimeout(function () { + ws = new WebSocket(url); + }, 1000); } }; @@ -397,5 +423,14 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => { }; }, [transcriptId]); - return { transcriptText, translateText, topics, finalSummary, status }; + return { + transcriptText, + translateText, + topics, + finalSummary, + title, + status, + waveform, + duration, + }; }; diff --git a/www/app/[domain]/transcripts/waveformLoading.tsx b/www/app/[domain]/transcripts/waveformLoading.tsx new file mode 100644 index 00000000..56540927 --- /dev/null +++ b/www/app/[domain]/transcripts/waveformLoading.tsx @@ -0,0 +1,11 @@ +import { faSpinner } from "@fortawesome/free-solid-svg-icons"; +import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; + +export default () => ( +
+ +
+); diff --git a/www/app/api/models/GetTranscript.ts b/www/app/api/models/GetTranscript.ts index 06e9c8ad..3c03c689 100644 --- a/www/app/api/models/GetTranscript.ts +++ b/www/app/api/models/GetTranscript.ts @@ -25,6 +25,12 @@ export interface GetTranscript { * @memberof GetTranscript */ id: any | null; + /** + * + * @type {any} + * @memberof GetTranscript + */ + userId: any | null; /** * * @type {any} @@ -73,6 +79,12 @@ export interface GetTranscript { * @memberof GetTranscript */ createdAt: any | null; + /** + * + * @type {any} + * @memberof GetTranscript + */ + shareMode?: any | null; /** * * @type {any} @@ -93,6 +105,7 @@ export interface GetTranscript { export function instanceOfGetTranscript(value: object): boolean { let isInstance = true; isInstance = isInstance && "id" in value; + isInstance = isInstance && "userId" in value; isInstance = isInstance && "name" in value; isInstance = isInstance && "status" in value; isInstance = isInstance && "locked" in value; @@ -120,6 +133,7 @@ export function GetTranscriptFromJSONTyped( } return { id: json["id"], + userId: json["user_id"], name: json["name"], status: json["status"], locked: json["locked"], @@ -128,6 +142,7 @@ export function GetTranscriptFromJSONTyped( shortSummary: json["short_summary"], longSummary: json["long_summary"], createdAt: json["created_at"], + shareMode: !exists(json, "share_mode") ? undefined : json["share_mode"], sourceLanguage: json["source_language"], targetLanguage: json["target_language"], }; @@ -142,6 +157,7 @@ export function GetTranscriptToJSON(value?: GetTranscript | null): any { } return { id: value.id, + user_id: value.userId, name: value.name, status: value.status, locked: value.locked, @@ -150,6 +166,7 @@ export function GetTranscriptToJSON(value?: GetTranscript | null): any { short_summary: value.shortSummary, long_summary: value.longSummary, created_at: value.createdAt, + share_mode: value.shareMode, source_language: value.sourceLanguage, target_language: value.targetLanguage, }; diff --git a/www/app/api/models/UpdateTranscript.ts b/www/app/api/models/UpdateTranscript.ts index d22df8b0..a710af69 100644 --- a/www/app/api/models/UpdateTranscript.ts +++ b/www/app/api/models/UpdateTranscript.ts @@ -49,6 +49,12 @@ export interface UpdateTranscript { * @memberof UpdateTranscript */ longSummary?: any | null; + /** + * + * @type {any} + * @memberof UpdateTranscript + */ + shareMode?: any | null; } /** @@ -81,6 +87,7 @@ export function UpdateTranscriptFromJSONTyped( longSummary: !exists(json, "long_summary") ? undefined : json["long_summary"], + shareMode: !exists(json, "share_mode") ? undefined : json["share_mode"], }; } @@ -97,5 +104,6 @@ export function UpdateTranscriptToJSON(value?: UpdateTranscript | null): any { title: value.title, short_summary: value.shortSummary, long_summary: value.longSummary, + share_mode: value.shareMode, }; } diff --git a/www/app/lib/errorUtils.ts b/www/app/lib/errorUtils.ts index 81a39b5d..e9e5300d 100644 --- a/www/app/lib/errorUtils.ts +++ b/www/app/lib/errorUtils.ts @@ -1,5 +1,8 @@ function shouldShowError(error: Error | null | undefined) { - if (error?.name == "ResponseError" && error["response"].status == 404) + if ( + error?.name == "ResponseError" && + (error["response"].status == 404 || error["response"].status == 403) + ) return false; if (error?.name == "FetchError") return false; return true; diff --git a/www/app/lib/fief.ts b/www/app/lib/fief.ts index 02db67f5..3af5c30f 100644 --- a/www/app/lib/fief.ts +++ b/www/app/lib/fief.ts @@ -66,10 +66,6 @@ export const getFiefAuthMiddleware = async (url) => { matcher: "/transcripts", parameters: {}, }, - { - matcher: "/transcripts/((?!new).*)", - parameters: {}, - }, { matcher: "/browse", parameters: {}, diff --git a/www/app/lib/getApi.ts b/www/app/lib/getApi.ts index 7392cc90..e1ece2a9 100644 --- a/www/app/lib/getApi.ts +++ b/www/app/lib/getApi.ts @@ -4,17 +4,19 @@ import { DefaultApi } from "../api/apis/DefaultApi"; import { useFiefAccessTokenInfo } from "@fief/fief/nextjs/react"; import { useContext, useEffect, useState } from "react"; import { DomainContext, featureEnabled } from "../[domain]/domainContext"; +import { CookieContext } from "../(auth)/fiefWrapper"; -export default function getApi(protectedPath: boolean): DefaultApi | undefined { +export default function getApi(): DefaultApi | undefined { const accessTokenInfo = useFiefAccessTokenInfo(); const api_url = useContext(DomainContext).api_url; const requireLogin = featureEnabled("requireLogin"); const [api, setApi] = useState(); + const { hasAuthCookie } = useContext(CookieContext); if (!api_url) throw new Error("no API URL"); useEffect(() => { - if (protectedPath && requireLogin && !accessTokenInfo) { + if (hasAuthCookie && requireLogin && !accessTokenInfo) { return; } @@ -25,7 +27,7 @@ export default function getApi(protectedPath: boolean): DefaultApi | undefined { : undefined, }); setApi(new DefaultApi(apiConfiguration)); - }, [!accessTokenInfo, protectedPath]); + }, [!accessTokenInfo, hasAuthCookie]); return api; } diff --git a/www/app/styles/form.scss b/www/app/styles/form.scss index 90eb4a83..da81f1db 100644 --- a/www/app/styles/form.scss +++ b/www/app/styles/form.scss @@ -35,3 +35,8 @@ body.is-light-mode .input-container { max-width: 100%; width: auto; } + +body .select-search-container .select-search--top.select-search-select { + top: auto; + bottom: 46px; +}