diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2a73b7c2..3dcbe202 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,3 +36,4 @@ repos: - id: isort name: isort (python) files: ^server/(gpu|evaluate|reflector)/ + args: ["--profile", "black", "--filter-files"] diff --git a/server/reflector/views/transcripts.py b/server/reflector/views/transcripts.py index 29db9ec7..9a5c7dfe 100644 --- a/server/reflector/views/transcripts.py +++ b/server/reflector/views/transcripts.py @@ -1,27 +1,28 @@ +from datetime import datetime +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Annotated, Optional +from uuid import uuid4 + +import av +import reflector.auth as auth from fastapi import ( APIRouter, + Depends, HTTPException, Request, WebSocket, WebSocketDisconnect, - Depends, ) from fastapi.responses import FileResponse -from starlette.concurrency import run_in_threadpool -from pydantic import BaseModel, Field -from uuid import uuid4 -from datetime import datetime from fastapi_pagination import Page, paginate -from reflector.logger import logger +from pydantic import BaseModel, Field from reflector.db import database, transcripts +from reflector.logger import logger from reflector.settings import settings -import reflector.auth as auth -from .rtc_offer import rtc_offer_base, RtcOffer, PipelineEvent -from typing import Annotated, Optional -from pathlib import Path -from tempfile import NamedTemporaryFile -import av +from starlette.concurrency import run_in_threadpool +from .rtc_offer import PipelineEvent, RtcOffer, rtc_offer_base router = APIRouter() @@ -139,15 +140,13 @@ class TranscriptController: results = await database.fetch_all(query) return results - async def get_by_id( - self, transcript_id: str, user_id: str | None = None - ) -> Transcript | None: + async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None: query = transcripts.select().where(transcripts.c.id == transcript_id) + if "user_id" in kwargs: + query = query.where(transcripts.c.user_id == kwargs["user_id"]) result = await database.fetch_one(query) if not result: return None - if user_id is not None and result["user_id"] != user_id: - return None return Transcript(**result) async def add(self, name: str, user_id: str | None = None): @@ -169,7 +168,7 @@ class TranscriptController: async def remove_by_id( self, transcript_id: str, user_id: str | None = None ) -> None: - transcript = await self.get_by_id(transcript_id) + transcript = await self.get_by_id(transcript_id, user_id=user_id) if not transcript: return if user_id is not None and transcript.user_id != user_id: @@ -282,8 +281,12 @@ async def transcript_delete( @router.get("/transcripts/{transcript_id}/audio") -async def transcript_get_audio(transcript_id: str): - transcript = await transcripts_controller.get_by_id(transcript_id) +async def transcript_get_audio( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") @@ -294,8 +297,12 @@ async def transcript_get_audio(transcript_id: str): @router.get("/transcripts/{transcript_id}/audio/mp3") -async def transcript_get_audio_mp3(transcript_id: str): - transcript = await transcripts_controller.get_by_id(transcript_id) +async def transcript_get_audio_mp3( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") @@ -308,8 +315,12 @@ async def transcript_get_audio_mp3(transcript_id: str): @router.get("/transcripts/{transcript_id}/topics", response_model=list[TranscriptTopic]) -async def transcript_get_topics(transcript_id: str): - transcript = await transcripts_controller.get_by_id(transcript_id) +async def transcript_get_topics( + transcript_id: str, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], +): + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found") return transcript.topics @@ -457,9 +468,13 @@ async def handle_rtc_event(event: PipelineEvent, args, data): @router.post("/transcripts/{transcript_id}/record/webrtc") async def transcript_record_webrtc( - transcript_id: str, params: RtcOffer, request: Request + transcript_id: str, + params: RtcOffer, + request: Request, + user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], ): - transcript = await transcripts_controller.get_by_id(transcript_id) + user_id = user["sub"] if user else None + transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id) if not transcript: raise HTTPException(status_code=404, detail="Transcript not found")