Add transcript source kind

This commit is contained in:
2024-10-04 16:38:29 +02:00
parent ebb32ee613
commit 39d02ab265
9 changed files with 159 additions and 13 deletions

View File

@@ -0,0 +1,48 @@
"""Add transcript source kind
Revision ID: 74b2b0236931
Revises: 0925da921477
Create Date: 2024-10-04 14:19:23.625447
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "74b2b0236931"
down_revision: Union[str, None] = "0925da921477"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"transcript",
sa.Column(
"source_kind",
sa.Enum("ROOM", "LIVE", "FILE", name="sourcekind"),
nullable=True,
),
)
op.execute(
"UPDATE transcript SET source_kind = 'room' WHERE meeting_id IS NOT NULL"
)
op.execute("UPDATE transcript SET source_kind = 'live' WHERE meeting_id IS NULL")
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.alter_column("source_kind", nullable=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "source_kind")
# ### end Alembic commands ###

View File

@@ -1,3 +1,4 @@
import enum
import json
import os
import shutil
@@ -14,8 +15,16 @@ from reflector.db import database, metadata
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.storage import Storage
from sqlalchemy import Enum
from sqlalchemy.sql import false
class SourceKind(enum.StrEnum):
ROOM = enum.auto()
LIVE = enum.auto()
FILE = enum.auto()
transcripts = sqlalchemy.Table(
"transcript",
metadata,
@@ -55,6 +64,11 @@ transcripts = sqlalchemy.Table(
sqlalchemy.String,
),
sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer, nullable=True),
sqlalchemy.Column(
"source_kind",
Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]),
nullable=False,
),
)
@@ -152,6 +166,7 @@ class Transcript(BaseModel):
reviewed: bool = False
meeting_id: str | None = None
zulip_message_id: int | None = None
source_kind: SourceKind
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
ev = TranscriptEvent(event=event, data=data.model_dump())
@@ -291,6 +306,7 @@ class TranscriptController:
order_by: str | None = None,
filter_empty: bool | None = False,
filter_recording: bool | None = False,
source_kind: SourceKind | None = None,
room_id: str | None = None,
search_term: str | None = None,
return_query: bool = False,
@@ -320,6 +336,9 @@ class TranscriptController:
if user_id:
query = query.where(transcripts.c.user_id == user_id)
if source_kind:
query = query.where(transcripts.c.source_kind == source_kind)
if room_id:
query = query.where(rooms.c.id == room_id)
@@ -422,6 +441,7 @@ class TranscriptController:
async def add(
self,
name: str,
source_kind: SourceKind,
source_language: str = "en",
target_language: str = "en",
user_id: str | None = None,
@@ -433,6 +453,7 @@ class TranscriptController:
"""
transcript = Transcript(
name=name,
source_kind=source_kind,
source_language=source_language,
target_language=target_language,
user_id=user_id,

View File

@@ -9,6 +9,7 @@ from jose import jwt
from pydantic import BaseModel, Field
from reflector.db.migrate_user import migrate_user
from reflector.db.transcripts import (
SourceKind,
TranscriptParticipant,
TranscriptTopic,
transcripts_controller,
@@ -61,6 +62,7 @@ class GetTranscript(BaseModel):
meeting_id: str | None
room_id: str | None
room_name: str | None
source_kind: SourceKind
class CreateTranscript(BaseModel):
@@ -89,6 +91,7 @@ async def transcripts_list(
room_id: str | None,
search_term: str | None,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
source_kind: SourceKind | None = None,
):
from reflector.db import database
@@ -105,6 +108,7 @@ async def transcripts_list(
database,
await transcripts_controller.get_all(
user_id=user_id,
source_kind=SourceKind(source_kind) if source_kind else None,
room_id=room_id,
search_term=search_term,
order_by="-created_at",
@@ -121,6 +125,7 @@ async def transcripts_create(
user_id = user["sub"] if user else None
return await transcripts_controller.add(
info.name,
source_kind=SourceKind.LIVE,
source_language=info.source_language,
target_language=info.target_language,
user_id=user_id,

View File

@@ -8,7 +8,7 @@ import structlog
from celery import shared_task
from celery.utils.log import get_task_logger
from reflector.db.meetings import meetings_controller
from reflector.db.transcripts import transcripts_controller
from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.pipelines.main_live_pipeline import asynctask, task_pipeline_process
from reflector.settings import settings
@@ -66,6 +66,7 @@ async def process_recording(bucket_name: str, object_key: str):
meeting = await meetings_controller.get_by_room_name(room_name)
transcript = await transcripts_controller.add(
"",
source_kind=SourceKind.ROOM
source_language="en",
target_language="en",
user_id=meeting.user_id,