www/server: introduce share mode

This commit is contained in:
2023-11-07 12:39:48 +01:00
parent 6282583d92
commit 226b92c347
8 changed files with 228 additions and 34 deletions

View File

@@ -1,5 +1,5 @@
from datetime import datetime, timedelta
from typing import Annotated, Optional
from typing import Annotated, Literal, Optional
import reflector.auth as auth
from fastapi import (
@@ -11,7 +11,8 @@ from fastapi import (
WebSocketDisconnect,
status,
)
from fastapi_pagination import Page, paginate
from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate
from jose import jwt
from pydantic import BaseModel, Field
from reflector.db.transcripts import (
@@ -48,6 +49,7 @@ def create_access_token(data: dict, expires_delta: timedelta):
class GetTranscript(BaseModel):
id: str
user_id: str | None
name: str
status: str
locked: bool
@@ -56,6 +58,7 @@ class GetTranscript(BaseModel):
short_summary: str | None
long_summary: str | None
created_at: datetime
share_mode: str = Field("private")
source_language: str | None
target_language: str | None
@@ -72,6 +75,7 @@ class UpdateTranscript(BaseModel):
title: Optional[str] = Field(None)
short_summary: Optional[str] = Field(None)
long_summary: Optional[str] = Field(None)
share_mode: Optional[Literal["public", "semi-private", "private"]] = Field(None)
class DeletionStatus(BaseModel):
@@ -82,12 +86,19 @@ class DeletionStatus(BaseModel):
async def transcripts_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
from reflector.db import database
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
return paginate(
await transcripts_controller.get_all(user_id=user_id, order_by="-created_at")
return await paginate(
database,
await transcripts_controller.get_all(
user_id=user_id,
order_by="-created_at",
return_query=True,
),
)
@@ -165,10 +176,9 @@ async def transcript_get(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript
return await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
@@ -192,6 +202,8 @@ async def transcript_update(
values["short_summary"] = info.short_summary
if info.title is not None:
values["title"] = info.title
if info.share_mode is not None:
values["share_mode"] = info.share_mode
await transcripts_controller.update(transcript, values)
return transcript
@@ -229,12 +241,12 @@ async def transcript_get_audio_mp3(
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
raise HTTPException(status_code=500, detail="Audio not found")
truncated_id = str(transcript.id).split("-")[0]
filename = f"recording_{truncated_id}.mp3"
@@ -253,12 +265,12 @@ async def transcript_get_audio_waveform(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
raise HTTPException(status_code=500, detail="Audio not found")
await run_in_threadpool(transcript.convert_audio_to_waveform)
@@ -274,9 +286,9 @@ async def transcript_get_topics(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
# convert to GetTranscriptTopic
return [
@@ -345,9 +357,9 @@ async def transcript_record_webrtc(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")