mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 04:39:06 +00:00
server: add transcript verification on get_by_id
This commit is contained in:
@@ -36,3 +36,4 @@ repos:
|
|||||||
- id: isort
|
- id: isort
|
||||||
name: isort (python)
|
name: isort (python)
|
||||||
files: ^server/(gpu|evaluate|reflector)/
|
files: ^server/(gpu|evaluate|reflector)/
|
||||||
|
args: ["--profile", "black", "--filter-files"]
|
||||||
|
|||||||
@@ -1,27 +1,28 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
|
from typing import Annotated, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import av
|
||||||
|
import reflector.auth as auth
|
||||||
from fastapi import (
|
from fastapi import (
|
||||||
APIRouter,
|
APIRouter,
|
||||||
|
Depends,
|
||||||
HTTPException,
|
HTTPException,
|
||||||
Request,
|
Request,
|
||||||
WebSocket,
|
WebSocket,
|
||||||
WebSocketDisconnect,
|
WebSocketDisconnect,
|
||||||
Depends,
|
|
||||||
)
|
)
|
||||||
from fastapi.responses import FileResponse
|
from fastapi.responses import FileResponse
|
||||||
from starlette.concurrency import run_in_threadpool
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from uuid import uuid4
|
|
||||||
from datetime import datetime
|
|
||||||
from fastapi_pagination import Page, paginate
|
from fastapi_pagination import Page, paginate
|
||||||
from reflector.logger import logger
|
from pydantic import BaseModel, Field
|
||||||
from reflector.db import database, transcripts
|
from reflector.db import database, transcripts
|
||||||
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
import reflector.auth as auth
|
from starlette.concurrency import run_in_threadpool
|
||||||
from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent
|
|
||||||
from typing import Annotated, Optional
|
|
||||||
from pathlib import Path
|
|
||||||
from tempfile import NamedTemporaryFile
|
|
||||||
import av
|
|
||||||
|
|
||||||
|
from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -139,15 +140,13 @@ class TranscriptController:
|
|||||||
results = await database.fetch_all(query)
|
results = await database.fetch_all(query)
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||||
self, transcript_id: str, user_id: str | None = None
|
|
||||||
) -> Transcript | None:
|
|
||||||
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||||
|
if "user_id" in kwargs:
|
||||||
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||||
result = await database.fetch_one(query)
|
result = await database.fetch_one(query)
|
||||||
if not result:
|
if not result:
|
||||||
return None
|
return None
|
||||||
if user_id is not None and result["user_id"] != user_id:
|
|
||||||
return None
|
|
||||||
return Transcript(**result)
|
return Transcript(**result)
|
||||||
|
|
||||||
async def add(self, name: str, user_id: str | None = None):
|
async def add(self, name: str, user_id: str | None = None):
|
||||||
@@ -169,7 +168,7 @@ class TranscriptController:
|
|||||||
async def remove_by_id(
|
async def remove_by_id(
|
||||||
self, transcript_id: str, user_id: str | None = None
|
self, transcript_id: str, user_id: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
transcript = await self.get_by_id(transcript_id)
|
transcript = await self.get_by_id(transcript_id, user_id=user_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
return
|
return
|
||||||
if user_id is not None and transcript.user_id != user_id:
|
if user_id is not None and transcript.user_id != user_id:
|
||||||
@@ -282,8 +281,12 @@ async def transcript_delete(
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}/audio")
|
@router.get("/transcripts/{transcript_id}/audio")
|
||||||
async def transcript_get_audio(transcript_id: str):
|
async def transcript_get_audio(
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
transcript_id: str,
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
):
|
||||||
|
user_id = user["sub"] if user else None
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
@@ -294,8 +297,12 @@ async def transcript_get_audio(transcript_id: str):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
@router.get("/transcripts/{transcript_id}/audio/mp3")
|
||||||
async def transcript_get_audio_mp3(transcript_id: str):
|
async def transcript_get_audio_mp3(
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
transcript_id: str,
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
):
|
||||||
|
user_id = user["sub"] if user else None
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
@@ -308,8 +315,12 @@ async def transcript_get_audio_mp3(transcript_id: str):
|
|||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
@router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic])
|
||||||
async def transcript_get_topics(transcript_id: str):
|
async def transcript_get_topics(
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
transcript_id: str,
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
|
):
|
||||||
|
user_id = user["sub"] if user else None
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
return transcript.topics
|
return transcript.topics
|
||||||
@@ -457,9 +468,13 @@ async def handle_rtc_event(event: PipelineEvent, args, data):
|
|||||||
|
|
||||||
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
@router.post("/transcripts/{transcript_id}/record/webrtc")
|
||||||
async def transcript_record_webrtc(
|
async def transcript_record_webrtc(
|
||||||
transcript_id: str, params: RtcOffer, request: Request
|
transcript_id: str,
|
||||||
|
params: RtcOffer,
|
||||||
|
request: Request,
|
||||||
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
):
|
):
|
||||||
transcript = await transcripts_controller.get_by_id(transcript_id)
|
user_id = user["sub"] if user else None
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user