diff --git a/server/migrations/versions/d3ff3a39297f_add_recordings.py b/server/migrations/versions/d3ff3a39297f_add_recordings.py new file mode 100644 index 00000000..92ceb29c --- /dev/null +++ b/server/migrations/versions/d3ff3a39297f_add_recordings.py @@ -0,0 +1,132 @@ +"""Add recordings + +Revision ID: d3ff3a39297f +Revises: b0e5f7876032 +Create Date: 2025-03-10 14:38:53.504413 + +""" + +import uuid +from datetime import datetime +from typing import Sequence, Union + +import boto3 +import sqlalchemy as sa +from alembic import op +from reflector.db.meetings import meetings +from reflector.db.recordings import Recording, recordings +from reflector.db.rooms import rooms +from reflector.db.transcripts import transcripts +from reflector.settings import settings + +# revision identifiers, used by Alembic. +revision: str = "d3ff3a39297f" +down_revision: Union[str, None] = "b0e5f7876032" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def add_recordings_from_s3(): + bind = op.get_bind() + + s3 = boto3.client( + "s3", + region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION, + aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID, + aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, + ) + + bucket_name = settings.AWS_WHEREBY_S3_BUCKET + paginator = s3.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=bucket_name) + + for page in pages: + if "Contents" not in page: + continue + + for obj in page["Contents"]: + object_key = obj["Key"] + + if not (object_key.endswith(".mp4")): + continue + + room_name = f"/{object_key[:36]}" + recorded_at = datetime.fromisoformat(object_key[37:57]) + + meeting = bind.execute( + meetings.select().where(meetings.c.room_name == room_name) + ).fetchone() + + recording = Recording( + id=str(uuid.uuid4()), + bucket_name=bucket_name, + object_key=object_key, + recorded_at=recorded_at, + meeting_id=meeting["id"], + ) + bind.execute(recordings.insert().values(recording.model_dump())) + + +def link_transcripts_to_recordings(): + bind = op.get_bind() + + room_transcripts = bind.execute( + transcripts.select() + .where(transcripts.c.meeting_id.isnot(None)) + .order_by(transcripts.c.meeting_id, transcripts.c.created_at) + ).fetchall() + + for transcript in room_transcripts: + transcript_recordings = bind.execute( + recordings.select() + .where( + recordings.c.meeting_id == transcript["meeting_id"], + ) + .order_by(recordings.c.recorded_at.desc()) + ).fetchall() + + if len(transcript_recordings) == 1: + bind.execute( + transcripts.update() + .where(transcripts.c.id == transcript["id"]) + .values(recording_id=transcript_recordings[0]["id"]) + ) + elif len(transcript_recordings) > 1: + matched_recording = next( + ( + r + for r in transcript_recordings + if r["recorded_at"] <= transcript["created_at"] + ), + None, + ) + bind.execute( + transcripts.update() + .where(transcripts.c.id == transcript["id"]) + .values(recording_id=matched_recording["id"]) + ) + + +def delete_recordings(): + bind = op.get_bind() + bind.execute(recordings.delete()) + + +def upgrade() -> None: + with op.batch_alter_table("recording", schema=None) as batch_op: + batch_op.create_unique_constraint( + "uq_recording_object_key", + ["bucket_name", "object_key"], + ) + op.add_column("transcript", sa.Column("recording_id", sa.String(), nullable=True)) + + add_recordings_from_s3() + link_transcripts_to_recordings() + + +def downgrade() -> None: + with op.batch_alter_table("recording", schema=None) as batch_op: + batch_op.drop_constraint("uq_recording_object_key", type_="unique") + op.drop_column("transcript", "recording_id") + + delete_recordings() diff --git a/server/reflector/db/__init__.py b/server/reflector/db/__init__.py index 3378d0c0..32111fd3 100644 --- a/server/reflector/db/__init__.py +++ b/server/reflector/db/__init__.py @@ -8,6 +8,7 @@ metadata = sqlalchemy.MetaData() # import models import reflector.db.meetings # noqa +import reflector.db.recordings # noqa import reflector.db.rooms # noqa import reflector.db.transcripts # noqa diff --git a/server/reflector/db/recordings.py b/server/reflector/db/recordings.py new file mode 100644 index 00000000..254e612a --- /dev/null +++ b/server/reflector/db/recordings.py @@ -0,0 +1,59 @@ +from datetime import datetime +from typing import Literal +from uuid import uuid4 + +import sqlalchemy as sa +from pydantic import BaseModel, Field +from reflector.db import database, metadata + +recordings = sa.Table( + "recording", + metadata, + sa.Column("id", sa.String, primary_key=True), + sa.Column("bucket_name", sa.String, nullable=False), + sa.Column("object_key", sa.String, nullable=False), + sa.Column("recorded_at", sa.DateTime, nullable=False), + sa.Column( + "status", + sa.String, + nullable=False, + server_default="pending", + ), + sa.Column("meeting_id", sa.String), +) + + +def generate_uuid4() -> str: + return str(uuid4()) + + +class Recording(BaseModel): + id: str = Field(default_factory=generate_uuid4) + bucket_name: str + object_key: str + recorded_at: datetime + status: Literal["pending", "processing", "completed", "failed"] = "pending" + meeting_id: str | None = None + + +class RecordingController: + async def create(self, recording: Recording): + query = recordings.insert().values(**recording.model_dump()) + await database.execute(query) + return recording + + async def get_by_id(self, id: str) -> Recording: + query = recordings.select().where(recordings.c.id == id) + result = await database.fetch_one(query) + return Recording(**result) if result else None + + async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording: + query = recordings.select().where( + recordings.c.bucket_name == bucket_name, + recordings.c.object_key == object_key, + ) + result = await database.fetch_one(query) + return Recording(**result) if result else None + + +recordings_controller = RecordingController() diff --git a/server/reflector/db/transcripts.py b/server/reflector/db/transcripts.py index 26ed9386..b9ffe0d2 100644 --- a/server/reflector/db/transcripts.py +++ b/server/reflector/db/transcripts.py @@ -63,6 +63,7 @@ transcripts = sqlalchemy.Table( "meeting_id", sqlalchemy.String, ), + sqlalchemy.Column("recording_id", sqlalchemy.String, nullable=True), sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer, nullable=True), sqlalchemy.Column( "source_kind", @@ -165,6 +166,7 @@ class Transcript(BaseModel): audio_location: str = "local" reviewed: bool = False meeting_id: str | None = None + recording_id: str | None = None zulip_message_id: int | None = None source_kind: SourceKind @@ -325,14 +327,17 @@ class TranscriptController: - `search_term`: filter transcripts by search term """ from reflector.db.meetings import meetings + from reflector.db.recordings import recordings from reflector.db.rooms import rooms query = ( transcripts.select() - .join(meetings, transcripts.c.meeting_id == meetings.c.id, isouter=True) + .join( + recordings, transcripts.c.recording_id == recordings.c.id, isouter=True + ) + .join(meetings, recordings.c.meeting_id == meetings.c.id, isouter=True) .join(rooms, meetings.c.room_id == rooms.c.id, isouter=True) ) - if user_id: query = query.where( or_(transcripts.c.user_id == user_id, rooms.c.is_shared) @@ -387,11 +392,13 @@ class TranscriptController: return None return Transcript(**result) - async def get_by_meeting_id(self, meeting_id: str, **kwargs) -> Transcript | None: + async def get_by_recording_id( + self, recording_id: str, **kwargs + ) -> Transcript | None: """ - Get a transcript by meeting_id + Get a transcript by recording_id """ - query = transcripts.select().where(transcripts.c.meeting_id == meeting_id) + query = transcripts.select().where(transcripts.c.recording_id == recording_id) if "user_id" in kwargs: query = query.where(transcripts.c.user_id == kwargs["user_id"]) result = await database.fetch_one(query) @@ -447,7 +454,7 @@ class TranscriptController: source_language: str = "en", target_language: str = "en", user_id: str | None = None, - meeting_id: str | None = None, + recording_id: str | None = None, share_mode: str = "private", ): """ @@ -459,7 +466,7 @@ class TranscriptController: source_language=source_language, target_language=target_language, user_id=user_id, - meeting_id=meeting_id, + recording_id=recording_id, share_mode=share_mode, ) query = transcripts.insert().values(**transcript.model_dump()) @@ -497,11 +504,11 @@ class TranscriptController: query = transcripts.delete().where(transcripts.c.id == transcript_id) await database.execute(query) - async def remove_by_meeting_id(self, meeting_id: str): + async def remove_by_recording_id(self, recording_id: str): """ - Remove a transcript by meeting_id + Remove a transcript by recording_id """ - query = transcripts.delete().where(transcripts.c.meeting_id == meeting_id) + query = transcripts.delete().where(transcripts.c.recording_id == recording_id) await database.execute(query) @asynccontextmanager diff --git a/server/reflector/pipelines/main_live_pipeline.py b/server/reflector/pipelines/main_live_pipeline.py index 8ce8d6bd..b42bcbad 100644 --- a/server/reflector/pipelines/main_live_pipeline.py +++ b/server/reflector/pipelines/main_live_pipeline.py @@ -18,6 +18,7 @@ from contextlib import asynccontextmanager from celery import chord, group, shared_task from pydantic import BaseModel from reflector.db.meetings import meetings_controller +from reflector.db.recordings import recordings_controller from reflector.db.rooms import rooms_controller from reflector.db.transcripts import ( Transcript, @@ -561,13 +562,22 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger): async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger): logger.info("Starting post to zulip") - if not transcript.meeting_id: - logger.info("Transcript has no meeting") + if not transcript.recording_id: + logger.info("Transcript has no recording") return - meeting = await meetings_controller.get_by_id(transcript.meeting_id) + recording = await recordings_controller.get_by_id(transcript.recording_id) + if not recording: + logger.info("Recording not found") + return + + if not recording.meeting_id: + logger.info("Recording has no meeting") + return + + meeting = await meetings_controller.get_by_id(recording.meeting_id) if not meeting: - logger.info("No meeting found for this transcript") + logger.info("No meeting found for this recording") return room = await rooms_controller.get_by_id(meeting.room_id) diff --git a/server/reflector/worker/process.py b/server/reflector/worker/process.py index eb9eb4a3..85e249e5 100644 --- a/server/reflector/worker/process.py +++ b/server/reflector/worker/process.py @@ -10,6 +10,7 @@ from celery import shared_task from celery.utils.log import get_task_logger from pydantic import ValidationError from reflector.db.meetings import meetings_controller +from reflector.db.recordings import Recording, recordings_controller from reflector.db.rooms import rooms_controller from reflector.db.transcripts import SourceKind, transcripts_controller from reflector.pipelines.main_live_pipeline import asynctask, task_pipeline_process @@ -65,12 +66,25 @@ def process_messages(): async def process_recording(bucket_name: str, object_key: str): logger.info("Processing recording: %s/%s", bucket_name, object_key) - # extract a guid from the object key + # extract a guid and a datetime from the object key room_name = f"/{object_key[:36]}" + recorded_at = datetime.fromisoformat(object_key[37:57]) + meeting = await meetings_controller.get_by_room_name(room_name) room = await rooms_controller.get_by_id(meeting.room_id) - transcript = await transcripts_controller.get_by_meeting_id(meeting.id) + recording = await recordings_controller.get_by_object_key(bucket_name, object_key) + if not recording: + recording = await recordings_controller.create( + Recording( + bucket_name=bucket_name, + object_key=object_key, + recorded_at=recorded_at, + meeting_id=meeting.id, + ) + ) + + transcript = await transcripts_controller.get_by_recording_id(recording.id) if transcript: await transcripts_controller.update( transcript, @@ -85,7 +99,7 @@ async def process_recording(bucket_name: str, object_key: str): source_language="en", target_language="en", user_id=room.user_id, - meeting_id=meeting.id, + recording_id=recording.id, share_mode="public", ) @@ -155,10 +169,10 @@ async def reprocess_failed_recordings(): ) reprocessed_count = 0 - try: paginator = s3.get_paginator("list_objects_v2") - pages = paginator.paginate(Bucket=settings.AWS_WHEREBY_S3_BUCKET) + bucket_name = settings.AWS_WHEREBY_S3_BUCKET + pages = paginator.paginate(Bucket=bucket_name) for page in pages: if "Contents" not in page: @@ -170,33 +184,29 @@ async def reprocess_failed_recordings(): if not (object_key.endswith(".mp4")): continue - room_name = f"/{object_key[:36]}" - meeting = await meetings_controller.get_by_room_name(room_name) - if not meeting: - logger.warning(f"No meeting found for recording: {object_key}") - continue - - room = await rooms_controller.get_by_id(meeting.room_id) - if not room: - logger.warning(f"No room found for meeting: {meeting.id}") + recording = await recordings_controller.get_by_object_key( + bucket_name, object_key + ) + if not recording: + logger.info(f"Queueing recording for processing: {object_key}") + process_recording.delay(bucket_name, object_key) + reprocessed_count += 1 continue transcript = None try: - transcript = await transcripts_controller.get_by_meeting_id( - meeting.id + transcript = await transcripts_controller.get_by_recording_id( + recording.id ) except ValidationError: - await transcripts_controller.remove_by_meeting_id(meeting.id) + await transcripts_controller.remove_by_recording_id(recording.id) logger.warning( - f"Removed invalid transcript for meeting: {meeting.id}" + f"Removed invalid transcript for recording: {recording.id}" ) if transcript is None or transcript.status == "error": - logger.info( - f"Queueing recording for processing: {object_key}, meeting {meeting.id}" - ) - process_recording.delay(settings.AWS_WHEREBY_S3_BUCKET, object_key) + logger.info(f"Queueing recording for processing: {object_key}") + process_recording.delay(bucket_name, object_key) reprocessed_count += 1 except Exception as e: diff --git a/server/runserver.sh b/server/runserver.sh index 31cce123..7b4cf141 100755 --- a/server/runserver.sh +++ b/server/runserver.sh @@ -3,9 +3,9 @@ if [ -f "/venv/bin/activate" ]; then source /venv/bin/activate fi -alembic upgrade head if [ "${ENTRYPOINT}" = "server" ]; then + alembic upgrade head python -m reflector.app elif [ "${ENTRYPOINT}" = "worker" ]; then celery -A reflector.worker.app worker --loglevel=info