mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-20 12:19:06 +00:00
www/server: introduce share mode
This commit is contained in:
30
server/migrations/versions/0fea6d96b096_add_share_mode.py
Normal file
30
server/migrations/versions/0fea6d96b096_add_share_mode.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""add share_mode
|
||||
|
||||
Revision ID: 0fea6d96b096
|
||||
Revises: 38a927dcb099
|
||||
Create Date: 2023-11-07 11:12:21.614198
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0fea6d96b096'
|
||||
down_revision: Union[str, None] = '38a927dcb099'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('transcript', sa.Column('share_mode', sa.String(), server_default='private', nullable=False))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('transcript', 'share_mode')
|
||||
# ### end Alembic commands ###
|
||||
@@ -2,10 +2,11 @@ import json
|
||||
from contextlib import asynccontextmanager
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from reflector.db import database, metadata
|
||||
from reflector.processors.types import Word as ProcessorWord
|
||||
@@ -30,6 +31,12 @@ transcripts = sqlalchemy.Table(
|
||||
sqlalchemy.Column("target_language", sqlalchemy.String, nullable=True),
|
||||
# with user attached, optional
|
||||
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||
sqlalchemy.Column(
|
||||
"share_mode",
|
||||
sqlalchemy.String,
|
||||
nullable=False,
|
||||
server_default="private",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -99,6 +106,7 @@ class Transcript(BaseModel):
|
||||
events: list[TranscriptEvent] = []
|
||||
source_language: str = "en"
|
||||
target_language: str = "en"
|
||||
share_mode: Literal["private", "semi-private", "public"] = "private"
|
||||
|
||||
def add_event(self, event: str, data: BaseModel) -> TranscriptEvent:
|
||||
ev = TranscriptEvent(event=event, data=data.model_dump())
|
||||
@@ -169,6 +177,7 @@ class TranscriptController:
|
||||
order_by: str | None = None,
|
||||
filter_empty: bool | None = False,
|
||||
filter_recording: bool | None = False,
|
||||
return_query: bool = False,
|
||||
) -> list[Transcript]:
|
||||
"""
|
||||
Get all transcripts
|
||||
@@ -195,6 +204,9 @@ class TranscriptController:
|
||||
if filter_recording:
|
||||
query = query.filter(transcripts.c.status != "recording")
|
||||
|
||||
if return_query:
|
||||
return query
|
||||
|
||||
results = await database.fetch_all(query)
|
||||
return results
|
||||
|
||||
@@ -210,6 +222,47 @@ class TranscriptController:
|
||||
return None
|
||||
return Transcript(**result)
|
||||
|
||||
async def get_by_id_for_http(
|
||||
self,
|
||||
transcript_id: str,
|
||||
user_id: str | None,
|
||||
) -> Transcript:
|
||||
"""
|
||||
Get a transcript by ID for HTTP request.
|
||||
|
||||
If not found, it will raise a 404 error.
|
||||
If the user is not allowed to access the transcript, it will raise a 403 error.
|
||||
|
||||
This method checks the share mode of the transcript and the user_id
|
||||
to determine if the user can access the transcript.
|
||||
"""
|
||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||
result = await database.fetch_one(query)
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
|
||||
# if the transcript is anonymous, share mode is not checked
|
||||
transcript = Transcript(**result)
|
||||
if transcript.user_id is None:
|
||||
return transcript
|
||||
|
||||
if transcript.share_mode == "private":
|
||||
# in private mode, only the owner can access the transcript
|
||||
if transcript.user_id == user_id:
|
||||
return transcript
|
||||
|
||||
elif transcript.share_mode == "semi-private":
|
||||
# in semi-private mode, only the owner and the users with the link
|
||||
# can access the transcript
|
||||
if user_id is not None:
|
||||
return transcript
|
||||
|
||||
elif transcript.share_mode == "public":
|
||||
# in public mode, everyone can access the transcript
|
||||
return transcript
|
||||
|
||||
raise HTTPException(status_code=403, detail="Transcript access denied")
|
||||
|
||||
async def add(
|
||||
self,
|
||||
name: str,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Optional
|
||||
from typing import Annotated, Literal, Optional
|
||||
|
||||
import reflector.auth as auth
|
||||
from fastapi import (
|
||||
@@ -11,7 +11,8 @@ from fastapi import (
|
||||
WebSocketDisconnect,
|
||||
status,
|
||||
)
|
||||
from fastapi_pagination import Page, paginate
|
||||
from fastapi_pagination import Page
|
||||
from fastapi_pagination.ext.databases import paginate
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
from reflector.db.transcripts import (
|
||||
@@ -48,6 +49,7 @@ def create_access_token(data: dict, expires_delta: timedelta):
|
||||
|
||||
class GetTranscript(BaseModel):
|
||||
id: str
|
||||
user_id: str | None
|
||||
name: str
|
||||
status: str
|
||||
locked: bool
|
||||
@@ -56,6 +58,7 @@ class GetTranscript(BaseModel):
|
||||
short_summary: str | None
|
||||
long_summary: str | None
|
||||
created_at: datetime
|
||||
share_mode: str = Field("private")
|
||||
source_language: str | None
|
||||
target_language: str | None
|
||||
|
||||
@@ -72,6 +75,7 @@ class UpdateTranscript(BaseModel):
|
||||
title: Optional[str] = Field(None)
|
||||
short_summary: Optional[str] = Field(None)
|
||||
long_summary: Optional[str] = Field(None)
|
||||
share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None)
|
||||
|
||||
|
||||
class DeletionStatus(BaseModel):
|
||||
@@ -82,12 +86,19 @@ class DeletionStatus(BaseModel):
|
||||
async def transcripts_list(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
from reflector.db import database
|
||||
|
||||
if not user and not settings.PUBLIC_MODE:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
user_id = user["sub"] if user else None
|
||||
return paginate(
|
||||
await transcripts_controller.get_all(user_id=user_id, order_by="-created_at")
|
||||
return await paginate(
|
||||
database,
|
||||
await transcripts_controller.get_all(
|
||||
user_id=user_id,
|
||||
order_by="-created_at",
|
||||
return_query=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -165,10 +176,9 @@ async def transcript_get(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
return transcript
|
||||
return await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
|
||||
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
|
||||
@@ -192,6 +202,8 @@ async def transcript_update(
|
||||
values["short_summary"] = info.short_summary
|
||||
if info.title is not None:
|
||||
values["title"] = info.title
|
||||
if info.share_mode is not None:
|
||||
values["share_mode"] = info.share_mode
|
||||
await transcripts_controller.update(transcript, values)
|
||||
return transcript
|
||||
|
||||
@@ -229,12 +241,12 @@ async def transcript_get_audio_mp3(
|
||||
except jwt.JWTError:
|
||||
raise unauthorized_exception
|
||||
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
raise HTTPException(status_code=500, detail="Audio not found")
|
||||
|
||||
truncated_id = str(transcript.id).split("-")[0]
|
||||
filename = f"recording_{truncated_id}.mp3"
|
||||
@@ -253,12 +265,12 @@ async def transcript_get_audio_waveform(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
) -> AudioWaveform:
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if not transcript.audio_mp3_filename.exists():
|
||||
raise HTTPException(status_code=404, detail="Audio not found")
|
||||
raise HTTPException(status_code=500, detail="Audio not found")
|
||||
|
||||
await run_in_threadpool(transcript.convert_audio_to_waveform)
|
||||
|
||||
@@ -274,9 +286,9 @@ async def transcript_get_topics(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
# convert to GetTranscriptTopic
|
||||
return [
|
||||
@@ -345,9 +357,9 @@ async def transcript_record_webrtc(
|
||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||
):
|
||||
user_id = user["sub"] if user else None
|
||||
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||
if not transcript:
|
||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||
transcript = await transcripts_controller.get_by_id_for_http(
|
||||
transcript_id, user_id=user_id
|
||||
)
|
||||
|
||||
if transcript.locked:
|
||||
raise HTTPException(status_code=400, detail="Transcript is locked")
|
||||
|
||||
Reference in New Issue
Block a user