Merge pull request #464 from Monadical-SAS/fix-transcription-linking

Add recordings
This commit is contained in:
2025-03-11 15:36:10 +01:00
committed by GitHub
7 changed files with 256 additions and 37 deletions

View File

@@ -0,0 +1,132 @@
"""Add recordings
Revision ID: d3ff3a39297f
Revises: b0e5f7876032
Create Date: 2025-03-10 14:38:53.504413
"""
import uuid
from datetime import datetime
from typing import Sequence, Union
import boto3
import sqlalchemy as sa
from alembic import op
from reflector.db.meetings import meetings
from reflector.db.recordings import Recording, recordings
from reflector.db.rooms import rooms
from reflector.db.transcripts import transcripts
from reflector.settings import settings
# revision identifiers, used by Alembic.
revision: str = "d3ff3a39297f"
down_revision: Union[str, None] = "b0e5f7876032"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def add_recordings_from_s3():
bind = op.get_bind()
s3 = boto3.client(
"s3",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
bucket_name = settings.AWS_WHEREBY_S3_BUCKET
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket_name)
for page in pages:
if "Contents" not in page:
continue
for obj in page["Contents"]:
object_key = obj["Key"]
if not (object_key.endswith(".mp4")):
continue
room_name = f"/{object_key[:36]}"
recorded_at = datetime.fromisoformat(object_key[37:57])
meeting = bind.execute(
meetings.select().where(meetings.c.room_name == room_name)
).fetchone()
recording = Recording(
id=str(uuid.uuid4()),
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting["id"],
)
bind.execute(recordings.insert().values(recording.model_dump()))
def link_transcripts_to_recordings():
bind = op.get_bind()
room_transcripts = bind.execute(
transcripts.select()
.where(transcripts.c.meeting_id.isnot(None))
.order_by(transcripts.c.meeting_id, transcripts.c.created_at)
).fetchall()
for transcript in room_transcripts:
transcript_recordings = bind.execute(
recordings.select()
.where(
recordings.c.meeting_id == transcript["meeting_id"],
)
.order_by(recordings.c.recorded_at.desc())
).fetchall()
if len(transcript_recordings) == 1:
bind.execute(
transcripts.update()
.where(transcripts.c.id == transcript["id"])
.values(recording_id=transcript_recordings[0]["id"])
)
elif len(transcript_recordings) > 1:
matched_recording = next(
(
r
for r in transcript_recordings
if r["recorded_at"] <= transcript["created_at"]
),
None,
)
bind.execute(
transcripts.update()
.where(transcripts.c.id == transcript["id"])
.values(recording_id=matched_recording["id"])
)
def delete_recordings():
bind = op.get_bind()
bind.execute(recordings.delete())
def upgrade() -> None:
with op.batch_alter_table("recording", schema=None) as batch_op:
batch_op.create_unique_constraint(
"uq_recording_object_key",
["bucket_name", "object_key"],
)
op.add_column("transcript", sa.Column("recording_id", sa.String(), nullable=True))
add_recordings_from_s3()
link_transcripts_to_recordings()
def downgrade() -> None:
with op.batch_alter_table("recording", schema=None) as batch_op:
batch_op.drop_constraint("uq_recording_object_key", type_="unique")
op.drop_column("transcript", "recording_id")
delete_recordings()

View File

@@ -8,6 +8,7 @@ metadata = sqlalchemy.MetaData()
# import models
import reflector.db.meetings # noqa
import reflector.db.recordings # noqa
import reflector.db.rooms # noqa
import reflector.db.transcripts # noqa

View File

@@ -0,0 +1,59 @@
from datetime import datetime
from typing import Literal
from uuid import uuid4
import sqlalchemy as sa
from pydantic import BaseModel, Field
from reflector.db import database, metadata
recordings = sa.Table(
"recording",
metadata,
sa.Column("id", sa.String, primary_key=True),
sa.Column("bucket_name", sa.String, nullable=False),
sa.Column("object_key", sa.String, nullable=False),
sa.Column("recorded_at", sa.DateTime, nullable=False),
sa.Column(
"status",
sa.String,
nullable=False,
server_default="pending",
),
sa.Column("meeting_id", sa.String),
)
def generate_uuid4() -> str:
return str(uuid4())
class Recording(BaseModel):
id: str = Field(default_factory=generate_uuid4)
bucket_name: str
object_key: str
recorded_at: datetime
status: Literal["pending", "processing", "completed", "failed"] = "pending"
meeting_id: str | None = None
class RecordingController:
async def create(self, recording: Recording):
query = recordings.insert().values(**recording.model_dump())
await database.execute(query)
return recording
async def get_by_id(self, id: str) -> Recording:
query = recordings.select().where(recordings.c.id == id)
result = await database.fetch_one(query)
return Recording(**result) if result else None
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
query = recordings.select().where(
recordings.c.bucket_name == bucket_name,
recordings.c.object_key == object_key,
)
result = await database.fetch_one(query)
return Recording(**result) if result else None
recordings_controller = RecordingController()

View File

@@ -63,6 +63,7 @@ transcripts = sqlalchemy.Table(
"meeting_id",
sqlalchemy.String,
),
sqlalchemy.Column("recording_id", sqlalchemy.String, nullable=True),
sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer, nullable=True),
sqlalchemy.Column(
"source_kind",
@@ -165,6 +166,7 @@ class Transcript(BaseModel):
audio_location: str = "local"
reviewed: bool = False
meeting_id: str | None = None
recording_id: str | None = None
zulip_message_id: int | None = None
source_kind: SourceKind
@@ -325,14 +327,17 @@ class TranscriptController:
- `search_term`: filter transcripts by search term
"""
from reflector.db.meetings import meetings
from reflector.db.recordings import recordings
from reflector.db.rooms import rooms
query = (
transcripts.select()
.join(meetings, transcripts.c.meeting_id == meetings.c.id, isouter=True)
.join(
recordings, transcripts.c.recording_id == recordings.c.id, isouter=True
)
.join(meetings, recordings.c.meeting_id == meetings.c.id, isouter=True)
.join(rooms, meetings.c.room_id == rooms.c.id, isouter=True)
)
if user_id:
query = query.where(
or_(transcripts.c.user_id == user_id, rooms.c.is_shared)
@@ -387,11 +392,13 @@ class TranscriptController:
return None
return Transcript(**result)
async def get_by_meeting_id(self, meeting_id: str, **kwargs) -> Transcript | None:
async def get_by_recording_id(
self, recording_id: str, **kwargs
) -> Transcript | None:
"""
Get a transcript by meeting_id
Get a transcript by recording_id
"""
query = transcripts.select().where(transcripts.c.meeting_id == meeting_id)
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
@@ -447,7 +454,7 @@ class TranscriptController:
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
meeting_id: str | None = None,
recording_id: str | None = None,
share_mode: str = "private",
):
"""
@@ -459,7 +466,7 @@ class TranscriptController:
source_language=source_language,
target_language=target_language,
user_id=user_id,
meeting_id=meeting_id,
recording_id=recording_id,
share_mode=share_mode,
)
query = transcripts.insert().values(**transcript.model_dump())
@@ -497,11 +504,11 @@ class TranscriptController:
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
async def remove_by_meeting_id(self, meeting_id: str):
async def remove_by_recording_id(self, recording_id: str):
"""
Remove a transcript by meeting_id
Remove a transcript by recording_id
"""
query = transcripts.delete().where(transcripts.c.meeting_id == meeting_id)
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
await database.execute(query)
@asynccontextmanager

View File

@@ -18,6 +18,7 @@ from contextlib import asynccontextmanager
from celery import chord, group, shared_task
from pydantic import BaseModel
from reflector.db.meetings import meetings_controller
from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import (
Transcript,
@@ -561,13 +562,22 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
logger.info("Starting post to zulip")
if not transcript.meeting_id:
logger.info("Transcript has no meeting")
if not transcript.recording_id:
logger.info("Transcript has no recording")
return
meeting = await meetings_controller.get_by_id(transcript.meeting_id)
recording = await recordings_controller.get_by_id(transcript.recording_id)
if not recording:
logger.info("Recording not found")
return
if not recording.meeting_id:
logger.info("Recording has no meeting")
return
meeting = await meetings_controller.get_by_id(recording.meeting_id)
if not meeting:
logger.info("No meeting found for this transcript")
logger.info("No meeting found for this recording")
return
room = await rooms_controller.get_by_id(meeting.room_id)

View File

@@ -10,6 +10,7 @@ from celery import shared_task
from celery.utils.log import get_task_logger
from pydantic import ValidationError
from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.pipelines.main_live_pipeline import asynctask, task_pipeline_process
@@ -65,12 +66,25 @@ def process_messages():
async def process_recording(bucket_name: str, object_key: str):
logger.info("Processing recording: %s/%s", bucket_name, object_key)
# extract a guid from the object key
# extract a guid and a datetime from the object key
room_name = f"/{object_key[:36]}"
recorded_at = datetime.fromisoformat(object_key[37:57])
meeting = await meetings_controller.get_by_room_name(room_name)
room = await rooms_controller.get_by_id(meeting.room_id)
transcript = await transcripts_controller.get_by_meeting_id(meeting.id)
recording = await recordings_controller.get_by_object_key(bucket_name, object_key)
if not recording:
recording = await recordings_controller.create(
Recording(
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting.id,
)
)
transcript = await transcripts_controller.get_by_recording_id(recording.id)
if transcript:
await transcripts_controller.update(
transcript,
@@ -85,7 +99,7 @@ async def process_recording(bucket_name: str, object_key: str):
source_language="en",
target_language="en",
user_id=room.user_id,
meeting_id=meeting.id,
recording_id=recording.id,
share_mode="public",
)
@@ -155,10 +169,10 @@ async def reprocess_failed_recordings():
)
reprocessed_count = 0
try:
paginator = s3.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=settings.AWS_WHEREBY_S3_BUCKET)
bucket_name = settings.AWS_WHEREBY_S3_BUCKET
pages = paginator.paginate(Bucket=bucket_name)
for page in pages:
if "Contents" not in page:
@@ -170,33 +184,29 @@ async def reprocess_failed_recordings():
if not (object_key.endswith(".mp4")):
continue
room_name = f"/{object_key[:36]}"
meeting = await meetings_controller.get_by_room_name(room_name)
if not meeting:
logger.warning(f"No meeting found for recording: {object_key}")
continue
room = await rooms_controller.get_by_id(meeting.room_id)
if not room:
logger.warning(f"No room found for meeting: {meeting.id}")
recording = await recordings_controller.get_by_object_key(
bucket_name, object_key
)
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
transcript = None
try:
transcript = await transcripts_controller.get_by_meeting_id(
meeting.id
transcript = await transcripts_controller.get_by_recording_id(
recording.id
)
except ValidationError:
await transcripts_controller.remove_by_meeting_id(meeting.id)
await transcripts_controller.remove_by_recording_id(recording.id)
logger.warning(
f"Removed invalid transcript for meeting: {meeting.id}"
f"Removed invalid transcript for recording: {recording.id}"
)
if transcript is None or transcript.status == "error":
logger.info(
f"Queueing recording for processing: {object_key}, meeting {meeting.id}"
)
process_recording.delay(settings.AWS_WHEREBY_S3_BUCKET, object_key)
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
except Exception as e:

View File

@@ -3,9 +3,9 @@
if [ -f "/venv/bin/activate" ]; then
source /venv/bin/activate
fi
alembic upgrade head
if [ "${ENTRYPOINT}" = "server" ]; then
alembic upgrade head
python -m reflector.app
elif [ "${ENTRYPOINT}" = "worker" ]; then
celery -A reflector.worker.app worker --loglevel=info