server: add transcript verification on get_by_id

This commit is contained in:
2023-08-18 12:39:19 +02:00
parent 2a3ad5657f
commit 0c93a39e33
2 changed files with 42 additions and 26 deletions

View File

@@ -1,27 +1,28 @@
from datetime import datetime
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Annotated, Optional
from uuid import uuid4
import av
import reflector.auth as auth
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
WebSocket,
WebSocketDisconnect,
Depends,
)
from fastapi.responses import FileResponse
from starlette.concurrency import run_in_threadpool
from pydantic import BaseModel, Field
from uuid import uuid4
from datetime import datetime
from fastapi_pagination import Page, paginate
from reflector.logger import logger
from pydantic import BaseModel, Field
from reflector.db import database, transcripts
from reflector.logger import logger
from reflector.settings import settings
import reflector.auth as auth
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
from typing import Annotated, Optional
from pathlib import Path
from tempfile import NamedTemporaryFile
import av
from starlette.concurrency import run_in_threadpool
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
router = APIRouter()
@@ -139,15 +140,13 @@ class TranscriptController:
results = await database.fetch_all(query)
return results
async def get_by_id(
self, transcript_id: str, user_id: str | None = None
) -> Transcript | None:
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
if not result:
return None
if user_id is not None and result["user_id"] != user_id:
return None
return Transcript(**result)
async def add(self, name: str, user_id: str | None = None):
@@ -169,7 +168,7 @@ class TranscriptController:
async def remove_by_id(
self, transcript_id: str, user_id: str | None = None
) -> None:
transcript = await self.get_by_id(transcript_id)
transcript = await self.get_by_id(transcript_id, user_id=user_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
@@ -282,8 +281,12 @@ async def transcript_delete(
@router.get("/transcripts/{transcript_id}/audio")
async def transcript_get_audio(transcript_id: str):
transcript = await transcripts_controller.get_by_id(transcript_id)
async def transcript_get_audio(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
@@ -294,8 +297,12 @@ async def transcript_get_audio(transcript_id: str):
@router.get("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(transcript_id: str):
transcript = await transcripts_controller.get_by_id(transcript_id)
async def transcript_get_audio_mp3(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
@@ -308,8 +315,12 @@ async def transcript_get_audio_mp3(transcript_id: str):
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
async def transcript_get_topics(transcript_id: str):
transcript = await transcripts_controller.get_by_id(transcript_id)
async def transcript_get_topics(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript.topics
@@ -457,9 +468,13 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str, params: RtcOffer, request: Request
transcript_id: str,
params: RtcOffer,
request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
transcript = await transcripts_controller.get_by_id(transcript_id)
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")