Files
reflector/server/reflector/views/transcripts.py
2023-11-20 20:13:18 +01:00

364 lines
12 KiB
Python

from datetime import datetime, timedelta
from typing import Annotated, Optional
import reflector.auth as auth
from fastapi import (
APIRouter,
Depends,
HTTPException,
Request,
WebSocket,
WebSocketDisconnect,
status,
)
from fastapi_pagination import Page, paginate
from jose import jwt
from pydantic import BaseModel, Field
from reflector.db.transcripts import (
AudioWaveform,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.settings import settings
from reflector.ws_manager import get_ws_manager
from ._range_requests_response import range_requests_response
from .rtc_offer import RtcOffer, rtc_offer_base
router = APIRouter()
ALGORITHM = "HS256"
DOWNLOAD_EXPIRE_MINUTES = 60
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
# ==============================================================
# Transcripts list
# ==============================================================
class GetTranscript(BaseModel):
id: str
name: str
status: str
locked: bool
duration: float
title: str | None
short_summary: str | None
long_summary: str | None
created_at: datetime
source_language: str | None
target_language: str | None
class CreateTranscript(BaseModel):
name: str
source_language: str = Field("en")
target_language: str = Field("en")
class UpdateTranscript(BaseModel):
name: Optional[str] = Field(None)
locked: Optional[bool] = Field(None)
title: Optional[str] = Field(None)
short_summary: Optional[str] = Field(None)
long_summary: Optional[str] = Field(None)
class DeletionStatus(BaseModel):
status: str
@router.get("/transcripts", response_model=Page[GetTranscript])
async def transcripts_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
return paginate(
await transcripts_controller.get_all(user_id=user_id, order_by="-created_at")
)
@router.post("/transcripts", response_model=GetTranscript)
async def transcripts_create(
info: CreateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
return await transcripts_controller.add(
info.name,
source_language=info.source_language,
target_language=info.target_language,
user_id=user_id,
)
# ==============================================================
# Single transcript
# ==============================================================
class GetTranscriptSegmentTopic(BaseModel):
text: str
start: float
speaker: int
class GetTranscriptTopic(BaseModel):
id: str
title: str
summary: str
timestamp: float
transcript: str
segments: list[GetTranscriptSegmentTopic] = []
@classmethod
def from_transcript_topic(cls, topic: TranscriptTopic):
if not topic.words:
# In previous version, words were missing
# Just output a segment with speaker 0
text = topic.transcript
segments = [
GetTranscriptSegmentTopic(
text=topic.transcript,
start=topic.timestamp,
speaker=0,
)
]
else:
# New versions include words
transcript = ProcessorTranscript(words=topic.words)
text = transcript.text
segments = [
GetTranscriptSegmentTopic(
text=segment.text,
start=segment.start,
speaker=segment.speaker,
)
for segment in transcript.as_segments()
]
return cls(
id=topic.id,
title=topic.title,
summary=topic.summary,
timestamp=topic.timestamp,
transcript=text,
segments=segments,
)
@router.get("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_get(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
return transcript
@router.patch("/transcripts/{transcript_id}", response_model=GetTranscript)
async def transcript_update(
transcript_id: str,
info: UpdateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
values = {}
if info.name is not None:
values["name"] = info.name
if info.locked is not None:
values["locked"] = info.locked
if info.long_summary is not None:
values["long_summary"] = info.long_summary
if info.short_summary is not None:
values["short_summary"] = info.short_summary
if info.title is not None:
values["title"] = info.title
await transcripts_controller.update(transcript, values)
return transcript
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
async def transcript_delete(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
return DeletionStatus(status="ok")
@router.get("/transcripts/{transcript_id}/audio/mp3")
@router.head("/transcripts/{transcript_id}/audio/mp3")
async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
token: str | None = None,
):
user_id = user["sub"] if user else None
if not user_id and token:
unauthorized_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
user_id: str = payload.get("sub")
except jwt.JWTError:
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
truncated_id = str(transcript.id).split("-")[0]
filename = f"recording_{truncated_id}.mp3"
return range_requests_response(
request,
transcript.audio_mp3_filename,
content_type="audio/mpeg",
content_disposition=f"attachment; filename={filename}",
)
@router.get("/transcripts/{transcript_id}/audio/waveform")
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcript.audio_mp3_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
return transcript.audio_waveform
@router.get(
"/transcripts/{transcript_id}/topics",
response_model=list[GetTranscriptTopic],
)
async def transcript_get_topics(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# convert to GetTranscriptTopic
return [
GetTranscriptTopic.from_transcript_topic(topic) for topic in transcript.topics
]
# ==============================================================
# Websocket
# ==============================================================
@router.get("/transcripts/{transcript_id}/events")
async def transcript_get_websocket_events(transcript_id: str):
pass
@router.websocket("/transcripts/{transcript_id}/events")
async def transcript_events_websocket(
transcript_id: str,
websocket: WebSocket,
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
# user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
# connect to websocket manager
# use ts:transcript_id as room id
room_id = f"ts:{transcript_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(room_id, websocket)
try:
# on first connection, send all events only to the current user
for event in transcript.events:
# for now, do not send TRANSCRIPT or STATUS options - theses are live event
# not necessary to be sent to the client; but keep the rest
name = event.event
if name in ("TRANSCRIPT", "STATUS"):
continue
await websocket.send_json(event.model_dump(mode="json"))
# XXX if transcript is final (locked=True and status=ended)
# XXX send a final event to the client and close the connection
# endless loop to wait for new events
# we do not have command system now,
while True:
await websocket.receive()
except (RuntimeError, WebSocketDisconnect):
await ws_manager.remove_user_from_room(room_id, websocket)
# ==============================================================
# Web RTC
# ==============================================================
@router.post("/transcripts/{transcript_id}/record/webrtc")
async def transcript_record_webrtc(
transcript_id: str,
params: RtcOffer,
request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(transcript_id, user_id=user_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if transcript.locked:
raise HTTPException(status_code=400, detail="Transcript is locked")
# create a pipeline runner
from reflector.pipelines.main_live_pipeline import PipelineMainLive
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)
# FIXME do not allow multiple recording at the same time
return await rtc_offer_base(
params,
request,
pipeline_runner=pipeline_runner,
)