Compare commits

..

20 Commits

Author SHA1 Message Date
34a3f5618c chore(main): release 0.17.0 (#717) 2025-11-12 21:25:59 -05:00
Igor Monadical
1473fd82dc feat: daily.co support as alternative to whereby (#691)
* llm instructions

* vibe dailyco

* vibe dailyco

* doc update (vibe)

* dont show recording ui on call

* stub processor (vibe)

* stub processor (vibe) self-review

* stub processor (vibe) self-review

* chore(main): release 0.14.0 (#670)

* Add multitrack pipeline

* Mixdown audio tracks

* Mixdown with pyav filter graph

* Trigger multitrack processing for daily recordings

* apply platform from envs in priority: non-dry

* Use explicit track keys for processing

* Align tracks of a multitrack recording

* Generate waveforms for the mixed audio

* Emit multriack pipeline events

* Fix multitrack pipeline track alignment

* dailico docs

* Enable multitrack reprocessing

* modal temp files uniform names, cleanup. remove llm temporary docs

* docs cleanup

* dont proceed with raw recordings if any of the downloads fail

* dry transcription pipelines

* remove is_miltitrack

* comments

* explicit dailyco room name

* docs

* remove stub data/method

* frontend daily/whereby code self-review (no-mistake)

* frontend daily/whereby code self-review (no-mistakes)

* frontend daily/whereby code self-review (no-mistakes)

* consent cleanup for multitrack (no-mistakes)

* llm fun

* remove extra comments

* fix tests

* merge migrations

* Store participant names

* Get participants by meeting session id

* pop back main branch migration

* s3 paddington (no-mistakes)

* comment

* pr comments

* pr comments

* pr comments

* platform / meeting cleanup

* Use participant names in summary generation

* platform assignment to meeting at controller level

* pr comment

* room playform properly default none

* room playform properly default none

* restore migration lost

* streaming WIP

* extract storage / use common storage / proper env vars for storage

* fix mocks tests

* remove fall back

* streaming for multifile

* cenrtal storage abstraction (no-mistakes)

* remove dead code / vars

* Set participant user id for authenticated users

* whereby recording name parsing fix

* whereby recording name parsing fix

* more file stream

* storage dry + tests

* remove homemade boto3 streaming and use proper boto

* update migration guide

* webhook creation script - print uuid

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
Co-authored-by: Mathieu Virbel <mat@meltingrocks.com>
Co-authored-by: Sergey Mankovsky <sergey@monadical.com>
2025-11-12 21:21:16 -05:00
372202b0e1 feat: add API key management UI (#716)
* feat: add API key management UI

- Created settings page for users to create, view, and delete API keys
- Added Settings link to app navigation header
- Fixed delete operation return value handling in backend to properly handle asyncpg's None response

* feat: replace browser confirm with dialog for API key deletion

- Added Chakra UI Dialog component for better UX when confirming API key deletion
- Implemented proper focus management with cancelRef for accessibility
- Replaced native browser confirm() with controlled dialog state

* style: format API keys page with consistent line breaks

* feat: auto-select API key text for easier copying

- Added automatic text selection after API key creation to streamline the copy workflow
- Applied className to Code component for DOM targeting

* feat: improve API keys page layout and responsiveness

- Reduced max width from 1200px to 800px for better readability
- Added explicit width constraint to ensure consistent sizing across viewports

* refactor: remove redundant comments from API keys page
2025-11-10 18:25:08 -05:00
Igor Monadical
d20aac66c4 ui search pagination 2+page re-search fix (#714)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-11-10 14:18:41 -05:00
dc4b737daa chore(main): release 0.16.0 (#711) 2025-10-24 16:18:49 -06:00
Igor Monadical
0baff7abf7 transcript ui copy button placement (#712)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-10-24 16:52:02 -04:00
Igor Monadical
962c40e2b6 feat: search date filter (#710)
* search date filter

* search date filter

* search date filter

* search date filter

* pr comment

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-10-23 20:16:43 -04:00
Igor Monadical
3c4b9f2103 chore: error reporting and naming (#708)
* chore: error reporting and naming

* chore: error reporting and naming

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-10-22 13:45:08 -04:00
Igor Monadical
c6c035aacf removal of email-verified from /me (#707)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-10-21 14:49:33 -04:00
c086b91445 chore(main): release 0.15.0 (#706) 2025-10-21 08:30:22 -06:00
Igor Monadical
9a258abc02 feat: api tokens (#705)
* feat: api tokens (vibe)

* self-review

* remove token terminology + pr comments (vibe)

* return email_verified

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-10-20 12:55:25 -04:00
af86c47f1d chore(main): release 0.14.0 (#670) 2025-10-08 14:57:31 -06:00
5f6910e513 feat: Add calendar event data to transcript webhook payload (#689)
* feat: add calendar event data to transcript webhook payload and implement get_by_id method

* Update server/reflector/worker/webhook.py

Co-authored-by: pr-agent-monadical[bot] <198624643+pr-agent-monadical[bot]@users.noreply.github.com>

* Update server/reflector/worker/webhook.py

Co-authored-by: pr-agent-monadical[bot] <198624643+pr-agent-monadical[bot]@users.noreply.github.com>

* style: format conditional time fields with line breaks for better readability

* docs: add calendar event fields to transcript.completed webhook payload schema

---------

Co-authored-by: pr-agent-monadical[bot] <198624643+pr-agent-monadical[bot]@users.noreply.github.com>
2025-10-08 11:11:57 -05:00
9a71af145e fix: update transcript list on reprocess (#676)
* Update transcript list on reprocess

* Fix transcript create

* Fix multiple sockets issue

* Pass token in sec websocket protocol

* userEvent parse example

* transcript list invalidation non-abstraction

* Emit only relevant events to the user room

* Add ws close code const

* Refactor user websocket endpoint

* Refactor user events provider

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-10-07 19:11:30 +02:00
eef6dc3903 fix: upgrade nemo toolkit (#678) 2025-10-07 16:45:02 +02:00
Igor Monadical
1dee255fed parakeet endpoint doc (#679)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-10-07 10:41:01 -04:00
5d98754305 fix: security review (#656)
* Add security review doc

* Add tests to reproduce security issues

* Fix security issues

* Fix tests

* Set auth auth backend for tests

* Fix ics api tests

* Fix transcript mutate check

* Update frontent env var names

* Remove permissions doc
2025-09-29 23:07:49 +02:00
Igor Monadical
969bd84fcc feat: container build for www / github (#672)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-09-24 12:27:45 -04:00
Igor Monadical
36608849ec fix: restore feature boolean logic (#671)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-09-24 11:57:49 -04:00
Igor Monadical
5bf64b5a41 feat: docker-compose for production frontend (#664)
* docker-compose for production frontend

* fix: Remove external Redis port mapping for Coolify compatibility

Redis should only be accessible within the internal Docker network in Coolify deployments to avoid port conflicts with other applications.

* fix: Remove external port mapping for web service in Coolify

Coolify handles port exposure through its proxy (Traefik), so services should not expose ports directly in the docker-compose file.

* server side client envs

* missing vars

* nextjs experimental

* fix claude 'fix'

* remove build env vars compose

* docker

* remove ports for coolify

* review

* cleanup

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-09-24 11:15:27 -04:00
164 changed files with 11114 additions and 4515 deletions

View File

@@ -1,4 +1,4 @@
name: Deploy to Amazon ECS name: Build container/push to container registry
on: [workflow_dispatch] on: [workflow_dispatch]

57
.github/workflows/docker-frontend.yml vendored Normal file
View File

@@ -0,0 +1,57 @@
name: Build and Push Frontend Docker Image
on:
push:
branches:
- main
paths:
- 'www/**'
- '.github/workflows/docker-frontend.yml'
workflow_dispatch:
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}-frontend
jobs:
build-and-push:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Log in to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=branch
type=sha,prefix={{branch}}-
type=raw,value=latest,enable={{is_default_branch}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: ./www
file: ./www/Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
platforms: linux/amd64,linux/arm64

View File

@@ -1,5 +1,44 @@
# Changelog # Changelog
## [0.17.0](https://github.com/Monadical-SAS/reflector/compare/v0.16.0...v0.17.0) (2025-11-13)
### Features
* add API key management UI ([#716](https://github.com/Monadical-SAS/reflector/issues/716)) ([372202b](https://github.com/Monadical-SAS/reflector/commit/372202b0e1a86823900b0aa77be1bfbc2893d8a1))
* daily.co support as alternative to whereby ([#691](https://github.com/Monadical-SAS/reflector/issues/691)) ([1473fd8](https://github.com/Monadical-SAS/reflector/commit/1473fd82dc472c394cbaa2987212ad662a74bcac))
## [0.16.0](https://github.com/Monadical-SAS/reflector/compare/v0.15.0...v0.16.0) (2025-10-24)
### Features
* search date filter ([#710](https://github.com/Monadical-SAS/reflector/issues/710)) ([962c40e](https://github.com/Monadical-SAS/reflector/commit/962c40e2b6428ac42fd10aea926782d7a6f3f902))
## [0.15.0](https://github.com/Monadical-SAS/reflector/compare/v0.14.0...v0.15.0) (2025-10-20)
### Features
* api tokens ([#705](https://github.com/Monadical-SAS/reflector/issues/705)) ([9a258ab](https://github.com/Monadical-SAS/reflector/commit/9a258abc0209b0ac3799532a507ea6a9125d703a))
## [0.14.0](https://github.com/Monadical-SAS/reflector/compare/v0.13.1...v0.14.0) (2025-10-08)
### Features
* Add calendar event data to transcript webhook payload ([#689](https://github.com/Monadical-SAS/reflector/issues/689)) ([5f6910e](https://github.com/Monadical-SAS/reflector/commit/5f6910e5131b7f28f86c9ecdcc57fed8412ee3cd))
* container build for www / github ([#672](https://github.com/Monadical-SAS/reflector/issues/672)) ([969bd84](https://github.com/Monadical-SAS/reflector/commit/969bd84fcc14851d1a101412a0ba115f1b7cde82))
* docker-compose for production frontend ([#664](https://github.com/Monadical-SAS/reflector/issues/664)) ([5bf64b5](https://github.com/Monadical-SAS/reflector/commit/5bf64b5a41f64535e22849b4bb11734d4dbb4aae))
### Bug Fixes
* restore feature boolean logic ([#671](https://github.com/Monadical-SAS/reflector/issues/671)) ([3660884](https://github.com/Monadical-SAS/reflector/commit/36608849ec64e953e3be456172502762e3c33df9))
* security review ([#656](https://github.com/Monadical-SAS/reflector/issues/656)) ([5d98754](https://github.com/Monadical-SAS/reflector/commit/5d98754305c6c540dd194dda268544f6d88bfaf8))
* update transcript list on reprocess ([#676](https://github.com/Monadical-SAS/reflector/issues/676)) ([9a71af1](https://github.com/Monadical-SAS/reflector/commit/9a71af145ee9b833078c78d0c684590ab12e9f0e))
* upgrade nemo toolkit ([#678](https://github.com/Monadical-SAS/reflector/issues/678)) ([eef6dc3](https://github.com/Monadical-SAS/reflector/commit/eef6dc39037329b65804297786d852dddb0557f9))
## [0.13.1](https://github.com/Monadical-SAS/reflector/compare/v0.13.0...v0.13.1) (2025-09-22) ## [0.13.1](https://github.com/Monadical-SAS/reflector/compare/v0.13.0...v0.13.1) (2025-09-22)

View File

@@ -151,7 +151,7 @@ All endpoints prefixed `/v1/`:
**Frontend** (`www/.env`): **Frontend** (`www/.env`):
- `NEXTAUTH_URL`, `NEXTAUTH_SECRET` - Authentication configuration - `NEXTAUTH_URL`, `NEXTAUTH_SECRET` - Authentication configuration
- `NEXT_PUBLIC_REFLECTOR_API_URL` - Backend API endpoint - `REFLECTOR_API_URL` - Backend API endpoint
- `REFLECTOR_DOMAIN_CONFIG` - Feature flags and domain settings - `REFLECTOR_DOMAIN_CONFIG` - Feature flags and domain settings
## Testing Strategy ## Testing Strategy

View File

@@ -168,6 +168,13 @@ You can manually process an audio file by calling the process tool:
uv run python -m reflector.tools.process path/to/audio.wav uv run python -m reflector.tools.process path/to/audio.wav
``` ```
## Build-time env variables
Next.js projects are more used to NEXT_PUBLIC_ prefixed buildtime vars. We don't have those for the reason we need to serve a ccustomizable prebuild docker container.
Instead, all the variables are runtime. Variables needed to the frontend are served to the frontend app at initial render.
It also means there's no static prebuild and no static files to serve for js/html.
## Feature Flags ## Feature Flags
@@ -177,24 +184,24 @@ Reflector uses environment variable-based feature flags to control application f
| Feature Flag | Environment Variable | | Feature Flag | Environment Variable |
|-------------|---------------------| |-------------|---------------------|
| `requireLogin` | `NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN` | | `requireLogin` | `FEATURE_REQUIRE_LOGIN` |
| `privacy` | `NEXT_PUBLIC_FEATURE_PRIVACY` | | `privacy` | `FEATURE_PRIVACY` |
| `browse` | `NEXT_PUBLIC_FEATURE_BROWSE` | | `browse` | `FEATURE_BROWSE` |
| `sendToZulip` | `NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP` | | `sendToZulip` | `FEATURE_SEND_TO_ZULIP` |
| `rooms` | `NEXT_PUBLIC_FEATURE_ROOMS` | | `rooms` | `FEATURE_ROOMS` |
### Setting Feature Flags ### Setting Feature Flags
Feature flags are controlled via environment variables using the pattern `NEXT_PUBLIC_FEATURE_{FEATURE_NAME}` where `{FEATURE_NAME}` is the SCREAMING_SNAKE_CASE version of the feature name. Feature flags are controlled via environment variables using the pattern `FEATURE_{FEATURE_NAME}` where `{FEATURE_NAME}` is the SCREAMING_SNAKE_CASE version of the feature name.
**Examples:** **Examples:**
```bash ```bash
# Enable user authentication requirement # Enable user authentication requirement
NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN=true FEATURE_REQUIRE_LOGIN=true
# Disable browse functionality # Disable browse functionality
NEXT_PUBLIC_FEATURE_BROWSE=false FEATURE_BROWSE=false
# Enable Zulip integration # Enable Zulip integration
NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP=true FEATURE_SEND_TO_ZULIP=true
``` ```

39
docker-compose.prod.yml Normal file
View File

@@ -0,0 +1,39 @@
# Production Docker Compose configuration for Frontend
# Usage: docker compose -f docker-compose.prod.yml up -d
services:
web:
build:
context: ./www
dockerfile: Dockerfile
image: reflector-frontend:latest
environment:
- KV_URL=${KV_URL:-redis://redis:6379}
- SITE_URL=${SITE_URL}
- API_URL=${API_URL}
- WEBSOCKET_URL=${WEBSOCKET_URL}
- NEXTAUTH_URL=${NEXTAUTH_URL:-http://localhost:3000}
- NEXTAUTH_SECRET=${NEXTAUTH_SECRET:-changeme-in-production}
- AUTHENTIK_ISSUER=${AUTHENTIK_ISSUER}
- AUTHENTIK_CLIENT_ID=${AUTHENTIK_CLIENT_ID}
- AUTHENTIK_CLIENT_SECRET=${AUTHENTIK_CLIENT_SECRET}
- AUTHENTIK_REFRESH_TOKEN_URL=${AUTHENTIK_REFRESH_TOKEN_URL}
- SENTRY_DSN=${SENTRY_DSN}
- SENTRY_IGNORE_API_RESOLUTION_ERROR=${SENTRY_IGNORE_API_RESOLUTION_ERROR:-1}
depends_on:
- redis
restart: unless-stopped
redis:
image: redis:7.2-alpine
restart: unless-stopped
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 30s
timeout: 3s
retries: 3
volumes:
- redis_data:/data
volumes:
redis_data:

View File

@@ -39,7 +39,7 @@ services:
ports: ports:
- 6379:6379 - 6379:6379
web: web:
image: node:18 image: node:22-alpine
ports: ports:
- "3000:3000" - "3000:3000"
command: sh -c "corepack enable && pnpm install && pnpm dev" command: sh -c "corepack enable && pnpm install && pnpm dev"
@@ -50,6 +50,8 @@ services:
- /app/node_modules - /app/node_modules
env_file: env_file:
- ./www/.env.local - ./www/.env.local
environment:
- NODE_ENV=development
postgres: postgres:
image: postgres:17 image: postgres:17

View File

@@ -77,7 +77,7 @@ image = (
.pip_install( .pip_install(
"hf_transfer==0.1.9", "hf_transfer==0.1.9",
"huggingface_hub[hf-xet]==0.31.2", "huggingface_hub[hf-xet]==0.31.2",
"nemo_toolkit[asr]==2.3.0", "nemo_toolkit[asr]==2.5.0",
"cuda-python==12.8.0", "cuda-python==12.8.0",
"fastapi==0.115.12", "fastapi==0.115.12",
"numpy<2", "numpy<2",

View File

@@ -1,3 +1,29 @@
## API Key Management
### Finding Your User ID
```bash
# Get your OAuth sub (user ID) - requires authentication
curl -H "Authorization: Bearer <your_jwt>" http://localhost:1250/v1/me
# Returns: {"sub": "your-oauth-sub-here", "email": "...", ...}
```
### Creating API Keys
```bash
curl -X POST http://localhost:1250/v1/user/api-keys \
-H "Authorization: Bearer <your_jwt>" \
-H "Content-Type: application/json" \
-d '{"name": "My API Key"}'
```
### Using API Keys
```bash
# Use X-API-Key header instead of Authorization
curl -H "X-API-Key: <your_api_key>" http://localhost:1250/v1/transcripts
```
## AWS S3/SQS usage clarification ## AWS S3/SQS usage clarification
Whereby.com uploads recordings directly to our S3 bucket when meetings end. Whereby.com uploads recordings directly to our S3 bucket when meetings end.

View File

@@ -1,118 +0,0 @@
# AsyncIO Event Loop Analysis for test_attendee_parsing_bug.py
## Problem Summary
The test passes but encounters an error during teardown where asyncpg tries to use a different/closed event loop, resulting in:
- `RuntimeError: Task got Future attached to a different loop`
- `RuntimeError: Event loop is closed`
## Root Cause Analysis
### 1. Multiple Event Loop Creation Points
The test environment creates event loops at different scopes:
1. **Session-scoped loop** (conftest.py:27-34):
- Created once per test session
- Used by session-scoped fixtures
- Closed after all tests complete
2. **Function-scoped loop** (pytest-asyncio default):
- Created for each async test function
- This is the loop that runs the actual test
- Closed immediately after test completes
3. **AsyncPG internal loop**:
- AsyncPG connections store a reference to the loop they were created with
- Used for connection lifecycle management
### 2. Event Loop Lifecycle Mismatch
The issue occurs because:
1. **Session fixture creates database connection** on session-scoped loop
2. **Test runs** on function-scoped loop (different from session loop)
3. **During teardown**, the session fixture tries to rollback/close using the original session loop
4. **AsyncPG connection** still references the function-scoped loop which is now closed
5. **Conflict**: SQLAlchemy tries to use session loop, but asyncpg Future is attached to the closed function loop
### 3. Configuration Issues
Current pytest configuration:
- `asyncio_mode = "auto"` in pyproject.toml
- `asyncio_default_fixture_loop_scope=session` (shown in test output)
- `asyncio_default_test_loop_scope=function` (shown in test output)
This mismatch between fixture loop scope (session) and test loop scope (function) causes the problem.
## Solutions
### Option 1: Align Loop Scopes (Recommended)
Change pytest-asyncio configuration to use consistent loop scopes:
```python
# pyproject.toml
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function" # Change from session to function
```
### Option 2: Use Function-Scoped Database Fixture
Change the `session` fixture scope from session to function:
```python
@pytest_asyncio.fixture # Remove scope="session"
async def session(setup_database):
# ... existing code ...
```
### Option 3: Explicit Loop Management
Ensure all async operations use the same loop:
```python
@pytest_asyncio.fixture
async def session(setup_database, event_loop):
# Force using the current event loop
engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
poolclass=NullPool,
connect_args={"loop": event_loop} # Pass explicit loop
)
# ... rest of fixture ...
```
### Option 4: Upgrade pytest-asyncio
The current version (1.1.0) has known issues with loop management. Consider upgrading to the latest version which has better loop scope handling.
## Immediate Workaround
For the test to run cleanly without the teardown error, you can:
1. Add explicit cleanup in the test:
```python
@pytest.mark.asyncio
async def test_attendee_parsing_bug(session):
# ... existing test code ...
# Explicit cleanup before fixture teardown
await session.commit() # or await session.close()
```
2. Or suppress the teardown error (not recommended for production):
```python
@pytest.fixture
async def session(setup_database):
# ... existing setup ...
try:
yield session
await session.rollback()
except RuntimeError as e:
if "Event loop is closed" not in str(e):
raise
finally:
await session.close()
```
## Recommendation
The cleanest solution is to align the loop scopes by setting both fixture and test loop scopes to "function" scope. This ensures each test gets its own clean event loop and avoids cross-contamination between tests.

View File

@@ -0,0 +1,234 @@
# Reflector Architecture: Whereby + Daily.co Recording Storage
## System Overview
```mermaid
graph TB
subgraph "Actors"
APP[Our App<br/>Reflector]
WHEREBY[Whereby Service<br/>External]
DAILY[Daily.co Service<br/>External]
end
subgraph "AWS S3 Buckets"
TRANSCRIPT_BUCKET[Transcript Bucket<br/>reflector-transcripts<br/>Output: Processed MP3s]
WHEREBY_BUCKET[Whereby Bucket<br/>reflector-whereby-recordings<br/>Input: Raw MP4s]
DAILY_BUCKET[Daily.co Bucket<br/>reflector-dailyco-recordings<br/>Input: Raw WebM tracks]
end
subgraph "AWS Infrastructure"
SQS[SQS Queue<br/>Whereby notifications]
end
subgraph "Database"
DB[(PostgreSQL<br/>Recordings, Transcripts, Meetings)]
end
APP -->|Write processed| TRANSCRIPT_BUCKET
APP -->|Read/Delete| WHEREBY_BUCKET
APP -->|Read/Delete| DAILY_BUCKET
APP -->|Poll| SQS
APP -->|Store metadata| DB
WHEREBY -->|Write recordings| WHEREBY_BUCKET
WHEREBY_BUCKET -->|S3 Event| SQS
WHEREBY -->|Participant webhooks<br/>room.client.joined/left| APP
DAILY -->|Write recordings| DAILY_BUCKET
DAILY -->|Recording webhook<br/>recording.ready-to-download| APP
```
**Note on Webhook vs S3 Event for Recording Processing:**
- **Whereby**: Uses S3 Events → SQS for recording availability (S3 as source of truth, no race conditions)
- **Daily.co**: Uses webhooks for recording availability (more immediate, built-in reliability)
- **Both**: Use webhooks for participant tracking (real-time updates)
## Credentials & Permissions
```mermaid
graph LR
subgraph "Master Credentials"
MASTER[TRANSCRIPT_STORAGE_AWS_*<br/>Access Key ID + Secret]
end
subgraph "Whereby Upload Credentials"
WHEREBY_CREDS[AWS_WHEREBY_ACCESS_KEY_*<br/>Access Key ID + Secret]
end
subgraph "Daily.co Upload Role"
DAILY_ROLE[DAILY_STORAGE_AWS_ROLE_ARN<br/>IAM Role ARN]
end
subgraph "Our App Uses"
MASTER -->|Read/Write/Delete| TRANSCRIPT_BUCKET[Transcript Bucket]
MASTER -->|Read/Delete| WHEREBY_BUCKET[Whereby Bucket]
MASTER -->|Read/Delete| DAILY_BUCKET[Daily.co Bucket]
MASTER -->|Poll/Delete| SQS[SQS Queue]
end
subgraph "We Give To Services"
WHEREBY_CREDS -->|Passed in API call| WHEREBY_SERVICE[Whereby Service]
WHEREBY_SERVICE -->|Write Only| WHEREBY_BUCKET
DAILY_ROLE -->|Passed in API call| DAILY_SERVICE[Daily.co Service]
DAILY_SERVICE -->|Assume Role| DAILY_ROLE
DAILY_SERVICE -->|Write Only| DAILY_BUCKET
end
```
# Video Platform Recording Integration
This document explains how Reflector receives and identifies multitrack audio recordings from different video platforms.
## Platform Comparison
| Platform | Delivery Method | Track Identification |
|----------|----------------|---------------------|
| **Daily.co** | Webhook | Explicit track list in payload |
| **Whereby** | SQS (S3 notifications) | Single file per notification |
---
## Daily.co (Webhook-based)
Daily.co uses **webhooks** to notify Reflector when recordings are ready.
### How It Works
1. **Daily.co sends webhook** when recording is ready
- Event type: `recording.ready-to-download`
- Endpoint: `/v1/daily/webhook` (`reflector/views/daily.py:46-102`)
2. **Webhook payload explicitly includes track list**:
```json
{
"recording_id": "7443ee0a-dab1-40eb-b316-33d6c0d5ff88",
"room_name": "daily-20251020193458",
"tracks": [
{
"type": "audio",
"s3Key": "monadical/daily-20251020193458/1760988935484-52f7f48b-fbab-431f-9a50-87b9abfc8255-cam-audio-1760988935922",
"size": 831843
},
{
"type": "audio",
"s3Key": "monadical/daily-20251020193458/1760988935484-a37c35e3-6f8e-4274-a482-e9d0f102a732-cam-audio-1760988943823",
"size": 408438
},
{
"type": "video",
"s3Key": "monadical/daily-20251020193458/...-video.webm",
"size": 30000000
}
]
}
```
3. **System extracts audio tracks** (`daily.py:211`):
```python
track_keys = [t.s3Key for t in tracks if t.type == "audio"]
```
4. **Triggers multitrack processing** (`daily.py:213-218`):
```python
process_multitrack_recording.delay(
bucket_name=bucket_name, # reflector-dailyco-local
room_name=room_name, # daily-20251020193458
recording_id=recording_id, # 7443ee0a-dab1-40eb-b316-33d6c0d5ff88
track_keys=track_keys # Only audio s3Keys
)
```
### Key Advantage: No Ambiguity
Even though multiple meetings may share the same S3 bucket/folder (`monadical/`), **there's no ambiguity** because:
- Each webhook payload contains the exact `s3Key` list for that specific `recording_id`
- No need to scan folders or guess which files belong together
- Each track's s3Key includes the room timestamp subfolder (e.g., `daily-20251020193458/`)
The room name includes timestamp (`daily-20251020193458`) to keep recordings organized, but **the webhook's explicit track list is what prevents mixing files from different meetings**.
### Track Timeline Extraction
Daily.co provides timing information in two places:
**1. PyAV WebM Metadata (current approach)**:
```python
# Read from WebM container stream metadata
stream.start_time = 8.130s # Meeting-relative timing
```
**2. Filename Timestamps (alternative approach, commit 3bae9076)**:
```
Filename format: {recording_start_ts}-{uuid}-cam-audio-{track_start_ts}.webm
Example: 1760988935484-52f7f48b-fbab-431f-9a50-87b9abfc8255-cam-audio-1760988935922.webm
Parse timestamps:
- recording_start_ts: 1760988935484 (Unix ms)
- track_start_ts: 1760988935922 (Unix ms)
- offset: (1760988935922 - 1760988935484) / 1000 = 0.438s
```
**Time Difference (PyAV vs Filename)**:
```
Track 0:
Filename offset: 438ms
PyAV metadata: 229ms
Difference: 209ms
Track 1:
Filename offset: 8339ms
PyAV metadata: 8130ms
Difference: 209ms
```
**Consistent 209ms delta** suggests network/encoding delay between file upload initiation (filename) and actual audio stream start (metadata).
**Current implementation uses PyAV metadata** because:
- More accurate (represents when audio actually started)
- Padding BEFORE transcription produces correct Whisper timestamps automatically
- No manual offset adjustment needed during transcript merge
### Why Re-encoding During Padding
Padding coincidentally involves re-encoding, which is important for Daily.co + Whisper:
**Problem:** Daily.co skips frames in recordings when microphone is muted or paused
- WebM containers have gaps where audio frames should be
- Whisper doesn't understand these gaps and produces incorrect timestamps
- Example: 5s of audio with 2s muted → file has frames only for 3s, Whisper thinks duration is 3s
**Solution:** Re-encoding via PyAV filter graph (`adelay` + `aresample`)
- Restores missing frames as silence
- Produces continuous audio stream without gaps
- Whisper now sees correct duration and produces accurate timestamps
**Why combined with padding:**
- Already re-encoding for padding (adding initial silence)
- More performant to do both operations in single PyAV pipeline
- Padded values needed for mixdown anyway (creating final MP3)
Implementation: `main_multitrack_pipeline.py:_apply_audio_padding_streaming()`
---
## Whereby (SQS-based)
Whereby uses **AWS SQS** (via S3 notifications) to notify Reflector when files are uploaded.
### How It Works
1. **Whereby uploads recording** to S3
2. **S3 sends notification** to SQS queue (one notification per file)
3. **Reflector polls SQS queue** (`worker/process.py:process_messages()`)
4. **System processes single file** (`worker/process.py:process_recording()`)
### Key Difference from Daily.co
**Whereby (SQS):** System receives S3 notification "file X was created" - only knows about one file at a time, would need to scan folder to find related files
**Daily.co (Webhook):** Daily explicitly tells system which files belong together in the webhook payload
---

View File

@@ -14,7 +14,7 @@ Webhooks are configured at the room level with two fields:
### `transcript.completed` ### `transcript.completed`
Triggered when a transcript has been fully processed, including transcription, diarization, summarization, and topic detection. Triggered when a transcript has been fully processed, including transcription, diarization, summarization, topic detection and calendar event integration.
### `test` ### `test`
@@ -128,6 +128,27 @@ This event includes a convenient URL for accessing the transcript:
"room": { "room": {
"id": "room-789", "id": "room-789",
"name": "Product Team Room" "name": "Product Team Room"
},
"calendar_event": {
"id": "calendar-event-123",
"ics_uid": "event-123",
"title": "Q3 Product Planning Meeting",
"start_time": "2025-08-27T12:00:00Z",
"end_time": "2025-08-27T12:30:00Z",
"description": "Team discussed Q3 product roadmap, prioritizing mobile app features and API improvements.",
"location": "Conference Room 1",
"attendees": [
{
"id": "participant-1",
"name": "John Doe",
"speaker": "Speaker 1"
},
{
"id": "participant-2",
"name": "Jane Smith",
"speaker": "Speaker 2"
}
]
} }
} }
``` ```

View File

@@ -27,7 +27,7 @@ AUTH_JWT_AUDIENCE=
#TRANSCRIPT_MODAL_API_KEY=xxxxx #TRANSCRIPT_MODAL_API_KEY=xxxxx
TRANSCRIPT_BACKEND=modal TRANSCRIPT_BACKEND=modal
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web.modal.run
TRANSCRIPT_MODAL_API_KEY= TRANSCRIPT_MODAL_API_KEY=
## ======================================================= ## =======================================================
@@ -71,3 +71,30 @@ DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
## Sentry DSN configuration ## Sentry DSN configuration
#SENTRY_DSN= #SENTRY_DSN=
## =======================================================
## Video Platform Configuration
## =======================================================
## Whereby
#WHEREBY_API_KEY=your-whereby-api-key
#WHEREBY_WEBHOOK_SECRET=your-whereby-webhook-secret
#WHEREBY_STORAGE_AWS_ACCESS_KEY_ID=your-aws-key
#WHEREBY_STORAGE_AWS_SECRET_ACCESS_KEY=your-aws-secret
#AWS_PROCESS_RECORDING_QUEUE_URL=https://sqs.us-west-2.amazonaws.com/...
## Daily.co
#DAILY_API_KEY=your-daily-api-key
#DAILY_WEBHOOK_SECRET=your-daily-webhook-secret
#DAILY_SUBDOMAIN=your-subdomain
#DAILY_WEBHOOK_UUID= # Auto-populated by recreate_daily_webhook.py script
#DAILYCO_STORAGE_AWS_ROLE_ARN=... # IAM role ARN for Daily.co S3 access
#DAILYCO_STORAGE_AWS_BUCKET_NAME=reflector-dailyco
#DAILYCO_STORAGE_AWS_REGION=us-west-2
## Whereby (optional separate bucket)
#WHEREBY_STORAGE_AWS_BUCKET_NAME=reflector-whereby
#WHEREBY_STORAGE_AWS_REGION=us-east-1
## Platform Configuration
#DEFAULT_VIDEO_PLATFORM=whereby # Default platform for new rooms

View File

@@ -3,7 +3,7 @@ from logging.config import fileConfig
from alembic import context from alembic import context
from sqlalchemy import engine_from_config, pool from sqlalchemy import engine_from_config, pool
from reflector.db.base import metadata from reflector.db import metadata
from reflector.settings import settings from reflector.settings import settings
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides

View File

@@ -0,0 +1,50 @@
"""add_platform_support
Revision ID: 1e49625677e4
Revises: 9e3f7b2a4c8e
Create Date: 2025-10-08 13:17:29.943612
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "1e49625677e4"
down_revision: Union[str, None] = "9e3f7b2a4c8e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Add platform field with default 'whereby' for backward compatibility."""
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"platform",
sa.String(),
nullable=True,
server_default=None,
)
)
with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"platform",
sa.String(),
nullable=False,
server_default="whereby",
)
)
def downgrade() -> None:
"""Remove platform field."""
with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.drop_column("platform")
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.drop_column("platform")

View File

@@ -28,7 +28,7 @@ def upgrade() -> None:
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table # Select all rows from the transcript table
results = bind.execute(select(transcript.c.id, transcript.c.topics)) results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results: for row in results:
transcript_id = row["id"] transcript_id = row["id"]
@@ -58,7 +58,7 @@ def downgrade() -> None:
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table # Select all rows from the transcript table
results = bind.execute(select(transcript.c.id, transcript.c.topics)) results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results: for row in results:
transcript_id = row["id"] transcript_id = row["id"]

View File

@@ -36,7 +36,9 @@ def upgrade() -> None:
# select only the one with duration = 0 # select only the one with duration = 0
results = bind.execute( results = bind.execute(
select(transcript.c.id, transcript.c.duration).where(transcript.c.duration == 0) select([transcript.c.id, transcript.c.duration]).where(
transcript.c.duration == 0
)
) )
data_dir = Path(settings.DATA_DIR) data_dir = Path(settings.DATA_DIR)

View File

@@ -28,7 +28,7 @@ def upgrade() -> None:
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table # Select all rows from the transcript table
results = bind.execute(select(transcript.c.id, transcript.c.topics)) results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results: for row in results:
transcript_id = row["id"] transcript_id = row["id"]
@@ -58,7 +58,7 @@ def downgrade() -> None:
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON)) transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table # Select all rows from the transcript table
results = bind.execute(select(transcript.c.id, transcript.c.topics)) results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results: for row in results:
transcript_id = row["id"] transcript_id = row["id"]

View File

@@ -0,0 +1,38 @@
"""add user api keys
Revision ID: 9e3f7b2a4c8e
Revises: dc035ff72fd5
Create Date: 2025-10-17 00:00:00.000000
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "9e3f7b2a4c8e"
down_revision: Union[str, None] = "dc035ff72fd5"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"user_api_key",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("key_hash", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
with op.batch_alter_table("user_api_key", schema=None) as batch_op:
batch_op.create_index("idx_user_api_key_hash", ["key_hash"], unique=True)
batch_op.create_index("idx_user_api_key_user_id", ["user_id"], unique=False)
def downgrade() -> None:
op.drop_table("user_api_key")

View File

@@ -0,0 +1,28 @@
"""add_track_keys
Revision ID: f8294b31f022
Revises: 1e49625677e4
Create Date: 2025-10-27 18:52:17.589167
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "f8294b31f022"
down_revision: Union[str, None] = "1e49625677e4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
with op.batch_alter_table("recording", schema=None) as batch_op:
batch_op.add_column(sa.Column("track_keys", sa.JSON(), nullable=True))
def downgrade() -> None:
with op.batch_alter_table("recording", schema=None) as batch_op:
batch_op.drop_column("track_keys")

View File

@@ -19,8 +19,8 @@ dependencies = [
"sentry-sdk[fastapi]>=1.29.2", "sentry-sdk[fastapi]>=1.29.2",
"httpx>=0.24.1", "httpx>=0.24.1",
"fastapi-pagination>=0.12.6", "fastapi-pagination>=0.12.6",
"sqlalchemy>=2.0.0", "databases[aiosqlite, asyncpg]>=0.7.0",
"asyncpg>=0.29.0", "sqlalchemy<1.5",
"alembic>=1.11.3", "alembic>=1.11.3",
"nltk>=3.8.1", "nltk>=3.8.1",
"prometheus-fastapi-instrumentator>=6.1.0", "prometheus-fastapi-instrumentator>=6.1.0",
@@ -46,7 +46,6 @@ dev = [
"black>=24.1.1", "black>=24.1.1",
"stamina>=23.1.0", "stamina>=23.1.0",
"pyinstrument>=4.6.1", "pyinstrument>=4.6.1",
"pytest-async-sqlalchemy>=0.2.0",
] ]
tests = [ tests = [
"pytest-cov>=4.1.0", "pytest-cov>=4.1.0",
@@ -112,15 +111,13 @@ source = ["reflector"]
[tool.pytest_env] [tool.pytest_env]
ENVIRONMENT = "pytest" ENVIRONMENT = "pytest"
DATABASE_URL = "postgresql+asyncpg://test_user:test_password@localhost:15432/reflector_test" DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
AUTH_BACKEND = "jwt"
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v" addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
testpaths = ["tests"] testpaths = ["tests"]
asyncio_mode = "auto" asyncio_mode = "auto"
asyncio_debug = true
asyncio_default_fixture_loop_scope = "session"
asyncio_default_test_loop_scope = "session"
markers = [ markers = [
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)", "model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
] ]

View File

@@ -12,6 +12,7 @@ from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.logger import logger from reflector.logger import logger
from reflector.metrics import metrics_init from reflector.metrics import metrics_init
from reflector.settings import settings from reflector.settings import settings
from reflector.views.daily import router as daily_router
from reflector.views.meetings import router as meetings_router from reflector.views.meetings import router as meetings_router
from reflector.views.rooms import router as rooms_router from reflector.views.rooms import router as rooms_router
from reflector.views.rtc_offer import router as rtc_offer_router from reflector.views.rtc_offer import router as rtc_offer_router
@@ -26,6 +27,8 @@ from reflector.views.transcripts_upload import router as transcripts_upload_rout
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
from reflector.views.transcripts_websocket import router as transcripts_websocket_router from reflector.views.transcripts_websocket import router as transcripts_websocket_router
from reflector.views.user import router as user_router from reflector.views.user import router as user_router
from reflector.views.user_api_keys import router as user_api_keys_router
from reflector.views.user_websocket import router as user_ws_router
from reflector.views.whereby import router as whereby_router from reflector.views.whereby import router as whereby_router
from reflector.views.zulip import router as zulip_router from reflector.views.zulip import router as zulip_router
@@ -65,6 +68,12 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
@app.get("/health")
async def health():
return {"status": "healthy"}
# metrics # metrics
instrumentator = Instrumentator( instrumentator = Instrumentator(
excluded_handlers=["/docs", "/metrics"], excluded_handlers=["/docs", "/metrics"],
@@ -84,8 +93,11 @@ app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1") app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(transcripts_process_router, prefix="/v1") app.include_router(transcripts_process_router, prefix="/v1")
app.include_router(user_router, prefix="/v1") app.include_router(user_router, prefix="/v1")
app.include_router(user_api_keys_router, prefix="/v1")
app.include_router(user_ws_router, prefix="/v1")
app.include_router(zulip_router, prefix="/v1") app.include_router(zulip_router, prefix="/v1")
app.include_router(whereby_router, prefix="/v1") app.include_router(whereby_router, prefix="/v1")
app.include_router(daily_router, prefix="/v1/daily")
add_pagination(app) add_pagination(app)
# prepare celery # prepare celery

View File

@@ -1,14 +1,21 @@
import asyncio import asyncio
import functools import functools
from reflector.db import get_database
def asynctask(f): def asynctask(f):
@functools.wraps(f) @functools.wraps(f)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
async def run_async(): async def run_with_db():
return await f(*args, **kwargs) database = get_database()
await database.connect()
try:
return await f(*args, **kwargs)
finally:
await database.disconnect()
coro = run_async() coro = run_with_db()
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
except RuntimeError: except RuntimeError:

View File

@@ -1,14 +1,16 @@
from typing import Annotated, Optional from typing import Annotated, List, Optional
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
from jose import JWTError, jwt from jose import JWTError, jwt
from pydantic import BaseModel from pydantic import BaseModel
from reflector.db.user_api_keys import user_api_keys_controller
from reflector.logger import logger from reflector.logger import logger
from reflector.settings import settings from reflector.settings import settings
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
jwt_public_key = open(f"reflector/auth/jwt/keys/{settings.AUTH_JWT_PUBLIC_KEY}").read() jwt_public_key = open(f"reflector/auth/jwt/keys/{settings.AUTH_JWT_PUBLIC_KEY}").read()
jwt_algorithm = settings.AUTH_JWT_ALGORITHM jwt_algorithm = settings.AUTH_JWT_ALGORITHM
@@ -26,7 +28,7 @@ class JWTException(Exception):
class UserInfo(BaseModel): class UserInfo(BaseModel):
sub: str sub: str
email: str email: Optional[str] = None
def __getitem__(self, key): def __getitem__(self, key):
return getattr(self, key) return getattr(self, key)
@@ -58,34 +60,53 @@ def authenticated(token: Annotated[str, Depends(oauth2_scheme)]):
return None return None
def current_user( async def _authenticate_user(
token: Annotated[Optional[str], Depends(oauth2_scheme)], jwt_token: Optional[str],
jwtauth: JWTAuth = Depends(), api_key: Optional[str],
): jwtauth: JWTAuth,
if token is None: ) -> UserInfo | None:
raise HTTPException(status_code=401, detail="Not authenticated") user_infos: List[UserInfo] = []
try: if api_key:
payload = jwtauth.verify_token(token) user_api_key = await user_api_keys_controller.verify_key(api_key)
sub = payload["sub"] if user_api_key:
email = payload["email"] user_infos.append(UserInfo(sub=user_api_key.user_id, email=None))
return UserInfo(sub=sub, email=email)
except JWTError as e:
logger.error(f"JWT error: {e}")
raise HTTPException(status_code=401, detail="Invalid authentication")
if jwt_token:
try:
payload = jwtauth.verify_token(jwt_token)
sub = payload["sub"]
email = payload["email"]
user_infos.append(UserInfo(sub=sub, email=email))
except JWTError as e:
logger.error(f"JWT error: {e}")
raise HTTPException(status_code=401, detail="Invalid authentication")
def current_user_optional( if len(user_infos) == 0:
token: Annotated[Optional[str], Depends(oauth2_scheme)],
jwtauth: JWTAuth = Depends(),
):
# we accept no token, but if one is provided, it must be a valid one.
if token is None:
return None return None
try:
payload = jwtauth.verify_token(token) if len(set([x.sub for x in user_infos])) > 1:
sub = payload["sub"] raise JWTException(
email = payload["email"] status_code=401,
return UserInfo(sub=sub, email=email) detail="Invalid authentication: more than one user provided",
except JWTError as e: )
logger.error(f"JWT error: {e}")
raise HTTPException(status_code=401, detail="Invalid authentication") return user_infos[0]
async def current_user(
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
api_key: Annotated[Optional[str], Depends(api_key_header)],
jwtauth: JWTAuth = Depends(),
):
user = await _authenticate_user(jwt_token, api_key, jwtauth)
if user is None:
raise HTTPException(status_code=401, detail="Not authenticated")
return user
async def current_user_optional(
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
api_key: Annotated[Optional[str], Depends(api_key_header)],
jwtauth: JWTAuth = Depends(),
):
return await _authenticate_user(jwt_token, api_key, jwtauth)

View File

@@ -1,69 +1,49 @@
from typing import AsyncGenerator import contextvars
from typing import Optional
from sqlalchemy.ext.asyncio import ( import databases
AsyncEngine, import sqlalchemy
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from reflector.db.base import Base as Base
from reflector.db.base import metadata as metadata
from reflector.events import subscribers_shutdown, subscribers_startup from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings from reflector.settings import settings
_engine: AsyncEngine | None = None metadata = sqlalchemy.MetaData()
_session_factory: async_sessionmaker[AsyncSession] | None = None
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
contextvars.ContextVar("database", default=None)
)
def get_engine() -> AsyncEngine: def get_database() -> databases.Database:
global _engine """Get database instance for current asyncio context"""
if _engine is None: db = _database_context.get()
_engine = create_async_engine( if db is None:
settings.DATABASE_URL, db = databases.Database(settings.DATABASE_URL)
echo=False, _database_context.set(db)
pool_pre_ping=True, return db
)
return _engine
def get_session_factory() -> async_sessionmaker[AsyncSession]:
global _session_factory
if _session_factory is None:
_session_factory = async_sessionmaker(
get_engine(),
class_=AsyncSession,
expire_on_commit=False,
)
return _session_factory
async def _get_session() -> AsyncGenerator[AsyncSession, None]:
# necessary implementation to ease mocking on pytest
async with get_session_factory()() as session:
yield session
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async for session in _get_session():
yield session
# import models
import reflector.db.calendar_events # noqa import reflector.db.calendar_events # noqa
import reflector.db.meetings # noqa import reflector.db.meetings # noqa
import reflector.db.recordings # noqa import reflector.db.recordings # noqa
import reflector.db.rooms # noqa import reflector.db.rooms # noqa
import reflector.db.transcripts # noqa import reflector.db.transcripts # noqa
import reflector.db.user_api_keys # noqa
kwargs = {}
if "postgres" not in settings.DATABASE_URL:
raise Exception("Only postgres database is supported in reflector")
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
@subscribers_startup.append @subscribers_startup.append
async def database_connect(_): async def database_connect(_):
get_engine() database = get_database()
await database.connect()
@subscribers_shutdown.append @subscribers_shutdown.append
async def database_disconnect(_): async def database_disconnect(_):
global _engine database = get_database()
if _engine: await database.disconnect()
await _engine.dispose()
_engine = None

View File

@@ -1,237 +0,0 @@
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(AsyncAttrs, DeclarativeBase):
pass
class TranscriptModel(Base):
__tablename__ = "transcript"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
name: Mapped[Optional[str]] = mapped_column(sa.String)
status: Mapped[Optional[str]] = mapped_column(sa.String)
locked: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
created_at: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
title: Mapped[Optional[str]] = mapped_column(sa.String)
short_summary: Mapped[Optional[str]] = mapped_column(sa.String)
long_summary: Mapped[Optional[str]] = mapped_column(sa.String)
topics: Mapped[Optional[list]] = mapped_column(sa.JSON)
events: Mapped[Optional[list]] = mapped_column(sa.JSON)
participants: Mapped[Optional[list]] = mapped_column(sa.JSON)
source_language: Mapped[Optional[str]] = mapped_column(sa.String)
target_language: Mapped[Optional[str]] = mapped_column(sa.String)
reviewed: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
audio_location: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="local"
)
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
share_mode: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="private"
)
meeting_id: Mapped[Optional[str]] = mapped_column(sa.String)
recording_id: Mapped[Optional[str]] = mapped_column(sa.String)
zulip_message_id: Mapped[Optional[int]] = mapped_column(sa.Integer)
source_kind: Mapped[str] = mapped_column(
sa.String, nullable=False
) # Enum will be handled separately
audio_deleted: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
room_id: Mapped[Optional[str]] = mapped_column(sa.String)
webvtt: Mapped[Optional[str]] = mapped_column(sa.Text)
__table_args__ = (
sa.Index("idx_transcript_recording_id", "recording_id"),
sa.Index("idx_transcript_user_id", "user_id"),
sa.Index("idx_transcript_created_at", "created_at"),
sa.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
sa.Index("idx_transcript_room_id", "room_id"),
sa.Index("idx_transcript_source_kind", "source_kind"),
sa.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
)
TranscriptModel.search_vector_en = sa.Column(
"search_vector_en",
TSVECTOR,
sa.Computed(
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
persisted=True,
),
)
class RoomModel(Base):
__tablename__ = "room"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
name: Mapped[str] = mapped_column(sa.String, nullable=False, unique=True)
user_id: Mapped[str] = mapped_column(sa.String, nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
zulip_auto_post: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
zulip_stream: Mapped[Optional[str]] = mapped_column(sa.String)
zulip_topic: Mapped[Optional[str]] = mapped_column(sa.String)
is_locked: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
room_mode: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="normal"
)
recording_type: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="cloud"
)
recording_trigger: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="automatic-2nd-participant"
)
is_shared: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
webhook_url: Mapped[Optional[str]] = mapped_column(sa.String)
webhook_secret: Mapped[Optional[str]] = mapped_column(sa.String)
ics_url: Mapped[Optional[str]] = mapped_column(sa.Text)
ics_fetch_interval: Mapped[Optional[int]] = mapped_column(
sa.Integer, server_default=sa.text("300")
)
ics_enabled: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
ics_last_sync: Mapped[Optional[datetime]] = mapped_column(
sa.DateTime(timezone=True)
)
ics_last_etag: Mapped[Optional[str]] = mapped_column(sa.Text)
__table_args__ = (
sa.Index("idx_room_is_shared", "is_shared"),
sa.Index("idx_room_ics_enabled", "ics_enabled"),
)
class MeetingModel(Base):
__tablename__ = "meeting"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
room_name: Mapped[Optional[str]] = mapped_column(sa.String)
room_url: Mapped[Optional[str]] = mapped_column(sa.String)
host_room_url: Mapped[Optional[str]] = mapped_column(sa.String)
start_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
end_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
room_id: Mapped[Optional[str]] = mapped_column(
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE")
)
is_locked: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
room_mode: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="normal"
)
recording_type: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="cloud"
)
recording_trigger: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="automatic-2nd-participant"
)
num_clients: Mapped[int] = mapped_column(
sa.Integer, nullable=False, server_default=sa.text("0")
)
is_active: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("true")
)
calendar_event_id: Mapped[Optional[str]] = mapped_column(
sa.String,
sa.ForeignKey(
"calendar_event.id",
ondelete="SET NULL",
name="fk_meeting_calendar_event_id",
),
)
calendar_metadata: Mapped[Optional[dict]] = mapped_column(JSONB)
__table_args__ = (
sa.Index("idx_meeting_room_id", "room_id"),
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
)
class MeetingConsentModel(Base):
__tablename__ = "meeting_consent"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
meeting_id: Mapped[str] = mapped_column(
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
)
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
consent_given: Mapped[bool] = mapped_column(sa.Boolean, nullable=False)
consent_timestamp: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
class RecordingModel(Base):
__tablename__ = "recording"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
meeting_id: Mapped[str] = mapped_column(
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
)
url: Mapped[str] = mapped_column(sa.String, nullable=False)
object_key: Mapped[str] = mapped_column(sa.String, nullable=False)
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
__table_args__ = (sa.Index("idx_recording_meeting_id", "meeting_id"),)
class CalendarEventModel(Base):
__tablename__ = "calendar_event"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
room_id: Mapped[str] = mapped_column(
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE"), nullable=False
)
ics_uid: Mapped[str] = mapped_column(sa.Text, nullable=False)
title: Mapped[Optional[str]] = mapped_column(sa.Text)
description: Mapped[Optional[str]] = mapped_column(sa.Text)
start_time: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
end_time: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
attendees: Mapped[Optional[dict]] = mapped_column(JSONB)
location: Mapped[Optional[str]] = mapped_column(sa.Text)
ics_raw_data: Mapped[Optional[str]] = mapped_column(sa.Text)
last_synced: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
is_deleted: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
__table_args__ = (
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
)
metadata = Base.metadata

View File

@@ -2,17 +2,45 @@ from datetime import datetime, timedelta, timezone
from typing import Any from typing import Any
import sqlalchemy as sa import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
from sqlalchemy import delete, select, update from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.base import CalendarEventModel from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4 from reflector.utils import generate_uuid4
calendar_events = sa.Table(
"calendar_event",
metadata,
sa.Column("id", sa.String, primary_key=True),
sa.Column(
"room_id",
sa.String,
sa.ForeignKey("room.id", ondelete="CASCADE", name="fk_calendar_event_room_id"),
nullable=False,
),
sa.Column("ics_uid", sa.Text, nullable=False),
sa.Column("title", sa.Text),
sa.Column("description", sa.Text),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("attendees", JSONB),
sa.Column("location", sa.Text),
sa.Column("ics_raw_data", sa.Text),
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False),
sa.Column("is_deleted", sa.Boolean, nullable=False, server_default=sa.false()),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"),
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
sa.Index(
"idx_calendar_event_deleted",
"is_deleted",
postgresql_where=sa.text("NOT is_deleted"),
),
)
class CalendarEvent(BaseModel): class CalendarEvent(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
room_id: str room_id: str
ics_uid: str ics_uid: str
@@ -30,157 +58,129 @@ class CalendarEvent(BaseModel):
class CalendarEventController: class CalendarEventController:
async def get_upcoming_events(
self,
session: AsyncSession,
room_id: str,
current_time: datetime,
buffer_minutes: int = 15,
) -> list[CalendarEvent]:
buffer_time = current_time + timedelta(minutes=buffer_minutes)
query = (
select(CalendarEventModel)
.where(
sa.and_(
CalendarEventModel.room_id == room_id,
CalendarEventModel.start_time <= buffer_time,
CalendarEventModel.end_time > current_time,
)
)
.order_by(CalendarEventModel.start_time)
)
result = await session.execute(query)
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
async def get_by_id(
self, session: AsyncSession, event_id: str
) -> CalendarEvent | None:
query = select(CalendarEventModel).where(CalendarEventModel.id == event_id)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
return None
return CalendarEvent.model_validate(row)
async def get_by_ics_uid(
self, session: AsyncSession, room_id: str, ics_uid: str
) -> CalendarEvent | None:
query = select(CalendarEventModel).where(
sa.and_(
CalendarEventModel.room_id == room_id,
CalendarEventModel.ics_uid == ics_uid,
)
)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
return None
return CalendarEvent.model_validate(row)
async def upsert(
self, session: AsyncSession, event: CalendarEvent
) -> CalendarEvent:
existing = await self.get_by_ics_uid(session, event.room_id, event.ics_uid)
if existing:
event.updated_at = datetime.now(timezone.utc)
query = (
update(CalendarEventModel)
.where(CalendarEventModel.id == existing.id)
.values(**event.model_dump(exclude={"id"}))
)
await session.execute(query)
await session.commit()
return event
else:
new_event = CalendarEventModel(**event.model_dump())
session.add(new_event)
await session.commit()
return event
async def delete_old_events(
self, session: AsyncSession, room_id: str, cutoff_date: datetime
) -> int:
query = delete(CalendarEventModel).where(
sa.and_(
CalendarEventModel.room_id == room_id,
CalendarEventModel.end_time < cutoff_date,
)
)
result = await session.execute(query)
await session.commit()
return result.rowcount
async def delete_events_not_in_list(
self, session: AsyncSession, room_id: str, keep_ics_uids: list[str]
) -> int:
if not keep_ics_uids:
query = delete(CalendarEventModel).where(
CalendarEventModel.room_id == room_id
)
else:
query = delete(CalendarEventModel).where(
sa.and_(
CalendarEventModel.room_id == room_id,
CalendarEventModel.ics_uid.notin_(keep_ics_uids),
)
)
result = await session.execute(query)
await session.commit()
return result.rowcount
async def get_by_room( async def get_by_room(
self, session: AsyncSession, room_id: str, include_deleted: bool = True self,
room_id: str,
include_deleted: bool = False,
start_after: datetime | None = None,
end_before: datetime | None = None,
) -> list[CalendarEvent]: ) -> list[CalendarEvent]:
query = select(CalendarEventModel).where(CalendarEventModel.room_id == room_id) query = calendar_events.select().where(calendar_events.c.room_id == room_id)
if not include_deleted: if not include_deleted:
query = query.where(CalendarEventModel.is_deleted == False) query = query.where(calendar_events.c.is_deleted == False)
result = await session.execute(query)
return [CalendarEvent.model_validate(row) for row in result.scalars().all()] if start_after:
query = query.where(calendar_events.c.start_time >= start_after)
if end_before:
query = query.where(calendar_events.c.end_time <= end_before)
query = query.order_by(calendar_events.c.start_time.asc())
results = await get_database().fetch_all(query)
return [CalendarEvent(**result) for result in results]
async def get_upcoming( async def get_upcoming(
self, session: AsyncSession, room_id: str, minutes_ahead: int = 120 self, room_id: str, minutes_ahead: int = 120
) -> list[CalendarEvent]: ) -> list[CalendarEvent]:
"""Get upcoming events for a room within the specified minutes, including currently happening events."""
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
buffer_time = now + timedelta(minutes=minutes_ahead) future_time = now + timedelta(minutes=minutes_ahead)
query = ( query = (
select(CalendarEventModel) calendar_events.select()
.where( .where(
sa.and_( sa.and_(
CalendarEventModel.room_id == room_id, calendar_events.c.room_id == room_id,
CalendarEventModel.start_time <= buffer_time, calendar_events.c.is_deleted == False,
CalendarEventModel.end_time > now, calendar_events.c.start_time <= future_time,
CalendarEventModel.is_deleted == False, calendar_events.c.end_time >= now,
) )
) )
.order_by(CalendarEventModel.start_time) .order_by(calendar_events.c.start_time.asc())
) )
result = await session.execute(query) results = await get_database().fetch_all(query)
return [CalendarEvent.model_validate(row) for row in result.scalars().all()] return [CalendarEvent(**result) for result in results]
async def get_by_id(self, event_id: str) -> CalendarEvent | None:
query = calendar_events.select().where(calendar_events.c.id == event_id)
result = await get_database().fetch_one(query)
return CalendarEvent(**result) if result else None
async def get_by_ics_uid(self, room_id: str, ics_uid: str) -> CalendarEvent | None:
query = calendar_events.select().where(
sa.and_(
calendar_events.c.room_id == room_id,
calendar_events.c.ics_uid == ics_uid,
)
)
result = await get_database().fetch_one(query)
return CalendarEvent(**result) if result else None
async def upsert(self, event: CalendarEvent) -> CalendarEvent:
existing = await self.get_by_ics_uid(event.room_id, event.ics_uid)
if existing:
event.id = existing.id
event.created_at = existing.created_at
event.updated_at = datetime.now(timezone.utc)
query = (
calendar_events.update()
.where(calendar_events.c.id == existing.id)
.values(**event.model_dump())
)
else:
query = calendar_events.insert().values(**event.model_dump())
await get_database().execute(query)
return event
async def soft_delete_missing( async def soft_delete_missing(
self, session: AsyncSession, room_id: str, current_ics_uids: list[str] self, room_id: str, current_ics_uids: list[str]
) -> int: ) -> int:
query = ( """Soft delete future events that are no longer in the calendar."""
update(CalendarEventModel) now = datetime.now(timezone.utc)
.where(
sa.and_( select_query = calendar_events.select().where(
CalendarEventModel.room_id == room_id, sa.and_(
CalendarEventModel.ics_uid.notin_(current_ics_uids) calendar_events.c.room_id == room_id,
if current_ics_uids calendar_events.c.start_time > now,
else True, calendar_events.c.is_deleted == False,
CalendarEventModel.end_time > datetime.now(timezone.utc), calendar_events.c.ics_uid.notin_(current_ics_uids)
) if current_ics_uids
else True,
) )
.values(is_deleted=True)
) )
result = await session.execute(query)
await session.commit() to_delete = await get_database().fetch_all(select_query)
delete_count = len(to_delete)
if delete_count > 0:
update_query = (
calendar_events.update()
.where(
sa.and_(
calendar_events.c.room_id == room_id,
calendar_events.c.start_time > now,
calendar_events.c.is_deleted == False,
calendar_events.c.ics_uid.notin_(current_ics_uids)
if current_ics_uids
else True,
)
)
.values(is_deleted=True, updated_at=now)
)
await get_database().execute(update_query)
return delete_count
async def delete_by_room(self, room_id: str) -> int:
query = calendar_events.delete().where(calendar_events.c.room_id == room_id)
result = await get_database().execute(query)
return result.rowcount return result.rowcount

View File

@@ -2,18 +2,89 @@ from datetime import datetime
from typing import Any, Literal from typing import Any, Literal
import sqlalchemy as sa import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
from sqlalchemy import select, update from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.base import MeetingConsentModel, MeetingModel from reflector.db import get_database, metadata
from reflector.db.rooms import Room from reflector.db.rooms import Room
from reflector.schemas.platform import WHEREBY_PLATFORM, Platform
from reflector.utils import generate_uuid4 from reflector.utils import generate_uuid4
from reflector.utils.string import assert_equal
from reflector.video_platforms.factory import get_platform
meetings = sa.Table(
"meeting",
metadata,
sa.Column("id", sa.String, primary_key=True),
sa.Column("room_name", sa.String),
sa.Column("room_url", sa.String),
sa.Column("host_room_url", sa.String),
sa.Column("start_date", sa.DateTime(timezone=True)),
sa.Column("end_date", sa.DateTime(timezone=True)),
sa.Column(
"room_id",
sa.String,
sa.ForeignKey("room.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
sa.Column("room_mode", sa.String, nullable=False, server_default="normal"),
sa.Column("recording_type", sa.String, nullable=False, server_default="cloud"),
sa.Column(
"recording_trigger",
sa.String,
nullable=False,
server_default="automatic-2nd-participant",
),
sa.Column(
"num_clients",
sa.Integer,
nullable=False,
server_default=sa.text("0"),
),
sa.Column(
"is_active",
sa.Boolean,
nullable=False,
server_default=sa.true(),
),
sa.Column(
"calendar_event_id",
sa.String,
sa.ForeignKey(
"calendar_event.id",
ondelete="SET NULL",
name="fk_meeting_calendar_event_id",
),
),
sa.Column("calendar_metadata", JSONB),
sa.Column(
"platform",
sa.String,
nullable=False,
server_default=assert_equal(WHEREBY_PLATFORM, "whereby"),
),
sa.Index("idx_meeting_room_id", "room_id"),
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
)
meeting_consent = sa.Table(
"meeting_consent",
metadata,
sa.Column("id", sa.String, primary_key=True),
sa.Column(
"meeting_id",
sa.String,
sa.ForeignKey("meeting.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("user_id", sa.String),
sa.Column("consent_given", sa.Boolean, nullable=False),
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
)
class MeetingConsent(BaseModel): class MeetingConsent(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
meeting_id: str meeting_id: str
user_id: str | None = None user_id: str | None = None
@@ -22,8 +93,6 @@ class MeetingConsent(BaseModel):
class Meeting(BaseModel): class Meeting(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str id: str
room_name: str room_name: str
room_url: str room_url: str
@@ -34,19 +103,19 @@ class Meeting(BaseModel):
is_locked: bool = False is_locked: bool = False
room_mode: Literal["normal", "group"] = "normal" room_mode: Literal["normal", "group"] = "normal"
recording_type: Literal["none", "local", "cloud"] = "cloud" recording_type: Literal["none", "local", "cloud"] = "cloud"
recording_trigger: Literal[ recording_trigger: Literal[ # whereby-specific
"none", "prompt", "automatic", "automatic-2nd-participant" "none", "prompt", "automatic", "automatic-2nd-participant"
] = "automatic-2nd-participant" ] = "automatic-2nd-participant"
num_clients: int = 0 num_clients: int = 0
is_active: bool = True is_active: bool = True
calendar_event_id: str | None = None calendar_event_id: str | None = None
calendar_metadata: dict[str, Any] | None = None calendar_metadata: dict[str, Any] | None = None
platform: Platform = WHEREBY_PLATFORM
class MeetingController: class MeetingController:
async def create( async def create(
self, self,
session: AsyncSession,
id: str, id: str,
room_name: str, room_name: str,
room_url: str, room_url: str,
@@ -71,20 +140,19 @@ class MeetingController:
recording_trigger=room.recording_trigger, recording_trigger=room.recording_trigger,
calendar_event_id=calendar_event_id, calendar_event_id=calendar_event_id,
calendar_metadata=calendar_metadata, calendar_metadata=calendar_metadata,
platform=get_platform(room.platform),
) )
new_meeting = MeetingModel(**meeting.model_dump()) query = meetings.insert().values(**meeting.model_dump())
session.add(new_meeting) await get_database().execute(query)
await session.commit()
return meeting return meeting
async def get_all_active(self, session: AsyncSession) -> list[Meeting]: async def get_all_active(self) -> list[Meeting]:
query = select(MeetingModel).where(MeetingModel.is_active) query = meetings.select().where(meetings.c.is_active)
result = await session.execute(query) results = await get_database().fetch_all(query)
return [Meeting.model_validate(row) for row in result.scalars().all()] return [Meeting(**result) for result in results]
async def get_by_room_name( async def get_by_room_name(
self, self,
session: AsyncSession,
room_name: str, room_name: str,
) -> Meeting | None: ) -> Meeting | None:
""" """
@@ -92,178 +160,182 @@ class MeetingController:
For backward compatibility, returns the most recent meeting. For backward compatibility, returns the most recent meeting.
""" """
query = ( query = (
select(MeetingModel) meetings.select()
.where(MeetingModel.room_name == room_name) .where(meetings.c.room_name == room_name)
.order_by(MeetingModel.end_date.desc()) .order_by(meetings.c.end_date.desc())
) )
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalar_one_or_none() if not result:
if not row:
return None return None
return Meeting.model_validate(row) return Meeting(**result)
async def get_active( async def get_active(self, room: Room, current_time: datetime) -> Meeting | None:
self, session: AsyncSession, room: Room, current_time: datetime
) -> Meeting | None:
""" """
Get latest active meeting for a room. Get latest active meeting for a room.
For backward compatibility, returns the most recent active meeting. For backward compatibility, returns the most recent active meeting.
""" """
end_date = getattr(meetings.c, "end_date")
query = ( query = (
select(MeetingModel) meetings.select()
.where( .where(
sa.and_( sa.and_(
MeetingModel.room_id == room.id, meetings.c.room_id == room.id,
MeetingModel.end_date > current_time, meetings.c.end_date > current_time,
MeetingModel.is_active, meetings.c.is_active,
) )
) )
.order_by(MeetingModel.end_date.desc()) .order_by(end_date.desc())
) )
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalar_one_or_none() if not result:
if not row:
return None return None
return Meeting.model_validate(row) return Meeting(**result)
async def get_all_active_for_room( async def get_all_active_for_room(
self, session: AsyncSession, room: Room, current_time: datetime self, room: Room, current_time: datetime
) -> list[Meeting]: ) -> list[Meeting]:
end_date = getattr(meetings.c, "end_date")
query = ( query = (
select(MeetingModel) meetings.select()
.where( .where(
sa.and_( sa.and_(
MeetingModel.room_id == room.id, meetings.c.room_id == room.id,
MeetingModel.end_date > current_time, meetings.c.end_date > current_time,
MeetingModel.is_active, meetings.c.is_active,
) )
) )
.order_by(MeetingModel.end_date.desc()) .order_by(end_date.desc())
) )
result = await session.execute(query) results = await get_database().fetch_all(query)
return [Meeting.model_validate(row) for row in result.scalars().all()] return [Meeting(**result) for result in results]
async def get_active_by_calendar_event( async def get_active_by_calendar_event(
self, self, room: Room, calendar_event_id: str, current_time: datetime
session: AsyncSession,
room: Room,
calendar_event_id: str,
current_time: datetime,
) -> Meeting | None: ) -> Meeting | None:
""" """
Get active meeting for a specific calendar event. Get active meeting for a specific calendar event.
""" """
query = select(MeetingModel).where( query = meetings.select().where(
sa.and_( sa.and_(
MeetingModel.room_id == room.id, meetings.c.room_id == room.id,
MeetingModel.calendar_event_id == calendar_event_id, meetings.c.calendar_event_id == calendar_event_id,
MeetingModel.end_date > current_time, meetings.c.end_date > current_time,
MeetingModel.is_active, meetings.c.is_active,
) )
) )
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalar_one_or_none() if not result:
if not row:
return None return None
return Meeting.model_validate(row) return Meeting(**result)
async def get_by_id( async def get_by_id(
self, session: AsyncSession, meeting_id: str, **kwargs self, meeting_id: str, room: Room | None = None
) -> Meeting | None: ) -> Meeting | None:
query = select(MeetingModel).where(MeetingModel.id == meeting_id) query = meetings.select().where(meetings.c.id == meeting_id)
result = await session.execute(query)
row = result.scalar_one_or_none() if room:
if not row: query = query.where(meetings.c.room_id == room.id)
result = await get_database().fetch_one(query)
if not result:
return None return None
return Meeting.model_validate(row) return Meeting(**result)
async def get_by_calendar_event( async def get_by_calendar_event(
self, session: AsyncSession, calendar_event_id: str self, calendar_event_id: str, room: Room
) -> Meeting | None: ) -> Meeting | None:
query = select(MeetingModel).where( query = meetings.select().where(
MeetingModel.calendar_event_id == calendar_event_id meetings.c.calendar_event_id == calendar_event_id
) )
result = await session.execute(query) if room:
row = result.scalar_one_or_none() query = query.where(meetings.c.room_id == room.id)
if not row: result = await get_database().fetch_one(query)
if not result:
return None return None
return Meeting.model_validate(row) return Meeting(**result)
async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs): async def update_meeting(self, meeting_id: str, **kwargs):
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
await get_database().execute(query)
async def increment_num_clients(self, meeting_id: str) -> None:
"""Atomically increment participant count."""
query = ( query = (
update(MeetingModel).where(MeetingModel.id == meeting_id).values(**kwargs) meetings.update()
.where(meetings.c.id == meeting_id)
.values(num_clients=meetings.c.num_clients + 1)
) )
await session.execute(query) await get_database().execute(query)
await session.commit()
async def decrement_num_clients(self, meeting_id: str) -> None:
"""Atomically decrement participant count (min 0)."""
query = (
meetings.update()
.where(meetings.c.id == meeting_id)
.values(
num_clients=sa.case(
(meetings.c.num_clients > 0, meetings.c.num_clients - 1), else_=0
)
)
)
await get_database().execute(query)
class MeetingConsentController: class MeetingConsentController:
async def get_by_meeting_id( async def get_by_meeting_id(self, meeting_id: str) -> list[MeetingConsent]:
self, session: AsyncSession, meeting_id: str query = meeting_consent.select().where(
) -> list[MeetingConsent]: meeting_consent.c.meeting_id == meeting_id
query = select(MeetingConsentModel).where(
MeetingConsentModel.meeting_id == meeting_id
) )
result = await session.execute(query) results = await get_database().fetch_all(query)
return [MeetingConsent.model_validate(row) for row in result.scalars().all()] return [MeetingConsent(**result) for result in results]
async def get_by_meeting_and_user( async def get_by_meeting_and_user(
self, session: AsyncSession, meeting_id: str, user_id: str self, meeting_id: str, user_id: str
) -> MeetingConsent | None: ) -> MeetingConsent | None:
"""Get existing consent for a specific user and meeting""" """Get existing consent for a specific user and meeting"""
query = select(MeetingConsentModel).where( query = meeting_consent.select().where(
sa.and_( meeting_consent.c.meeting_id == meeting_id,
MeetingConsentModel.meeting_id == meeting_id, meeting_consent.c.user_id == user_id,
MeetingConsentModel.user_id == user_id,
)
) )
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalar_one_or_none() if result is None:
if row is None:
return None return None
return MeetingConsent.model_validate(row) return MeetingConsent(**result)
async def upsert( async def upsert(self, consent: MeetingConsent) -> MeetingConsent:
self, session: AsyncSession, consent: MeetingConsent
) -> MeetingConsent:
if consent.user_id: if consent.user_id:
# For authenticated users, check if consent already exists # For authenticated users, check if consent already exists
# not transactional but we're ok with that; the consents ain't deleted anyways # not transactional but we're ok with that; the consents ain't deleted anyways
existing = await self.get_by_meeting_and_user( existing = await self.get_by_meeting_and_user(
session, consent.meeting_id, consent.user_id consent.meeting_id, consent.user_id
) )
if existing: if existing:
query = ( query = (
update(MeetingConsentModel) meeting_consent.update()
.where(MeetingConsentModel.id == existing.id) .where(meeting_consent.c.id == existing.id)
.values( .values(
consent_given=consent.consent_given, consent_given=consent.consent_given,
consent_timestamp=consent.consent_timestamp, consent_timestamp=consent.consent_timestamp,
) )
) )
await session.execute(query) await get_database().execute(query)
await session.commit()
existing.consent_given = consent.consent_given existing.consent_given = consent.consent_given
existing.consent_timestamp = consent.consent_timestamp existing.consent_timestamp = consent.consent_timestamp
return existing return existing
new_consent = MeetingConsentModel(**consent.model_dump()) query = meeting_consent.insert().values(**consent.model_dump())
session.add(new_consent) await get_database().execute(query)
await session.commit()
return consent return consent
async def has_any_denial(self, session: AsyncSession, meeting_id: str) -> bool: async def has_any_denial(self, meeting_id: str) -> bool:
"""Check if any participant denied consent for this meeting""" """Check if any participant denied consent for this meeting"""
query = select(MeetingConsentModel).where( query = meeting_consent.select().where(
sa.and_( meeting_consent.c.meeting_id == meeting_id,
MeetingConsentModel.meeting_id == meeting_id, meeting_consent.c.consent_given.is_(False),
MeetingConsentModel.consent_given.is_(False),
)
) )
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalar_one_or_none() return result is not None
return row is not None
meetings_controller = MeetingController() meetings_controller = MeetingController()

View File

@@ -1,79 +1,65 @@
from datetime import datetime, timezone from datetime import datetime
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field import sqlalchemy as sa
from sqlalchemy import delete, select from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.base import RecordingModel from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4 from reflector.utils import generate_uuid4
recordings = sa.Table(
"recording",
metadata,
sa.Column("id", sa.String, primary_key=True),
sa.Column("bucket_name", sa.String, nullable=False),
sa.Column("object_key", sa.String, nullable=False),
sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
sa.Column(
"status",
sa.String,
nullable=False,
server_default="pending",
),
sa.Column("meeting_id", sa.String),
sa.Column("track_keys", sa.JSON, nullable=True),
sa.Index("idx_recording_meeting_id", "meeting_id"),
)
class Recording(BaseModel): class Recording(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
meeting_id: str bucket_name: str
url: str # for single-track
object_key: str object_key: str
duration: float | None = None recorded_at: datetime
created_at: datetime status: Literal["pending", "processing", "completed", "failed"] = "pending"
meeting_id: str | None = None
# for multitrack reprocessing
track_keys: list[str] | None = None
class RecordingController: class RecordingController:
async def create( async def create(self, recording: Recording):
self, query = recordings.insert().values(**recording.model_dump())
session: AsyncSession, await get_database().execute(query)
meeting_id: str,
url: str,
object_key: str,
duration: float | None = None,
created_at: datetime | None = None,
):
if created_at is None:
created_at = datetime.now(timezone.utc)
recording = Recording(
meeting_id=meeting_id,
url=url,
object_key=object_key,
duration=duration,
created_at=created_at,
)
new_recording = RecordingModel(**recording.model_dump())
session.add(new_recording)
await session.commit()
return recording return recording
async def get_by_id( async def get_by_id(self, id: str) -> Recording:
self, session: AsyncSession, recording_id: str query = recordings.select().where(recordings.c.id == id)
) -> Recording | None: result = await get_database().fetch_one(query)
""" return Recording(**result) if result else None
Get a recording by id
"""
query = select(RecordingModel).where(RecordingModel.id == recording_id)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
return None
return Recording.model_validate(row)
async def get_by_meeting_id( async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
self, session: AsyncSession, meeting_id: str query = recordings.select().where(
) -> list[Recording]: recordings.c.bucket_name == bucket_name,
""" recordings.c.object_key == object_key,
Get all recordings for a meeting )
""" result = await get_database().fetch_one(query)
query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id) return Recording(**result) if result else None
result = await session.execute(query)
return [Recording.model_validate(row) for row in result.scalars().all()]
async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None: async def remove_by_id(self, id: str) -> None:
""" query = recordings.delete().where(recordings.c.id == id)
Remove a recording by id await get_database().execute(query)
"""
query = delete(RecordingModel).where(RecordingModel.id == recording_id)
await session.execute(query)
await session.commit()
recordings_controller = RecordingController() recordings_controller = RecordingController()

View File

@@ -3,19 +3,66 @@ from datetime import datetime, timezone
from sqlite3 import IntegrityError from sqlite3 import IntegrityError
from typing import Literal from typing import Literal
import sqlalchemy
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, Field
from sqlalchemy import delete, select, update from sqlalchemy.sql import false, or_
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import or_
from reflector.db.base import RoomModel from reflector.db import get_database, metadata
from reflector.schemas.platform import Platform
from reflector.utils import generate_uuid4 from reflector.utils import generate_uuid4
rooms = sqlalchemy.Table(
"room",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
sqlalchemy.Column(
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column("zulip_stream", sqlalchemy.String),
sqlalchemy.Column("zulip_topic", sqlalchemy.String),
sqlalchemy.Column(
"is_locked", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column(
"room_mode", sqlalchemy.String, nullable=False, server_default="normal"
),
sqlalchemy.Column(
"recording_type", sqlalchemy.String, nullable=False, server_default="cloud"
),
sqlalchemy.Column(
"recording_trigger",
sqlalchemy.String,
nullable=False,
server_default="automatic-2nd-participant",
),
sqlalchemy.Column(
"is_shared", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column("webhook_url", sqlalchemy.String, nullable=True),
sqlalchemy.Column("webhook_secret", sqlalchemy.String, nullable=True),
sqlalchemy.Column("ics_url", sqlalchemy.Text),
sqlalchemy.Column("ics_fetch_interval", sqlalchemy.Integer, server_default="300"),
sqlalchemy.Column(
"ics_enabled", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column("ics_last_sync", sqlalchemy.DateTime(timezone=True)),
sqlalchemy.Column("ics_last_etag", sqlalchemy.Text),
sqlalchemy.Column(
"platform",
sqlalchemy.String,
nullable=True,
server_default=None,
),
sqlalchemy.Index("idx_room_is_shared", "is_shared"),
sqlalchemy.Index("idx_room_ics_enabled", "ics_enabled"),
)
class Room(BaseModel): class Room(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
name: str name: str
user_id: str user_id: str
@@ -26,7 +73,7 @@ class Room(BaseModel):
is_locked: bool = False is_locked: bool = False
room_mode: Literal["normal", "group"] = "normal" room_mode: Literal["normal", "group"] = "normal"
recording_type: Literal["none", "local", "cloud"] = "cloud" recording_type: Literal["none", "local", "cloud"] = "cloud"
recording_trigger: Literal[ recording_trigger: Literal[ # whereby-specific
"none", "prompt", "automatic", "automatic-2nd-participant" "none", "prompt", "automatic", "automatic-2nd-participant"
] = "automatic-2nd-participant" ] = "automatic-2nd-participant"
is_shared: bool = False is_shared: bool = False
@@ -37,12 +84,12 @@ class Room(BaseModel):
ics_enabled: bool = False ics_enabled: bool = False
ics_last_sync: datetime | None = None ics_last_sync: datetime | None = None
ics_last_etag: str | None = None ics_last_etag: str | None = None
platform: Platform | None = None
class RoomController: class RoomController:
async def get_all( async def get_all(
self, self,
session: AsyncSession,
user_id: str | None = None, user_id: str | None = None,
order_by: str | None = None, order_by: str | None = None,
return_query: bool = False, return_query: bool = False,
@@ -56,14 +103,14 @@ class RoomController:
Parameters: Parameters:
- `order_by`: field to order by, e.g. "-created_at" - `order_by`: field to order by, e.g. "-created_at"
""" """
query = select(RoomModel) query = rooms.select()
if user_id is not None: if user_id is not None:
query = query.where(or_(RoomModel.user_id == user_id, RoomModel.is_shared)) query = query.where(or_(rooms.c.user_id == user_id, rooms.c.is_shared))
else: else:
query = query.where(RoomModel.is_shared) query = query.where(rooms.c.is_shared)
if order_by is not None: if order_by is not None:
field = getattr(RoomModel, order_by[1:]) field = getattr(rooms.c, order_by[1:])
if order_by.startswith("-"): if order_by.startswith("-"):
field = field.desc() field = field.desc()
query = query.order_by(field) query = query.order_by(field)
@@ -71,12 +118,11 @@ class RoomController:
if return_query: if return_query:
return query return query
result = await session.execute(query) results = await get_database().fetch_all(query)
return [Room.model_validate(row) for row in result.scalars().all()] return results
async def add( async def add(
self, self,
session: AsyncSession,
name: str, name: str,
user_id: str, user_id: str,
zulip_auto_post: bool, zulip_auto_post: bool,
@@ -92,6 +138,7 @@ class RoomController:
ics_url: str | None = None, ics_url: str | None = None,
ics_fetch_interval: int = 300, ics_fetch_interval: int = 300,
ics_enabled: bool = False, ics_enabled: bool = False,
platform: Platform | None = None,
): ):
""" """
Add a new room Add a new room
@@ -115,28 +162,25 @@ class RoomController:
ics_url=ics_url, ics_url=ics_url,
ics_fetch_interval=ics_fetch_interval, ics_fetch_interval=ics_fetch_interval,
ics_enabled=ics_enabled, ics_enabled=ics_enabled,
platform=platform,
) )
new_room = RoomModel(**room.model_dump()) query = rooms.insert().values(**room.model_dump())
session.add(new_room)
try: try:
await session.flush() await get_database().execute(query)
except IntegrityError: except IntegrityError:
raise HTTPException(status_code=400, detail="Room name is not unique") raise HTTPException(status_code=400, detail="Room name is not unique")
return room return room
async def update( async def update(self, room: Room, values: dict, mutate=True):
self, session: AsyncSession, room: Room, values: dict, mutate=True
):
""" """
Update a room fields with key/values in values Update a room fields with key/values in values
""" """
if values.get("webhook_url") and not values.get("webhook_secret"): if values.get("webhook_url") and not values.get("webhook_secret"):
values["webhook_secret"] = secrets.token_urlsafe(32) values["webhook_secret"] = secrets.token_urlsafe(32)
query = update(RoomModel).where(RoomModel.id == room.id).values(**values) query = rooms.update().where(rooms.c.id == room.id).values(**values)
try: try:
await session.execute(query) await get_database().execute(query)
await session.flush()
except IntegrityError: except IntegrityError:
raise HTTPException(status_code=400, detail="Room name is not unique") raise HTTPException(status_code=400, detail="Room name is not unique")
@@ -144,79 +188,67 @@ class RoomController:
for key, value in values.items(): for key, value in values.items():
setattr(room, key, value) setattr(room, key, value)
async def get_by_id( async def get_by_id(self, room_id: str, **kwargs) -> Room | None:
self, session: AsyncSession, room_id: str, **kwargs
) -> Room | None:
""" """
Get a room by id Get a room by id
""" """
query = select(RoomModel).where(RoomModel.id == room_id) query = rooms.select().where(rooms.c.id == room_id)
if "user_id" in kwargs: if "user_id" in kwargs:
query = query.where(RoomModel.user_id == kwargs["user_id"]) query = query.where(rooms.c.user_id == kwargs["user_id"])
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalars().first() if not result:
if not row:
return None return None
return Room.model_validate(row) return Room(**result)
async def get_by_name( async def get_by_name(self, room_name: str, **kwargs) -> Room | None:
self, session: AsyncSession, room_name: str, **kwargs
) -> Room | None:
""" """
Get a room by name Get a room by name
""" """
query = select(RoomModel).where(RoomModel.name == room_name) query = rooms.select().where(rooms.c.name == room_name)
if "user_id" in kwargs: if "user_id" in kwargs:
query = query.where(RoomModel.user_id == kwargs["user_id"]) query = query.where(rooms.c.user_id == kwargs["user_id"])
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalars().first() if not result:
if not row:
return None return None
return Room.model_validate(row) return Room(**result)
async def get_by_id_for_http( async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Room:
self, session: AsyncSession, meeting_id: str, user_id: str | None
) -> Room:
""" """
Get a room by ID for HTTP request. Get a room by ID for HTTP request.
If not found, it will raise a 404 error. If not found, it will raise a 404 error.
""" """
query = select(RoomModel).where(RoomModel.id == meeting_id) query = rooms.select().where(rooms.c.id == meeting_id)
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalars().first() if not result:
if not row:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
room = Room.model_validate(row) room = Room(**result)
return room return room
async def get_ics_enabled(self, session: AsyncSession) -> list[Room]: async def get_ics_enabled(self) -> list[Room]:
query = select(RoomModel).where( query = rooms.select().where(
RoomModel.ics_enabled == True, RoomModel.ics_url != None rooms.c.ics_enabled == True, rooms.c.ics_url != None
) )
result = await session.execute(query) results = await get_database().fetch_all(query)
results = result.scalars().all() return [Room(**result) for result in results]
return [Room(**row.__dict__) for row in results]
async def remove_by_id( async def remove_by_id(
self, self,
session: AsyncSession,
room_id: str, room_id: str,
user_id: str | None = None, user_id: str | None = None,
) -> None: ) -> None:
""" """
Remove a room by id Remove a room by id
""" """
room = await self.get_by_id(session, room_id, user_id=user_id) room = await self.get_by_id(room_id, user_id=user_id)
if not room: if not room:
return return
if user_id is not None and room.user_id != user_id: if user_id is not None and room.user_id != user_id:
return return
query = delete(RoomModel).where(RoomModel.id == room_id) query = rooms.delete().where(rooms.c.id == room_id)
await session.execute(query) await get_database().execute(query)
await session.flush()
rooms_controller = RoomController() rooms_controller = RoomController()

View File

@@ -8,6 +8,7 @@ from typing import Annotated, Any, Dict, Iterator
import sqlalchemy import sqlalchemy
import webvtt import webvtt
from databases.interfaces import Record as DbRecord
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
@@ -19,10 +20,11 @@ from pydantic import (
constr, constr,
field_serializer, field_serializer,
) )
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.base import RoomModel, TranscriptModel from reflector.db import get_database
from reflector.db.transcripts import SourceKind, TranscriptStatus from reflector.db.rooms import rooms
from reflector.db.transcripts import SourceKind, TranscriptStatus, transcripts
from reflector.db.utils import is_postgresql
from reflector.logger import logger from reflector.logger import logger
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
@@ -133,6 +135,8 @@ class SearchParameters(BaseModel):
user_id: str | None = None user_id: str | None = None
room_id: str | None = None room_id: str | None = None
source_kind: SourceKind | None = None source_kind: SourceKind | None = None
from_datetime: datetime | None = None
to_datetime: datetime | None = None
class SearchResultDB(BaseModel): class SearchResultDB(BaseModel):
@@ -329,30 +333,36 @@ class SearchController:
@classmethod @classmethod
async def search_transcripts( async def search_transcripts(
cls, session: AsyncSession, params: SearchParameters cls, params: SearchParameters
) -> tuple[list[SearchResult], int]: ) -> tuple[list[SearchResult], int]:
""" """
Full-text search for transcripts using PostgreSQL tsvector. Full-text search for transcripts using PostgreSQL tsvector.
Returns (results, total_count). Returns (results, total_count).
""" """
if not is_postgresql():
logger.warning(
"Full-text search requires PostgreSQL. Returning empty results."
)
return [], 0
base_columns = [ base_columns = [
TranscriptModel.id, transcripts.c.id,
TranscriptModel.title, transcripts.c.title,
TranscriptModel.created_at, transcripts.c.created_at,
TranscriptModel.duration, transcripts.c.duration,
TranscriptModel.status, transcripts.c.status,
TranscriptModel.user_id, transcripts.c.user_id,
TranscriptModel.room_id, transcripts.c.room_id,
TranscriptModel.source_kind, transcripts.c.source_kind,
TranscriptModel.webvtt, transcripts.c.webvtt,
TranscriptModel.long_summary, transcripts.c.long_summary,
sqlalchemy.case( sqlalchemy.case(
( (
TranscriptModel.room_id.isnot(None) & RoomModel.id.is_(None), transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
"Deleted Room", "Deleted Room",
), ),
else_=RoomModel.name, else_=rooms.c.name,
).label("room_name"), ).label("room_name"),
] ]
search_query = None search_query = None
@@ -361,7 +371,7 @@ class SearchController:
"english", params.query_text "english", params.query_text
) )
rank_column = sqlalchemy.func.ts_rank( rank_column = sqlalchemy.func.ts_rank(
TranscriptModel.search_vector_en, transcripts.c.search_vector_en,
search_query, search_query,
32, # normalization flag: rank/(rank+1) for 0-1 range 32, # normalization flag: rank/(rank+1) for 0-1 range
).label("rank") ).label("rank")
@@ -369,51 +379,55 @@ class SearchController:
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank") rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
columns = base_columns + [rank_column] columns = base_columns + [rank_column]
base_query = ( base_query = sqlalchemy.select(columns).select_from(
sqlalchemy.select(*columns) transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
.select_from(TranscriptModel)
.outerjoin(RoomModel, TranscriptModel.room_id == RoomModel.id)
) )
if params.query_text is not None: if params.query_text is not None:
# because already initialized based on params.query_text presence above # because already initialized based on params.query_text presence above
assert search_query is not None assert search_query is not None
base_query = base_query.where( base_query = base_query.where(
TranscriptModel.search_vector_en.op("@@")(search_query) transcripts.c.search_vector_en.op("@@")(search_query)
) )
if params.user_id: if params.user_id:
base_query = base_query.where( base_query = base_query.where(
sqlalchemy.or_( sqlalchemy.or_(
TranscriptModel.user_id == params.user_id, RoomModel.is_shared transcripts.c.user_id == params.user_id, rooms.c.is_shared
) )
) )
else: else:
base_query = base_query.where(RoomModel.is_shared) base_query = base_query.where(rooms.c.is_shared)
if params.room_id: if params.room_id:
base_query = base_query.where(TranscriptModel.room_id == params.room_id) base_query = base_query.where(transcripts.c.room_id == params.room_id)
if params.source_kind: if params.source_kind:
base_query = base_query.where( base_query = base_query.where(
TranscriptModel.source_kind == params.source_kind transcripts.c.source_kind == params.source_kind
)
if params.from_datetime:
base_query = base_query.where(
transcripts.c.created_at >= params.from_datetime
)
if params.to_datetime:
base_query = base_query.where(
transcripts.c.created_at <= params.to_datetime
) )
if params.query_text is not None: if params.query_text is not None:
order_by = sqlalchemy.desc(sqlalchemy.text("rank")) order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
else: else:
order_by = sqlalchemy.desc(TranscriptModel.created_at) order_by = sqlalchemy.desc(transcripts.c.created_at)
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset) query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
result = await session.execute(query) rs = await get_database().fetch_all(query)
rs = result.mappings().all()
count_query = sqlalchemy.select(sqlalchemy.func.count()).select_from( count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
base_query.alias("search_results") base_query.alias("search_results")
) )
count_result = await session.execute(count_query) total = await get_database().fetch_val(count_query)
total = count_result.scalar()
def _process_result(r: dict) -> SearchResult: def _process_result(r: DbRecord) -> SearchResult:
r_dict: Dict[str, Any] = dict(r) r_dict: Dict[str, Any] = dict(r)
webvtt_raw: str | None = r_dict.pop("webvtt", None) webvtt_raw: str | None = r_dict.pop("webvtt", None)

View File

@@ -7,18 +7,21 @@ from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
import sqlalchemy
from fastapi import HTTPException from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field, field_serializer from pydantic import BaseModel, ConfigDict, Field, field_serializer
from sqlalchemy import delete, insert, select, update from sqlalchemy import Enum
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.sql import or_ from sqlalchemy.sql import false, or_
from reflector.db.base import RoomModel, TranscriptModel from reflector.db import get_database, metadata
from reflector.db.recordings import recordings_controller from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms
from reflector.db.utils import is_postgresql
from reflector.logger import logger from reflector.logger import logger
from reflector.processors.types import Word as ProcessorWord from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import get_recordings_storage, get_transcripts_storage from reflector.storage import get_transcripts_storage
from reflector.utils import generate_uuid4 from reflector.utils import generate_uuid4
from reflector.utils.webvtt import topics_to_webvtt from reflector.utils.webvtt import topics_to_webvtt
@@ -29,6 +32,91 @@ class SourceKind(enum.StrEnum):
FILE = enum.auto() FILE = enum.auto()
transcripts = sqlalchemy.Table(
"transcript",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String),
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Float),
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
sqlalchemy.Column("title", sqlalchemy.String),
sqlalchemy.Column("short_summary", sqlalchemy.String),
sqlalchemy.Column("long_summary", sqlalchemy.String),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("participants", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String),
sqlalchemy.Column("target_language", sqlalchemy.String),
sqlalchemy.Column(
"reviewed", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column(
"audio_location",
sqlalchemy.String,
nullable=False,
server_default="local",
),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
sqlalchemy.Column(
"share_mode",
sqlalchemy.String,
nullable=False,
server_default="private",
),
sqlalchemy.Column(
"meeting_id",
sqlalchemy.String,
),
sqlalchemy.Column("recording_id", sqlalchemy.String),
sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer),
sqlalchemy.Column(
"source_kind",
Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]),
nullable=False,
),
# indicative field: whether associated audio is deleted
# the main "audio deleted" is the presence of the audio itself / consents not-given
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
sqlalchemy.Column("room_id", sqlalchemy.String),
sqlalchemy.Column("webvtt", sqlalchemy.Text),
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
)
# Add PostgreSQL-specific full-text search column
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
if is_postgresql():
transcripts.append_column(
sqlalchemy.Column(
"search_vector_en",
TSVECTOR,
sqlalchemy.Computed(
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
persisted=True,
),
)
)
# Add GIN index for the search vector
transcripts.append_constraint(
sqlalchemy.Index(
"idx_transcript_search_vector_en",
"search_vector_en",
postgresql_using="gin",
)
)
def generate_transcript_name() -> str: def generate_transcript_name() -> str:
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}" return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
@@ -98,13 +186,12 @@ class TranscriptParticipant(BaseModel):
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
speaker: int | None speaker: int | None
name: str name: str
user_id: str | None = None
class Transcript(BaseModel): class Transcript(BaseModel):
"""Full transcript model with all fields.""" """Full transcript model with all fields."""
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4) id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name) name: str = Field(default_factory=generate_transcript_name)
@@ -273,7 +360,6 @@ class Transcript(BaseModel):
class TranscriptController: class TranscriptController:
async def get_all( async def get_all(
self, self,
session: AsyncSession,
user_id: str | None = None, user_id: str | None = None,
order_by: str | None = None, order_by: str | None = None,
filter_empty: bool | None = False, filter_empty: bool | None = False,
@@ -298,114 +384,102 @@ class TranscriptController:
- `search_term`: filter transcripts by search term - `search_term`: filter transcripts by search term
""" """
query = select(TranscriptModel).join( query = transcripts.select().join(
RoomModel, TranscriptModel.room_id == RoomModel.id, isouter=True rooms, transcripts.c.room_id == rooms.c.id, isouter=True
) )
if user_id: if user_id:
query = query.where( query = query.where(
or_(TranscriptModel.user_id == user_id, RoomModel.is_shared) or_(transcripts.c.user_id == user_id, rooms.c.is_shared)
) )
else: else:
query = query.where(RoomModel.is_shared) query = query.where(rooms.c.is_shared)
if source_kind: if source_kind:
query = query.where(TranscriptModel.source_kind == source_kind) query = query.where(transcripts.c.source_kind == source_kind)
if room_id: if room_id:
query = query.where(TranscriptModel.room_id == room_id) query = query.where(transcripts.c.room_id == room_id)
if search_term: if search_term:
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%")) query = query.where(transcripts.c.title.ilike(f"%{search_term}%"))
# Exclude heavy JSON columns from list queries # Exclude heavy JSON columns from list queries
# Get all ORM column attributes except excluded ones
transcript_columns = [ transcript_columns = [
getattr(TranscriptModel, col.name) col for col in transcripts.c if col.name not in exclude_columns
for col in TranscriptModel.__table__.c
if col.name not in exclude_columns
] ]
query = query.with_only_columns( query = query.with_only_columns(
*transcript_columns, transcript_columns
RoomModel.name.label("room_name"), + [
rooms.c.name.label("room_name"),
]
) )
if order_by is not None: if order_by is not None:
field = getattr(TranscriptModel, order_by[1:]) field = getattr(transcripts.c, order_by[1:])
if order_by.startswith("-"): if order_by.startswith("-"):
field = field.desc() field = field.desc()
query = query.order_by(field) query = query.order_by(field)
if filter_empty: if filter_empty:
query = query.filter(TranscriptModel.status != "idle") query = query.filter(transcripts.c.status != "idle")
if filter_recording: if filter_recording:
query = query.filter(TranscriptModel.status != "recording") query = query.filter(transcripts.c.status != "recording")
# print(query.compile(compile_kwargs={"literal_binds": True})) # print(query.compile(compile_kwargs={"literal_binds": True}))
if return_query: if return_query:
return query return query
result = await session.execute(query) results = await get_database().fetch_all(query)
return [dict(row) for row in result.mappings().all()] return results
async def get_by_id( async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
self, session: AsyncSession, transcript_id: str, **kwargs
) -> Transcript | None:
""" """
Get a transcript by id Get a transcript by id
""" """
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id) query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs: if "user_id" in kwargs:
query = query.where(TranscriptModel.user_id == kwargs["user_id"]) query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalar_one_or_none() if not result:
if not row:
return None return None
return Transcript.model_validate(row) return Transcript(**result)
async def get_by_recording_id( async def get_by_recording_id(
self, session: AsyncSession, recording_id: str, **kwargs self, recording_id: str, **kwargs
) -> Transcript | None: ) -> Transcript | None:
""" """
Get a transcript by recording_id Get a transcript by recording_id
""" """
query = select(TranscriptModel).where( query = transcripts.select().where(transcripts.c.recording_id == recording_id)
TranscriptModel.recording_id == recording_id
)
if "user_id" in kwargs: if "user_id" in kwargs:
query = query.where(TranscriptModel.user_id == kwargs["user_id"]) query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalar_one_or_none() if not result:
if not row:
return None return None
return Transcript.model_validate(row) return Transcript(**result)
async def get_by_room_id( async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]:
self, session: AsyncSession, room_id: str, **kwargs
) -> list[Transcript]:
""" """
Get transcripts by room_id (direct access without joins) Get transcripts by room_id (direct access without joins)
""" """
query = select(TranscriptModel).where(TranscriptModel.room_id == room_id) query = transcripts.select().where(transcripts.c.room_id == room_id)
if "user_id" in kwargs: if "user_id" in kwargs:
query = query.where(TranscriptModel.user_id == kwargs["user_id"]) query = query.where(transcripts.c.user_id == kwargs["user_id"])
if "order_by" in kwargs: if "order_by" in kwargs:
order_by = kwargs["order_by"] order_by = kwargs["order_by"]
field = getattr(TranscriptModel, order_by[1:]) field = getattr(transcripts.c, order_by[1:])
if order_by.startswith("-"): if order_by.startswith("-"):
field = field.desc() field = field.desc()
query = query.order_by(field) query = query.order_by(field)
results = await session.execute(query) results = await get_database().fetch_all(query)
return [ return [Transcript(**result) for result in results]
Transcript.model_validate(dict(row)) for row in results.mappings().all()
]
async def get_by_id_for_http( async def get_by_id_for_http(
self, self,
session: AsyncSession,
transcript_id: str, transcript_id: str,
user_id: str | None, user_id: str | None,
) -> Transcript: ) -> Transcript:
@@ -418,14 +492,13 @@ class TranscriptController:
This method checks the share mode of the transcript and the user_id This method checks the share mode of the transcript and the user_id
to determine if the user can access the transcript. to determine if the user can access the transcript.
""" """
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id) query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await session.execute(query) result = await get_database().fetch_one(query)
row = result.scalar_one_or_none() if not result:
if not row:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
# if the transcript is anonymous, share mode is not checked # if the transcript is anonymous, share mode is not checked
transcript = Transcript.model_validate(row) transcript = Transcript(**result)
if transcript.user_id is None: if transcript.user_id is None:
return transcript return transcript
@@ -448,7 +521,6 @@ class TranscriptController:
async def add( async def add(
self, self,
session: AsyncSession,
name: str, name: str,
source_kind: SourceKind, source_kind: SourceKind,
source_language: str = "en", source_language: str = "en",
@@ -473,15 +545,14 @@ class TranscriptController:
meeting_id=meeting_id, meeting_id=meeting_id,
room_id=room_id, room_id=room_id,
) )
query = insert(TranscriptModel).values(**transcript.model_dump()) query = transcripts.insert().values(**transcript.model_dump())
await session.execute(query) await get_database().execute(query)
await session.commit()
return transcript return transcript
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates. # TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
# using mutate=True is discouraged # using mutate=True is discouraged
async def update( async def update(
self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False self, transcript: Transcript, values: dict, mutate=False
) -> Transcript: ) -> Transcript:
""" """
Update a transcript fields with key/values in values. Update a transcript fields with key/values in values.
@@ -490,12 +561,11 @@ class TranscriptController:
values = TranscriptController._handle_topics_update(values) values = TranscriptController._handle_topics_update(values)
query = ( query = (
update(TranscriptModel) transcripts.update()
.where(TranscriptModel.id == transcript.id) .where(transcripts.c.id == transcript.id)
.values(**values) .values(**values)
) )
await session.execute(query) await get_database().execute(query)
await session.commit()
if mutate: if mutate:
for key, value in values.items(): for key, value in values.items():
setattr(transcript, key, value) setattr(transcript, key, value)
@@ -524,14 +594,13 @@ class TranscriptController:
async def remove_by_id( async def remove_by_id(
self, self,
session: AsyncSession,
transcript_id: str, transcript_id: str,
user_id: str | None = None, user_id: str | None = None,
) -> None: ) -> None:
""" """
Remove a transcript by id Remove a transcript by id
""" """
transcript = await self.get_by_id(session, transcript_id) transcript = await self.get_by_id(transcript_id)
if not transcript: if not transcript:
return return
if user_id is not None and transcript.user_id != user_id: if user_id is not None and transcript.user_id != user_id:
@@ -551,51 +620,59 @@ class TranscriptController:
if transcript.recording_id: if transcript.recording_id:
try: try:
recording = await recordings_controller.get_by_id( recording = await recordings_controller.get_by_id(
session, transcript.recording_id transcript.recording_id
) )
if recording: if recording:
try: try:
await get_recordings_storage().delete_file(recording.object_key) await get_transcripts_storage().delete_file(
recording.object_key, bucket=recording.bucket_name
)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Failed to delete recording object from S3", "Failed to delete recording object from S3",
exc_info=e, exc_info=e,
recording_id=transcript.recording_id, recording_id=transcript.recording_id,
) )
await recordings_controller.remove_by_id( await recordings_controller.remove_by_id(transcript.recording_id)
session, transcript.recording_id
)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Failed to delete recording row", "Failed to delete recording row",
exc_info=e, exc_info=e,
recording_id=transcript.recording_id, recording_id=transcript.recording_id,
) )
query = delete(TranscriptModel).where(TranscriptModel.id == transcript_id) query = transcripts.delete().where(transcripts.c.id == transcript_id)
await session.execute(query) await get_database().execute(query)
await session.commit()
async def remove_by_recording_id(self, session: AsyncSession, recording_id: str): async def remove_by_recording_id(self, recording_id: str):
""" """
Remove a transcript by recording_id Remove a transcript by recording_id
""" """
query = delete(TranscriptModel).where( query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
TranscriptModel.recording_id == recording_id await get_database().execute(query)
)
await session.execute(query) @staticmethod
await session.commit() def user_can_mutate(transcript: Transcript, user_id: str | None) -> bool:
"""
Returns True if the given user is allowed to modify the transcript.
Policy:
- Anonymous transcripts (user_id is None) cannot be modified via API
- Only the owner (matching user_id) can modify their transcript
"""
if transcript.user_id is None:
return False
return user_id and transcript.user_id == user_id
@asynccontextmanager @asynccontextmanager
async def transaction(self, session: AsyncSession): async def transaction(self):
""" """
A context manager for database transaction A context manager for database transaction
""" """
async with session.begin(): async with get_database().transaction(isolation="serializable"):
yield yield
async def append_event( async def append_event(
self, self,
session: AsyncSession,
transcript: Transcript, transcript: Transcript,
event: str, event: str,
data: Any, data: Any,
@@ -604,12 +681,11 @@ class TranscriptController:
Append an event to a transcript Append an event to a transcript
""" """
resp = transcript.add_event(event=event, data=data) resp = transcript.add_event(event=event, data=data)
await self.update(session, transcript, {"events": transcript.events_dump()}) await self.update(transcript, {"events": transcript.events_dump()})
return resp return resp
async def upsert_topic( async def upsert_topic(
self, self,
session: AsyncSession,
transcript: Transcript, transcript: Transcript,
topic: TranscriptTopic, topic: TranscriptTopic,
) -> TranscriptEvent: ) -> TranscriptEvent:
@@ -617,9 +693,9 @@ class TranscriptController:
Upsert topics to a transcript Upsert topics to a transcript
""" """
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await self.update(session, transcript, {"topics": transcript.topics_dump()}) await self.update(transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, session: AsyncSession, transcript: Transcript): async def move_mp3_to_storage(self, transcript: Transcript):
""" """
Move mp3 file to storage Move mp3 file to storage
""" """
@@ -643,28 +719,25 @@ class TranscriptController:
# indicate on the transcript that the audio is now on storage # indicate on the transcript that the audio is now on storage
# mutates transcript argument # mutates transcript argument
await self.update( await self.update(transcript, {"audio_location": "storage"}, mutate=True)
session, transcript, {"audio_location": "storage"}, mutate=True
)
# unlink the local file # unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True) transcript.audio_mp3_filename.unlink(missing_ok=True)
async def download_mp3_from_storage( async def download_mp3_from_storage(self, transcript: Transcript):
self, session: AsyncSession, transcript: Transcript
):
""" """
Download audio from storage Download audio from storage
""" """
transcript.audio_mp3_filename.write_bytes( storage = get_transcripts_storage()
await get_transcripts_storage().get_file( try:
transcript.storage_audio_path, with open(transcript.audio_mp3_filename, "wb") as f:
) await storage.stream_to_fileobj(transcript.storage_audio_path, f)
) except Exception:
transcript.audio_mp3_filename.unlink(missing_ok=True)
raise
async def upsert_participant( async def upsert_participant(
self, self,
session: AsyncSession,
transcript: Transcript, transcript: Transcript,
participant: TranscriptParticipant, participant: TranscriptParticipant,
) -> TranscriptParticipant: ) -> TranscriptParticipant:
@@ -672,14 +745,11 @@ class TranscriptController:
Add/update a participant to a transcript Add/update a participant to a transcript
""" """
result = transcript.upsert_participant(participant) result = transcript.upsert_participant(participant)
await self.update( await self.update(transcript, {"participants": transcript.participants_dump()})
session, transcript, {"participants": transcript.participants_dump()}
)
return result return result
async def delete_participant( async def delete_participant(
self, self,
session: AsyncSession,
transcript: Transcript, transcript: Transcript,
participant_id: str, participant_id: str,
): ):
@@ -687,31 +757,28 @@ class TranscriptController:
Delete a participant from a transcript Delete a participant from a transcript
""" """
transcript.delete_participant(participant_id) transcript.delete_participant(participant_id)
await self.update( await self.update(transcript, {"participants": transcript.participants_dump()})
session, transcript, {"participants": transcript.participants_dump()}
)
async def set_status( async def set_status(
self, session: AsyncSession, transcript_id: str, status: TranscriptStatus self, transcript_id: str, status: TranscriptStatus
) -> TranscriptEvent | None: ) -> TranscriptEvent | None:
""" """
Update the status of a transcript Update the status of a transcript
Will add an event STATUS + update the status field of transcript Will add an event STATUS + update the status field of transcript
""" """
async with self.transaction(session): async with self.transaction():
transcript = await self.get_by_id(session, transcript_id) transcript = await self.get_by_id(transcript_id)
if not transcript: if not transcript:
raise Exception(f"Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")
if transcript.status == status: if transcript.status == status:
return return
resp = await self.append_event( resp = await self.append_event(
session,
transcript=transcript, transcript=transcript,
event="STATUS", event="STATUS",
data=StrValue(value=status), data=StrValue(value=status),
) )
await self.update(session, transcript, {"status": status}) await self.update(transcript, {"status": status})
return resp return resp

View File

@@ -0,0 +1,91 @@
import hmac
import secrets
from datetime import datetime, timezone
from hashlib import sha256
import sqlalchemy
from pydantic import BaseModel, Field
from reflector.db import get_database, metadata
from reflector.settings import settings
from reflector.utils import generate_uuid4
from reflector.utils.string import NonEmptyString
user_api_keys = sqlalchemy.Table(
"user_api_key",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
sqlalchemy.Column("key_hash", sqlalchemy.String, nullable=False),
sqlalchemy.Column("name", sqlalchemy.String, nullable=True),
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
sqlalchemy.Index("idx_user_api_key_hash", "key_hash", unique=True),
sqlalchemy.Index("idx_user_api_key_user_id", "user_id"),
)
class UserApiKey(BaseModel):
id: NonEmptyString = Field(default_factory=generate_uuid4)
user_id: NonEmptyString
key_hash: NonEmptyString
name: NonEmptyString | None = None
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class UserApiKeyController:
@staticmethod
def generate_key() -> NonEmptyString:
return secrets.token_urlsafe(48)
@staticmethod
def hash_key(key: NonEmptyString) -> str:
return hmac.new(
settings.SECRET_KEY.encode(), key.encode(), digestmod=sha256
).hexdigest()
@classmethod
async def create_key(
cls,
user_id: NonEmptyString,
name: NonEmptyString | None = None,
) -> tuple[UserApiKey, NonEmptyString]:
plaintext = cls.generate_key()
api_key = UserApiKey(
user_id=user_id,
key_hash=cls.hash_key(plaintext),
name=name,
)
query = user_api_keys.insert().values(**api_key.model_dump())
await get_database().execute(query)
return api_key, plaintext
@classmethod
async def verify_key(cls, plaintext_key: NonEmptyString) -> UserApiKey | None:
key_hash = cls.hash_key(plaintext_key)
query = user_api_keys.select().where(
user_api_keys.c.key_hash == key_hash,
)
result = await get_database().fetch_one(query)
return UserApiKey(**result) if result else None
@staticmethod
async def list_by_user_id(user_id: NonEmptyString) -> list[UserApiKey]:
query = (
user_api_keys.select()
.where(user_api_keys.c.user_id == user_id)
.order_by(user_api_keys.c.created_at.desc())
)
results = await get_database().fetch_all(query)
return [UserApiKey(**r) for r in results]
@staticmethod
async def delete_key(key_id: NonEmptyString, user_id: NonEmptyString) -> bool:
query = user_api_keys.delete().where(
(user_api_keys.c.id == key_id) & (user_api_keys.c.user_id == user_id)
)
result = await get_database().execute(query)
# asyncpg returns None for DELETE, consider it success if no exception
return result is None or result > 0
user_api_keys_controller = UserApiKeyController()

View File

@@ -0,0 +1,9 @@
"""Database utility functions."""
from reflector.db import get_database
def is_postgresql() -> bool:
return get_database().url.scheme and get_database().url.scheme.startswith(
"postgresql"
)

View File

@@ -0,0 +1 @@
"""Pipeline modules for audio processing."""

View File

@@ -13,10 +13,8 @@ from pathlib import Path
import av import av
import structlog import structlog
from celery import chain, shared_task from celery import chain, shared_task
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import ( from reflector.db.transcripts import (
SourceKind, SourceKind,
@@ -25,23 +23,18 @@ from reflector.db.transcripts import (
transcripts_controller, transcripts_controller,
) )
from reflector.logger import logger from reflector.logger import logger
from reflector.pipelines import topic_processing
from reflector.pipelines.main_live_pipeline import ( from reflector.pipelines.main_live_pipeline import (
PipelineMainBase, PipelineMainBase,
broadcast_to_sockets, broadcast_to_sockets,
task_cleanup_consent, task_cleanup_consent,
task_pipeline_post_to_zulip, task_pipeline_post_to_zulip,
) )
from reflector.processors import ( from reflector.pipelines.transcription_helpers import transcribe_file_with_processor
AudioFileWriterProcessor, from reflector.processors import AudioFileWriterProcessor
TranscriptFinalSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptTopicDetectorProcessor,
)
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.file_diarization import FileDiarizationInput from reflector.processors.file_diarization import FileDiarizationInput
from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor from reflector.processors.file_diarization_auto import FileDiarizationAutoProcessor
from reflector.processors.file_transcript import FileTranscriptInput
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
from reflector.processors.transcript_diarization_assembler import ( from reflector.processors.transcript_diarization_assembler import (
TranscriptDiarizationAssemblerInput, TranscriptDiarizationAssemblerInput,
TranscriptDiarizationAssemblerProcessor, TranscriptDiarizationAssemblerProcessor,
@@ -55,23 +48,9 @@ from reflector.processors.types import (
) )
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import get_transcripts_storage from reflector.storage import get_transcripts_storage
from reflector.worker.session_decorator import with_session
from reflector.worker.webhook import send_transcript_webhook from reflector.worker.webhook import send_transcript_webhook
class EmptyPipeline:
"""Empty pipeline for processors that need a pipeline reference"""
def __init__(self, logger: structlog.BoundLogger):
self.logger = logger
def get_pref(self, k, d=None):
return d
async def emit(self, event):
pass
class PipelineMainFile(PipelineMainBase): class PipelineMainFile(PipelineMainBase):
""" """
Optimized file processing pipeline. Optimized file processing pipeline.
@@ -84,7 +63,7 @@ class PipelineMainFile(PipelineMainBase):
def __init__(self, transcript_id: str): def __init__(self, transcript_id: str):
super().__init__(transcript_id=transcript_id) super().__init__(transcript_id=transcript_id)
self.logger = logger.bind(transcript_id=self.transcript_id) self.logger = logger.bind(transcript_id=self.transcript_id)
self.empty_pipeline = EmptyPipeline(logger=self.logger) self.empty_pipeline = topic_processing.EmptyPipeline(logger=self.logger)
def _handle_gather_exceptions(self, results: list, operation: str) -> None: def _handle_gather_exceptions(self, results: list, operation: str) -> None:
"""Handle exceptions from asyncio.gather with return_exceptions=True""" """Handle exceptions from asyncio.gather with return_exceptions=True"""
@@ -100,23 +79,17 @@ class PipelineMainFile(PipelineMainBase):
@broadcast_to_sockets @broadcast_to_sockets
async def set_status(self, transcript_id: str, status: TranscriptStatus): async def set_status(self, transcript_id: str, status: TranscriptStatus):
async with self.lock_transaction(): async with self.lock_transaction():
async with get_session_factory()() as session: return await transcripts_controller.set_status(transcript_id, status)
return await transcripts_controller.set_status(
session, transcript_id, status
)
async def process(self, file_path: Path): async def process(self, file_path: Path):
"""Main entry point for file processing""" """Main entry point for file processing"""
self.logger.info(f"Starting file pipeline for {file_path}") self.logger.info(f"Starting file pipeline for {file_path}")
async with get_session_factory()() as session: transcript = await self.get_transcript()
transcript = await transcripts_controller.get_by_id(
session, self.transcript_id
)
# Clear transcript as we're going to regenerate everything # Clear transcript as we're going to regenerate everything
async with self.transaction():
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"events": [], "events": [],
@@ -132,7 +105,6 @@ class PipelineMainFile(PipelineMainBase):
# Run parallel processing # Run parallel processing
await self.run_parallel_processing( await self.run_parallel_processing(
session,
audio_path, audio_path,
audio_url, audio_url,
transcript.source_language, transcript.source_language,
@@ -141,8 +113,7 @@ class PipelineMainFile(PipelineMainBase):
self.logger.info("File pipeline complete") self.logger.info("File pipeline complete")
async with get_session_factory()() as session: await self.set_status(transcript.id, "ended")
await transcripts_controller.set_status(session, transcript.id, "ended")
async def extract_and_write_audio( async def extract_and_write_audio(
self, file_path: Path, transcript: Transcript self, file_path: Path, transcript: Transcript
@@ -204,7 +175,6 @@ class PipelineMainFile(PipelineMainBase):
async def run_parallel_processing( async def run_parallel_processing(
self, self,
session,
audio_path: Path, audio_path: Path,
audio_url: str, audio_url: str,
source_language: str, source_language: str,
@@ -218,7 +188,7 @@ class PipelineMainFile(PipelineMainBase):
# Phase 1: Parallel processing of independent tasks # Phase 1: Parallel processing of independent tasks
transcription_task = self.transcribe_file(audio_url, source_language) transcription_task = self.transcribe_file(audio_url, source_language)
diarization_task = self.diarize_file(audio_url) diarization_task = self.diarize_file(audio_url)
waveform_task = self.generate_waveform(session, audio_path) waveform_task = self.generate_waveform(audio_path)
results = await asyncio.gather( results = await asyncio.gather(
transcription_task, diarization_task, waveform_task, return_exceptions=True transcription_task, diarization_task, waveform_task, return_exceptions=True
@@ -266,7 +236,7 @@ class PipelineMainFile(PipelineMainBase):
) )
results = await asyncio.gather( results = await asyncio.gather(
self.generate_title(topics), self.generate_title(topics),
self.generate_summaries(session, topics), self.generate_summaries(topics),
return_exceptions=True, return_exceptions=True,
) )
@@ -274,24 +244,7 @@ class PipelineMainFile(PipelineMainBase):
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType: async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
"""Transcribe complete file""" """Transcribe complete file"""
processor = FileTranscriptAutoProcessor() return await transcribe_file_with_processor(audio_url, language)
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
# Store result for retrieval
result: TranscriptType | None = None
async def capture_result(transcript):
nonlocal result
result = transcript
processor.on(capture_result)
await processor.push(input_data)
await processor.flush()
if not result:
raise ValueError("No transcript captured")
return result
async def diarize_file(self, audio_url: str) -> list[DiarizationSegment] | None: async def diarize_file(self, audio_url: str) -> list[DiarizationSegment] | None:
"""Get diarization for file""" """Get diarization for file"""
@@ -318,9 +271,9 @@ class PipelineMainFile(PipelineMainBase):
self.logger.error(f"Diarization failed: {e}") self.logger.error(f"Diarization failed: {e}")
return None return None
async def generate_waveform(self, session: AsyncSession, audio_path: Path): async def generate_waveform(self, audio_path: Path):
"""Generate and save waveform""" """Generate and save waveform"""
transcript = await transcripts_controller.get_by_id(session, self.transcript_id) transcript = await self.get_transcript()
processor = AudioWaveformProcessor( processor = AudioWaveformProcessor(
audio_path=audio_path, audio_path=audio_path,
@@ -334,76 +287,43 @@ class PipelineMainFile(PipelineMainBase):
async def detect_topics( async def detect_topics(
self, transcript: TranscriptType, target_language: str self, transcript: TranscriptType, target_language: str
) -> list[TitleSummary]: ) -> list[TitleSummary]:
"""Detect topics from complete transcript""" return await topic_processing.detect_topics(
chunk_size = 300 transcript,
topics: list[TitleSummary] = [] target_language,
on_topic_callback=self.on_topic,
async def on_topic(topic: TitleSummary): empty_pipeline=self.empty_pipeline,
topics.append(topic) )
return await self.on_topic(topic)
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
topic_detector.set_pipeline(self.empty_pipeline)
for i in range(0, len(transcript.words), chunk_size):
chunk_words = transcript.words[i : i + chunk_size]
if not chunk_words:
continue
chunk_transcript = TranscriptType(
words=chunk_words, translation=transcript.translation
)
await topic_detector.push(chunk_transcript)
await topic_detector.flush()
return topics
async def generate_title(self, topics: list[TitleSummary]): async def generate_title(self, topics: list[TitleSummary]):
"""Generate title from topics""" return await topic_processing.generate_title(
if not topics: topics,
self.logger.warning("No topics for title generation") on_title_callback=self.on_title,
return empty_pipeline=self.empty_pipeline,
logger=self.logger,
processor = TranscriptFinalTitleProcessor(callback=self.on_title)
processor.set_pipeline(self.empty_pipeline)
for topic in topics:
await processor.push(topic)
await processor.flush()
async def generate_summaries(self, session, topics: list[TitleSummary]):
"""Generate long and short summaries from topics"""
if not topics:
self.logger.warning("No topics for summary generation")
return
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
processor = TranscriptFinalSummaryProcessor(
transcript=transcript,
callback=self.on_long_summary,
on_short_summary=self.on_short_summary,
) )
processor.set_pipeline(self.empty_pipeline)
for topic in topics: async def generate_summaries(self, topics: list[TitleSummary]):
await processor.push(topic) transcript = await self.get_transcript()
return await topic_processing.generate_summaries(
await processor.flush() topics,
transcript,
on_long_summary_callback=self.on_long_summary,
on_short_summary_callback=self.on_short_summary,
empty_pipeline=self.empty_pipeline,
logger=self.logger,
)
@shared_task @shared_task
@asynctask @asynctask
@with_session async def task_send_webhook_if_needed(*, transcript_id: str):
async def task_send_webhook_if_needed(session, *, transcript_id: str):
"""Send webhook if this is a room recording with webhook configured""" """Send webhook if this is a room recording with webhook configured"""
transcript = await transcripts_controller.get_by_id(session, transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript: if not transcript:
return return
if transcript.source_kind == SourceKind.ROOM and transcript.room_id: if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
room = await rooms_controller.get_by_id(session, transcript.room_id) room = await rooms_controller.get_by_id(transcript.room_id)
if room and room.webhook_url: if room and room.webhook_url:
logger.info( logger.info(
"Dispatching webhook", "Dispatching webhook",
@@ -418,10 +338,10 @@ async def task_send_webhook_if_needed(session, *, transcript_id: str):
@shared_task @shared_task
@asynctask @asynctask
@with_session async def task_pipeline_file_process(*, transcript_id: str):
async def task_pipeline_file_process(session, *, transcript_id: str):
"""Celery task for file pipeline processing""" """Celery task for file pipeline processing"""
transcript = await transcripts_controller.get_by_id(session, transcript_id)
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript: if not transcript:
raise Exception(f"Transcript {transcript_id} not found") raise Exception(f"Transcript {transcript_id} not found")
@@ -439,7 +359,12 @@ async def task_pipeline_file_process(session, *, transcript_id: str):
await pipeline.process(audio_file) await pipeline.process(audio_file)
except Exception: except Exception as e:
logger.error(
f"File pipeline failed for transcript {transcript_id}: {type(e).__name__}: {str(e)}",
exc_info=True,
transcript_id=transcript_id,
)
await pipeline.set_status(transcript_id, "error") await pipeline.set_status(transcript_id, "error")
raise raise

View File

@@ -17,14 +17,11 @@ from contextlib import asynccontextmanager
from typing import Generic from typing import Generic
import av import av
import boto3
from celery import chord, current_task, group, shared_task from celery import chord, current_task, group, shared_task
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from structlog import BoundLogger as Logger from structlog import BoundLogger as Logger
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db import get_session_factory
from reflector.db.meetings import meeting_consent_controller, meetings_controller from reflector.db.meetings import meeting_consent_controller, meetings_controller
from reflector.db.recordings import recordings_controller from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
@@ -64,7 +61,6 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import get_transcripts_storage from reflector.storage import get_transcripts_storage
from reflector.worker.session_decorator import with_session_and_transcript
from reflector.ws_manager import WebsocketManager, get_ws_manager from reflector.ws_manager import WebsocketManager, get_ws_manager
from reflector.zulip import ( from reflector.zulip import (
get_zulip_message, get_zulip_message,
@@ -88,6 +84,20 @@ def broadcast_to_sockets(func):
message=resp.model_dump(mode="json"), message=resp.model_dump(mode="json"),
) )
transcript = await transcripts_controller.get_by_id(self.transcript_id)
if transcript and transcript.user_id:
# Emit only relevant events to the user room to avoid noisy updates.
# Allowed: STATUS, FINAL_TITLE, DURATION. All are prefixed with TRANSCRIPT_
allowed_user_events = {"STATUS", "FINAL_TITLE", "DURATION"}
if resp.event in allowed_user_events:
await self.ws_manager.send_json(
room_id=f"user:{transcript.user_id}",
message={
"event": f"TRANSCRIPT_{resp.event}",
"data": {"id": self.transcript_id, **resp.data},
},
)
return wrapper return wrapper
@@ -99,10 +109,9 @@ def get_transcript(func):
@functools.wraps(func) @functools.wraps(func)
async def wrapper(**kwargs): async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id") transcript_id = kwargs.pop("transcript_id")
async with get_session_factory()() as session: transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript: if not transcript:
raise Exception(f"Transcript {transcript_id} not found") raise Exception("Transcript {transcript_id} not found")
# Enhanced logger with Celery task context # Enhanced logger with Celery task context
tlogger = logger.bind(transcript_id=transcript.id) tlogger = logger.bind(transcript_id=transcript.id)
@@ -143,9 +152,11 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
self._ws_manager = get_ws_manager() self._ws_manager = get_ws_manager()
return self._ws_manager return self._ws_manager
async def get_transcript(self, session: AsyncSession) -> Transcript: async def get_transcript(self) -> Transcript:
# fetch the transcript # fetch the transcript
result = await transcripts_controller.get_by_id(session, self.transcript_id) result = await transcripts_controller.get_by_id(
transcript_id=self.transcript_id
)
if not result: if not result:
raise Exception("Transcript not found") raise Exception("Transcript not found")
return result return result
@@ -177,8 +188,8 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@asynccontextmanager @asynccontextmanager
async def transaction(self): async def transaction(self):
async with self.lock_transaction(): async with self.lock_transaction():
async with get_session_factory()() as session: async with transcripts_controller.transaction():
yield session yield
@broadcast_to_sockets @broadcast_to_sockets
async def on_status(self, status): async def on_status(self, status):
@@ -209,17 +220,13 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
# when the status of the pipeline changes, update the transcript # when the status of the pipeline changes, update the transcript
async with self._lock: async with self._lock:
async with get_session_factory()() as session: return await transcripts_controller.set_status(self.transcript_id, status)
return await transcripts_controller.set_status(
session, self.transcript_id, status
)
@broadcast_to_sockets @broadcast_to_sockets
async def on_transcript(self, data): async def on_transcript(self, data):
async with self.transaction() as session: async with self.transaction():
transcript = await self.get_transcript(session) transcript = await self.get_transcript()
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="TRANSCRIPT", event="TRANSCRIPT",
data=TranscriptText(text=data.text, translation=data.translation), data=TranscriptText(text=data.text, translation=data.translation),
@@ -236,11 +243,10 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
) )
if isinstance(data, TitleSummaryWithIdProcessorType): if isinstance(data, TitleSummaryWithIdProcessorType):
topic.id = data.id topic.id = data.id
async with self.transaction() as session: async with self.transaction():
transcript = await self.get_transcript(session) transcript = await self.get_transcript()
await transcripts_controller.upsert_topic(session, transcript, topic) await transcripts_controller.upsert_topic(transcript, topic)
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="TOPIC", event="TOPIC",
data=topic, data=topic,
@@ -249,18 +255,16 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_title(self, data): async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title) final_title = TranscriptFinalTitle(title=data.title)
async with self.transaction() as session: async with self.transaction():
transcript = await self.get_transcript(session) transcript = await self.get_transcript()
if not transcript.title: if not transcript.title:
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"title": final_title.title, "title": final_title.title,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="FINAL_TITLE", event="FINAL_TITLE",
data=final_title, data=final_title,
@@ -269,17 +273,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_long_summary(self, data): async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary) final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.transaction() as session: async with self.transaction():
transcript = await self.get_transcript(session) transcript = await self.get_transcript()
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"long_summary": final_long_summary.long_summary, "long_summary": final_long_summary.long_summary,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="FINAL_LONG_SUMMARY", event="FINAL_LONG_SUMMARY",
data=final_long_summary, data=final_long_summary,
@@ -290,17 +292,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
final_short_summary = TranscriptFinalShortSummary( final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary short_summary=data.short_summary
) )
async with self.transaction() as session: async with self.transaction():
transcript = await self.get_transcript(session) transcript = await self.get_transcript()
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"short_summary": final_short_summary.short_summary, "short_summary": final_short_summary.short_summary,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session,
transcript=transcript, transcript=transcript,
event="FINAL_SHORT_SUMMARY", event="FINAL_SHORT_SUMMARY",
data=final_short_summary, data=final_short_summary,
@@ -308,30 +308,29 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets @broadcast_to_sockets
async def on_duration(self, data): async def on_duration(self, data):
async with self.transaction() as session: async with self.transaction():
duration = TranscriptDuration(duration=data) duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript(session) transcript = await self.get_transcript()
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"duration": duration.duration, "duration": duration.duration,
}, },
) )
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session, transcript=transcript, event="DURATION", data=duration transcript=transcript, event="DURATION", data=duration
) )
@broadcast_to_sockets @broadcast_to_sockets
async def on_waveform(self, data): async def on_waveform(self, data):
async with self.transaction() as session: async with self.transaction():
waveform = TranscriptWaveform(waveform=data) waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript(session) transcript = await self.get_transcript()
return await transcripts_controller.append_event( return await transcripts_controller.append_event(
session, transcript=transcript, event="WAVEFORM", data=waveform transcript=transcript, event="WAVEFORM", data=waveform
) )
@@ -344,8 +343,7 @@ class PipelineMainLive(PipelineMainBase):
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# create a context for the whole rtc transaction # create a context for the whole rtc transaction
# add a customised logger to the context # add a customised logger to the context
async with get_session_factory()() as session: transcript = await self.get_transcript()
transcript = await self.get_transcript(session)
processors = [ processors = [
AudioFileWriterProcessor( AudioFileWriterProcessor(
@@ -393,8 +391,7 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
# now let's start the pipeline by pushing information to the # now let's start the pipeline by pushing information to the
# first processor diarization processor # first processor diarization processor
# XXX translation is lost when converting our data model to the processor model # XXX translation is lost when converting our data model to the processor model
async with get_session_factory()() as session: transcript = await self.get_transcript()
transcript = await self.get_transcript(session)
# diarization works only if the file is uploaded to an external storage # diarization works only if the file is uploaded to an external storage
if transcript.audio_location == "local": if transcript.audio_location == "local":
@@ -427,8 +424,7 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
async def create(self) -> Pipeline: async def create(self) -> Pipeline:
# get transcript # get transcript
async with get_session_factory()() as session: self._transcript = transcript = await self.get_transcript()
self._transcript = transcript = await self.get_transcript(session)
# create pipeline # create pipeline
processors = self.get_processors() processors = self.get_processors()
@@ -533,7 +529,8 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Convert to mp3 done") logger.info("Convert to mp3 done")
async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger): @get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
if not settings.TRANSCRIPT_STORAGE_BACKEND: if not settings.TRANSCRIPT_STORAGE_BACKEND:
logger.info("No storage backend configured, skipping mp3 upload") logger.info("No storage backend configured, skipping mp3 upload")
return return
@@ -551,7 +548,7 @@ async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
return return
# Upload to external storage and delete the file # Upload to external storage and delete the file
await transcripts_controller.move_mp3_to_storage(session, transcript) await transcripts_controller.move_mp3_to_storage(transcript)
logger.info("Upload mp3 done") logger.info("Upload mp3 done")
@@ -580,27 +577,25 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Summaries done") logger.info("Summaries done")
async def cleanup_consent(session, transcript: Transcript, logger: Logger): @get_transcript
async def cleanup_consent(transcript: Transcript, logger: Logger):
logger.info("Starting consent cleanup") logger.info("Starting consent cleanup")
consent_denied = False consent_denied = False
recording = None recording = None
meeting = None
try: try:
if transcript.recording_id: if transcript.recording_id:
recording = await recordings_controller.get_by_id( recording = await recordings_controller.get_by_id(transcript.recording_id)
session, transcript.recording_id
)
if recording and recording.meeting_id: if recording and recording.meeting_id:
meeting = await meetings_controller.get_by_id( meeting = await meetings_controller.get_by_id(recording.meeting_id)
session, recording.meeting_id
)
if meeting: if meeting:
consent_denied = await meeting_consent_controller.has_any_denial( consent_denied = await meeting_consent_controller.has_any_denial(
session, meeting.id meeting.id
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to get fetch consent: {e}", exc_info=e) logger.error(f"Failed to fetch consent: {e}", exc_info=e)
consent_denied = True raise
if not consent_denied: if not consent_denied:
logger.info("Consent approved, keeping all files") logger.info("Consent approved, keeping all files")
@@ -608,25 +603,24 @@ async def cleanup_consent(session, transcript: Transcript, logger: Logger):
logger.info("Consent denied, cleaning up all related audio files") logger.info("Consent denied, cleaning up all related audio files")
if recording and recording.bucket_name and recording.object_key: deletion_errors = []
s3_whereby = boto3.client( if recording and recording.bucket_name:
"s3", keys_to_delete = []
aws_access_key_id=settings.AWS_WHEREBY_ACCESS_KEY_ID, if recording.track_keys:
aws_secret_access_key=settings.AWS_WHEREBY_ACCESS_KEY_SECRET, keys_to_delete = recording.track_keys
) elif recording.object_key:
try: keys_to_delete = [recording.object_key]
s3_whereby.delete_object(
Bucket=recording.bucket_name, Key=recording.object_key master_storage = get_transcripts_storage()
) for key in keys_to_delete:
logger.info( try:
f"Deleted original Whereby recording: {recording.bucket_name}/{recording.object_key}" await master_storage.delete_file(key, bucket=recording.bucket_name)
) logger.info(f"Deleted recording file: {recording.bucket_name}/{key}")
except Exception as e: except Exception as e:
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e) error_msg = f"Failed to delete {key}: {e}"
logger.error(error_msg, exc_info=e)
deletion_errors.append(error_msg)
# non-transactional, files marked for deletion not actually deleted is possible
await transcripts_controller.update(session, transcript, {"audio_deleted": True})
# 2. Delete processed audio from transcript storage S3 bucket
if transcript.audio_location == "storage": if transcript.audio_location == "storage":
storage = get_transcripts_storage() storage = get_transcripts_storage()
try: try:
@@ -635,28 +629,39 @@ async def cleanup_consent(session, transcript: Transcript, logger: Logger):
f"Deleted processed audio from storage: {transcript.storage_audio_path}" f"Deleted processed audio from storage: {transcript.storage_audio_path}"
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to delete processed audio: {e}", exc_info=e) error_msg = f"Failed to delete processed audio: {e}"
logger.error(error_msg, exc_info=e)
deletion_errors.append(error_msg)
# 3. Delete local audio files
try: try:
if hasattr(transcript, "audio_mp3_filename") and transcript.audio_mp3_filename: if hasattr(transcript, "audio_mp3_filename") and transcript.audio_mp3_filename:
transcript.audio_mp3_filename.unlink(missing_ok=True) transcript.audio_mp3_filename.unlink(missing_ok=True)
if hasattr(transcript, "audio_wav_filename") and transcript.audio_wav_filename: if hasattr(transcript, "audio_wav_filename") and transcript.audio_wav_filename:
transcript.audio_wav_filename.unlink(missing_ok=True) transcript.audio_wav_filename.unlink(missing_ok=True)
except Exception as e: except Exception as e:
logger.error(f"Failed to delete local audio files: {e}", exc_info=e) error_msg = f"Failed to delete local audio files: {e}"
logger.error(error_msg, exc_info=e)
deletion_errors.append(error_msg)
logger.info("Consent cleanup done") if deletion_errors:
logger.warning(
f"Consent cleanup completed with {len(deletion_errors)} errors",
errors=deletion_errors,
)
else:
await transcripts_controller.update(transcript, {"audio_deleted": True})
logger.info("Consent cleanup done - all audio deleted")
async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger): @get_transcript
async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
logger.info("Starting post to zulip") logger.info("Starting post to zulip")
if not transcript.recording_id: if not transcript.recording_id:
logger.info("Transcript has no recording") logger.info("Transcript has no recording")
return return
recording = await recordings_controller.get_by_id(session, transcript.recording_id) recording = await recordings_controller.get_by_id(transcript.recording_id)
if not recording: if not recording:
logger.info("Recording not found") logger.info("Recording not found")
return return
@@ -665,12 +670,12 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
logger.info("Recording has no meeting") logger.info("Recording has no meeting")
return return
meeting = await meetings_controller.get_by_id(session, recording.meeting_id) meeting = await meetings_controller.get_by_id(recording.meeting_id)
if not meeting: if not meeting:
logger.info("No meeting found for this recording") logger.info("No meeting found for this recording")
return return
room = await rooms_controller.get_by_id(session, meeting.room_id) room = await rooms_controller.get_by_id(meeting.room_id)
if not room: if not room:
logger.error(f"Missing room for a meeting {meeting.id}") logger.error(f"Missing room for a meeting {meeting.id}")
return return
@@ -696,7 +701,7 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
room.zulip_stream, room.zulip_topic, message room.zulip_stream, room.zulip_topic, message
) )
await transcripts_controller.update( await transcripts_controller.update(
session, transcript, {"zulip_message_id": response["id"]} transcript, {"zulip_message_id": response["id"]}
) )
logger.info("Posted to zulip") logger.info("Posted to zulip")
@@ -727,11 +732,8 @@ async def task_pipeline_convert_to_mp3(*, transcript_id: str):
@shared_task @shared_task
@asynctask @asynctask
@with_session_and_transcript async def task_pipeline_upload_mp3(*, transcript_id: str):
async def task_pipeline_upload_mp3( await pipeline_upload_mp3(transcript_id=transcript_id)
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_upload_mp3(session, transcript=transcript, logger=logger)
@shared_task @shared_task
@@ -754,20 +756,14 @@ async def task_pipeline_final_summaries(*, transcript_id: str):
@shared_task @shared_task
@asynctask @asynctask
@with_session_and_transcript async def task_cleanup_consent(*, transcript_id: str):
async def task_cleanup_consent( await cleanup_consent(transcript_id=transcript_id)
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await cleanup_consent(session, transcript=transcript, logger=logger)
@shared_task @shared_task
@asynctask @asynctask
@with_session_and_transcript async def task_pipeline_post_to_zulip(*, transcript_id: str):
async def task_pipeline_post_to_zulip( await pipeline_post_to_zulip(transcript_id=transcript_id)
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
def pipeline_post(*, transcript_id: str): def pipeline_post(*, transcript_id: str):
@@ -799,16 +795,14 @@ def pipeline_post(*, transcript_id: str):
async def pipeline_process(transcript: Transcript, logger: Logger): async def pipeline_process(transcript: Transcript, logger: Logger):
try: try:
if transcript.audio_location == "storage": if transcript.audio_location == "storage":
async with get_session_factory()() as session: await transcripts_controller.download_mp3_from_storage(transcript)
await transcripts_controller.download_mp3_from_storage(transcript) transcript.audio_waveform_filename.unlink(missing_ok=True)
transcript.audio_waveform_filename.unlink(missing_ok=True) await transcripts_controller.update(
await transcripts_controller.update( transcript,
session, {
transcript, "topics": [],
{ },
"topics": [], )
},
)
# open audio # open audio
audio_filename = next(transcript.data_path.glob("upload.*"), None) audio_filename = next(transcript.data_path.glob("upload.*"), None)
@@ -840,14 +834,12 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
except Exception as exc: except Exception as exc:
logger.error("Pipeline error", exc_info=exc) logger.error("Pipeline error", exc_info=exc)
async with get_session_factory()() as session: await transcripts_controller.update(
await transcripts_controller.update( transcript,
session, {
transcript, "status": "error",
{ },
"status": "error", )
},
)
raise raise
logger.info("Pipeline ended") logger.info("Pipeline ended")

View File

@@ -0,0 +1,694 @@
import asyncio
import math
import tempfile
from fractions import Fraction
from pathlib import Path
import av
from av.audio.resampler import AudioResampler
from celery import chain, shared_task
from reflector.asynctask import asynctask
from reflector.db.transcripts import (
TranscriptStatus,
TranscriptWaveform,
transcripts_controller,
)
from reflector.logger import logger
from reflector.pipelines import topic_processing
from reflector.pipelines.main_file_pipeline import task_send_webhook_if_needed
from reflector.pipelines.main_live_pipeline import (
PipelineMainBase,
broadcast_to_sockets,
task_cleanup_consent,
task_pipeline_post_to_zulip,
)
from reflector.pipelines.transcription_helpers import transcribe_file_with_processor
from reflector.processors import AudioFileWriterProcessor
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import TitleSummary
from reflector.processors.types import Transcript as TranscriptType
from reflector.storage import Storage, get_transcripts_storage
from reflector.utils.string import NonEmptyString
# Audio encoding constants
OPUS_STANDARD_SAMPLE_RATE = 48000
OPUS_DEFAULT_BIT_RATE = 128000
# Storage operation constants
PRESIGNED_URL_EXPIRATION_SECONDS = 7200 # 2 hours
class PipelineMainMultitrack(PipelineMainBase):
def __init__(self, transcript_id: str):
super().__init__(transcript_id=transcript_id)
self.logger = logger.bind(transcript_id=self.transcript_id)
self.empty_pipeline = topic_processing.EmptyPipeline(logger=self.logger)
async def pad_track_for_transcription(
self,
track_url: NonEmptyString,
track_idx: int,
storage: Storage,
) -> NonEmptyString:
"""
Pad a single track with silence based on stream metadata start_time.
Downloads from S3 presigned URL, processes via PyAV using tempfile, uploads to S3.
Returns presigned URL of padded track (or original URL if no padding needed).
Memory usage:
- Pattern: fixed_overhead(2-5MB) for PyAV codec/filters
- PyAV streams input efficiently (no full download, verified)
- Output written to tempfile (disk-based, not memory)
- Upload streams from file handle (boto3 chunks, typically 5-10MB)
Daily.co raw-tracks timing - Two approaches:
CURRENT APPROACH (PyAV metadata):
The WebM stream.start_time field encodes MEETING-RELATIVE timing:
- t=0: When Daily.co recording started (first participant joined)
- start_time=8.13s: This participant's track began 8.13s after recording started
- Purpose: Enables track alignment without external manifest files
This is NOT:
- Stream-internal offset (first packet timestamp relative to stream start)
- Absolute/wall-clock time
- Recording duration
ALTERNATIVE APPROACH (filename parsing):
Daily.co filenames contain Unix timestamps (milliseconds):
Format: {recording_start_ts}-{participant_id}-cam-audio-{track_start_ts}.webm
Example: 1760988935484-52f7f48b-fbab-431f-9a50-87b9abfc8255-cam-audio-1760988935922.webm
Can calculate offset: (track_start_ts - recording_start_ts) / 1000
- Track 0: (1760988935922 - 1760988935484) / 1000 = 0.438s
- Track 1: (1760988943823 - 1760988935484) / 1000 = 8.339s
TIME DIFFERENCE: PyAV metadata vs filename timestamps differ by ~209ms:
- Track 0: filename=438ms, metadata=229ms (diff: 209ms)
- Track 1: filename=8339ms, metadata=8130ms (diff: 209ms)
Consistent delta suggests network/encoding delay. PyAV metadata is ground truth
(represents when audio stream actually started vs when file upload initiated).
Example with 2 participants:
Track A: start_time=0.2s → Joined 200ms after recording began
Track B: start_time=8.1s → Joined 8.1 seconds later
After padding:
Track A: [0.2s silence] + [speech...]
Track B: [8.1s silence] + [speech...]
Whisper transcription timestamps are now synchronized:
Track A word at 5.0s → happened at meeting t=5.0s
Track B word at 10.0s → happened at meeting t=10.0s
Merging just sorts by timestamp - no offset calculation needed.
Padding coincidentally involves re-encoding. It's important when we work with Daily.co + Whisper.
This is because Daily.co returns recordings with skipped frames e.g. when microphone muted.
Daily.co doesn't understand those frames and ignores them, causing timestamp issues in transcription.
Re-encoding restores those frames. We do padding and re-encoding together just because it's convenient and more performant:
we need padded values for mix mp3 anyways
"""
transcript = await self.get_transcript()
try:
# PyAV streams input from S3 URL efficiently (2-5MB fixed overhead for codec/filters)
with av.open(track_url) as in_container:
start_time_seconds = self._extract_stream_start_time_from_container(
in_container, track_idx
)
if start_time_seconds <= 0:
self.logger.info(
f"Track {track_idx} requires no padding (start_time={start_time_seconds}s)",
track_idx=track_idx,
)
return track_url
# Use tempfile instead of BytesIO for better memory efficiency
# Reduces peak memory usage during encoding/upload
with tempfile.NamedTemporaryFile(
suffix=".webm", delete=False
) as temp_file:
temp_path = temp_file.name
try:
self._apply_audio_padding_to_file(
in_container, temp_path, start_time_seconds, track_idx
)
storage_path = (
f"file_pipeline/{transcript.id}/tracks/padded_{track_idx}.webm"
)
# Upload using file handle for streaming
with open(temp_path, "rb") as padded_file:
await storage.put_file(storage_path, padded_file)
finally:
# Clean up temp file
Path(temp_path).unlink(missing_ok=True)
padded_url = await storage.get_file_url(
storage_path,
operation="get_object",
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
)
self.logger.info(
f"Successfully padded track {track_idx}",
track_idx=track_idx,
start_time_seconds=start_time_seconds,
padded_url=padded_url,
)
return padded_url
except Exception as e:
self.logger.error(
f"Failed to process track {track_idx}",
track_idx=track_idx,
url=track_url,
error=str(e),
exc_info=True,
)
raise Exception(
f"Track {track_idx} padding failed - transcript would have incorrect timestamps"
) from e
def _extract_stream_start_time_from_container(
self, container, track_idx: int
) -> float:
"""
Extract meeting-relative start time from WebM stream metadata.
Uses PyAV to read stream.start_time from WebM container.
More accurate than filename timestamps by ~209ms due to network/encoding delays.
"""
start_time_seconds = 0.0
try:
audio_streams = [s for s in container.streams if s.type == "audio"]
stream = audio_streams[0] if audio_streams else container.streams[0]
# 1) Try stream-level start_time (most reliable for Daily.co tracks)
if stream.start_time is not None and stream.time_base is not None:
start_time_seconds = float(stream.start_time * stream.time_base)
# 2) Fallback to container-level start_time (in av.time_base units)
if (start_time_seconds <= 0) and (container.start_time is not None):
start_time_seconds = float(container.start_time * av.time_base)
# 3) Fallback to first packet DTS in stream.time_base
if start_time_seconds <= 0:
for packet in container.demux(stream):
if packet.dts is not None:
start_time_seconds = float(packet.dts * stream.time_base)
break
except Exception as e:
self.logger.warning(
"PyAV metadata read failed; assuming 0 start_time",
track_idx=track_idx,
error=str(e),
)
start_time_seconds = 0.0
self.logger.info(
f"Track {track_idx} stream metadata: start_time={start_time_seconds:.3f}s",
track_idx=track_idx,
)
return start_time_seconds
def _apply_audio_padding_to_file(
self,
in_container,
output_path: str,
start_time_seconds: float,
track_idx: int,
) -> None:
"""Apply silence padding to audio track using PyAV filter graph, writing to file"""
delay_ms = math.floor(start_time_seconds * 1000)
self.logger.info(
f"Padding track {track_idx} with {delay_ms}ms delay using PyAV",
track_idx=track_idx,
delay_ms=delay_ms,
)
try:
with av.open(output_path, "w", format="webm") as out_container:
in_stream = next(
(s for s in in_container.streams if s.type == "audio"), None
)
if in_stream is None:
raise Exception("No audio stream in input")
out_stream = out_container.add_stream(
"libopus", rate=OPUS_STANDARD_SAMPLE_RATE
)
out_stream.bit_rate = OPUS_DEFAULT_BIT_RATE
graph = av.filter.Graph()
abuf_args = (
f"time_base=1/{OPUS_STANDARD_SAMPLE_RATE}:"
f"sample_rate={OPUS_STANDARD_SAMPLE_RATE}:"
f"sample_fmt=s16:"
f"channel_layout=stereo"
)
src = graph.add("abuffer", args=abuf_args, name="src")
aresample_f = graph.add("aresample", args="async=1", name="ares")
# adelay requires one delay value per channel separated by '|'
delays_arg = f"{delay_ms}|{delay_ms}"
adelay_f = graph.add(
"adelay", args=f"delays={delays_arg}:all=1", name="delay"
)
sink = graph.add("abuffersink", name="sink")
src.link_to(aresample_f)
aresample_f.link_to(adelay_f)
adelay_f.link_to(sink)
graph.configure()
resampler = AudioResampler(
format="s16", layout="stereo", rate=OPUS_STANDARD_SAMPLE_RATE
)
# Decode -> resample -> push through graph -> encode Opus
for frame in in_container.decode(in_stream):
out_frames = resampler.resample(frame) or []
for rframe in out_frames:
rframe.sample_rate = OPUS_STANDARD_SAMPLE_RATE
rframe.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
src.push(rframe)
while True:
try:
f_out = sink.pull()
except Exception:
break
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
for packet in out_stream.encode(f_out):
out_container.mux(packet)
src.push(None)
while True:
try:
f_out = sink.pull()
except Exception:
break
f_out.sample_rate = OPUS_STANDARD_SAMPLE_RATE
f_out.time_base = Fraction(1, OPUS_STANDARD_SAMPLE_RATE)
for packet in out_stream.encode(f_out):
out_container.mux(packet)
for packet in out_stream.encode(None):
out_container.mux(packet)
except Exception as e:
self.logger.error(
"PyAV padding failed for track",
track_idx=track_idx,
delay_ms=delay_ms,
error=str(e),
exc_info=True,
)
raise
async def mixdown_tracks(
self,
track_urls: list[str],
writer: AudioFileWriterProcessor,
offsets_seconds: list[float] | None = None,
) -> None:
"""Multi-track mixdown using PyAV filter graph (amix), reading from S3 presigned URLs"""
target_sample_rate: int | None = None
for url in track_urls:
if not url:
continue
container = None
try:
container = av.open(url)
for frame in container.decode(audio=0):
target_sample_rate = frame.sample_rate
break
except Exception:
continue
finally:
if container is not None:
container.close()
if target_sample_rate:
break
if not target_sample_rate:
self.logger.error("Mixdown failed - no decodable audio frames found")
raise Exception("Mixdown failed: No decodable audio frames in any track")
# Build PyAV filter graph:
# N abuffer (s32/stereo)
# -> optional adelay per input (for alignment)
# -> amix (s32)
# -> aformat(s16)
# -> sink
graph = av.filter.Graph()
inputs = []
valid_track_urls = [url for url in track_urls if url]
input_offsets_seconds = None
if offsets_seconds is not None:
input_offsets_seconds = [
offsets_seconds[i] for i, url in enumerate(track_urls) if url
]
for idx, url in enumerate(valid_track_urls):
args = (
f"time_base=1/{target_sample_rate}:"
f"sample_rate={target_sample_rate}:"
f"sample_fmt=s32:"
f"channel_layout=stereo"
)
in_ctx = graph.add("abuffer", args=args, name=f"in{idx}")
inputs.append(in_ctx)
if not inputs:
self.logger.error("Mixdown failed - no valid inputs for graph")
raise Exception("Mixdown failed: No valid inputs for filter graph")
mixer = graph.add("amix", args=f"inputs={len(inputs)}:normalize=0", name="mix")
fmt = graph.add(
"aformat",
args=(
f"sample_fmts=s32:channel_layouts=stereo:sample_rates={target_sample_rate}"
),
name="fmt",
)
sink = graph.add("abuffersink", name="out")
# Optional per-input delay before mixing
delays_ms: list[int] = []
if input_offsets_seconds is not None:
base = min(input_offsets_seconds) if input_offsets_seconds else 0.0
delays_ms = [
max(0, int(round((o - base) * 1000))) for o in input_offsets_seconds
]
else:
delays_ms = [0 for _ in inputs]
for idx, in_ctx in enumerate(inputs):
delay_ms = delays_ms[idx] if idx < len(delays_ms) else 0
if delay_ms > 0:
# adelay requires one value per channel; use same for stereo
adelay = graph.add(
"adelay",
args=f"delays={delay_ms}|{delay_ms}:all=1",
name=f"delay{idx}",
)
in_ctx.link_to(adelay)
adelay.link_to(mixer, 0, idx)
else:
in_ctx.link_to(mixer, 0, idx)
mixer.link_to(fmt)
fmt.link_to(sink)
graph.configure()
containers = []
try:
# Open all containers with cleanup guaranteed
for i, url in enumerate(valid_track_urls):
try:
c = av.open(url)
containers.append(c)
except Exception as e:
self.logger.warning(
"Mixdown: failed to open container from URL",
input=i,
url=url,
error=str(e),
)
if not containers:
self.logger.error("Mixdown failed - no valid containers opened")
raise Exception("Mixdown failed: Could not open any track containers")
decoders = [c.decode(audio=0) for c in containers]
active = [True] * len(decoders)
resamplers = [
AudioResampler(format="s32", layout="stereo", rate=target_sample_rate)
for _ in decoders
]
while any(active):
for i, (dec, is_active) in enumerate(zip(decoders, active)):
if not is_active:
continue
try:
frame = next(dec)
except StopIteration:
active[i] = False
continue
if frame.sample_rate != target_sample_rate:
continue
out_frames = resamplers[i].resample(frame) or []
for rf in out_frames:
rf.sample_rate = target_sample_rate
rf.time_base = Fraction(1, target_sample_rate)
inputs[i].push(rf)
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
await writer.push(mixed)
for in_ctx in inputs:
in_ctx.push(None)
while True:
try:
mixed = sink.pull()
except Exception:
break
mixed.sample_rate = target_sample_rate
mixed.time_base = Fraction(1, target_sample_rate)
await writer.push(mixed)
finally:
# Cleanup all containers, even if processing failed
for c in containers:
if c is not None:
try:
c.close()
except Exception:
pass # Best effort cleanup
@broadcast_to_sockets
async def set_status(self, transcript_id: str, status: TranscriptStatus):
async with self.lock_transaction():
return await transcripts_controller.set_status(transcript_id, status)
async def on_waveform(self, data):
async with self.transaction():
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
transcript=transcript, event="WAVEFORM", data=waveform
)
async def process(self, bucket_name: str, track_keys: list[str]):
transcript = await self.get_transcript()
async with self.transaction():
await transcripts_controller.update(
transcript,
{
"events": [],
"topics": [],
},
)
source_storage = get_transcripts_storage()
transcript_storage = source_storage
track_urls: list[str] = []
for key in track_keys:
url = await source_storage.get_file_url(
key,
operation="get_object",
expires_in=PRESIGNED_URL_EXPIRATION_SECONDS,
bucket=bucket_name,
)
track_urls.append(url)
self.logger.info(
f"Generated presigned URL for track from {bucket_name}",
key=key,
)
created_padded_files = set()
padded_track_urls: list[str] = []
for idx, url in enumerate(track_urls):
padded_url = await self.pad_track_for_transcription(
url, idx, transcript_storage
)
padded_track_urls.append(padded_url)
if padded_url != url:
storage_path = f"file_pipeline/{transcript.id}/tracks/padded_{idx}.webm"
created_padded_files.add(storage_path)
self.logger.info(f"Track {idx} processed, padded URL: {padded_url}")
transcript.data_path.mkdir(parents=True, exist_ok=True)
mp3_writer = AudioFileWriterProcessor(
path=str(transcript.audio_mp3_filename),
on_duration=self.on_duration,
)
await self.mixdown_tracks(padded_track_urls, mp3_writer, offsets_seconds=None)
await mp3_writer.flush()
if not transcript.audio_mp3_filename.exists():
raise Exception(
"Mixdown failed - no MP3 file generated. Cannot proceed without playable audio."
)
storage_path = f"{transcript.id}/audio.mp3"
# Use file handle streaming to avoid loading entire MP3 into memory
mp3_size = transcript.audio_mp3_filename.stat().st_size
with open(transcript.audio_mp3_filename, "rb") as mp3_file:
await transcript_storage.put_file(storage_path, mp3_file)
mp3_url = await transcript_storage.get_file_url(storage_path)
await transcripts_controller.update(transcript, {"audio_location": "storage"})
self.logger.info(
f"Uploaded mixed audio to storage",
storage_path=storage_path,
size=mp3_size,
url=mp3_url,
)
self.logger.info("Generating waveform from mixed audio")
waveform_processor = AudioWaveformProcessor(
audio_path=transcript.audio_mp3_filename,
waveform_path=transcript.audio_waveform_filename,
on_waveform=self.on_waveform,
)
waveform_processor.set_pipeline(self.empty_pipeline)
await waveform_processor.flush()
self.logger.info("Waveform generated successfully")
speaker_transcripts: list[TranscriptType] = []
for idx, padded_url in enumerate(padded_track_urls):
if not padded_url:
continue
t = await self.transcribe_file(padded_url, transcript.source_language)
if not t.words:
continue
for w in t.words:
w.speaker = idx
speaker_transcripts.append(t)
self.logger.info(
f"Track {idx} transcribed successfully with {len(t.words)} words",
track_idx=idx,
)
valid_track_count = len([url for url in padded_track_urls if url])
if valid_track_count > 0 and len(speaker_transcripts) != valid_track_count:
raise Exception(
f"Only {len(speaker_transcripts)}/{valid_track_count} tracks transcribed successfully. "
f"All tracks must succeed to avoid incomplete transcripts."
)
if not speaker_transcripts:
raise Exception("No valid track transcriptions")
self.logger.info(f"Cleaning up {len(created_padded_files)} temporary S3 files")
cleanup_tasks = []
for storage_path in created_padded_files:
cleanup_tasks.append(transcript_storage.delete_file(storage_path))
if cleanup_tasks:
cleanup_results = await asyncio.gather(
*cleanup_tasks, return_exceptions=True
)
for storage_path, result in zip(created_padded_files, cleanup_results):
if isinstance(result, Exception):
self.logger.warning(
"Failed to cleanup temporary padded track",
storage_path=storage_path,
error=str(result),
)
merged_words = []
for t in speaker_transcripts:
merged_words.extend(t.words)
merged_words.sort(
key=lambda w: w.start if hasattr(w, "start") and w.start is not None else 0
)
merged_transcript = TranscriptType(words=merged_words, translation=None)
await self.on_transcript(merged_transcript)
topics = await self.detect_topics(merged_transcript, transcript.target_language)
await asyncio.gather(
self.generate_title(topics),
self.generate_summaries(topics),
return_exceptions=False,
)
await self.set_status(transcript.id, "ended")
async def transcribe_file(self, audio_url: str, language: str) -> TranscriptType:
return await transcribe_file_with_processor(audio_url, language)
async def detect_topics(
self, transcript: TranscriptType, target_language: str
) -> list[TitleSummary]:
return await topic_processing.detect_topics(
transcript,
target_language,
on_topic_callback=self.on_topic,
empty_pipeline=self.empty_pipeline,
)
async def generate_title(self, topics: list[TitleSummary]):
return await topic_processing.generate_title(
topics,
on_title_callback=self.on_title,
empty_pipeline=self.empty_pipeline,
logger=self.logger,
)
async def generate_summaries(self, topics: list[TitleSummary]):
transcript = await self.get_transcript()
return await topic_processing.generate_summaries(
topics,
transcript,
on_long_summary_callback=self.on_long_summary,
on_short_summary_callback=self.on_short_summary,
empty_pipeline=self.empty_pipeline,
logger=self.logger,
)
@shared_task
@asynctask
async def task_pipeline_multitrack_process(
*, transcript_id: str, bucket_name: str, track_keys: list[str]
):
pipeline = PipelineMainMultitrack(transcript_id=transcript_id)
try:
await pipeline.set_status(transcript_id, "processing")
await pipeline.process(bucket_name, track_keys)
except Exception:
await pipeline.set_status(transcript_id, "error")
raise
post_chain = chain(
task_cleanup_consent.si(transcript_id=transcript_id),
task_pipeline_post_to_zulip.si(transcript_id=transcript_id),
task_send_webhook_if_needed.si(transcript_id=transcript_id),
)
post_chain.delay()

View File

@@ -0,0 +1,109 @@
"""
Topic processing utilities
==========================
Shared topic detection, title generation, and summarization logic
used across file and multitrack pipelines.
"""
from typing import Callable
import structlog
from reflector.db.transcripts import Transcript
from reflector.processors import (
TranscriptFinalSummaryProcessor,
TranscriptFinalTitleProcessor,
TranscriptTopicDetectorProcessor,
)
from reflector.processors.types import TitleSummary
from reflector.processors.types import Transcript as TranscriptType
class EmptyPipeline:
def __init__(self, logger: structlog.BoundLogger):
self.logger = logger
def get_pref(self, k, d=None):
return d
async def emit(self, event):
pass
async def detect_topics(
transcript: TranscriptType,
target_language: str,
*,
on_topic_callback: Callable,
empty_pipeline: EmptyPipeline,
) -> list[TitleSummary]:
chunk_size = 300
topics: list[TitleSummary] = []
async def on_topic(topic: TitleSummary):
topics.append(topic)
return await on_topic_callback(topic)
topic_detector = TranscriptTopicDetectorProcessor(callback=on_topic)
topic_detector.set_pipeline(empty_pipeline)
for i in range(0, len(transcript.words), chunk_size):
chunk_words = transcript.words[i : i + chunk_size]
if not chunk_words:
continue
chunk_transcript = TranscriptType(
words=chunk_words, translation=transcript.translation
)
await topic_detector.push(chunk_transcript)
await topic_detector.flush()
return topics
async def generate_title(
topics: list[TitleSummary],
*,
on_title_callback: Callable,
empty_pipeline: EmptyPipeline,
logger: structlog.BoundLogger,
):
if not topics:
logger.warning("No topics for title generation")
return
processor = TranscriptFinalTitleProcessor(callback=on_title_callback)
processor.set_pipeline(empty_pipeline)
for topic in topics:
await processor.push(topic)
await processor.flush()
async def generate_summaries(
topics: list[TitleSummary],
transcript: Transcript,
*,
on_long_summary_callback: Callable,
on_short_summary_callback: Callable,
empty_pipeline: EmptyPipeline,
logger: structlog.BoundLogger,
):
if not topics:
logger.warning("No topics for summary generation")
return
processor = TranscriptFinalSummaryProcessor(
transcript=transcript,
callback=on_long_summary_callback,
on_short_summary=on_short_summary_callback,
)
processor.set_pipeline(empty_pipeline)
for topic in topics:
await processor.push(topic)
await processor.flush()

View File

@@ -0,0 +1,34 @@
from reflector.processors.file_transcript import FileTranscriptInput
from reflector.processors.file_transcript_auto import FileTranscriptAutoProcessor
from reflector.processors.types import Transcript as TranscriptType
async def transcribe_file_with_processor(
audio_url: str,
language: str,
processor_name: str | None = None,
) -> TranscriptType:
processor = (
FileTranscriptAutoProcessor(name=processor_name)
if processor_name
else FileTranscriptAutoProcessor()
)
input_data = FileTranscriptInput(audio_url=audio_url, language=language)
result: TranscriptType | None = None
async def capture_result(transcript):
nonlocal result
result = transcript
processor.on(capture_result)
await processor.push(input_data)
await processor.flush()
if not result:
processor_label = processor_name or "default"
raise ValueError(
f"No transcript captured from {processor_label} processor for audio: {audio_url}"
)
return result

View File

@@ -56,6 +56,16 @@ class FileTranscriptModalProcessor(FileTranscriptProcessor):
}, },
follow_redirects=True, follow_redirects=True,
) )
if response.status_code != 200:
error_body = response.text
self.logger.error(
"Modal API error",
audio_url=data.audio_url,
status_code=response.status_code,
error_body=error_body,
)
response.raise_for_status() response.raise_for_status()
result = response.json() result = response.json()

View File

@@ -165,6 +165,7 @@ class SummaryBuilder:
self.llm: LLM = llm self.llm: LLM = llm
self.model_name: str = llm.model_name self.model_name: str = llm.model_name
self.logger = logger or structlog.get_logger() self.logger = logger or structlog.get_logger()
self.participant_instructions: str | None = None
if filename: if filename:
self.read_transcript_from_file(filename) self.read_transcript_from_file(filename)
@@ -191,14 +192,61 @@ class SummaryBuilder:
self, prompt: str, output_cls: Type[T], tone_name: str | None = None self, prompt: str, output_cls: Type[T], tone_name: str | None = None
) -> T: ) -> T:
"""Generic function to get structured output from LLM for non-function-calling models.""" """Generic function to get structured output from LLM for non-function-calling models."""
# Add participant instructions to the prompt if available
enhanced_prompt = self._enhance_prompt_with_participants(prompt)
return await self.llm.get_structured_response( return await self.llm.get_structured_response(
prompt, [self.transcript], output_cls, tone_name=tone_name enhanced_prompt, [self.transcript], output_cls, tone_name=tone_name
) )
async def _get_response(
self, prompt: str, texts: list[str], tone_name: str | None = None
) -> str:
"""Get text response with automatic participant instructions injection."""
enhanced_prompt = self._enhance_prompt_with_participants(prompt)
return await self.llm.get_response(enhanced_prompt, texts, tone_name=tone_name)
def _enhance_prompt_with_participants(self, prompt: str) -> str:
"""Add participant instructions to any prompt if participants are known."""
if self.participant_instructions:
self.logger.debug("Adding participant instructions to prompt")
return f"{prompt}\n\n{self.participant_instructions}"
return prompt
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
# Participants # Participants
# ---------------------------------------------------------------------------- # ----------------------------------------------------------------------------
def set_known_participants(self, participants: list[str]) -> None:
"""
Set known participants directly without LLM identification.
This is used when participants are already identified and stored.
They are appended at the end of the transcript, providing more context for the assistant.
"""
if not participants:
self.logger.warning("No participants provided")
return
self.logger.info(
"Using known participants",
participants=participants,
)
participants_md = self.format_list_md(participants)
self.transcript += f"\n\n# Participants\n\n{participants_md}"
# Set instructions that will be automatically added to all prompts
participants_list = ", ".join(participants)
self.participant_instructions = dedent(
f"""
# IMPORTANT: Participant Names
The following participants are identified in this conversation: {participants_list}
You MUST use these specific participant names when referring to people in your response.
Do NOT use generic terms like "a participant", "someone", "attendee", "Speaker 1", "Speaker 2", etc.
Always refer to people by their actual names (e.g., "John suggested..." not "A participant suggested...").
"""
).strip()
async def identify_participants(self) -> None: async def identify_participants(self) -> None:
""" """
From a transcript, try to identify the participants using TreeSummarize with structured output. From a transcript, try to identify the participants using TreeSummarize with structured output.
@@ -232,6 +280,19 @@ class SummaryBuilder:
if unique_participants: if unique_participants:
participants_md = self.format_list_md(unique_participants) participants_md = self.format_list_md(unique_participants)
self.transcript += f"\n\n# Participants\n\n{participants_md}" self.transcript += f"\n\n# Participants\n\n{participants_md}"
# Set instructions that will be automatically added to all prompts
participants_list = ", ".join(unique_participants)
self.participant_instructions = dedent(
f"""
# IMPORTANT: Participant Names
The following participants are identified in this conversation: {participants_list}
You MUST use these specific participant names when referring to people in your response.
Do NOT use generic terms like "a participant", "someone", "attendee", "Speaker 1", "Speaker 2", etc.
Always refer to people by their actual names (e.g., "John suggested..." not "A participant suggested...").
"""
).strip()
else: else:
self.logger.warning("No participants identified in the transcript") self.logger.warning("No participants identified in the transcript")
@@ -318,13 +379,13 @@ class SummaryBuilder:
for subject in self.subjects: for subject in self.subjects:
detailed_prompt = DETAILED_SUBJECT_PROMPT_TEMPLATE.format(subject=subject) detailed_prompt = DETAILED_SUBJECT_PROMPT_TEMPLATE.format(subject=subject)
detailed_response = await self.llm.get_response( detailed_response = await self._get_response(
detailed_prompt, [self.transcript], tone_name="Topic assistant" detailed_prompt, [self.transcript], tone_name="Topic assistant"
) )
paragraph_prompt = PARAGRAPH_SUMMARY_PROMPT paragraph_prompt = PARAGRAPH_SUMMARY_PROMPT
paragraph_response = await self.llm.get_response( paragraph_response = await self._get_response(
paragraph_prompt, [str(detailed_response)], tone_name="Topic summarizer" paragraph_prompt, [str(detailed_response)], tone_name="Topic summarizer"
) )
@@ -345,7 +406,7 @@ class SummaryBuilder:
recap_prompt = RECAP_PROMPT recap_prompt = RECAP_PROMPT
recap_response = await self.llm.get_response( recap_response = await self._get_response(
recap_prompt, [summaries_text], tone_name="Recap summarizer" recap_prompt, [summaries_text], tone_name="Recap summarizer"
) )

View File

@@ -26,7 +26,25 @@ class TranscriptFinalSummaryProcessor(Processor):
async def get_summary_builder(self, text) -> SummaryBuilder: async def get_summary_builder(self, text) -> SummaryBuilder:
builder = SummaryBuilder(self.llm, logger=self.logger) builder = SummaryBuilder(self.llm, logger=self.logger)
builder.set_transcript(text) builder.set_transcript(text)
await builder.identify_participants()
# Use known participants if available, otherwise identify them
if self.transcript and self.transcript.participants:
# Extract participant names from the stored participants
participant_names = [p.name for p in self.transcript.participants if p.name]
if participant_names:
self.logger.info(
f"Using {len(participant_names)} known participants from transcript"
)
builder.set_known_participants(participant_names)
else:
self.logger.info(
"Participants field exists but is empty, identifying participants"
)
await builder.identify_participants()
else:
self.logger.info("No participants stored, identifying participants")
await builder.identify_participants()
await builder.generate_summary() await builder.generate_summary()
return builder return builder
@@ -49,18 +67,30 @@ class TranscriptFinalSummaryProcessor(Processor):
speakermap = {} speakermap = {}
if self.transcript: if self.transcript:
speakermap = { speakermap = {
participant["speaker"]: participant["name"] p.speaker: p.name
for participant in self.transcript.participants for p in (self.transcript.participants or [])
if p.speaker is not None and p.name
} }
self.logger.info(
f"Built speaker map with {len(speakermap)} participants",
speakermap=speakermap,
)
# build the transcript as a single string # build the transcript as a single string
# XXX: unsure if the participants name as replaced directly in speaker ? # Replace speaker IDs with actual participant names if available
text_transcript = [] text_transcript = []
unique_speakers = set()
for topic in self.chunks: for topic in self.chunks:
for segment in topic.transcript.as_segments(): for segment in topic.transcript.as_segments():
name = speakermap.get(segment.speaker, f"Speaker {segment.speaker}") name = speakermap.get(segment.speaker, f"Speaker {segment.speaker}")
unique_speakers.add((segment.speaker, name))
text_transcript.append(f"{name}: {segment.text}") text_transcript.append(f"{name}: {segment.text}")
self.logger.info(
f"Built transcript with {len(unique_speakers)} unique speakers",
speakers=list(unique_speakers),
)
text_transcript = "\n".join(text_transcript) text_transcript = "\n".join(text_transcript)
last_chunk = self.chunks[-1] last_chunk = self.chunks[-1]

View File

@@ -1,6 +1,6 @@
from textwrap import dedent from textwrap import dedent
from pydantic import BaseModel, Field from pydantic import AliasChoices, BaseModel, Field
from reflector.llm import LLM from reflector.llm import LLM
from reflector.processors.base import Processor from reflector.processors.base import Processor
@@ -34,8 +34,14 @@ TOPIC_PROMPT = dedent(
class TopicResponse(BaseModel): class TopicResponse(BaseModel):
"""Structured response for topic detection""" """Structured response for topic detection"""
title: str = Field(description="A descriptive title for the topic being discussed") title: str = Field(
summary: str = Field(description="A concise 1-2 sentence summary of the discussion") description="A descriptive title for the topic being discussed",
validation_alias=AliasChoices("title", "Title"),
)
summary: str = Field(
description="A concise 1-2 sentence summary of the discussion",
validation_alias=AliasChoices("summary", "Summary"),
)
class TranscriptTopicDetectorProcessor(Processor): class TranscriptTopicDetectorProcessor(Processor):

View File

@@ -0,0 +1,5 @@
from typing import Literal
Platform = Literal["whereby", "daily"]
WHEREBY_PLATFORM: Platform = "whereby"
DAILY_PLATFORM: Platform = "daily"

View File

@@ -55,7 +55,6 @@ import httpx
import pytz import pytz
import structlog import structlog
from icalendar import Calendar, Event from icalendar import Calendar, Event
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
from reflector.db.rooms import Room, rooms_controller from reflector.db.rooms import Room, rooms_controller
@@ -295,7 +294,7 @@ class ICSSyncService:
def __init__(self): def __init__(self):
self.fetch_service = ICSFetchService() self.fetch_service = ICSFetchService()
async def sync_room_calendar(self, session: AsyncSession, room: Room) -> SyncResult: async def sync_room_calendar(self, room: Room) -> SyncResult:
async with RedisAsyncLock( async with RedisAsyncLock(
f"ics_sync_room:{room.id}", skip_if_locked=True f"ics_sync_room:{room.id}", skip_if_locked=True
) as lock: ) as lock:
@@ -306,11 +305,9 @@ class ICSSyncService:
"reason": "Sync already in progress", "reason": "Sync already in progress",
} }
return await self._sync_room_calendar(session, room) return await self._sync_room_calendar(room)
async def _sync_room_calendar( async def _sync_room_calendar(self, room: Room) -> SyncResult:
self, session: AsyncSession, room: Room
) -> SyncResult:
if not room.ics_enabled or not room.ics_url: if not room.ics_enabled or not room.ics_url:
return {"status": SyncStatus.SKIPPED, "reason": "ICS not configured"} return {"status": SyncStatus.SKIPPED, "reason": "ICS not configured"}
@@ -343,11 +340,10 @@ class ICSSyncService:
events, total_events = self.fetch_service.extract_room_events( events, total_events = self.fetch_service.extract_room_events(
calendar, room.name, room_url calendar, room.name, room_url
) )
sync_result = await self._sync_events_to_database(session, room.id, events) sync_result = await self._sync_events_to_database(room.id, events)
# Update room sync metadata # Update room sync metadata
await rooms_controller.update( await rooms_controller.update(
session,
room, room,
{ {
"ics_last_sync": datetime.now(timezone.utc), "ics_last_sync": datetime.now(timezone.utc),
@@ -376,7 +372,7 @@ class ICSSyncService:
return time_since_sync.total_seconds() >= room.ics_fetch_interval return time_since_sync.total_seconds() >= room.ics_fetch_interval
async def _sync_events_to_database( async def _sync_events_to_database(
self, session: AsyncSession, room_id: str, events: list[EventData] self, room_id: str, events: list[EventData]
) -> SyncStats: ) -> SyncStats:
created = 0 created = 0
updated = 0 updated = 0
@@ -386,7 +382,7 @@ class ICSSyncService:
for event_data in events: for event_data in events:
calendar_event = CalendarEvent(room_id=room_id, **event_data) calendar_event = CalendarEvent(room_id=room_id, **event_data)
existing = await calendar_events_controller.get_by_ics_uid( existing = await calendar_events_controller.get_by_ics_uid(
session, room_id, event_data["ics_uid"] room_id, event_data["ics_uid"]
) )
if existing: if existing:
@@ -394,12 +390,12 @@ class ICSSyncService:
else: else:
created += 1 created += 1
await calendar_events_controller.upsert(session, calendar_event) await calendar_events_controller.upsert(calendar_event)
current_ics_uids.append(event_data["ics_uid"]) current_ics_uids.append(event_data["ics_uid"])
# Soft delete events that are no longer in calendar # Soft delete events that are no longer in calendar
deleted = await calendar_events_controller.soft_delete_missing( deleted = await calendar_events_controller.soft_delete_missing(
session, room_id, current_ics_uids room_id, current_ics_uids
) )
return { return {

View File

@@ -1,6 +1,7 @@
from pydantic.types import PositiveInt from pydantic.types import PositiveInt
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from reflector.schemas.platform import WHEREBY_PLATFORM, Platform
from reflector.utils.string import NonEmptyString from reflector.utils.string import NonEmptyString
@@ -47,14 +48,17 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Recording storage # Platform-specific recording storage (follows {PREFIX}_STORAGE_AWS_{CREDENTIAL} pattern)
RECORDING_STORAGE_BACKEND: str | None = None # Whereby storage configuration
WHEREBY_STORAGE_AWS_BUCKET_NAME: str | None = None
WHEREBY_STORAGE_AWS_REGION: str | None = None
WHEREBY_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
WHEREBY_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Recording storage configuration for AWS # Daily.co storage configuration
RECORDING_STORAGE_AWS_BUCKET_NAME: str = "recording-bucket" DAILYCO_STORAGE_AWS_BUCKET_NAME: str | None = None
RECORDING_STORAGE_AWS_REGION: str = "us-east-1" DAILYCO_STORAGE_AWS_REGION: str | None = None
RECORDING_STORAGE_AWS_ACCESS_KEY_ID: str | None = None DAILYCO_STORAGE_AWS_ROLE_ARN: str | None = None
RECORDING_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Translate into the target language # Translate into the target language
TRANSLATION_BACKEND: str = "passthrough" TRANSLATION_BACKEND: str = "passthrough"
@@ -124,11 +128,20 @@ class Settings(BaseSettings):
WHEREBY_API_URL: str = "https://api.whereby.dev/v1" WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
WHEREBY_API_KEY: NonEmptyString | None = None WHEREBY_API_KEY: NonEmptyString | None = None
WHEREBY_WEBHOOK_SECRET: str | None = None WHEREBY_WEBHOOK_SECRET: str | None = None
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
SQS_POLLING_TIMEOUT_SECONDS: int = 60 SQS_POLLING_TIMEOUT_SECONDS: int = 60
# Daily.co integration
DAILY_API_KEY: str | None = None
DAILY_WEBHOOK_SECRET: str | None = None
DAILY_SUBDOMAIN: str | None = None
DAILY_WEBHOOK_UUID: str | None = (
None # Webhook UUID for this environment. Not used by production code
)
# Platform Configuration
DEFAULT_VIDEO_PLATFORM: Platform = WHEREBY_PLATFORM
# Zulip integration # Zulip integration
ZULIP_REALM: str | None = None ZULIP_REALM: str | None = None
ZULIP_API_KEY: str | None = None ZULIP_API_KEY: str | None = None

View File

@@ -3,6 +3,13 @@ from reflector.settings import settings
def get_transcripts_storage() -> Storage: def get_transcripts_storage() -> Storage:
"""
Get storage for processed transcript files (master credentials).
Also use this for ALL our file operations with bucket override:
master = get_transcripts_storage()
master.delete_file(key, bucket=recording.bucket_name)
"""
assert settings.TRANSCRIPT_STORAGE_BACKEND assert settings.TRANSCRIPT_STORAGE_BACKEND
return Storage.get_instance( return Storage.get_instance(
name=settings.TRANSCRIPT_STORAGE_BACKEND, name=settings.TRANSCRIPT_STORAGE_BACKEND,
@@ -10,8 +17,53 @@ def get_transcripts_storage() -> Storage:
) )
def get_recordings_storage() -> Storage: def get_whereby_storage() -> Storage:
"""
Get storage config for Whereby (for passing to Whereby API).
Usage:
whereby_storage = get_whereby_storage()
key_id, secret = whereby_storage.key_credentials
whereby_api.create_meeting(
bucket=whereby_storage.bucket_name,
access_key_id=key_id,
secret=secret,
)
Do NOT use for our file operations - use get_transcripts_storage() instead.
"""
if not settings.WHEREBY_STORAGE_AWS_BUCKET_NAME:
raise ValueError(
"WHEREBY_STORAGE_AWS_BUCKET_NAME required for Whereby with AWS storage"
)
return Storage.get_instance( return Storage.get_instance(
name=settings.RECORDING_STORAGE_BACKEND, name="aws",
settings_prefix="RECORDING_STORAGE_", settings_prefix="WHEREBY_STORAGE_",
)
def get_dailyco_storage() -> Storage:
"""
Get storage config for Daily.co (for passing to Daily API).
Usage:
daily_storage = get_dailyco_storage()
daily_api.create_meeting(
bucket=daily_storage.bucket_name,
region=daily_storage.region,
role_arn=daily_storage.role_credential,
)
Do NOT use for our file operations - use get_transcripts_storage() instead.
"""
# Fail fast if platform-specific config missing
if not settings.DAILYCO_STORAGE_AWS_BUCKET_NAME:
raise ValueError(
"DAILYCO_STORAGE_AWS_BUCKET_NAME required for Daily.co with AWS storage"
)
return Storage.get_instance(
name="aws",
settings_prefix="DAILYCO_STORAGE_",
) )

View File

@@ -1,10 +1,23 @@
import importlib import importlib
from typing import BinaryIO, Union
from pydantic import BaseModel from pydantic import BaseModel
from reflector.settings import settings from reflector.settings import settings
class StorageError(Exception):
"""Base exception for storage operations."""
pass
class StoragePermissionError(StorageError):
"""Exception raised when storage operation fails due to permission issues."""
pass
class FileResult(BaseModel): class FileResult(BaseModel):
filename: str filename: str
url: str url: str
@@ -36,26 +49,113 @@ class Storage:
return cls._registry[name](**config) return cls._registry[name](**config)
async def put_file(self, filename: str, data: bytes) -> FileResult: # Credential properties for API passthrough
return await self._put_file(filename, data) @property
def bucket_name(self) -> str:
async def _put_file(self, filename: str, data: bytes) -> FileResult: """Default bucket name for this storage instance."""
raise NotImplementedError raise NotImplementedError
async def delete_file(self, filename: str): @property
return await self._delete_file(filename) def region(self) -> str:
"""AWS region for this storage instance."""
async def _delete_file(self, filename: str):
raise NotImplementedError raise NotImplementedError
async def get_file_url(self, filename: str) -> str: @property
return await self._get_file_url(filename) def access_key_id(self) -> str | None:
"""AWS access key ID (None for role-based auth). Prefer key_credentials property."""
return None
async def _get_file_url(self, filename: str) -> str: @property
def secret_access_key(self) -> str | None:
"""AWS secret access key (None for role-based auth). Prefer key_credentials property."""
return None
@property
def role_arn(self) -> str | None:
"""AWS IAM role ARN for role-based auth (None for key-based auth). Prefer role_credential property."""
return None
@property
def key_credentials(self) -> tuple[str, str]:
"""
Get (access_key_id, secret_access_key) for key-based auth.
Raises ValueError if storage uses IAM role instead.
"""
raise NotImplementedError raise NotImplementedError
async def get_file(self, filename: str): @property
return await self._get_file(filename) def role_credential(self) -> str:
"""
async def _get_file(self, filename: str): Get IAM role ARN for role-based auth.
Raises ValueError if storage uses access keys instead.
"""
raise NotImplementedError
async def put_file(
self, filename: str, data: Union[bytes, BinaryIO], *, bucket: str | None = None
) -> FileResult:
"""Upload data. bucket: override instance default if provided."""
return await self._put_file(filename, data, bucket=bucket)
async def _put_file(
self, filename: str, data: Union[bytes, BinaryIO], *, bucket: str | None = None
) -> FileResult:
raise NotImplementedError
async def delete_file(self, filename: str, *, bucket: str | None = None):
"""Delete file. bucket: override instance default if provided."""
return await self._delete_file(filename, bucket=bucket)
async def _delete_file(self, filename: str, *, bucket: str | None = None):
raise NotImplementedError
async def get_file_url(
self,
filename: str,
operation: str = "get_object",
expires_in: int = 3600,
*,
bucket: str | None = None,
) -> str:
"""Generate presigned URL. bucket: override instance default if provided."""
return await self._get_file_url(filename, operation, expires_in, bucket=bucket)
async def _get_file_url(
self,
filename: str,
operation: str = "get_object",
expires_in: int = 3600,
*,
bucket: str | None = None,
) -> str:
raise NotImplementedError
async def get_file(self, filename: str, *, bucket: str | None = None):
"""Download file. bucket: override instance default if provided."""
return await self._get_file(filename, bucket=bucket)
async def _get_file(self, filename: str, *, bucket: str | None = None):
raise NotImplementedError
async def list_objects(
self, prefix: str = "", *, bucket: str | None = None
) -> list[str]:
"""List object keys. bucket: override instance default if provided."""
return await self._list_objects(prefix, bucket=bucket)
async def _list_objects(
self, prefix: str = "", *, bucket: str | None = None
) -> list[str]:
raise NotImplementedError
async def stream_to_fileobj(
self, filename: str, fileobj: BinaryIO, *, bucket: str | None = None
):
"""Stream file directly to file object without loading into memory.
bucket: override instance default if provided."""
return await self._stream_to_fileobj(filename, fileobj, bucket=bucket)
async def _stream_to_fileobj(
self, filename: str, fileobj: BinaryIO, *, bucket: str | None = None
):
raise NotImplementedError raise NotImplementedError

View File

@@ -1,79 +1,236 @@
from functools import wraps
from typing import BinaryIO, Union
import aioboto3 import aioboto3
from botocore.config import Config
from botocore.exceptions import ClientError
from reflector.logger import logger from reflector.logger import logger
from reflector.storage.base import FileResult, Storage from reflector.storage.base import FileResult, Storage, StoragePermissionError
def handle_s3_client_errors(operation_name: str):
"""Decorator to handle S3 ClientError with bucket-aware messaging.
Args:
operation_name: Human-readable operation name for error messages (e.g., "upload", "delete")
"""
def decorator(func):
@wraps(func)
async def wrapper(self, *args, **kwargs):
bucket = kwargs.get("bucket")
try:
return await func(self, *args, **kwargs)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code in ("AccessDenied", "NoSuchBucket"):
actual_bucket = bucket or self._bucket_name
bucket_context = (
f"overridden bucket '{actual_bucket}'"
if bucket
else f"default bucket '{actual_bucket}'"
)
raise StoragePermissionError(
f"S3 {operation_name} failed for {bucket_context}: {error_code}. "
f"Check TRANSCRIPT_STORAGE_AWS_* credentials have permission."
) from e
raise
return wrapper
return decorator
class AwsStorage(Storage): class AwsStorage(Storage):
"""AWS S3 storage with bucket override for multi-platform recording architecture.
Master credentials access all buckets via optional bucket parameter in operations."""
def __init__( def __init__(
self, self,
aws_access_key_id: str,
aws_secret_access_key: str,
aws_bucket_name: str, aws_bucket_name: str,
aws_region: str, aws_region: str,
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
aws_role_arn: str | None = None,
): ):
if not aws_access_key_id:
raise ValueError("Storage `aws_storage` require `aws_access_key_id`")
if not aws_secret_access_key:
raise ValueError("Storage `aws_storage` require `aws_secret_access_key`")
if not aws_bucket_name: if not aws_bucket_name:
raise ValueError("Storage `aws_storage` require `aws_bucket_name`") raise ValueError("Storage `aws_storage` require `aws_bucket_name`")
if not aws_region: if not aws_region:
raise ValueError("Storage `aws_storage` require `aws_region`") raise ValueError("Storage `aws_storage` require `aws_region`")
if not aws_access_key_id and not aws_role_arn:
raise ValueError(
"Storage `aws_storage` require either `aws_access_key_id` or `aws_role_arn`"
)
if aws_role_arn and (aws_access_key_id or aws_secret_access_key):
raise ValueError(
"Storage `aws_storage` cannot use both `aws_role_arn` and access keys"
)
super().__init__() super().__init__()
self.aws_bucket_name = aws_bucket_name self._bucket_name = aws_bucket_name
self._region = aws_region
self._access_key_id = aws_access_key_id
self._secret_access_key = aws_secret_access_key
self._role_arn = aws_role_arn
self.aws_folder = "" self.aws_folder = ""
if "/" in aws_bucket_name: if "/" in aws_bucket_name:
self.aws_bucket_name, self.aws_folder = aws_bucket_name.split("/", 1) self._bucket_name, self.aws_folder = aws_bucket_name.split("/", 1)
self.boto_config = Config(retries={"max_attempts": 3, "mode": "adaptive"})
self.session = aioboto3.Session( self.session = aioboto3.Session(
aws_access_key_id=aws_access_key_id, aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key, aws_secret_access_key=aws_secret_access_key,
region_name=aws_region, region_name=aws_region,
) )
self.base_url = f"https://{aws_bucket_name}.s3.amazonaws.com/" self.base_url = f"https://{self._bucket_name}.s3.amazonaws.com/"
async def _put_file(self, filename: str, data: bytes) -> FileResult: # Implement credential properties
bucket = self.aws_bucket_name @property
folder = self.aws_folder def bucket_name(self) -> str:
logger.info(f"Uploading {filename} to S3 {bucket}/{folder}") return self._bucket_name
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client: @property
await client.put_object( def region(self) -> str:
Bucket=bucket, return self._region
Key=s3filename,
Body=data, @property
def access_key_id(self) -> str | None:
return self._access_key_id
@property
def secret_access_key(self) -> str | None:
return self._secret_access_key
@property
def role_arn(self) -> str | None:
return self._role_arn
@property
def key_credentials(self) -> tuple[str, str]:
"""Get (access_key_id, secret_access_key) for key-based auth."""
if self._role_arn:
raise ValueError(
"Storage uses IAM role authentication. "
"Use role_credential property instead of key_credentials."
) )
if not self._access_key_id or not self._secret_access_key:
raise ValueError("Storage access key credentials not configured")
return (self._access_key_id, self._secret_access_key)
async def _get_file_url(self, filename: str) -> FileResult: @property
bucket = self.aws_bucket_name def role_credential(self) -> str:
"""Get IAM role ARN for role-based auth."""
if self._access_key_id or self._secret_access_key:
raise ValueError(
"Storage uses access key authentication. "
"Use key_credentials property instead of role_credential."
)
if not self._role_arn:
raise ValueError("Storage IAM role ARN not configured")
return self._role_arn
@handle_s3_client_errors("upload")
async def _put_file(
self, filename: str, data: Union[bytes, BinaryIO], *, bucket: str | None = None
) -> FileResult:
actual_bucket = bucket or self._bucket_name
folder = self.aws_folder folder = self.aws_folder
s3filename = f"{folder}/{filename}" if folder else filename s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client: logger.info(f"Uploading {filename} to S3 {actual_bucket}/{folder}")
async with self.session.client("s3", config=self.boto_config) as client:
if isinstance(data, bytes):
await client.put_object(Bucket=actual_bucket, Key=s3filename, Body=data)
else:
# boto3 reads file-like object in chunks
# avoids creating extra memory copy vs bytes.getvalue() approach
await client.upload_fileobj(data, Bucket=actual_bucket, Key=s3filename)
url = await self._get_file_url(filename, bucket=bucket)
return FileResult(filename=filename, url=url)
@handle_s3_client_errors("presign")
async def _get_file_url(
self,
filename: str,
operation: str = "get_object",
expires_in: int = 3600,
*,
bucket: str | None = None,
) -> str:
actual_bucket = bucket or self._bucket_name
folder = self.aws_folder
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3", config=self.boto_config) as client:
presigned_url = await client.generate_presigned_url( presigned_url = await client.generate_presigned_url(
"get_object", operation,
Params={"Bucket": bucket, "Key": s3filename}, Params={"Bucket": actual_bucket, "Key": s3filename},
ExpiresIn=3600, ExpiresIn=expires_in,
) )
return presigned_url return presigned_url
async def _delete_file(self, filename: str): @handle_s3_client_errors("delete")
bucket = self.aws_bucket_name async def _delete_file(self, filename: str, *, bucket: str | None = None):
actual_bucket = bucket or self._bucket_name
folder = self.aws_folder folder = self.aws_folder
logger.info(f"Deleting {filename} from S3 {bucket}/{folder}") logger.info(f"Deleting {filename} from S3 {actual_bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client: async with self.session.client("s3", config=self.boto_config) as client:
await client.delete_object(Bucket=bucket, Key=s3filename) await client.delete_object(Bucket=actual_bucket, Key=s3filename)
async def _get_file(self, filename: str): @handle_s3_client_errors("download")
bucket = self.aws_bucket_name async def _get_file(self, filename: str, *, bucket: str | None = None):
actual_bucket = bucket or self._bucket_name
folder = self.aws_folder folder = self.aws_folder
logger.info(f"Downloading {filename} from S3 {bucket}/{folder}") logger.info(f"Downloading {filename} from S3 {actual_bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3") as client: async with self.session.client("s3", config=self.boto_config) as client:
response = await client.get_object(Bucket=bucket, Key=s3filename) response = await client.get_object(Bucket=actual_bucket, Key=s3filename)
return await response["Body"].read() return await response["Body"].read()
@handle_s3_client_errors("list_objects")
async def _list_objects(
self, prefix: str = "", *, bucket: str | None = None
) -> list[str]:
actual_bucket = bucket or self._bucket_name
folder = self.aws_folder
# Combine folder and prefix
s3prefix = f"{folder}/{prefix}" if folder else prefix
logger.info(f"Listing objects from S3 {actual_bucket} with prefix '{s3prefix}'")
keys = []
async with self.session.client("s3", config=self.boto_config) as client:
paginator = client.get_paginator("list_objects_v2")
async for page in paginator.paginate(Bucket=actual_bucket, Prefix=s3prefix):
if "Contents" in page:
for obj in page["Contents"]:
# Strip folder prefix from keys if present
key = obj["Key"]
if folder:
if key.startswith(f"{folder}/"):
key = key[len(folder) + 1 :]
elif key == folder:
# Skip folder marker itself
continue
keys.append(key)
return keys
@handle_s3_client_errors("stream")
async def _stream_to_fileobj(
self, filename: str, fileobj: BinaryIO, *, bucket: str | None = None
):
"""Stream file from S3 directly to file object without loading into memory."""
actual_bucket = bucket or self._bucket_name
folder = self.aws_folder
logger.info(f"Streaming {filename} from S3 {actual_bucket}/{folder}")
s3filename = f"{folder}/{filename}" if folder else filename
async with self.session.client("s3", config=self.boto_config) as client:
await client.download_fileobj(
Bucket=actual_bucket, Key=s3filename, Fileobj=fileobj
)
Storage.register("aws", AwsStorage) Storage.register("aws", AwsStorage)

View File

@@ -9,12 +9,12 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve() filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}" settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import get_session_factory from reflector.db import get_database, transcripts
from reflector.db.transcripts import transcripts_controller
session_factory = get_session_factory() database = get_database()
async with session_factory() as session: await database.connect()
transcripts = await transcripts_controller.get_all(session) transcripts = await database.fetch_all(transcripts.select())
await database.disconnect()
def export_transcript(transcript, output_dir): def export_transcript(transcript, output_dir):
for topic in transcript.topics: for topic in transcript.topics:

View File

@@ -8,12 +8,12 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve() filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}" settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import get_session_factory from reflector.db import get_database, transcripts
from reflector.db.transcripts import transcripts_controller
session_factory = get_session_factory() database = get_database()
async with session_factory() as session: await database.connect()
transcripts = await transcripts_controller.get_all(session) transcripts = await database.fetch_all(transcripts.select())
await database.disconnect()
def export_transcript(transcript): def export_transcript(transcript):
tid = transcript.id tid = transcript.id

View File

@@ -11,9 +11,6 @@ import time
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal from typing import Any, Dict, List, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db import get_session_factory
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
from reflector.logger import logger from reflector.logger import logger
from reflector.pipelines.main_file_pipeline import ( from reflector.pipelines.main_file_pipeline import (
@@ -53,7 +50,6 @@ TranscriptId = str
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system) # common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
# ideally we want to get rid of it at some point # ideally we want to get rid of it at some point
async def prepare_entry( async def prepare_entry(
session: AsyncSession,
source_path: str, source_path: str,
source_language: str, source_language: str,
target_language: str, target_language: str,
@@ -61,7 +57,6 @@ async def prepare_entry(
file_path = Path(source_path) file_path = Path(source_path)
transcript = await transcripts_controller.add( transcript = await transcripts_controller.add(
session,
file_path.name, file_path.name,
# note that the real file upload has SourceKind: LIVE for the reason of it's an error # note that the real file upload has SourceKind: LIVE for the reason of it's an error
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
@@ -83,20 +78,16 @@ async def prepare_entry(
logger.info(f"Copied {source_path} to {upload_path}") logger.info(f"Copied {source_path} to {upload_path}")
# pipelines expect entity status "uploaded" # pipelines expect entity status "uploaded"
await transcripts_controller.update(session, transcript, {"status": "uploaded"}) await transcripts_controller.update(transcript, {"status": "uploaded"})
return transcript.id return transcript.id
# same reason as prepare_entry # same reason as prepare_entry
async def extract_result_from_entry( async def extract_result_from_entry(
session: AsyncSession, transcript_id: TranscriptId, output_path: str
transcript_id: TranscriptId,
output_path: str,
) -> None: ) -> None:
post_final_transcript = await transcripts_controller.get_by_id( post_final_transcript = await transcripts_controller.get_by_id(transcript_id)
session, transcript_id
)
# assert post_final_transcript.status == "ended" # assert post_final_transcript.status == "ended"
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582 # File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
@@ -124,7 +115,6 @@ async def extract_result_from_entry(
async def process_live_pipeline( async def process_live_pipeline(
session: AsyncSession,
transcript_id: TranscriptId, transcript_id: TranscriptId,
): ):
"""Process transcript_id with transcription and diarization""" """Process transcript_id with transcription and diarization"""
@@ -133,9 +123,7 @@ async def process_live_pipeline(
await live_pipeline_process(transcript_id=transcript_id) await live_pipeline_process(transcript_id=transcript_id)
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr) print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
pre_final_transcript = await transcripts_controller.get_by_id( pre_final_transcript = await transcripts_controller.get_by_id(transcript_id)
session, transcript_id
)
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post # assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
assert pre_final_transcript.status != "ended" assert pre_final_transcript.status != "ended"
@@ -172,17 +160,21 @@ async def process(
pipeline: Literal["live", "file"], pipeline: Literal["live", "file"],
output_path: str = None, output_path: str = None,
): ):
session_factory = get_session_factory() from reflector.db import get_database
async with session_factory() as session:
database = get_database()
# db connect is a part of ceremony
await database.connect()
try:
transcript_id = await prepare_entry( transcript_id = await prepare_entry(
session,
source_path, source_path,
source_language, source_language,
target_language, target_language,
) )
pipeline_handlers = { pipeline_handlers = {
"live": lambda tid: process_live_pipeline(session, tid), "live": process_live_pipeline,
"file": process_file_pipeline, "file": process_file_pipeline,
} }
@@ -192,7 +184,9 @@ async def process(
await handler(transcript_id) await handler(transcript_id)
await extract_result_from_entry(session, transcript_id, output_path) await extract_result_from_entry(transcript_id, output_path)
finally:
await database.disconnect()
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -0,0 +1,26 @@
from reflector.utils.string import NonEmptyString
DailyRoomName = str
def extract_base_room_name(daily_room_name: DailyRoomName) -> NonEmptyString:
"""
Extract base room name from Daily.co timestamped room name.
Daily.co creates rooms with timestamp suffix: {base_name}-YYYYMMDDHHMMSS
This function removes the timestamp to get the original room name.
Examples:
"daily-20251020193458""daily"
"daily-2-20251020193458""daily-2"
"my-room-name-20251020193458""my-room-name"
Args:
daily_room_name: Full Daily.co room name with optional timestamp
Returns:
Base room name without timestamp suffix
"""
base_name = daily_room_name.rsplit("-", 1)[0]
assert base_name, f"Extracted base name is empty from: {daily_room_name}"
return base_name

View File

@@ -0,0 +1,9 @@
from datetime import datetime, timezone
def parse_datetime_with_timezone(iso_string: str) -> datetime:
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt

View File

@@ -1,4 +1,4 @@
from typing import Annotated from typing import Annotated, TypeVar
from pydantic import Field, TypeAdapter, constr from pydantic import Field, TypeAdapter, constr
@@ -21,3 +21,12 @@ def try_parse_non_empty_string(s: str) -> NonEmptyString | None:
if not s: if not s:
return None return None
return parse_non_empty_string(s) return parse_non_empty_string(s)
T = TypeVar("T", bound=str)
def assert_equal[T](s1: T, s2: T) -> T:
if s1 != s2:
raise ValueError(f"assert_equal: {s1} != {s2}")
return s1

View File

@@ -0,0 +1,37 @@
"""URL manipulation utilities."""
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
def add_query_param(url: str, key: str, value: str) -> str:
"""
Add or update a query parameter in a URL.
Properly handles URLs with or without existing query parameters,
preserving fragments and encoding special characters.
Args:
url: The URL to modify
key: The query parameter name
value: The query parameter value
Returns:
The URL with the query parameter added or updated
Examples:
>>> add_query_param("https://example.com/room", "t", "token123")
'https://example.com/room?t=token123'
>>> add_query_param("https://example.com/room?existing=param", "t", "token123")
'https://example.com/room?existing=param&t=token123'
"""
parsed = urlparse(url)
query_params = parse_qs(parsed.query, keep_blank_values=True)
query_params[key] = [value]
new_query = urlencode(query_params, doseq=True)
new_parsed = parsed._replace(query=new_query)
return urlunparse(new_parsed)

View File

@@ -0,0 +1,11 @@
from .base import VideoPlatformClient
from .models import MeetingData, VideoPlatformConfig
from .registry import get_platform_client, register_platform
__all__ = [
"VideoPlatformClient",
"VideoPlatformConfig",
"MeetingData",
"get_platform_client",
"register_platform",
]

View File

@@ -0,0 +1,54 @@
from abc import ABC, abstractmethod
from datetime import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from ..schemas.platform import Platform
from ..utils.string import NonEmptyString
from .models import MeetingData, VideoPlatformConfig
if TYPE_CHECKING:
from reflector.db.rooms import Room
# separator doesn't guarantee there's no more "ROOM_PREFIX_SEPARATOR" strings in room name
ROOM_PREFIX_SEPARATOR = "-"
class VideoPlatformClient(ABC):
PLATFORM_NAME: Platform
def __init__(self, config: VideoPlatformConfig):
self.config = config
@abstractmethod
async def create_meeting(
self, room_name_prefix: NonEmptyString, end_date: datetime, room: "Room"
) -> MeetingData:
pass
@abstractmethod
async def get_room_sessions(self, room_name: str) -> List[Any] | None:
pass
@abstractmethod
async def delete_room(self, room_name: str) -> bool:
pass
@abstractmethod
async def upload_logo(self, room_name: str, logo_path: str) -> bool:
pass
@abstractmethod
def verify_webhook_signature(
self, body: bytes, signature: str, timestamp: Optional[str] = None
) -> bool:
pass
def format_recording_config(self, room: "Room") -> Dict[str, Any]:
if room.recording_type == "cloud" and self.config.s3_bucket:
return {
"type": room.recording_type,
"bucket": self.config.s3_bucket,
"region": self.config.s3_region,
"trigger": room.recording_trigger,
}
return {"type": room.recording_type}

View File

@@ -0,0 +1,198 @@
import base64
import hmac
from datetime import datetime
from hashlib import sha256
from http import HTTPStatus
from typing import Any, Dict, List, Optional
import httpx
from reflector.db.rooms import Room
from reflector.logger import logger
from reflector.storage import get_dailyco_storage
from ..schemas.platform import Platform
from ..utils.daily import DailyRoomName
from ..utils.string import NonEmptyString
from .base import ROOM_PREFIX_SEPARATOR, VideoPlatformClient
from .models import MeetingData, RecordingType, VideoPlatformConfig
class DailyClient(VideoPlatformClient):
PLATFORM_NAME: Platform = "daily"
TIMEOUT = 10
BASE_URL = "https://api.daily.co/v1"
TIMESTAMP_FORMAT = "%Y%m%d%H%M%S"
RECORDING_NONE: RecordingType = "none"
RECORDING_CLOUD: RecordingType = "cloud"
def __init__(self, config: VideoPlatformConfig):
super().__init__(config)
self.headers = {
"Authorization": f"Bearer {config.api_key}",
"Content-Type": "application/json",
}
async def create_meeting(
self, room_name_prefix: NonEmptyString, end_date: datetime, room: Room
) -> MeetingData:
"""
Daily.co rooms vs meetings:
- We create a NEW Daily.co room for each Reflector meeting
- Daily.co meeting/session starts automatically when first participant joins
- Room auto-deletes after exp time
- Meeting.room_name stores the timestamped Daily.co room name
"""
timestamp = datetime.now().strftime(self.TIMESTAMP_FORMAT)
room_name = f"{room_name_prefix}{ROOM_PREFIX_SEPARATOR}{timestamp}"
data = {
"name": room_name,
"privacy": "private" if room.is_locked else "public",
"properties": {
"enable_recording": "raw-tracks"
if room.recording_type != self.RECORDING_NONE
else False,
"enable_chat": True,
"enable_screenshare": True,
"start_video_off": False,
"start_audio_off": False,
"exp": int(end_date.timestamp()),
},
}
# Get storage config for passing to Daily API
daily_storage = get_dailyco_storage()
assert daily_storage.bucket_name, "S3 bucket must be configured"
data["properties"]["recordings_bucket"] = {
"bucket_name": daily_storage.bucket_name,
"bucket_region": daily_storage.region,
"assume_role_arn": daily_storage.role_credential,
"allow_api_access": True,
}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.BASE_URL}/rooms",
headers=self.headers,
json=data,
timeout=self.TIMEOUT,
)
if response.status_code >= 400:
logger.error(
"Daily.co API error",
status_code=response.status_code,
response_body=response.text,
request_data=data,
)
response.raise_for_status()
result = response.json()
room_url = result["url"]
return MeetingData(
meeting_id=result["id"],
room_name=result["name"],
room_url=room_url,
host_room_url=room_url,
platform=self.PLATFORM_NAME,
extra_data=result,
)
async def get_room_sessions(self, room_name: str) -> List[Any] | None:
# no such api
return None
async def get_room_presence(self, room_name: str) -> Dict[str, Any]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.BASE_URL}/rooms/{room_name}/presence",
headers=self.headers,
timeout=self.TIMEOUT,
)
response.raise_for_status()
return response.json()
async def get_meeting_participants(self, meeting_id: str) -> Dict[str, Any]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.BASE_URL}/meetings/{meeting_id}/participants",
headers=self.headers,
timeout=self.TIMEOUT,
)
response.raise_for_status()
return response.json()
async def get_recording(self, recording_id: str) -> Dict[str, Any]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.BASE_URL}/recordings/{recording_id}",
headers=self.headers,
timeout=self.TIMEOUT,
)
response.raise_for_status()
return response.json()
async def delete_room(self, room_name: str) -> bool:
async with httpx.AsyncClient() as client:
response = await client.delete(
f"{self.BASE_URL}/rooms/{room_name}",
headers=self.headers,
timeout=self.TIMEOUT,
)
return response.status_code in (HTTPStatus.OK, HTTPStatus.NOT_FOUND)
async def upload_logo(self, room_name: str, logo_path: str) -> bool:
return True
def verify_webhook_signature(
self, body: bytes, signature: str, timestamp: Optional[str] = None
) -> bool:
"""Verify Daily.co webhook signature.
Daily.co uses:
- X-Webhook-Signature header
- X-Webhook-Timestamp header
- Signature format: HMAC-SHA256(base64_decode(secret), timestamp + '.' + body)
- Result is base64 encoded
"""
if not signature or not timestamp:
return False
try:
secret_bytes = base64.b64decode(self.config.webhook_secret)
signed_content = timestamp.encode() + b"." + body
expected = hmac.new(secret_bytes, signed_content, sha256).digest()
expected_b64 = base64.b64encode(expected).decode()
return hmac.compare_digest(expected_b64, signature)
except Exception as e:
logger.error("Daily.co webhook signature verification failed", exc_info=e)
return False
async def create_meeting_token(
self,
room_name: DailyRoomName,
enable_recording: bool,
user_id: Optional[str] = None,
) -> str:
data = {"properties": {"room_name": room_name}}
if enable_recording:
data["properties"]["start_cloud_recording"] = True
data["properties"]["enable_recording_ui"] = False
if user_id:
data["properties"]["user_id"] = user_id
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.BASE_URL}/meeting-tokens",
headers=self.headers,
json=data,
timeout=self.TIMEOUT,
)
response.raise_for_status()
return response.json()["token"]

View File

@@ -0,0 +1,62 @@
from typing import Optional
from reflector.settings import settings
from reflector.storage import get_dailyco_storage, get_whereby_storage
from ..schemas.platform import WHEREBY_PLATFORM, Platform
from .base import VideoPlatformClient, VideoPlatformConfig
from .registry import get_platform_client
def get_platform_config(platform: Platform) -> VideoPlatformConfig:
if platform == WHEREBY_PLATFORM:
if not settings.WHEREBY_API_KEY:
raise ValueError(
"WHEREBY_API_KEY is required when platform='whereby'. "
"Set WHEREBY_API_KEY environment variable."
)
whereby_storage = get_whereby_storage()
key_id, secret = whereby_storage.key_credentials
return VideoPlatformConfig(
api_key=settings.WHEREBY_API_KEY,
webhook_secret=settings.WHEREBY_WEBHOOK_SECRET or "",
api_url=settings.WHEREBY_API_URL,
s3_bucket=whereby_storage.bucket_name,
s3_region=whereby_storage.region,
aws_access_key_id=key_id,
aws_access_key_secret=secret,
)
elif platform == "daily":
if not settings.DAILY_API_KEY:
raise ValueError(
"DAILY_API_KEY is required when platform='daily'. "
"Set DAILY_API_KEY environment variable."
)
if not settings.DAILY_SUBDOMAIN:
raise ValueError(
"DAILY_SUBDOMAIN is required when platform='daily'. "
"Set DAILY_SUBDOMAIN environment variable."
)
daily_storage = get_dailyco_storage()
return VideoPlatformConfig(
api_key=settings.DAILY_API_KEY,
webhook_secret=settings.DAILY_WEBHOOK_SECRET or "",
subdomain=settings.DAILY_SUBDOMAIN,
s3_bucket=daily_storage.bucket_name,
s3_region=daily_storage.region,
aws_role_arn=daily_storage.role_credential,
)
else:
raise ValueError(f"Unknown platform: {platform}")
def create_platform_client(platform: Platform) -> VideoPlatformClient:
config = get_platform_config(platform)
return get_platform_client(platform, config)
def get_platform(room_platform: Optional[Platform] = None) -> Platform:
if room_platform:
return room_platform
return settings.DEFAULT_VIDEO_PLATFORM

View File

@@ -0,0 +1,40 @@
from typing import Any, Dict, Literal, Optional
from pydantic import BaseModel, Field
from reflector.schemas.platform import WHEREBY_PLATFORM, Platform
RecordingType = Literal["none", "local", "cloud"]
class MeetingData(BaseModel):
platform: Platform
meeting_id: str = Field(description="Platform-specific meeting identifier")
room_url: str = Field(description="URL for participants to join")
host_room_url: str = Field(description="URL for hosts (may be same as room_url)")
room_name: str = Field(description="Human-readable room name")
extra_data: Dict[str, Any] = Field(default_factory=dict)
class Config:
json_schema_extra = {
"example": {
"platform": WHEREBY_PLATFORM,
"meeting_id": "12345678",
"room_url": "https://subdomain.whereby.com/room-20251008120000",
"host_room_url": "https://subdomain.whereby.com/room-20251008120000?roomKey=abc123",
"room_name": "room-20251008120000",
}
}
class VideoPlatformConfig(BaseModel):
api_key: str
webhook_secret: str
api_url: Optional[str] = None
subdomain: Optional[str] = None # Whereby/Daily subdomain
s3_bucket: Optional[str] = None
s3_region: Optional[str] = None
# Whereby uses access keys, Daily uses IAM role
aws_access_key_id: Optional[str] = None
aws_access_key_secret: Optional[str] = None
aws_role_arn: Optional[str] = None

View File

@@ -0,0 +1,35 @@
from typing import Dict, Type
from ..schemas.platform import DAILY_PLATFORM, WHEREBY_PLATFORM, Platform
from .base import VideoPlatformClient, VideoPlatformConfig
_PLATFORMS: Dict[Platform, Type[VideoPlatformClient]] = {}
def register_platform(name: Platform, client_class: Type[VideoPlatformClient]):
_PLATFORMS[name] = client_class
def get_platform_client(
platform: Platform, config: VideoPlatformConfig
) -> VideoPlatformClient:
if platform not in _PLATFORMS:
raise ValueError(f"Unknown video platform: {platform}")
client_class = _PLATFORMS[platform]
return client_class(config)
def get_available_platforms() -> list[Platform]:
return list(_PLATFORMS.keys())
def _register_builtin_platforms():
from .daily import DailyClient # noqa: PLC0415
from .whereby import WherebyClient # noqa: PLC0415
register_platform(WHEREBY_PLATFORM, WherebyClient)
register_platform(DAILY_PLATFORM, DailyClient)
_register_builtin_platforms()

View File

@@ -0,0 +1,141 @@
import hmac
import json
import re
import time
from datetime import datetime
from hashlib import sha256
from typing import Any, Dict, Optional
import httpx
from reflector.db.rooms import Room
from reflector.storage import get_whereby_storage
from ..schemas.platform import WHEREBY_PLATFORM, Platform
from ..utils.string import NonEmptyString
from .base import (
MeetingData,
VideoPlatformClient,
VideoPlatformConfig,
)
from .whereby_utils import whereby_room_name_prefix
class WherebyClient(VideoPlatformClient):
PLATFORM_NAME: Platform = WHEREBY_PLATFORM
TIMEOUT = 10 # seconds
MAX_ELAPSED_TIME = 60 * 1000 # 1 minute in milliseconds
def __init__(self, config: VideoPlatformConfig):
super().__init__(config)
self.headers = {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {config.api_key}",
}
async def create_meeting(
self, room_name_prefix: NonEmptyString, end_date: datetime, room: Room
) -> MeetingData:
data = {
"isLocked": room.is_locked,
"roomNamePrefix": whereby_room_name_prefix(room_name_prefix),
"roomNamePattern": "uuid",
"roomMode": room.room_mode,
"endDate": end_date.isoformat(),
"fields": ["hostRoomUrl"],
}
if room.recording_type == "cloud":
# Get storage config for passing credentials to Whereby API
whereby_storage = get_whereby_storage()
key_id, secret = whereby_storage.key_credentials
data["recording"] = {
"type": room.recording_type,
"destination": {
"provider": "s3",
"bucket": whereby_storage.bucket_name,
"accessKeyId": key_id,
"accessKeySecret": secret,
"fileFormat": "mp4",
},
"startTrigger": room.recording_trigger,
}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.config.api_url}/meetings",
headers=self.headers,
json=data,
timeout=self.TIMEOUT,
)
response.raise_for_status()
result = response.json()
return MeetingData(
meeting_id=result["meetingId"],
room_name=result["roomName"],
room_url=result["roomUrl"],
host_room_url=result["hostRoomUrl"],
platform=self.PLATFORM_NAME,
extra_data=result,
)
async def get_room_sessions(self, room_name: str) -> Dict[str, Any]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.config.api_url}/insights/room-sessions?roomName={room_name}",
headers=self.headers,
timeout=self.TIMEOUT,
)
response.raise_for_status()
return response.json().get("results", [])
async def delete_room(self, room_name: str) -> bool:
return True
async def upload_logo(self, room_name: str, logo_path: str) -> bool:
async with httpx.AsyncClient() as client:
with open(logo_path, "rb") as f:
response = await client.put(
f"{self.config.api_url}/rooms/{room_name}/theme/logo",
headers={
"Authorization": f"Bearer {self.config.api_key}",
},
timeout=self.TIMEOUT,
files={"image": f},
)
response.raise_for_status()
return True
def verify_webhook_signature(
self, body: bytes, signature: str, timestamp: Optional[str] = None
) -> bool:
if not signature:
return False
matches = re.match(r"t=(.*),v1=(.*)", signature)
if not matches:
return False
ts, sig = matches.groups()
current_time = int(time.time() * 1000)
diff_time = current_time - int(ts) * 1000
if diff_time >= self.MAX_ELAPSED_TIME:
return False
body_dict = json.loads(body)
signed_payload = f"{ts}.{json.dumps(body_dict, separators=(',', ':'))}"
hmac_obj = hmac.new(
self.config.webhook_secret.encode("utf-8"),
signed_payload.encode("utf-8"),
sha256,
)
expected_signature = hmac_obj.hexdigest()
try:
return hmac.compare_digest(
expected_signature.encode("utf-8"), sig.encode("utf-8")
)
except Exception:
return False

View File

@@ -0,0 +1,38 @@
import re
from datetime import datetime
from reflector.utils.datetime import parse_datetime_with_timezone
from reflector.utils.string import NonEmptyString, parse_non_empty_string
from reflector.video_platforms.base import ROOM_PREFIX_SEPARATOR
def parse_whereby_recording_filename(
object_key: NonEmptyString,
) -> (NonEmptyString, datetime):
filename = parse_non_empty_string(object_key.rsplit(".", 1)[0])
timestamp_pattern = r"(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}Z)"
match = re.search(timestamp_pattern, filename)
if not match:
raise ValueError(f"No ISO timestamp found in filename: {object_key}")
timestamp_str = match.group(1)
timestamp_start = match.start(1)
room_name_part = filename[:timestamp_start]
if room_name_part.endswith(ROOM_PREFIX_SEPARATOR):
room_name_part = room_name_part[: -len(ROOM_PREFIX_SEPARATOR)]
else:
raise ValueError(
f"room name {room_name_part} doesnt have {ROOM_PREFIX_SEPARATOR} at the end of filename: {object_key}"
)
return parse_non_empty_string(room_name_part), parse_datetime_with_timezone(
timestamp_str
)
def whereby_room_name_prefix(room_name_prefix: NonEmptyString) -> NonEmptyString:
return room_name_prefix + ROOM_PREFIX_SEPARATOR
# room name comes with "/" from whereby api but lacks "/" e.g. in recording filenames
def room_name_to_whereby_api_room_name(room_name: NonEmptyString) -> NonEmptyString:
return f"/{room_name}"

View File

@@ -0,0 +1,233 @@
import json
from typing import Any, Dict, Literal
from fastapi import APIRouter, HTTPException, Request
from pydantic import BaseModel
from reflector.db.meetings import meetings_controller
from reflector.logger import logger as _logger
from reflector.settings import settings
from reflector.utils.daily import DailyRoomName
from reflector.video_platforms.factory import create_platform_client
from reflector.worker.process import process_multitrack_recording
router = APIRouter()
logger = _logger.bind(platform="daily")
class DailyTrack(BaseModel):
type: Literal["audio", "video"]
s3Key: str
size: int
class DailyWebhookEvent(BaseModel):
version: str
type: str
id: str
payload: Dict[str, Any]
event_ts: float
def _extract_room_name(event: DailyWebhookEvent) -> DailyRoomName | None:
"""Extract room name from Daily event payload.
Daily.co API inconsistency:
- participant.* events use "room" field
- recording.* events use "room_name" field
"""
return event.payload.get("room_name") or event.payload.get("room")
@router.post("/webhook")
async def webhook(request: Request):
"""Handle Daily webhook events.
Daily.co circuit-breaker: After 3+ failed responses (4xx/5xx), webhook
state→FAILED, stops sending events. Reset: scripts/recreate_daily_webhook.py
"""
body = await request.body()
signature = request.headers.get("X-Webhook-Signature", "")
timestamp = request.headers.get("X-Webhook-Timestamp", "")
client = create_platform_client("daily")
# TEMPORARY: Bypass signature check for testing
# TODO: Remove this after testing is complete
BYPASS_FOR_TESTING = True
if not BYPASS_FOR_TESTING:
if not client.verify_webhook_signature(body, signature, timestamp):
logger.warning(
"Invalid webhook signature",
signature=signature,
timestamp=timestamp,
has_body=bool(body),
)
raise HTTPException(status_code=401, detail="Invalid webhook signature")
try:
body_json = json.loads(body)
except json.JSONDecodeError:
raise HTTPException(status_code=422, detail="Invalid JSON")
if body_json.get("test") == "test":
logger.info("Received Daily webhook test event")
return {"status": "ok"}
# Parse as actual event
try:
event = DailyWebhookEvent(**body_json)
except Exception as e:
logger.error("Failed to parse webhook event", error=str(e), body=body.decode())
raise HTTPException(status_code=422, detail="Invalid event format")
# Handle participant events
if event.type == "participant.joined":
await _handle_participant_joined(event)
elif event.type == "participant.left":
await _handle_participant_left(event)
elif event.type == "recording.started":
await _handle_recording_started(event)
elif event.type == "recording.ready-to-download":
await _handle_recording_ready(event)
elif event.type == "recording.error":
await _handle_recording_error(event)
else:
logger.warning(
"Unhandled Daily webhook event type",
event_type=event.type,
payload=event.payload,
)
return {"status": "ok"}
async def _handle_participant_joined(event: DailyWebhookEvent):
daily_room_name = _extract_room_name(event)
if not daily_room_name:
logger.warning("participant.joined: no room in payload", payload=event.payload)
return
meeting = await meetings_controller.get_by_room_name(daily_room_name)
if meeting:
await meetings_controller.increment_num_clients(meeting.id)
logger.info(
"Participant joined",
meeting_id=meeting.id,
room_name=daily_room_name,
recording_type=meeting.recording_type,
recording_trigger=meeting.recording_trigger,
)
else:
logger.warning(
"participant.joined: meeting not found", room_name=daily_room_name
)
async def _handle_participant_left(event: DailyWebhookEvent):
room_name = _extract_room_name(event)
if not room_name:
return
meeting = await meetings_controller.get_by_room_name(room_name)
if meeting:
await meetings_controller.decrement_num_clients(meeting.id)
async def _handle_recording_started(event: DailyWebhookEvent):
room_name = _extract_room_name(event)
if not room_name:
logger.warning(
"recording.started: no room_name in payload", payload=event.payload
)
return
meeting = await meetings_controller.get_by_room_name(room_name)
if meeting:
logger.info(
"Recording started",
meeting_id=meeting.id,
room_name=room_name,
recording_id=event.payload.get("recording_id"),
platform="daily",
)
else:
logger.warning("recording.started: meeting not found", room_name=room_name)
async def _handle_recording_ready(event: DailyWebhookEvent):
"""Handle recording ready for download event.
Daily.co webhook payload for raw-tracks recordings:
{
"recording_id": "...",
"room_name": "test2-20251009192341",
"tracks": [
{"type": "audio", "s3Key": "monadical/test2-.../uuid-cam-audio-123.webm", "size": 400000},
{"type": "video", "s3Key": "monadical/test2-.../uuid-cam-video-456.webm", "size": 30000000}
]
}
"""
room_name = _extract_room_name(event)
recording_id = event.payload.get("recording_id")
tracks_raw = event.payload.get("tracks", [])
if not room_name or not tracks_raw:
logger.warning(
"recording.ready-to-download: missing room_name or tracks",
room_name=room_name,
has_tracks=bool(tracks_raw),
payload=event.payload,
)
return
try:
tracks = [DailyTrack(**t) for t in tracks_raw]
except Exception as e:
logger.error(
"recording.ready-to-download: invalid tracks structure",
error=str(e),
tracks=tracks_raw,
)
return
logger.info(
"Recording ready for download",
room_name=room_name,
recording_id=recording_id,
num_tracks=len(tracks),
platform="daily",
)
bucket_name = settings.DAILYCO_STORAGE_AWS_BUCKET_NAME
if not bucket_name:
logger.error(
"DAILYCO_STORAGE_AWS_BUCKET_NAME not configured; cannot process Daily recording"
)
return
track_keys = [t.s3Key for t in tracks if t.type == "audio"]
process_multitrack_recording.delay(
bucket_name=bucket_name,
daily_room_name=room_name,
recording_id=recording_id,
track_keys=track_keys,
)
async def _handle_recording_error(event: DailyWebhookEvent):
room_name = _extract_room_name(event)
error = event.payload.get("error", "Unknown error")
if room_name:
meeting = await meetings_controller.get_by_room_name(room_name)
if meeting:
logger.error(
"Recording error",
meeting_id=meeting.id,
room_name=room_name,
error=error,
platform="daily",
)

View File

@@ -5,20 +5,24 @@ from typing import Annotated, Any, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi_pagination import Page from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate from fastapi_pagination.ext.databases import apaginate
from pydantic import BaseModel from pydantic import BaseModel
from redis.exceptions import LockError from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session from reflector.db import get_database
from reflector.db.calendar_events import calendar_events_controller from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.redis_cache import RedisAsyncLock from reflector.redis_cache import RedisAsyncLock
from reflector.schemas.platform import Platform
from reflector.services.ics_sync import ics_sync_service from reflector.services.ics_sync import ics_sync_service
from reflector.settings import settings from reflector.settings import settings
from reflector.whereby import create_meeting, upload_logo from reflector.utils.url import add_query_param
from reflector.video_platforms.factory import (
create_platform_client,
get_platform,
)
from reflector.worker.webhook import test_webhook from reflector.worker.webhook import test_webhook
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -42,6 +46,7 @@ class Room(BaseModel):
ics_enabled: bool = False ics_enabled: bool = False
ics_last_sync: Optional[datetime] = None ics_last_sync: Optional[datetime] = None
ics_last_etag: Optional[str] = None ics_last_etag: Optional[str] = None
platform: Platform
class RoomDetails(Room): class RoomDetails(Room):
@@ -69,6 +74,7 @@ class Meeting(BaseModel):
is_active: bool = True is_active: bool = True
calendar_event_id: str | None = None calendar_event_id: str | None = None
calendar_metadata: dict[str, Any] | None = None calendar_metadata: dict[str, Any] | None = None
platform: Platform
class CreateRoom(BaseModel): class CreateRoom(BaseModel):
@@ -86,6 +92,7 @@ class CreateRoom(BaseModel):
ics_url: Optional[str] = None ics_url: Optional[str] = None
ics_fetch_interval: int = 300 ics_fetch_interval: int = 300
ics_enabled: bool = False ics_enabled: bool = False
platform: Optional[Platform] = None
class UpdateRoom(BaseModel): class UpdateRoom(BaseModel):
@@ -103,6 +110,7 @@ class UpdateRoom(BaseModel):
ics_url: Optional[str] = None ics_url: Optional[str] = None
ics_fetch_interval: Optional[int] = None ics_fetch_interval: Optional[int] = None
ics_enabled: Optional[bool] = None ics_enabled: Optional[bool] = None
platform: Optional[Platform] = None
class CreateRoomMeeting(BaseModel): class CreateRoomMeeting(BaseModel):
@@ -166,40 +174,40 @@ class CalendarEventResponse(BaseModel):
router = APIRouter() router = APIRouter()
def parse_datetime_with_timezone(iso_string: str) -> datetime:
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
@router.get("/rooms", response_model=Page[RoomDetails]) @router.get("/rooms", response_model=Page[RoomDetails])
async def rooms_list( async def rooms_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> list[RoomDetails]: ) -> list[RoomDetails]:
if not user and not settings.PUBLIC_MODE: if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated") raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
query = await rooms_controller.get_all( paginated = await apaginate(
session, user_id=user_id, order_by="-created_at", return_query=True get_database(),
await rooms_controller.get_all(
user_id=user_id, order_by="-created_at", return_query=True
),
) )
return await paginate(session, query)
for room in paginated.items:
room.platform = get_platform(room.platform)
return paginated
@router.get("/rooms/{room_id}", response_model=RoomDetails) @router.get("/rooms/{room_id}", response_model=RoomDetails)
async def rooms_get( async def rooms_get(
room_id: str, room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id) room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
if not room.is_shared and (user_id is None or room.user_id != user_id):
raise HTTPException(status_code=403, detail="Room access denied")
room.platform = get_platform(room.platform)
return room return room
@@ -207,37 +215,33 @@ async def rooms_get(
async def rooms_get_by_name( async def rooms_get_by_name(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name) room = await rooms_controller.get_by_name(room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
# Convert to RoomDetails format (add webhook fields if user is owner)
room_dict = room.__dict__.copy() room_dict = room.__dict__.copy()
if user_id == room.user_id: if user_id == room.user_id:
# User is owner, include webhook details if available
room_dict["webhook_url"] = getattr(room, "webhook_url", None) room_dict["webhook_url"] = getattr(room, "webhook_url", None)
room_dict["webhook_secret"] = getattr(room, "webhook_secret", None) room_dict["webhook_secret"] = getattr(room, "webhook_secret", None)
else: else:
# Non-owner, hide webhook details
room_dict["webhook_url"] = None room_dict["webhook_url"] = None
room_dict["webhook_secret"] = None room_dict["webhook_secret"] = None
room_dict["platform"] = get_platform(room.platform)
return RoomDetails(**room_dict) return RoomDetails(**room_dict)
@router.post("/rooms", response_model=Room) @router.post("/rooms", response_model=Room)
async def rooms_create( async def rooms_create(
room: CreateRoom, room: CreateRoom,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"]
return await rooms_controller.add( return await rooms_controller.add(
session,
name=room.name, name=room.name,
user_id=user_id, user_id=user_id,
zulip_auto_post=room.zulip_auto_post, zulip_auto_post=room.zulip_auto_post,
@@ -253,6 +257,7 @@ async def rooms_create(
ics_url=room.ics_url, ics_url=room.ics_url,
ics_fetch_interval=room.ics_fetch_interval, ics_fetch_interval=room.ics_fetch_interval,
ics_enabled=room.ics_enabled, ics_enabled=room.ics_enabled,
platform=room.platform,
) )
@@ -260,29 +265,32 @@ async def rooms_create(
async def rooms_update( async def rooms_update(
room_id: str, room_id: str,
info: UpdateRoom, info: UpdateRoom,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"]
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id) room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
if room.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
values = info.dict(exclude_unset=True) values = info.dict(exclude_unset=True)
await rooms_controller.update(session, room, values) await rooms_controller.update(room, values)
room.platform = get_platform(room.platform)
return room return room
@router.delete("/rooms/{room_id}", response_model=DeletionStatus) @router.delete("/rooms/{room_id}", response_model=DeletionStatus)
async def rooms_delete( async def rooms_delete(
room_id: str, room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"]
room = await rooms_controller.get_by_id(session, room_id, user_id=user_id) room = await rooms_controller.get_by_id(room_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
await rooms_controller.remove_by_id(session, room.id, user_id=user_id) if room.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
await rooms_controller.remove_by_id(room.id, user_id=user_id)
return DeletionStatus(status="ok") return DeletionStatus(status="ok")
@@ -291,10 +299,9 @@ async def rooms_create_meeting(
room_name: str, room_name: str,
info: CreateRoomMeeting, info: CreateRoomMeeting,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name) room = await rooms_controller.get_by_name(room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -310,26 +317,28 @@ async def rooms_create_meeting(
meeting = None meeting = None
if not info.allow_duplicated: if not info.allow_duplicated:
meeting = await meetings_controller.get_active( meeting = await meetings_controller.get_active(
session, room=room, current_time=current_time room=room, current_time=current_time
) )
if meeting is None: if meeting is None:
end_date = current_time + timedelta(hours=8) end_date = current_time + timedelta(hours=8)
whereby_meeting = await create_meeting("", end_date=end_date, room=room) platform = get_platform(room.platform)
client = create_platform_client(platform)
await upload_logo(whereby_meeting["roomName"], "./images/logo.png") meeting_data = await client.create_meeting(
room.name, end_date=end_date, room=room
)
await client.upload_logo(meeting_data.room_name, "./images/logo.png")
meeting = await meetings_controller.create( meeting = await meetings_controller.create(
session, id=meeting_data.meeting_id,
id=whereby_meeting["meetingId"], room_name=meeting_data.room_name,
room_name=whereby_meeting["roomName"], room_url=meeting_data.room_url,
room_url=whereby_meeting["roomUrl"], host_room_url=meeting_data.host_room_url,
host_room_url=whereby_meeting["hostRoomUrl"], start_date=current_time,
start_date=parse_datetime_with_timezone( end_date=end_date,
whereby_meeting["startDate"]
),
end_date=parse_datetime_with_timezone(whereby_meeting["endDate"]),
room=room, room=room,
) )
except LockError: except LockError:
@@ -338,6 +347,18 @@ async def rooms_create_meeting(
status_code=503, detail="Meeting creation in progress, please try again" status_code=503, detail="Meeting creation in progress, please try again"
) )
if meeting.platform == "daily" and room.recording_trigger != "none":
client = create_platform_client(meeting.platform)
token = await client.create_meeting_token(
meeting.room_name,
enable_recording=True,
user_id=user_id,
)
meeting = meeting.model_copy()
meeting.room_url = add_query_param(meeting.room_url, "t", token)
if meeting.host_room_url:
meeting.host_room_url = add_query_param(meeting.host_room_url, "t", token)
if user_id != room.user_id: if user_id != room.user_id:
meeting.host_room_url = "" meeting.host_room_url = ""
@@ -347,17 +368,16 @@ async def rooms_create_meeting(
@router.post("/rooms/{room_id}/webhook/test", response_model=WebhookTestResult) @router.post("/rooms/{room_id}/webhook/test", response_model=WebhookTestResult)
async def rooms_test_webhook( async def rooms_test_webhook(
room_id: str, room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
): ):
"""Test webhook configuration by sending a sample payload.""" """Test webhook configuration by sending a sample payload."""
user_id = user["sub"] if user else None user_id = user["sub"]
room = await rooms_controller.get_by_id(session, room_id) room = await rooms_controller.get_by_id(room_id)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
if user_id and room.user_id != user_id: if room.user_id != user_id:
raise HTTPException( raise HTTPException(
status_code=403, detail="Not authorized to test this room's webhook" status_code=403, detail="Not authorized to test this room's webhook"
) )
@@ -370,10 +390,9 @@ async def rooms_test_webhook(
async def rooms_sync_ics( async def rooms_sync_ics(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name) room = await rooms_controller.get_by_name(room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -386,7 +405,7 @@ async def rooms_sync_ics(
if not room.ics_enabled or not room.ics_url: if not room.ics_enabled or not room.ics_url:
raise HTTPException(status_code=400, detail="ICS not configured for this room") raise HTTPException(status_code=400, detail="ICS not configured for this room")
result = await ics_sync_service.sync_room_calendar(session, room) result = await ics_sync_service.sync_room_calendar(room)
if result["status"] == "error": if result["status"] == "error":
raise HTTPException( raise HTTPException(
@@ -400,10 +419,9 @@ async def rooms_sync_ics(
async def rooms_ics_status( async def rooms_ics_status(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name) room = await rooms_controller.get_by_name(room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
@@ -418,7 +436,7 @@ async def rooms_ics_status(
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval) next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
events = await calendar_events_controller.get_by_room( events = await calendar_events_controller.get_by_room(
session, room.id, include_deleted=False room.id, include_deleted=False
) )
return ICSStatus( return ICSStatus(
@@ -434,16 +452,15 @@ async def rooms_ics_status(
async def rooms_list_meetings( async def rooms_list_meetings(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name) room = await rooms_controller.get_by_name(room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
events = await calendar_events_controller.get_by_room( events = await calendar_events_controller.get_by_room(
session, room.id, include_deleted=False room.id, include_deleted=False
) )
if user_id != room.user_id: if user_id != room.user_id:
@@ -461,16 +478,15 @@ async def rooms_list_upcoming_meetings(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
minutes_ahead: int = 120, minutes_ahead: int = 120,
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name) room = await rooms_controller.get_by_name(room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
events = await calendar_events_controller.get_upcoming( events = await calendar_events_controller.get_upcoming(
session, room.id, minutes_ahead=minutes_ahead room.id, minutes_ahead=minutes_ahead
) )
if user_id != room.user_id: if user_id != room.user_id:
@@ -485,20 +501,22 @@ async def rooms_list_upcoming_meetings(
async def rooms_list_active_meetings( async def rooms_list_active_meetings(
room_name: str, room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name) room = await rooms_controller.get_by_name(room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
meetings = await meetings_controller.get_all_active_for_room( meetings = await meetings_controller.get_all_active_for_room(
session, room=room, current_time=current_time room=room, current_time=current_time
) )
# Hide host URLs from non-owners effective_platform = get_platform(room.platform)
for meeting in meetings:
meeting.platform = effective_platform
if user_id != room.user_id: if user_id != room.user_id:
for meeting in meetings: for meeting in meetings:
meeting.host_room_url = "" meeting.host_room_url = ""
@@ -511,24 +529,18 @@ async def rooms_get_meeting(
room_name: str, room_name: str,
meeting_id: str, meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
"""Get a single meeting by ID within a specific room.""" """Get a single meeting by ID within a specific room."""
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name) room = await rooms_controller.get_by_name(room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(session, meeting_id) meeting = await meetings_controller.get_by_id(meeting_id, room=room)
if not meeting: if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found") raise HTTPException(status_code=404, detail="Meeting not found")
if meeting.room_id != room.id:
raise HTTPException(
status_code=403, detail="Meeting does not belong to this room"
)
if user_id != room.user_id and not room.is_shared: if user_id != room.user_id and not room.is_shared:
meeting.host_room_url = "" meeting.host_room_url = ""
@@ -540,24 +552,18 @@ async def rooms_join_meeting(
room_name: str, room_name: str,
meeting_id: str, meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name) room = await rooms_controller.get_by_name(room_name)
if not room: if not room:
raise HTTPException(status_code=404, detail="Room not found") raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(session, meeting_id) meeting = await meetings_controller.get_by_id(meeting_id, room=room)
if not meeting: if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found") raise HTTPException(status_code=404, detail="Meeting not found")
if meeting.room_id != room.id:
raise HTTPException(
status_code=403, detail="Meeting does not belong to this room"
)
if not meeting.is_active: if not meeting.is_active:
raise HTTPException(status_code=400, detail="Meeting is not active") raise HTTPException(status_code=400, detail="Meeting is not active")
@@ -565,7 +571,6 @@ async def rooms_join_meeting(
if meeting.end_date <= current_time: if meeting.end_date <= current_time:
raise HTTPException(status_code=400, detail="Meeting has ended") raise HTTPException(status_code=400, detail="Meeting has ended")
# Hide host URL from non-owners
if user_id != room.user_id: if user_id != room.user_id:
meeting.host_room_url = "" meeting.host_room_url = ""

View File

@@ -3,15 +3,12 @@ from typing import Annotated, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate from fastapi_pagination.ext.databases import apaginate
from jose import jwt from jose import jwt
from pydantic import BaseModel, Field, constr, field_serializer from pydantic import AwareDatetime, BaseModel, Field, constr, field_serializer
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session from reflector.db import get_database
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.search import ( from reflector.db.search import (
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_LIMIT,
SearchLimit, SearchLimit,
@@ -35,6 +32,7 @@ from reflector.db.transcripts import (
from reflector.processors.types import Transcript as ProcessorTranscript from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.processors.types import Word from reflector.processors.types import Word
from reflector.settings import settings from reflector.settings import settings
from reflector.ws_manager import get_ws_manager
from reflector.zulip import ( from reflector.zulip import (
InvalidMessageError, InvalidMessageError,
get_zulip_message, get_zulip_message,
@@ -135,6 +133,21 @@ SearchOffsetParam = Annotated[
SearchOffsetBase, Query(description="Number of results to skip") SearchOffsetBase, Query(description="Number of results to skip")
] ]
SearchFromDatetimeParam = Annotated[
AwareDatetime | None,
Query(
alias="from",
description="Filter transcripts created on or after this datetime (ISO 8601 with timezone)",
),
]
SearchToDatetimeParam = Annotated[
AwareDatetime | None,
Query(
alias="to",
description="Filter transcripts created on or before this datetime (ISO 8601 with timezone)",
),
]
class SearchResponse(BaseModel): class SearchResponse(BaseModel):
results: list[SearchResult] results: list[SearchResult]
@@ -150,25 +163,24 @@ async def transcripts_list(
source_kind: SourceKind | None = None, source_kind: SourceKind | None = None,
room_id: str | None = None, room_id: str | None = None,
search_term: str | None = None, search_term: str | None = None,
session: AsyncSession = Depends(get_session),
): ):
if not user and not settings.PUBLIC_MODE: if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated") raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
query = await transcripts_controller.get_all( return await apaginate(
session, get_database(),
user_id=user_id, await transcripts_controller.get_all(
source_kind=SourceKind(source_kind) if source_kind else None, user_id=user_id,
room_id=room_id, source_kind=SourceKind(source_kind) if source_kind else None,
search_term=search_term, room_id=room_id,
order_by="-created_at", search_term=search_term,
return_query=True, order_by="-created_at",
return_query=True,
),
) )
return await paginate(session, query)
@router.get("/transcripts/search", response_model=SearchResponse) @router.get("/transcripts/search", response_model=SearchResponse)
async def transcripts_search( async def transcripts_search(
@@ -177,19 +189,23 @@ async def transcripts_search(
offset: SearchOffsetParam = 0, offset: SearchOffsetParam = 0,
room_id: Optional[str] = None, room_id: Optional[str] = None,
source_kind: Optional[SourceKind] = None, source_kind: Optional[SourceKind] = None,
from_datetime: SearchFromDatetimeParam = None,
to_datetime: SearchToDatetimeParam = None,
user: Annotated[ user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional) Optional[auth.UserInfo], Depends(auth.current_user_optional)
] = None, ] = None,
session: AsyncSession = Depends(get_session),
): ):
""" """Full-text search across transcript titles and content."""
Full-text search across transcript titles and content.
"""
if not user and not settings.PUBLIC_MODE: if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated") raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
if from_datetime and to_datetime and from_datetime > to_datetime:
raise HTTPException(
status_code=400, detail="'from' must be less than or equal to 'to'"
)
search_params = SearchParameters( search_params = SearchParameters(
query_text=parse_search_query_param(q), query_text=parse_search_query_param(q),
limit=limit, limit=limit,
@@ -197,9 +213,11 @@ async def transcripts_search(
user_id=user_id, user_id=user_id,
room_id=room_id, room_id=room_id,
source_kind=source_kind, source_kind=source_kind,
from_datetime=from_datetime,
to_datetime=to_datetime,
) )
results, total = await search_controller.search_transcripts(session, search_params) results, total = await search_controller.search_transcripts(search_params)
return SearchResponse( return SearchResponse(
results=results, results=results,
@@ -214,11 +232,9 @@ async def transcripts_search(
async def transcripts_create( async def transcripts_create(
info: CreateTranscript, info: CreateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
return await transcripts_controller.add( transcript = await transcripts_controller.add(
session,
info.name, info.name,
source_kind=info.source_kind or SourceKind.LIVE, source_kind=info.source_kind or SourceKind.LIVE,
source_language=info.source_language, source_language=info.source_language,
@@ -226,6 +242,14 @@ async def transcripts_create(
user_id=user_id, user_id=user_id,
) )
if user_id:
await get_ws_manager().send_json(
room_id=f"user:{user_id}",
message={"event": "TRANSCRIPT_CREATED", "data": {"id": transcript.id}},
)
return transcript
# ============================================================== # ==============================================================
# Single transcript # Single transcript
@@ -338,11 +362,10 @@ class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
async def transcript_get( async def transcript_get(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
return await transcripts_controller.get_by_id_for_http( return await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
@@ -350,38 +373,36 @@ async def transcript_get(
async def transcript_update( async def transcript_update(
transcript_id: str, transcript_id: str,
info: UpdateTranscript, info: UpdateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if not transcripts_controller.user_can_mutate(transcript, user_id):
raise HTTPException(status_code=403, detail="Not authorized")
values = info.dict(exclude_unset=True) values = info.dict(exclude_unset=True)
updated_transcript = await transcripts_controller.update( updated_transcript = await transcripts_controller.update(transcript, values)
session, transcript, values
)
return updated_transcript return updated_transcript
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus) @router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
async def transcript_delete( async def transcript_delete(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"]
transcript = await transcripts_controller.get_by_id(session, transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
if not transcripts_controller.user_can_mutate(transcript, user_id):
raise HTTPException(status_code=403, detail="Not authorized")
if transcript.meeting_id: await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
meeting = await meetings_controller.get_by_id(session, transcript.meeting_id) await get_ws_manager().send_json(
room = await rooms_controller.get_by_id(session, meeting.room_id) room_id=f"user:{user_id}",
if room.is_shared: message={"event": "TRANSCRIPT_DELETED", "data": {"id": transcript.id}},
user_id = None )
await transcripts_controller.remove_by_id(session, transcript.id, user_id=user_id)
return DeletionStatus(status="ok") return DeletionStatus(status="ok")
@@ -392,11 +413,10 @@ async def transcript_delete(
async def transcript_get_topics( async def transcript_get_topics(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
# convert to GetTranscriptTopic # convert to GetTranscriptTopic
@@ -412,11 +432,10 @@ async def transcript_get_topics(
async def transcript_get_topics_with_words( async def transcript_get_topics_with_words(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
# convert to GetTranscriptTopicWithWords # convert to GetTranscriptTopicWithWords
@@ -434,11 +453,10 @@ async def transcript_get_topics_with_words_per_speaker(
transcript_id: str, transcript_id: str,
topic_id: str, topic_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
# get the topic from the transcript # get the topic from the transcript
@@ -456,16 +474,16 @@ async def transcript_post_to_zulip(
stream: str, stream: str,
topic: str, topic: str,
include_topics: bool, include_topics: bool,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
if not transcripts_controller.user_can_mutate(transcript, user_id):
raise HTTPException(status_code=403, detail="Not authorized")
content = get_zulip_message(transcript, include_topics) content = get_zulip_message(transcript, include_topics)
message_updated = False message_updated = False
@@ -481,5 +499,5 @@ async def transcript_post_to_zulip(
if not message_updated: if not message_updated:
response = await send_message_to_zulip(stream, topic, content) response = await send_message_to_zulip(stream, topic, content)
await transcripts_controller.update( await transcripts_controller.update(
session, transcript, {"zulip_message_id": response["id"]} transcript, {"zulip_message_id": response["id"]}
) )

View File

@@ -9,10 +9,8 @@ from typing import Annotated, Optional
import httpx import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from jose import jwt from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import AudioWaveform, transcripts_controller from reflector.db.transcripts import AudioWaveform, transcripts_controller
from reflector.settings import settings from reflector.settings import settings
from reflector.views.transcripts import ALGORITHM from reflector.views.transcripts import ALGORITHM
@@ -34,7 +32,6 @@ async def transcript_get_audio_mp3(
request: Request, request: Request,
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
token: str | None = None, token: str | None = None,
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
@@ -51,7 +48,7 @@ async def transcript_get_audio_mp3(
raise unauthorized_exception raise unauthorized_exception
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.audio_location == "storage": if transcript.audio_location == "storage":
@@ -89,7 +86,7 @@ async def transcript_get_audio_mp3(
return range_requests_response( return range_requests_response(
request, request,
transcript.audio_mp3_filename.as_posix(), transcript.audio_mp3_filename,
content_type="audio/mpeg", content_type="audio/mpeg",
content_disposition=f"attachment; filename={filename}", content_disposition=f"attachment; filename={filename}",
) )
@@ -99,18 +96,13 @@ async def transcript_get_audio_mp3(
async def transcript_get_audio_waveform( async def transcript_get_audio_waveform(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> AudioWaveform: ) -> AudioWaveform:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if not transcript.audio_waveform_filename.exists(): if not transcript.audio_waveform_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found") raise HTTPException(status_code=404, detail="Audio not found")
audio_waveform = transcript.audio_waveform return transcript.audio_waveform
if not audio_waveform:
raise HTTPException(status_code=404, detail="Audio waveform not found")
return audio_waveform

View File

@@ -8,10 +8,8 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
from reflector.views.types import DeletionStatus from reflector.views.types import DeletionStatus
@@ -39,11 +37,10 @@ class UpdateParticipant(BaseModel):
async def transcript_get_participants( async def transcript_get_participants(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> list[Participant]: ) -> list[Participant]:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.participants is None: if transcript.participants is None:
@@ -59,13 +56,14 @@ async def transcript_get_participants(
async def transcript_add_participant( async def transcript_add_participant(
transcript_id: str, transcript_id: str,
participant: CreateParticipant, participant: CreateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
) -> Participant: ) -> Participant:
user_id = user["sub"] if user else None user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
# ensure the speaker is unique # ensure the speaker is unique
if participant.speaker is not None and transcript.participants is not None: if participant.speaker is not None and transcript.participants is not None:
@@ -77,7 +75,7 @@ async def transcript_add_participant(
) )
obj = await transcripts_controller.upsert_participant( obj = await transcripts_controller.upsert_participant(
session, transcript, TranscriptParticipant(**participant.dict()) transcript, TranscriptParticipant(**participant.dict())
) )
return Participant.model_validate(obj) return Participant.model_validate(obj)
@@ -87,11 +85,10 @@ async def transcript_get_participant(
transcript_id: str, transcript_id: str,
participant_id: str, participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant: ) -> Participant:
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
for p in transcript.participants: for p in transcript.participants:
@@ -106,13 +103,14 @@ async def transcript_update_participant(
transcript_id: str, transcript_id: str,
participant_id: str, participant_id: str,
participant: UpdateParticipant, participant: UpdateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
) -> Participant: ) -> Participant:
user_id = user["sub"] if user else None user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
# ensure the speaker is unique # ensure the speaker is unique
for p in transcript.participants: for p in transcript.participants:
@@ -136,7 +134,7 @@ async def transcript_update_participant(
fields = participant.dict(exclude_unset=True) fields = participant.dict(exclude_unset=True)
obj = obj.copy(update=fields) obj = obj.copy(update=fields)
await transcripts_controller.upsert_participant(session, transcript, obj) await transcripts_controller.upsert_participant(transcript, obj)
return Participant.model_validate(obj) return Participant.model_validate(obj)
@@ -144,12 +142,13 @@ async def transcript_update_participant(
async def transcript_delete_participant( async def transcript_delete_participant(
transcript_id: str, transcript_id: str,
participant_id: str, participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
) -> DeletionStatus: ) -> DeletionStatus:
user_id = user["sub"] if user else None user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
await transcripts_controller.delete_participant(session, transcript, participant_id) if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
await transcripts_controller.delete_participant(transcript, participant_id)
return DeletionStatus(status="ok") return DeletionStatus(status="ok")

View File

@@ -3,12 +3,14 @@ from typing import Annotated, Optional
import celery import celery
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session from reflector.db.recordings import recordings_controller
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.pipelines.main_multitrack_pipeline import (
task_pipeline_multitrack_process,
)
router = APIRouter() router = APIRouter()
@@ -21,11 +23,10 @@ class ProcessStatus(BaseModel):
async def transcript_process( async def transcript_process(
transcript_id: str, transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.locked: if transcript.locked:
@@ -36,14 +37,35 @@ async def transcript_process(
status_code=400, detail="Recording is not ready for processing" status_code=400, detail="Recording is not ready for processing"
) )
# avoid duplicate scheduling for either pipeline
if task_is_scheduled_or_active( if task_is_scheduled_or_active(
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process", "reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
transcript_id=transcript_id, transcript_id=transcript_id,
) or task_is_scheduled_or_active(
"reflector.pipelines.main_multitrack_pipeline.task_pipeline_multitrack_process",
transcript_id=transcript_id,
): ):
return ProcessStatus(status="already running") return ProcessStatus(status="already running")
# schedule a background task process the file # Determine processing mode strictly from DB to avoid S3 scans
task_pipeline_file_process.delay(transcript_id=transcript_id) bucket_name = None
track_keys: list[str] = []
if transcript.recording_id:
recording = await recordings_controller.get_by_id(transcript.recording_id)
if recording:
bucket_name = recording.bucket_name
track_keys = list(getattr(recording, "track_keys", []) or [])
if bucket_name:
task_pipeline_multitrack_process.delay(
transcript_id=transcript_id,
bucket_name=bucket_name,
track_keys=track_keys,
)
else:
# Default single-file pipeline
task_pipeline_file_process.delay(transcript_id=transcript_id)
return ProcessStatus(status="ok") return ProcessStatus(status="ok")

View File

@@ -8,10 +8,8 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
router = APIRouter() router = APIRouter()
@@ -37,13 +35,14 @@ class SpeakerMerge(BaseModel):
async def transcript_assign_speaker( async def transcript_assign_speaker(
transcript_id: str, transcript_id: str,
assignment: SpeakerAssignment, assignment: SpeakerAssignment,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
) -> SpeakerAssignmentStatus: ) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
@@ -82,9 +81,7 @@ async def transcript_assign_speaker(
# if the participant does not have a speaker, create one # if the participant does not have a speaker, create one
if participant.speaker is None: if participant.speaker is None:
participant.speaker = transcript.find_empty_speaker() participant.speaker = transcript.find_empty_speaker()
await transcripts_controller.upsert_participant( await transcripts_controller.upsert_participant(transcript, participant)
session, transcript, participant
)
speaker = participant.speaker speaker = participant.speaker
@@ -105,7 +102,6 @@ async def transcript_assign_speaker(
for topic in changed_topics: for topic in changed_topics:
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"topics": transcript.topics_dump(), "topics": transcript.topics_dump(),
@@ -119,13 +115,14 @@ async def transcript_assign_speaker(
async def transcript_merge_speaker( async def transcript_merge_speaker(
transcript_id: str, transcript_id: str,
merge: SpeakerMerge, merge: SpeakerMerge,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[auth.UserInfo, Depends(auth.current_user)],
session: AsyncSession = Depends(get_session),
) -> SpeakerAssignmentStatus: ) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")
@@ -170,7 +167,6 @@ async def transcript_merge_speaker(
for topic in changed_topics: for topic in changed_topics:
transcript.upsert_topic(topic) transcript.upsert_topic(topic)
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"topics": transcript.topics_dump(), "topics": transcript.topics_dump(),

View File

@@ -3,10 +3,8 @@ from typing import Annotated, Optional
import av import av
from fastapi import APIRouter, Depends, HTTPException, UploadFile from fastapi import APIRouter, Depends, HTTPException, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
@@ -24,11 +22,10 @@ async def transcript_record_upload(
total_chunks: int, total_chunks: int,
chunk: UploadFile, chunk: UploadFile,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.locked: if transcript.locked:
@@ -92,7 +89,7 @@ async def transcript_record_upload(
container.close() container.close()
# set the status to "uploaded" # set the status to "uploaded"
await transcripts_controller.update(session, transcript, {"status": "uploaded"}) await transcripts_controller.update(transcript, {"status": "uploaded"})
# launch a background task to process the file # launch a background task to process the file
task_pipeline_file_process.delay(transcript_id=transcript_id) task_pipeline_file_process.delay(transcript_id=transcript_id)

View File

@@ -1,10 +1,8 @@
from typing import Annotated, Optional from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from .rtc_offer import RtcOffer, rtc_offer_base from .rtc_offer import RtcOffer, rtc_offer_base
@@ -18,11 +16,10 @@ async def transcript_record_webrtc(
params: RtcOffer, params: RtcOffer,
request: Request, request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
): ):
user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http( transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id transcript_id, user_id=user_id
) )
if transcript.locked: if transcript.locked:

View File

@@ -4,8 +4,11 @@ Transcripts websocket API
""" """
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
import reflector.auth as auth
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.ws_manager import get_ws_manager from reflector.ws_manager import get_ws_manager
@@ -21,10 +24,12 @@ async def transcript_get_websocket_events(transcript_id: str):
async def transcript_events_websocket( async def transcript_events_websocket(
transcript_id: str, transcript_id: str,
websocket: WebSocket, websocket: WebSocket,
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)], user: Optional[auth.UserInfo] = Depends(auth.current_user_optional),
): ):
# user_id = user["sub"] if user else None user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(session, transcript_id) transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript: if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found") raise HTTPException(status_code=404, detail="Transcript not found")

View File

@@ -11,7 +11,6 @@ router = APIRouter()
class UserInfo(BaseModel): class UserInfo(BaseModel):
sub: str sub: str
email: Optional[str] email: Optional[str]
email_verified: Optional[bool]
@router.get("/me") @router.get("/me")

View File

@@ -0,0 +1,62 @@
from datetime import datetime
from typing import Annotated
import structlog
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
import reflector.auth as auth
from reflector.db.user_api_keys import user_api_keys_controller
from reflector.utils.string import NonEmptyString
router = APIRouter()
logger = structlog.get_logger(__name__)
class CreateApiKeyRequest(BaseModel):
name: NonEmptyString | None = None
class ApiKeyResponse(BaseModel):
id: NonEmptyString
user_id: NonEmptyString
name: NonEmptyString | None
created_at: datetime
class CreateApiKeyResponse(ApiKeyResponse):
key: NonEmptyString
@router.post("/user/api-keys", response_model=CreateApiKeyResponse)
async def create_api_key(
req: CreateApiKeyRequest,
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
):
api_key_model, plaintext = await user_api_keys_controller.create_key(
user_id=user["sub"],
name=req.name,
)
return CreateApiKeyResponse(
**api_key_model.model_dump(),
key=plaintext,
)
@router.get("/user/api-keys", response_model=list[ApiKeyResponse])
async def list_api_keys(
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
):
api_keys = await user_api_keys_controller.list_by_user_id(user["sub"])
return [ApiKeyResponse(**k.model_dump()) for k in api_keys]
@router.delete("/user/api-keys/{key_id}")
async def delete_api_key(
key_id: NonEmptyString,
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
):
deleted = await user_api_keys_controller.delete_key(key_id, user["sub"])
if not deleted:
raise HTTPException(status_code=404)
return {"status": "ok"}

View File

@@ -0,0 +1,53 @@
from typing import Optional
from fastapi import APIRouter, WebSocket
from reflector.auth.auth_jwt import JWTAuth # type: ignore
from reflector.ws_manager import get_ws_manager
router = APIRouter()
# Close code for unauthorized WebSocket connections
UNAUTHORISED = 4401
@router.websocket("/events")
async def user_events_websocket(websocket: WebSocket):
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
token: Optional[str] = None
negotiated_subprotocol: Optional[str] = None
if len(parts) >= 2 and parts[0].lower() == "bearer":
negotiated_subprotocol = "bearer"
token = parts[1]
user_id: Optional[str] = None
if not token:
await websocket.close(code=UNAUTHORISED)
return
try:
payload = JWTAuth().verify_token(token)
user_id = payload.get("sub")
except Exception:
await websocket.close(code=UNAUTHORISED)
return
if not user_id:
await websocket.close(code=UNAUTHORISED)
return
room_id = f"user:{user_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(
room_id, websocket, subprotocol=negotiated_subprotocol
)
try:
while True:
await websocket.receive()
finally:
if room_id:
await ws_manager.remove_user_from_room(room_id, websocket)

View File

@@ -1,114 +0,0 @@
import logging
from datetime import datetime
import httpx
from reflector.db.rooms import Room
from reflector.settings import settings
from reflector.utils.string import parse_non_empty_string
logger = logging.getLogger(__name__)
def _get_headers():
api_key = parse_non_empty_string(
settings.WHEREBY_API_KEY, "WHEREBY_API_KEY value is required."
)
return {
"Content-Type": "application/json; charset=utf-8",
"Authorization": f"Bearer {api_key}",
}
TIMEOUT = 10 # seconds
def _get_whereby_s3_auth():
errors = []
try:
bucket_name = parse_non_empty_string(
settings.RECORDING_STORAGE_AWS_BUCKET_NAME,
"RECORDING_STORAGE_AWS_BUCKET_NAME value is required.",
)
except Exception as e:
errors.append(e)
try:
key_id = parse_non_empty_string(
settings.AWS_WHEREBY_ACCESS_KEY_ID,
"AWS_WHEREBY_ACCESS_KEY_ID value is required.",
)
except Exception as e:
errors.append(e)
try:
key_secret = parse_non_empty_string(
settings.AWS_WHEREBY_ACCESS_KEY_SECRET,
"AWS_WHEREBY_ACCESS_KEY_SECRET value is required.",
)
except Exception as e:
errors.append(e)
if len(errors) > 0:
raise Exception(
f"Failed to get Whereby auth settings: {', '.join(str(e) for e in errors)}"
)
return bucket_name, key_id, key_secret
async def create_meeting(room_name_prefix: str, end_date: datetime, room: Room):
s3_bucket_name, s3_key_id, s3_key_secret = _get_whereby_s3_auth()
data = {
"isLocked": room.is_locked,
"roomNamePrefix": room_name_prefix,
"roomNamePattern": "uuid",
"roomMode": room.room_mode,
"endDate": end_date.isoformat(),
"recording": {
"type": room.recording_type,
"destination": {
"provider": "s3",
"bucket": s3_bucket_name,
"accessKeyId": s3_key_id,
"accessKeySecret": s3_key_secret,
"fileFormat": "mp4",
},
"startTrigger": room.recording_trigger,
},
"fields": ["hostRoomUrl"],
}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{settings.WHEREBY_API_URL}/meetings",
headers=_get_headers(),
json=data,
timeout=TIMEOUT,
)
if response.status_code == 403:
logger.warning(
f"Failed to create meeting: access denied on Whereby: {response.text}"
)
response.raise_for_status()
return response.json()
async def get_room_sessions(room_name: str):
async with httpx.AsyncClient() as client:
response = await client.get(
f"{settings.WHEREBY_API_URL}/insights/room-sessions?roomName={room_name}",
headers=_get_headers(),
timeout=TIMEOUT,
)
response.raise_for_status()
return response.json()
async def upload_logo(room_name: str, logo_path: str):
async with httpx.AsyncClient() as client:
with open(logo_path, "rb") as f:
response = await client.put(
f"{settings.WHEREBY_API_URL}/rooms{room_name}/theme/logo",
headers={
"Authorization": f"Bearer {settings.WHEREBY_API_KEY}",
},
timeout=TIMEOUT,
files={"image": f},
)
response.raise_for_status()

View File

@@ -10,16 +10,16 @@ from typing import TypedDict
import structlog import structlog
from celery import shared_task from celery import shared_task
from databases import Database
from pydantic.types import PositiveInt from pydantic.types import PositiveInt
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel from reflector.db import get_database
from reflector.db.transcripts import transcripts_controller from reflector.db.meetings import meetings
from reflector.db.recordings import recordings
from reflector.db.transcripts import transcripts, transcripts_controller
from reflector.settings import settings from reflector.settings import settings
from reflector.storage import get_recordings_storage from reflector.storage import get_transcripts_storage
from reflector.worker.session_decorator import with_session
logger = structlog.get_logger(__name__) logger = structlog.get_logger(__name__)
@@ -34,49 +34,51 @@ class CleanupStats(TypedDict):
async def delete_single_transcript( async def delete_single_transcript(
session: AsyncSession, transcript_data: dict, stats: CleanupStats db: Database, transcript_data: dict, stats: CleanupStats
): ):
transcript_id = transcript_data["id"] transcript_id = transcript_data["id"]
meeting_id = transcript_data["meeting_id"] meeting_id = transcript_data["meeting_id"]
recording_id = transcript_data["recording_id"] recording_id = transcript_data["recording_id"]
try: try:
if meeting_id: async with db.transaction(isolation="serializable"):
await session.execute( if meeting_id:
delete(MeetingModel).where(MeetingModel.id == meeting_id) await db.execute(meetings.delete().where(meetings.c.id == meeting_id))
) stats["meetings_deleted"] += 1
stats["meetings_deleted"] += 1 logger.info("Deleted associated meeting", meeting_id=meeting_id)
logger.info("Deleted associated meeting", meeting_id=meeting_id)
if recording_id: if recording_id:
result = await session.execute( recording = await db.fetch_one(
select(RecordingModel).where(RecordingModel.id == recording_id) recordings.select().where(recordings.c.id == recording_id)
) )
recording = result.mappings().first() if recording:
if recording: try:
try: await get_transcripts_storage().delete_file(
await get_recordings_storage().delete_file(recording["object_key"]) recording["object_key"], bucket=recording["bucket_name"]
except Exception as storage_error: )
logger.warning( except Exception as storage_error:
"Failed to delete recording from storage", logger.warning(
recording_id=recording_id, "Failed to delete recording from storage",
object_key=recording["object_key"], recording_id=recording_id,
error=str(storage_error), object_key=recording["object_key"],
error=str(storage_error),
)
await db.execute(
recordings.delete().where(recordings.c.id == recording_id)
)
stats["recordings_deleted"] += 1
logger.info(
"Deleted associated recording", recording_id=recording_id
) )
await session.execute( await transcripts_controller.remove_by_id(transcript_id)
delete(RecordingModel).where(RecordingModel.id == recording_id) stats["transcripts_deleted"] += 1
) logger.info(
stats["recordings_deleted"] += 1 "Deleted transcript",
logger.info("Deleted associated recording", recording_id=recording_id) transcript_id=transcript_id,
created_at=transcript_data["created_at"].isoformat(),
await transcripts_controller.remove_by_id(session, transcript_id) )
stats["transcripts_deleted"] += 1
logger.info(
"Deleted transcript",
transcript_id=transcript_id,
created_at=transcript_data["created_at"].isoformat(),
)
except Exception as e: except Exception as e:
error_msg = f"Failed to delete transcript {transcript_id}: {str(e)}" error_msg = f"Failed to delete transcript {transcript_id}: {str(e)}"
logger.error(error_msg, exc_info=e) logger.error(error_msg, exc_info=e)
@@ -84,30 +86,18 @@ async def delete_single_transcript(
async def cleanup_old_transcripts( async def cleanup_old_transcripts(
session: AsyncSession, cutoff_date: datetime, stats: CleanupStats db: Database, cutoff_date: datetime, stats: CleanupStats
): ):
"""Delete old anonymous transcripts and their associated recordings/meetings.""" """Delete old anonymous transcripts and their associated recordings/meetings."""
query = select( query = transcripts.select().where(
TranscriptModel.id, (transcripts.c.created_at < cutoff_date) & (transcripts.c.user_id.is_(None))
TranscriptModel.meeting_id,
TranscriptModel.recording_id,
TranscriptModel.created_at,
).where(
(TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None))
) )
old_transcripts = await db.fetch_all(query)
result = await session.execute(query)
old_transcripts = result.mappings().all()
logger.info(f"Found {len(old_transcripts)} old transcripts to delete") logger.info(f"Found {len(old_transcripts)} old transcripts to delete")
for transcript_data in old_transcripts: for transcript_data in old_transcripts:
try: await delete_single_transcript(db, transcript_data, stats)
await delete_single_transcript(session, transcript_data, stats)
except Exception as e:
error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}"
logger.error(error_msg, exc_info=e)
stats["errors"].append(error_msg)
def log_cleanup_results(stats: CleanupStats): def log_cleanup_results(stats: CleanupStats):
@@ -127,7 +117,6 @@ def log_cleanup_results(stats: CleanupStats):
async def cleanup_old_public_data( async def cleanup_old_public_data(
session: AsyncSession,
days: PositiveInt | None = None, days: PositiveInt | None = None,
) -> CleanupStats | None: ) -> CleanupStats | None:
if days is None: if days is None:
@@ -150,7 +139,8 @@ async def cleanup_old_public_data(
"errors": [], "errors": [],
} }
await cleanup_old_transcripts(session, cutoff_date, stats) db = get_database()
await cleanup_old_transcripts(db, cutoff_date, stats)
log_cleanup_results(stats) log_cleanup_results(stats)
return stats return stats
@@ -161,6 +151,5 @@ async def cleanup_old_public_data(
retry_kwargs={"max_retries": 3, "countdown": 300}, retry_kwargs={"max_retries": 3, "countdown": 300},
) )
@asynctask @asynctask
@with_session async def cleanup_old_public_data_task(days: int | None = None):
async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None): await cleanup_old_public_data(days=days)
await cleanup_old_public_data(session, days=days)

View File

@@ -3,26 +3,23 @@ from datetime import datetime, timedelta, timezone
import structlog import structlog
from celery import shared_task from celery import shared_task
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.asynctask import asynctask from reflector.asynctask import asynctask
from reflector.db.calendar_events import calendar_events_controller from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import Room, rooms_controller
from reflector.redis_cache import RedisAsyncLock from reflector.redis_cache import RedisAsyncLock
from reflector.services.ics_sync import SyncStatus, ics_sync_service from reflector.services.ics_sync import SyncStatus, ics_sync_service
from reflector.whereby import create_meeting, upload_logo from reflector.video_platforms.factory import create_platform_client, get_platform
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__)) logger = structlog.wrap_logger(get_task_logger(__name__))
@shared_task @shared_task
@asynctask @asynctask
@with_session async def sync_room_ics(room_id: str):
async def sync_room_ics(session: AsyncSession, room_id: str):
try: try:
room = await rooms_controller.get_by_id(session, room_id) room = await rooms_controller.get_by_id(room_id)
if not room: if not room:
logger.warning("Room not found for ICS sync", room_id=room_id) logger.warning("Room not found for ICS sync", room_id=room_id)
return return
@@ -32,7 +29,7 @@ async def sync_room_ics(session: AsyncSession, room_id: str):
return return
logger.info("Starting ICS sync for room", room_id=room_id, room_name=room.name) logger.info("Starting ICS sync for room", room_id=room_id, room_name=room.name)
result = await ics_sync_service.sync_room_calendar(session, room) result = await ics_sync_service.sync_room_calendar(room)
if result["status"] == SyncStatus.SUCCESS: if result["status"] == SyncStatus.SUCCESS:
logger.info( logger.info(
@@ -58,12 +55,11 @@ async def sync_room_ics(session: AsyncSession, room_id: str):
@shared_task @shared_task
@asynctask @asynctask
@with_session async def sync_all_ics_calendars():
async def sync_all_ics_calendars(session: AsyncSession):
try: try:
logger.info("Starting sync for all ICS-enabled rooms") logger.info("Starting sync for all ICS-enabled rooms")
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session) ics_enabled_rooms = await rooms_controller.get_ics_enabled()
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled") logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
for room in ics_enabled_rooms: for room in ics_enabled_rooms:
@@ -90,21 +86,17 @@ def _should_sync(room) -> bool:
MEETING_DEFAULT_DURATION = timedelta(hours=1) MEETING_DEFAULT_DURATION = timedelta(hours=1)
async def create_upcoming_meetings_for_event( async def create_upcoming_meetings_for_event(event, create_window, room: Room):
session: AsyncSession, event, create_window, room_id, room
):
if event.start_time <= create_window: if event.start_time <= create_window:
return return
existing_meeting = await meetings_controller.get_by_calendar_event( existing_meeting = await meetings_controller.get_by_calendar_event(event.id, room)
session, event.id
)
if existing_meeting: if existing_meeting:
return return
logger.info( logger.info(
"Pre-creating meeting for calendar event", "Pre-creating meeting for calendar event",
room_id=room_id, room_id=room.id,
event_id=event.id, event_id=event.id,
event_title=event.title, event_title=event.title,
) )
@@ -112,21 +104,22 @@ async def create_upcoming_meetings_for_event(
try: try:
end_date = event.end_time or (event.start_time + MEETING_DEFAULT_DURATION) end_date = event.end_time or (event.start_time + MEETING_DEFAULT_DURATION)
whereby_meeting = await create_meeting( client = create_platform_client(get_platform(room.platform))
meeting_data = await client.create_meeting(
"", "",
end_date=end_date, end_date=end_date,
room=room, room=room,
) )
await upload_logo(whereby_meeting["roomName"], "./images/logo.png") await client.upload_logo(meeting_data.room_name, "./images/logo.png")
meeting = await meetings_controller.create( meeting = await meetings_controller.create(
session, id=meeting_data.meeting_id,
id=whereby_meeting["meetingId"], room_name=meeting_data.room_name,
room_name=whereby_meeting["roomName"], room_url=meeting_data.room_url,
room_url=whereby_meeting["roomUrl"], host_room_url=meeting_data.host_room_url,
host_room_url=whereby_meeting["hostRoomUrl"], start_date=event.start_time,
start_date=datetime.fromisoformat(whereby_meeting["startDate"]), end_date=end_date,
end_date=datetime.fromisoformat(whereby_meeting["endDate"]),
room=room, room=room,
calendar_event_id=event.id, calendar_event_id=event.id,
calendar_metadata={ calendar_metadata={
@@ -145,7 +138,7 @@ async def create_upcoming_meetings_for_event(
except Exception as e: except Exception as e:
logger.error( logger.error(
"Failed to pre-create meeting", "Failed to pre-create meeting",
room_id=room_id, room_id=room.id,
event_id=event.id, event_id=event.id,
error=str(e), error=str(e),
) )
@@ -153,8 +146,7 @@ async def create_upcoming_meetings_for_event(
@shared_task @shared_task
@asynctask @asynctask
@with_session async def create_upcoming_meetings():
async def create_upcoming_meetings(session: AsyncSession):
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock: async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
if not lock.acquired: if not lock.acquired:
logger.warning( logger.warning(
@@ -165,21 +157,18 @@ async def create_upcoming_meetings(session: AsyncSession):
try: try:
logger.info("Starting creation of upcoming meetings") logger.info("Starting creation of upcoming meetings")
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session) ics_enabled_rooms = await rooms_controller.get_ics_enabled()
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
create_window = now - timedelta(minutes=6) create_window = now - timedelta(minutes=6)
for room in ics_enabled_rooms: for room in ics_enabled_rooms:
events = await calendar_events_controller.get_upcoming( events = await calendar_events_controller.get_upcoming(
session,
room.id, room.id,
minutes_ahead=7, minutes_ahead=7,
) )
for event in events: for event in events:
await create_upcoming_meetings_for_event( await create_upcoming_meetings_for_event(event, create_window, room)
session, event, create_window, room.id, room
)
logger.info("Completed pre-creation check for upcoming meetings") logger.info("Completed pre-creation check for upcoming meetings")
except Exception as e: except Exception as e:

View File

@@ -1,5 +1,6 @@
import json import json
import os import os
import re
from datetime import datetime, timezone from datetime import datetime, timezone
from urllib.parse import unquote from urllib.parse import unquote
@@ -10,30 +11,36 @@ from celery import shared_task
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from pydantic import ValidationError from pydantic import ValidationError
from redis.exceptions import LockError from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.meetings import meetings_controller from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import SourceKind, transcripts_controller from reflector.db.transcripts import (
SourceKind,
TranscriptParticipant,
transcripts_controller,
)
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.pipelines.main_live_pipeline import asynctask from reflector.pipelines.main_live_pipeline import asynctask
from reflector.pipelines.main_multitrack_pipeline import (
task_pipeline_multitrack_process,
)
from reflector.pipelines.topic_processing import EmptyPipeline
from reflector.processors import AudioFileWriterProcessor
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.redis_cache import get_redis_client from reflector.redis_cache import get_redis_client
from reflector.settings import settings from reflector.settings import settings
from reflector.whereby import get_room_sessions from reflector.storage import get_transcripts_storage
from reflector.worker.session_decorator import with_session from reflector.utils.daily import DailyRoomName, extract_base_room_name
from reflector.video_platforms.factory import create_platform_client
from reflector.video_platforms.whereby_utils import (
parse_whereby_recording_filename,
room_name_to_whereby_api_room_name,
)
logger = structlog.wrap_logger(get_task_logger(__name__)) logger = structlog.wrap_logger(get_task_logger(__name__))
def parse_datetime_with_timezone(iso_string: str) -> datetime:
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
@shared_task @shared_task
def process_messages(): def process_messages():
queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL
@@ -75,49 +82,42 @@ def process_messages():
logger.error("process_messages", error=str(e)) logger.error("process_messages", error=str(e))
# only whereby supported.
@shared_task @shared_task
@asynctask @asynctask
@with_session async def process_recording(bucket_name: str, object_key: str):
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
logger.info("Processing recording: %s/%s", bucket_name, object_key) logger.info("Processing recording: %s/%s", bucket_name, object_key)
# extract a guid and a datetime from the object key room_name_part, recorded_at = parse_whereby_recording_filename(object_key)
room_name = f"/{object_key[:36]}"
recorded_at = parse_datetime_with_timezone(object_key[37:57])
meeting = await meetings_controller.get_by_room_name(session, room_name) # we store whereby api room names, NOT whereby room names
if not meeting: room_name = room_name_to_whereby_api_room_name(room_name_part)
logger.warning("Room not found, may be deleted ?", room_name=room_name)
return
room = await rooms_controller.get_by_id(session, meeting.room_id) meeting = await meetings_controller.get_by_room_name(room_name)
room = await rooms_controller.get_by_id(meeting.room_id)
recording = await recordings_controller.get_by_object_key( recording = await recordings_controller.get_by_object_key(bucket_name, object_key)
session, bucket_name, object_key
)
if not recording: if not recording:
recording = await recordings_controller.create( recording = await recordings_controller.create(
session,
Recording( Recording(
bucket_name=bucket_name, bucket_name=bucket_name,
object_key=object_key, object_key=object_key,
recorded_at=recorded_at, recorded_at=recorded_at,
meeting_id=meeting.id, meeting_id=meeting.id,
), )
) )
transcript = await transcripts_controller.get_by_recording_id(session, recording.id) transcript = await transcripts_controller.get_by_recording_id(recording.id)
if transcript: if transcript:
await transcripts_controller.update( await transcripts_controller.update(
session,
transcript, transcript,
{ {
"topics": [], "topics": [],
"participants": [],
}, },
) )
else: else:
transcript = await transcripts_controller.add( transcript = await transcripts_controller.add(
session,
"", "",
source_kind=SourceKind.ROOM, source_kind=SourceKind.ROOM,
source_language="en", source_language="en",
@@ -133,15 +133,15 @@ async def process_recording(session: AsyncSession, bucket_name: str, object_key:
upload_filename = transcript.data_path / f"upload{extension}" upload_filename = transcript.data_path / f"upload{extension}"
upload_filename.parent.mkdir(parents=True, exist_ok=True) upload_filename.parent.mkdir(parents=True, exist_ok=True)
s3 = boto3.client( storage = get_transcripts_storage()
"s3",
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION,
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY,
)
with open(upload_filename, "wb") as f: try:
s3.download_fileobj(bucket_name, object_key, f) with open(upload_filename, "wb") as f:
await storage.stream_to_fileobj(object_key, f, bucket=bucket_name)
except Exception:
# Clean up partial file on stream failure
upload_filename.unlink(missing_ok=True)
raise
container = av.open(upload_filename.as_posix()) container = av.open(upload_filename.as_posix())
try: try:
@@ -153,15 +153,173 @@ async def process_recording(session: AsyncSession, bucket_name: str, object_key:
finally: finally:
container.close() container.close()
await transcripts_controller.update(session, transcript, {"status": "uploaded"}) await transcripts_controller.update(transcript, {"status": "uploaded"})
task_pipeline_file_process.delay(transcript_id=transcript.id) task_pipeline_file_process.delay(transcript_id=transcript.id)
@shared_task @shared_task
@asynctask @asynctask
@with_session async def process_multitrack_recording(
async def process_meetings(session: AsyncSession): bucket_name: str,
daily_room_name: DailyRoomName,
recording_id: str,
track_keys: list[str],
):
logger.info(
"Processing multitrack recording",
bucket=bucket_name,
room_name=daily_room_name,
recording_id=recording_id,
provided_keys=len(track_keys),
)
if not track_keys:
logger.warning("No audio track keys provided")
return
tz = timezone.utc
recorded_at = datetime.now(tz)
try:
if track_keys:
folder = os.path.basename(os.path.dirname(track_keys[0]))
ts_match = re.search(r"(\d{14})$", folder)
if ts_match:
ts = ts_match.group(1)
recorded_at = datetime.strptime(ts, "%Y%m%d%H%M%S").replace(tzinfo=tz)
except Exception as e:
logger.warning(
f"Could not parse recorded_at from keys, using now() {recorded_at}",
e,
exc_info=True,
)
meeting = await meetings_controller.get_by_room_name(daily_room_name)
room_name_base = extract_base_room_name(daily_room_name)
room = await rooms_controller.get_by_name(room_name_base)
if not room:
raise Exception(f"Room not found: {room_name_base}")
if not meeting:
raise Exception(f"Meeting not found: {room_name_base}")
logger.info(
"Found existing Meeting for recording",
meeting_id=meeting.id,
room_name=daily_room_name,
recording_id=recording_id,
)
recording = await recordings_controller.get_by_id(recording_id)
if not recording:
object_key_dir = os.path.dirname(track_keys[0]) if track_keys else ""
recording = await recordings_controller.create(
Recording(
id=recording_id,
bucket_name=bucket_name,
object_key=object_key_dir,
recorded_at=recorded_at,
meeting_id=meeting.id,
track_keys=track_keys,
)
)
else:
# Recording already exists; assume metadata was set at creation time
pass
transcript = await transcripts_controller.get_by_recording_id(recording.id)
if transcript:
await transcripts_controller.update(
transcript,
{
"topics": [],
"participants": [],
},
)
else:
transcript = await transcripts_controller.add(
"",
source_kind=SourceKind.ROOM,
source_language="en",
target_language="en",
user_id=room.user_id,
recording_id=recording.id,
share_mode="public",
meeting_id=meeting.id,
room_id=room.id,
)
try:
daily_client = create_platform_client("daily")
id_to_name = {}
id_to_user_id = {}
mtg_session_id = None
try:
rec_details = await daily_client.get_recording(recording_id)
mtg_session_id = rec_details.get("mtgSessionId")
except Exception as e:
logger.warning(
"Failed to fetch Daily recording details",
error=str(e),
recording_id=recording_id,
exc_info=True,
)
if mtg_session_id:
try:
payload = await daily_client.get_meeting_participants(mtg_session_id)
for p in payload.get("data", []):
pid = p.get("participant_id")
name = p.get("user_name")
user_id = p.get("user_id")
if pid and name:
id_to_name[pid] = name
if pid and user_id:
id_to_user_id[pid] = user_id
except Exception as e:
logger.warning(
"Failed to fetch Daily meeting participants",
error=str(e),
mtg_session_id=mtg_session_id,
exc_info=True,
)
else:
logger.warning(
"No mtgSessionId found for recording; participant names may be generic",
recording_id=recording_id,
)
for idx, key in enumerate(track_keys):
base = os.path.basename(key)
m = re.search(r"\d{13,}-([0-9a-fA-F-]{36})-cam-audio-", base)
participant_id = m.group(1) if m else None
default_name = f"Speaker {idx}"
name = id_to_name.get(participant_id, default_name)
user_id = id_to_user_id.get(participant_id)
participant = TranscriptParticipant(
id=participant_id, speaker=idx, name=name, user_id=user_id
)
await transcripts_controller.upsert_participant(transcript, participant)
except Exception as e:
logger.warning("Failed to map participant names", error=str(e), exc_info=True)
task_pipeline_multitrack_process.delay(
transcript_id=transcript.id,
bucket_name=bucket_name,
track_keys=track_keys,
)
@shared_task
@asynctask
async def process_meetings():
""" """
Checks which meetings are still active and deactivates those that have ended. Checks which meetings are still active and deactivates those that have ended.
@@ -177,8 +335,8 @@ async def process_meetings(session: AsyncSession):
Uses distributed locking to prevent race conditions when multiple workers Uses distributed locking to prevent race conditions when multiple workers
process the same meeting simultaneously. process the same meeting simultaneously.
""" """
logger.info("Processing meetings") logger.debug("Processing meetings")
meetings = await meetings_controller.get_all_active(session) meetings = await meetings_controller.get_all_active()
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
redis_client = get_redis_client() redis_client = get_redis_client()
processed_count = 0 processed_count = 0
@@ -202,7 +360,8 @@ async def process_meetings(session: AsyncSession):
end_date = end_date.replace(tzinfo=timezone.utc) end_date = end_date.replace(tzinfo=timezone.utc)
# This API call could be slow, extend lock if needed # This API call could be slow, extend lock if needed
response = await get_room_sessions(meeting.room_name) client = create_platform_client(meeting.platform)
room_sessions = await client.get_room_sessions(meeting.room_name)
try: try:
# Extend lock after slow operation to ensure we still hold it # Extend lock after slow operation to ensure we still hold it
@@ -211,7 +370,6 @@ async def process_meetings(session: AsyncSession):
logger_.warning("Lost lock for meeting, skipping") logger_.warning("Lost lock for meeting, skipping")
continue continue
room_sessions = response.get("results", [])
has_active_sessions = room_sessions and any( has_active_sessions = room_sessions and any(
rs["endedAt"] is None for rs in room_sessions rs["endedAt"] is None for rs in room_sessions
) )
@@ -231,9 +389,7 @@ async def process_meetings(session: AsyncSession):
logger_.debug("Meeting not yet started, keep it") logger_.debug("Meeting not yet started, keep it")
if should_deactivate: if should_deactivate:
await meetings_controller.update_meeting( await meetings_controller.update_meeting(meeting.id, is_active=False)
session, meeting.id, is_active=False
)
logger_.info("Meeting is deactivated") logger_.info("Meeting is deactivated")
processed_count += 1 processed_count += 1
@@ -246,72 +402,120 @@ async def process_meetings(session: AsyncSession):
except LockError: except LockError:
pass # Lock already released or expired pass # Lock already released or expired
logger.info( logger.debug(
"Processed meetings finished", "Processed meetings finished",
processed_count=processed_count, processed_count=processed_count,
skipped_count=skipped_count, skipped_count=skipped_count,
) )
async def convert_audio_and_waveform(transcript) -> None:
"""Convert WebM to MP3 and generate waveform for Daily.co recordings.
This bypasses the full file pipeline which would overwrite stub data.
"""
try:
logger.info(
"Converting audio to MP3 and generating waveform",
transcript_id=transcript.id,
)
upload_path = transcript.data_path / "upload.webm"
mp3_path = transcript.audio_mp3_filename
# Convert WebM to MP3
mp3_writer = AudioFileWriterProcessor(path=mp3_path)
container = av.open(str(upload_path))
for frame in container.decode(audio=0):
await mp3_writer.push(frame)
await mp3_writer.flush()
container.close()
logger.info(
"Converted WebM to MP3",
transcript_id=transcript.id,
mp3_size=mp3_path.stat().st_size,
)
waveform_processor = AudioWaveformProcessor(
audio_path=mp3_path,
waveform_path=transcript.audio_waveform_filename,
)
waveform_processor.set_pipeline(EmptyPipeline(logger))
await waveform_processor.flush()
logger.info(
"Generated waveform",
transcript_id=transcript.id,
waveform_path=transcript.audio_waveform_filename,
)
# Update transcript status to ended (successful)
await transcripts_controller.update(transcript, {"status": "ended"})
except Exception as e:
logger.error(
"Failed to convert audio or generate waveform",
transcript_id=transcript.id,
error=str(e),
)
# Keep status as uploaded even if conversion fails
pass
@shared_task @shared_task
@asynctask @asynctask
@with_session async def reprocess_failed_recordings():
async def reprocess_failed_recordings(session: AsyncSession):
""" """
Find recordings in the S3 bucket and check if they have proper transcriptions. Find recordings in Whereby S3 bucket and check if they have proper transcriptions.
If not, requeue them for processing. If not, requeue them for processing.
"""
logger.info("Checking for recordings that need processing or reprocessing")
s3 = boto3.client( Note: Daily.co recordings are processed via webhooks, not this cron job.
"s3", """
region_name=settings.TRANSCRIPT_STORAGE_AWS_REGION, logger.info("Checking Whereby recordings that need processing or reprocessing")
aws_access_key_id=settings.TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID,
aws_secret_access_key=settings.TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY, if not settings.WHEREBY_STORAGE_AWS_BUCKET_NAME:
) raise ValueError(
"WHEREBY_STORAGE_AWS_BUCKET_NAME required for Whereby recording reprocessing. "
"Set WHEREBY_STORAGE_AWS_BUCKET_NAME environment variable."
)
storage = get_transcripts_storage()
bucket_name = settings.WHEREBY_STORAGE_AWS_BUCKET_NAME
reprocessed_count = 0 reprocessed_count = 0
try: try:
paginator = s3.get_paginator("list_objects_v2") object_keys = await storage.list_objects(prefix="", bucket=bucket_name)
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
pages = paginator.paginate(Bucket=bucket_name)
for page in pages: for object_key in object_keys:
if "Contents" not in page: if not object_key.endswith(".mp4"):
continue continue
for obj in page["Contents"]: recording = await recordings_controller.get_by_object_key(
object_key = obj["Key"] bucket_name, object_key
)
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
if not (object_key.endswith(".mp4")): transcript = None
continue try:
transcript = await transcripts_controller.get_by_recording_id(
recording = await recordings_controller.get_by_object_key( recording.id
session, bucket_name, object_key )
except ValidationError:
await transcripts_controller.remove_by_recording_id(recording.id)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
) )
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
transcript = None if transcript is None or transcript.status == "error":
try: logger.info(f"Queueing recording for processing: {object_key}")
transcript = await transcripts_controller.get_by_recording_id( process_recording.delay(bucket_name, object_key)
session, recording.id reprocessed_count += 1
)
except ValidationError:
await transcripts_controller.remove_by_recording_id(
session, recording.id
)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
)
if transcript is None or transcript.status == "error":
logger.info(f"Queueing recording for processing: {object_key}")
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
except Exception as e: except Exception as e:
logger.error(f"Error checking S3 bucket: {str(e)}") logger.error(f"Error checking S3 bucket: {str(e)}")

View File

@@ -1,109 +0,0 @@
"""
Session management decorator for async worker tasks.
This decorator ensures that all worker tasks have a properly managed database session
that stays open for the entire duration of the task execution.
"""
import functools
from typing import Any, Callable, TypeVar
from celery import current_task
from reflector.db import get_session_factory
from reflector.db.transcripts import transcripts_controller
from reflector.logger import logger
F = TypeVar("F", bound=Callable[..., Any])
def with_session(func: F) -> F:
"""
Decorator that provides an AsyncSession as the first argument to the decorated function.
This should be used AFTER the @asynctask decorator on Celery tasks to ensure
proper session management throughout the task execution.
Example:
@shared_task
@asynctask
@with_session
async def my_task(session: AsyncSession, arg1: str, arg2: int):
# session is automatically provided and managed
result = await some_controller.get_by_id(session, arg1)
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
# Pass session as first argument to the decorated function
return await func(session, *args, **kwargs)
return wrapper
def with_session_and_transcript(func: F) -> F:
"""
Decorator that provides both an AsyncSession and a Transcript to the decorated function.
This decorator:
1. Extracts transcript_id from kwargs
2. Creates and manages a database session
3. Fetches the transcript using the session
4. Creates an enhanced logger with Celery task context
5. Passes session, transcript, and logger to the decorated function
This should be used AFTER the @asynctask decorator on Celery tasks.
Example:
@shared_task
@asynctask
@with_session_and_transcript
async def my_task(session: AsyncSession, transcript: Transcript, logger: Logger, arg1: str):
# session, transcript, and logger are automatically provided
room = await rooms_controller.get_by_id(session, transcript.room_id)
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
transcript_id = kwargs.pop("transcript_id", None)
if not transcript_id:
raise ValueError(
"transcript_id is required for @with_session_and_transcript"
)
session_factory = get_session_factory()
async with session_factory() as session:
async with session.begin():
# Fetch the transcript
transcript = await transcripts_controller.get_by_id(
session, transcript_id
)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
# Create enhanced logger with Celery task context
tlogger = logger.bind(transcript_id=transcript.id)
if current_task:
tlogger = tlogger.bind(
task_id=current_task.request.id,
task_name=current_task.name,
worker_hostname=current_task.request.hostname,
task_retries=current_task.request.retries,
transcript_id=transcript_id,
)
try:
# Pass session, transcript, and logger to the decorated function
return await func(
session, transcript=transcript, logger=tlogger, *args, **kwargs
)
except Exception:
tlogger.exception("Error in task execution")
raise
return wrapper

View File

@@ -10,14 +10,14 @@ import httpx
import structlog import structlog
from celery import shared_task from celery import shared_task
from celery.utils.log import get_task_logger from celery.utils.log import get_task_logger
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_live_pipeline import asynctask from reflector.pipelines.main_live_pipeline import asynctask
from reflector.settings import settings from reflector.settings import settings
from reflector.utils.webvtt import topics_to_webvtt from reflector.utils.webvtt import topics_to_webvtt
from reflector.worker.session_decorator import with_session
logger = structlog.wrap_logger(get_task_logger(__name__)) logger = structlog.wrap_logger(get_task_logger(__name__))
@@ -41,13 +41,11 @@ def generate_webhook_signature(payload: bytes, secret: str, timestamp: str) -> s
retry_backoff_max=3600, # Max 1 hour between retries retry_backoff_max=3600, # Max 1 hour between retries
) )
@asynctask @asynctask
@with_session
async def send_transcript_webhook( async def send_transcript_webhook(
self, self,
transcript_id: str, transcript_id: str,
room_id: str, room_id: str,
event_id: str, event_id: str,
session: AsyncSession,
): ):
log = logger.bind( log = logger.bind(
transcript_id=transcript_id, transcript_id=transcript_id,
@@ -57,12 +55,12 @@ async def send_transcript_webhook(
try: try:
# Fetch transcript and room # Fetch transcript and room
transcript = await transcripts_controller.get_by_id(session, transcript_id) transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript: if not transcript:
log.error("Transcript not found, skipping webhook") log.error("Transcript not found, skipping webhook")
return return
room = await rooms_controller.get_by_id(session, room_id) room = await rooms_controller.get_by_id(room_id)
if not room: if not room:
log.error("Room not found, skipping webhook") log.error("Room not found, skipping webhook")
return return
@@ -88,6 +86,18 @@ async def send_transcript_webhook(
} }
) )
# Fetch meeting and calendar event if they exist
calendar_event = None
try:
if transcript.meeting_id:
meeting = await meetings_controller.get_by_id(transcript.meeting_id)
if meeting and meeting.calendar_event_id:
calendar_event = await calendar_events_controller.get_by_id(
meeting.calendar_event_id
)
except Exception as e:
logger.error("Error fetching meeting or calendar event", error=str(e))
# Build webhook payload # Build webhook payload
frontend_url = f"{settings.UI_BASE_URL}/transcripts/{transcript.id}" frontend_url = f"{settings.UI_BASE_URL}/transcripts/{transcript.id}"
participants = [ participants = [
@@ -120,6 +130,33 @@ async def send_transcript_webhook(
}, },
} }
# Always include calendar_event field, even if no event is present
payload_data["calendar_event"] = {}
# Add calendar event data if present
if calendar_event:
calendar_data = {
"id": calendar_event.id,
"ics_uid": calendar_event.ics_uid,
"title": calendar_event.title,
"start_time": calendar_event.start_time.isoformat()
if calendar_event.start_time
else None,
"end_time": calendar_event.end_time.isoformat()
if calendar_event.end_time
else None,
}
# Add optional fields only if they exist
if calendar_event.description:
calendar_data["description"] = calendar_event.description
if calendar_event.location:
calendar_data["location"] = calendar_event.location
if calendar_event.attendees:
calendar_data["attendees"] = calendar_event.attendees
payload_data["calendar_event"] = calendar_data
# Convert to JSON # Convert to JSON
payload_json = json.dumps(payload_data, separators=(",", ":")) payload_json = json.dumps(payload_data, separators=(",", ":"))
payload_bytes = payload_json.encode("utf-8") payload_bytes = payload_json.encode("utf-8")

View File

@@ -65,8 +65,13 @@ class WebsocketManager:
self.tasks: dict = {} self.tasks: dict = {}
self.pubsub_client = pubsub_client self.pubsub_client = pubsub_client
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None: async def add_user_to_room(
await websocket.accept() self, room_id: str, websocket: WebSocket, subprotocol: str | None = None
) -> None:
if subprotocol:
await websocket.accept(subprotocol=subprotocol)
else:
await websocket.accept()
if room_id in self.rooms: if room_id in self.rooms:
self.rooms[room_id].append(websocket) self.rooms[room_id].append(websocket)

View File

@@ -0,0 +1,123 @@
#!/usr/bin/env python3
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import httpx
from reflector.settings import settings
async def setup_webhook(webhook_url: str):
"""
Create or update Daily.co webhook for this environment.
Uses DAILY_WEBHOOK_UUID to identify existing webhook.
"""
if not settings.DAILY_API_KEY:
print("Error: DAILY_API_KEY not set")
return 1
headers = {
"Authorization": f"Bearer {settings.DAILY_API_KEY}",
"Content-Type": "application/json",
}
webhook_data = {
"url": webhook_url,
"eventTypes": [
"participant.joined",
"participant.left",
"recording.started",
"recording.ready-to-download",
"recording.error",
],
"hmac": settings.DAILY_WEBHOOK_SECRET,
}
async with httpx.AsyncClient() as client:
webhook_uuid = settings.DAILY_WEBHOOK_UUID
if webhook_uuid:
# Update existing webhook
print(f"Updating existing webhook {webhook_uuid}...")
try:
resp = await client.patch(
f"https://api.daily.co/v1/webhooks/{webhook_uuid}",
headers=headers,
json=webhook_data,
)
resp.raise_for_status()
result = resp.json()
print(f"✓ Updated webhook {result['uuid']} (state: {result['state']})")
print(f" URL: {result['url']}")
return 0
except httpx.HTTPStatusError as e:
if e.response.status_code == 404:
print(f"Webhook {webhook_uuid} not found, creating new one...")
webhook_uuid = None # Fall through to creation
else:
print(f"Error updating webhook: {e}")
return 1
if not webhook_uuid:
# Create new webhook
print("Creating new webhook...")
resp = await client.post(
"https://api.daily.co/v1/webhooks", headers=headers, json=webhook_data
)
resp.raise_for_status()
result = resp.json()
webhook_uuid = result["uuid"]
print(f"✓ Created webhook {webhook_uuid} (state: {result['state']})")
print(f" URL: {result['url']}")
print()
print("=" * 60)
print("IMPORTANT: Add this to your environment variables:")
print("=" * 60)
print(f"DAILY_WEBHOOK_UUID: {webhook_uuid}")
print("=" * 60)
print()
# Try to write UUID to .env file
env_file = Path(__file__).parent.parent / ".env"
if env_file.exists():
lines = env_file.read_text().splitlines()
updated = False
# Update existing DAILY_WEBHOOK_UUID line or add it
for i, line in enumerate(lines):
if line.startswith("DAILY_WEBHOOK_UUID="):
lines[i] = f"DAILY_WEBHOOK_UUID={webhook_uuid}"
updated = True
break
if not updated:
lines.append(f"DAILY_WEBHOOK_UUID={webhook_uuid}")
env_file.write_text("\n".join(lines) + "\n")
print(f"✓ Also saved to local .env file")
else:
print(f"⚠ Local .env file not found - please add manually")
return 0
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python recreate_daily_webhook.py <webhook_url>")
print(
"Example: python recreate_daily_webhook.py https://example.com/v1/daily/webhook"
)
print()
print("Behavior:")
print(" - If DAILY_WEBHOOK_UUID set: Updates existing webhook")
print(
" - If DAILY_WEBHOOK_UUID empty: Creates new webhook, saves UUID to .env"
)
sys.exit(1)
sys.exit(asyncio.run(setup_webhook(sys.argv[1])))

View File

@@ -1,20 +1,21 @@
import asyncio
import os import os
import sys from contextlib import asynccontextmanager
from tempfile import NamedTemporaryFile from tempfile import NamedTemporaryFile
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from reflector.schemas.platform import WHEREBY_PLATFORM
@pytest.fixture(scope="session")
def event_loop():
if sys.platform.startswith("win") and sys.version_info[:2] >= (3, 8):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
loop = asyncio.new_event_loop() @pytest.fixture(scope="session", autouse=True)
yield loop def register_mock_platform():
loop.close() from mocks.mock_platform import MockPlatformClient
from reflector.video_platforms.registry import register_platform
register_platform(WHEREBY_PLATFORM, MockPlatformClient)
yield
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@@ -47,6 +48,7 @@ def docker_compose_file(pytestconfig):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services): def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive."""
port = docker_services.port_for("postgres_test", 5432) port = docker_services.port_for("postgres_test", 5432)
def is_responsive(): def is_responsive():
@@ -67,6 +69,7 @@ def postgres_service(docker_ip, docker_services):
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive) docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
# Return connection parameters
return { return {
"host": docker_ip, "host": docker_ip,
"port": port, "port": port,
@@ -76,27 +79,20 @@ def postgres_service(docker_ip, docker_services):
} }
@pytest.fixture(scope="session") @pytest.fixture(scope="function", autouse=True)
def _database_url(postgres_service): @pytest.mark.asyncio
db_config = postgres_service async def setup_database(postgres_service):
DATABASE_URL = ( from reflector.db import engine, metadata, get_database # noqa
f"postgresql+asyncpg://{db_config['user']}:{db_config['password']}"
f"@{db_config['host']}:{db_config['port']}/{db_config['dbname']}"
)
# Override settings metadata.drop_all(bind=engine)
from reflector.settings import settings metadata.create_all(bind=engine)
database = get_database()
settings.DATABASE_URL = DATABASE_URL try:
await database.connect()
return DATABASE_URL yield
finally:
await database.disconnect()
@pytest.fixture(scope="session")
def init_database():
from reflector.db import Base
return Base.metadata.create_all
@pytest.fixture @pytest.fixture
@@ -344,17 +340,8 @@ def celery_includes():
] ]
@pytest.fixture(autouse=True)
async def ensure_db_session_in_app(db_session):
async def mock_get_session():
yield db_session
with patch("reflector.db._get_session", side_effect=mock_get_session):
yield
@pytest.fixture @pytest.fixture
async def client(db_session): async def client():
from httpx import AsyncClient from httpx import AsyncClient
from reflector.app import app from reflector.app import app
@@ -363,6 +350,166 @@ async def client(db_session):
yield ac yield ac
@pytest.fixture(autouse=True)
async def ws_manager_in_memory(monkeypatch):
"""Replace Redis-based WS manager with an in-memory implementation for tests."""
import asyncio
import json
from reflector.ws_manager import WebsocketManager
class _InMemorySubscriber:
def __init__(self, queue: asyncio.Queue):
self.queue = queue
async def get_message(self, ignore_subscribe_messages: bool = True):
try:
return await asyncio.wait_for(self.queue.get(), timeout=0.05)
except Exception:
return None
class InMemoryPubSubManager:
def __init__(self):
self.queues: dict[str, asyncio.Queue] = {}
self.connected = False
async def connect(self) -> None:
self.connected = True
async def disconnect(self) -> None:
self.connected = False
async def send_json(self, room_id: str, message: dict) -> None:
if room_id not in self.queues:
self.queues[room_id] = asyncio.Queue()
payload = json.dumps(message).encode("utf-8")
await self.queues[room_id].put(
{"channel": room_id.encode("utf-8"), "data": payload}
)
async def subscribe(self, room_id: str):
if room_id not in self.queues:
self.queues[room_id] = asyncio.Queue()
return _InMemorySubscriber(self.queues[room_id])
async def unsubscribe(self, room_id: str) -> None:
# keep queue for potential later resubscribe within same test
pass
pubsub = InMemoryPubSubManager()
ws_manager = WebsocketManager(pubsub_client=pubsub)
def _get_ws_manager():
return ws_manager
# Patch all places that imported get_ws_manager at import time
monkeypatch.setattr("reflector.ws_manager.get_ws_manager", _get_ws_manager)
monkeypatch.setattr(
"reflector.pipelines.main_live_pipeline.get_ws_manager", _get_ws_manager
)
monkeypatch.setattr(
"reflector.views.transcripts_websocket.get_ws_manager", _get_ws_manager
)
monkeypatch.setattr(
"reflector.views.user_websocket.get_ws_manager", _get_ws_manager
)
monkeypatch.setattr("reflector.views.transcripts.get_ws_manager", _get_ws_manager)
# Websocket auth: avoid OAuth2 on websocket dependencies; allow anonymous
import reflector.auth as auth
# Ensure FastAPI uses our override for routes that captured the original callable
from reflector.app import app as fastapi_app
try:
fastapi_app.dependency_overrides[auth.current_user_optional] = lambda: None
except Exception:
pass
# Stub Redis cache used by profanity filter to avoid external Redis
from reflector import redis_cache as rc
class _FakeRedis:
def __init__(self):
self._data = {}
def get(self, key):
value = self._data.get(key)
if value is None:
return None
if isinstance(value, bytes):
return value
return str(value).encode("utf-8")
def setex(self, key, duration, value):
# ignore duration for tests
if isinstance(value, bytes):
self._data[key] = value
else:
self._data[key] = str(value).encode("utf-8")
fake_redises: dict[int, _FakeRedis] = {}
def _get_redis_client(db=0):
if db not in fake_redises:
fake_redises[db] = _FakeRedis()
return fake_redises[db]
monkeypatch.setattr(rc, "get_redis_client", _get_redis_client)
yield
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client():
async with authenticated_client_ctx():
yield
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client2():
async with authenticated_client2_ctx():
yield
@asynccontextmanager
async def authenticated_client_ctx():
from reflector.app import app
from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {
"sub": "randomuserid",
"email": "test@mail.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "randomuserid",
"email": "test@mail.com",
}
yield
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@asynccontextmanager
async def authenticated_client2_ctx():
from reflector.app import app
from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {
"sub": "randomuserid2",
"email": "test@mail.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "randomuserid2",
"email": "test@mail.com",
}
yield
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def fake_mp3_upload(): def fake_mp3_upload():
with patch( with patch(
@@ -373,7 +520,7 @@ def fake_mp3_upload():
@pytest.fixture @pytest.fixture
async def fake_transcript_with_topics(tmpdir, client, db_session): async def fake_transcript_with_topics(tmpdir, client):
import shutil import shutil
from pathlib import Path from pathlib import Path
@@ -389,10 +536,10 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
assert response.status_code == 200 assert response.status_code == 200
tid = response.json()["id"] tid = response.json()["id"]
transcript = await transcripts_controller.get_by_id(db_session, tid) transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None assert transcript is not None
await transcripts_controller.update(db_session, transcript, {"status": "ended"}) await transcripts_controller.update(transcript, {"status": "ended"})
# manually copy a file at the expected location # manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename audio_filename = transcript.audio_mp3_filename
@@ -402,7 +549,6 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
# create some topics # create some topics
await transcripts_controller.upsert_topic( await transcripts_controller.upsert_topic(
db_session,
transcript, transcript,
TranscriptTopic( TranscriptTopic(
title="Topic 1", title="Topic 1",
@@ -416,7 +562,6 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
), ),
) )
await transcripts_controller.upsert_topic( await transcripts_controller.upsert_topic(
db_session,
transcript, transcript,
TranscriptTopic( TranscriptTopic(
title="Topic 2", title="Topic 2",

View File

View File

@@ -0,0 +1,112 @@
import uuid
from datetime import datetime
from typing import Any, Dict, Literal, Optional
from reflector.db.rooms import Room
from reflector.video_platforms.base import (
ROOM_PREFIX_SEPARATOR,
MeetingData,
VideoPlatformClient,
VideoPlatformConfig,
)
MockPlatform = Literal["mock"]
class MockPlatformClient(VideoPlatformClient):
PLATFORM_NAME: MockPlatform = "mock"
def __init__(self, config: VideoPlatformConfig):
super().__init__(config)
self._rooms: Dict[str, Dict[str, Any]] = {}
self._webhook_calls: list[Dict[str, Any]] = []
async def create_meeting(
self, room_name_prefix: str, end_date: datetime, room: Room
) -> MeetingData:
meeting_id = str(uuid.uuid4())
room_name = f"{room_name_prefix}{ROOM_PREFIX_SEPARATOR}{meeting_id[:8]}"
room_url = f"https://mock.video/{room_name}"
host_room_url = f"{room_url}?host=true"
self._rooms[room_name] = {
"id": meeting_id,
"name": room_name,
"url": room_url,
"host_url": host_room_url,
"end_date": end_date,
"room": room,
"participants": [],
"is_active": True,
}
return MeetingData.model_construct(
meeting_id=meeting_id,
room_name=room_name,
room_url=room_url,
host_room_url=host_room_url,
platform="whereby",
extra_data={"mock": True},
)
async def get_room_sessions(self, room_name: str) -> Dict[str, Any]:
if room_name not in self._rooms:
return {"error": "Room not found"}
room_data = self._rooms[room_name]
return {
"roomName": room_name,
"sessions": [
{
"sessionId": room_data["id"],
"startTime": datetime.utcnow().isoformat(),
"participants": room_data["participants"],
"isActive": room_data["is_active"],
}
],
}
async def delete_room(self, room_name: str) -> bool:
if room_name in self._rooms:
self._rooms[room_name]["is_active"] = False
return True
return False
async def upload_logo(self, room_name: str, logo_path: str) -> bool:
if room_name in self._rooms:
self._rooms[room_name]["logo_path"] = logo_path
return True
return False
def verify_webhook_signature(
self, body: bytes, signature: str, timestamp: Optional[str] = None
) -> bool:
return signature == "valid"
def add_participant(
self, room_name: str, participant_id: str, participant_name: str
):
if room_name in self._rooms:
self._rooms[room_name]["participants"].append(
{
"id": participant_id,
"name": participant_name,
"joined_at": datetime.utcnow().isoformat(),
}
)
def trigger_webhook(self, event_type: str, data: Dict[str, Any]):
self._webhook_calls.append(
{
"type": event_type,
"data": data,
"timestamp": datetime.utcnow().isoformat(),
}
)
def get_webhook_calls(self) -> list[Dict[str, Any]]:
return self._webhook_calls.copy()
def clear_data(self):
self._rooms.clear()
self._webhook_calls.clear()

View File

@@ -1,5 +1,5 @@
import os import os
from unittest.mock import patch from unittest.mock import AsyncMock, patch
import pytest import pytest
@@ -8,7 +8,7 @@ from reflector.services.ics_sync import ICSSyncService
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_attendee_parsing_bug(db_session): async def test_attendee_parsing_bug():
""" """
Test that reproduces the attendee parsing bug where a string with comma-separated Test that reproduces the attendee parsing bug where a string with comma-separated
emails gets parsed as individual characters instead of separate email addresses. emails gets parsed as individual characters instead of separate email addresses.
@@ -16,8 +16,8 @@ async def test_attendee_parsing_bug(db_session):
The bug manifests as getting 29 attendees with emails like "M", "A", "I", etc. The bug manifests as getting 29 attendees with emails like "M", "A", "I", etc.
instead of properly parsed email addresses. instead of properly parsed email addresses.
""" """
# Create a test room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="test-room", name="test-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -31,8 +31,8 @@ async def test_attendee_parsing_bug(db_session):
ics_url="http://test.com/test.ics", ics_url="http://test.com/test.ics",
ics_enabled=True, ics_enabled=True,
) )
await db_session.flush()
# Read the test ICS file that reproduces the bug and update it with current time
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
test_ics_path = os.path.join( test_ics_path = os.path.join(
@@ -41,26 +41,30 @@ async def test_attendee_parsing_bug(db_session):
with open(test_ics_path, "r") as f: with open(test_ics_path, "r") as f:
ics_content = f.read() ics_content = f.read()
# Replace the dates with current time + 1 hour to ensure it's within the 24h window
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
future_time = now + timedelta(hours=1) future_time = now + timedelta(hours=1)
end_time = future_time + timedelta(hours=1) end_time = future_time + timedelta(hours=1)
# Format dates for ICS format
dtstart = future_time.strftime("%Y%m%dT%H%M%SZ") dtstart = future_time.strftime("%Y%m%dT%H%M%SZ")
dtend = end_time.strftime("%Y%m%dT%H%M%SZ") dtend = end_time.strftime("%Y%m%dT%H%M%SZ")
dtstamp = now.strftime("%Y%m%dT%H%M%SZ") dtstamp = now.strftime("%Y%m%dT%H%M%SZ")
# Update the ICS content with current dates
ics_content = ics_content.replace("20250910T180000Z", dtstart) ics_content = ics_content.replace("20250910T180000Z", dtstart)
ics_content = ics_content.replace("20250910T190000Z", dtend) ics_content = ics_content.replace("20250910T190000Z", dtend)
ics_content = ics_content.replace("20250910T174000Z", dtstamp) ics_content = ics_content.replace("20250910T174000Z", dtstamp)
# Create sync service and mock the fetch
sync_service = ICSSyncService() sync_service = ICSSyncService()
from unittest.mock import AsyncMock
with patch.object( with patch.object(
sync_service.fetch_service, "fetch_ics", new_callable=AsyncMock sync_service.fetch_service, "fetch_ics", new_callable=AsyncMock
) as mock_fetch: ) as mock_fetch:
mock_fetch.return_value = ics_content mock_fetch.return_value = ics_content
# Debug: Parse the ICS content directly to examine attendee parsing
calendar = sync_service.fetch_service.parse_ics(ics_content) calendar = sync_service.fetch_service.parse_ics(ics_content)
from reflector.settings import settings from reflector.settings import settings
@@ -76,23 +80,113 @@ async def test_attendee_parsing_bug(db_session):
print(f"Total events in calendar: {total_events}") print(f"Total events in calendar: {total_events}")
print(f"Events matching room: {len(events)}") print(f"Events matching room: {len(events)}")
result = await sync_service.sync_room_calendar(db_session, room) # Perform the sync
result = await sync_service.sync_room_calendar(room)
# Check that the sync succeeded
assert result.get("status") == "success" assert result.get("status") == "success"
assert result.get("events_found", 0) >= 0 assert result.get("events_found", 0) >= 0 # Allow for debugging
# We already have the matching events from the debug code above
assert len(events) == 1 assert len(events) == 1
event = events[0] event = events[0]
attendees = event["attendees"] # This is where the bug manifests - check the attendees
attendees = event["attendees"]
print(f"Number of attendees: {len(attendees)}") # Print attendee info for debugging
for i, attendee in enumerate(attendees): print(f"Number of attendees found: {len(attendees)}")
print(f"Attendee {i}: {attendee}") for i, attendee in enumerate(attendees):
print(
f"Attendee {i}: email='{attendee.get('email')}', name='{attendee.get('name')}'"
)
assert len(attendees) == 30, f"Expected 30 attendees, got {len(attendees)}" # With the fix, we should now get properly parsed email addresses
# Check that no single characters are parsed as emails
single_char_emails = [
att for att in attendees if att.get("email") and len(att["email"]) == 1
]
assert attendees[0]["email"] == "alice@example.com" if single_char_emails:
assert attendees[1]["email"] == "bob@example.com" print(
assert attendees[2]["email"] == "charlie@example.com" f"BUG DETECTED: Found {len(single_char_emails)} single-character emails:"
assert any(att["email"] == "organizer@example.com" for att in attendees) )
for att in single_char_emails:
print(f" - '{att['email']}'")
# Should have attendees but not single-character emails
assert len(attendees) > 0
assert (
len(single_char_emails) == 0
), f"Found {len(single_char_emails)} single-character emails, parsing is still buggy"
# Check that all emails are valid (contain @ symbol)
valid_emails = [
att for att in attendees if att.get("email") and "@" in att["email"]
]
assert len(valid_emails) == len(
attendees
), "Some attendees don't have valid email addresses"
# We expect around 29 attendees (28 from the comma-separated list + 1 organizer)
assert (
len(attendees) >= 25
), f"Expected around 29 attendees, got {len(attendees)}"
@pytest.mark.asyncio
async def test_correct_attendee_parsing():
"""
Test what correct attendee parsing should look like.
"""
from datetime import datetime, timezone
from icalendar import Event
from reflector.services.ics_sync import ICSFetchService
service = ICSFetchService()
# Create a properly formatted event with multiple attendees
event = Event()
event.add("uid", "test-correct-attendees")
event.add("summary", "Test Meeting")
event.add("location", "http://test.com/test")
event.add("dtstart", datetime.now(timezone.utc))
event.add("dtend", datetime.now(timezone.utc))
# Add attendees the correct way (separate ATTENDEE lines)
event.add("attendee", "mailto:alice@example.com", parameters={"CN": "Alice"})
event.add("attendee", "mailto:bob@example.com", parameters={"CN": "Bob"})
event.add("attendee", "mailto:charlie@example.com", parameters={"CN": "Charlie"})
event.add(
"organizer", "mailto:organizer@example.com", parameters={"CN": "Organizer"}
)
# Parse the event
result = service._parse_event(event)
assert result is not None
attendees = result["attendees"]
# Should have 4 attendees (3 attendees + 1 organizer)
assert len(attendees) == 4
# Check that all emails are valid email addresses
emails = [att["email"] for att in attendees if att.get("email")]
expected_emails = [
"alice@example.com",
"bob@example.com",
"charlie@example.com",
"organizer@example.com",
]
for email in emails:
assert "@" in email, f"Invalid email format: {email}"
assert len(email) > 5, f"Email too short: {email}"
# Check that we have the expected emails
assert "alice@example.com" in emails
assert "bob@example.com" in emails
assert "charlie@example.com" in emails
assert "organizer@example.com" in emails

View File

@@ -11,11 +11,10 @@ from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calendar_event_create(db_session): async def test_calendar_event_create():
"""Test creating a calendar event.""" """Test creating a calendar event."""
# Create a room first # Create a room first
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="test-room", name="test-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -45,7 +44,7 @@ async def test_calendar_event_create(db_session):
) )
# Save event # Save event
saved_event = await calendar_events_controller.upsert(db_session, event) saved_event = await calendar_events_controller.upsert(event)
assert saved_event.ics_uid == "test-event-123" assert saved_event.ics_uid == "test-event-123"
assert saved_event.title == "Team Meeting" assert saved_event.title == "Team Meeting"
@@ -54,11 +53,10 @@ async def test_calendar_event_create(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calendar_event_get_by_room(db_session): async def test_calendar_event_get_by_room():
"""Test getting calendar events for a room.""" """Test getting calendar events for a room."""
# Create room # Create room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="events-room", name="events-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -82,10 +80,10 @@ async def test_calendar_event_get_by_room(db_session):
start_time=now + timedelta(hours=i), start_time=now + timedelta(hours=i),
end_time=now + timedelta(hours=i + 1), end_time=now + timedelta(hours=i + 1),
) )
await calendar_events_controller.upsert(db_session, event) await calendar_events_controller.upsert(event)
# Get events for room # Get events for room
events = await calendar_events_controller.get_by_room(db_session, room.id) events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 3 assert len(events) == 3
assert all(e.room_id == room.id for e in events) assert all(e.room_id == room.id for e in events)
@@ -95,11 +93,10 @@ async def test_calendar_event_get_by_room(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calendar_event_get_upcoming(db_session): async def test_calendar_event_get_upcoming():
"""Test getting upcoming events within time window.""" """Test getting upcoming events within time window."""
# Create room # Create room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="upcoming-room", name="upcoming-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -123,7 +120,7 @@ async def test_calendar_event_get_upcoming(db_session):
start_time=now - timedelta(hours=2), start_time=now - timedelta(hours=2),
end_time=now - timedelta(hours=1), end_time=now - timedelta(hours=1),
) )
await calendar_events_controller.upsert(db_session, past_event) await calendar_events_controller.upsert(past_event)
# Upcoming event within 30 minutes # Upcoming event within 30 minutes
upcoming_event = CalendarEvent( upcoming_event = CalendarEvent(
@@ -133,7 +130,7 @@ async def test_calendar_event_get_upcoming(db_session):
start_time=now + timedelta(minutes=15), start_time=now + timedelta(minutes=15),
end_time=now + timedelta(minutes=45), end_time=now + timedelta(minutes=45),
) )
await calendar_events_controller.upsert(db_session, upcoming_event) await calendar_events_controller.upsert(upcoming_event)
# Currently happening event (started 10 minutes ago, ends in 20 minutes) # Currently happening event (started 10 minutes ago, ends in 20 minutes)
current_event = CalendarEvent( current_event = CalendarEvent(
@@ -143,7 +140,7 @@ async def test_calendar_event_get_upcoming(db_session):
start_time=now - timedelta(minutes=10), start_time=now - timedelta(minutes=10),
end_time=now + timedelta(minutes=20), end_time=now + timedelta(minutes=20),
) )
await calendar_events_controller.upsert(db_session, current_event) await calendar_events_controller.upsert(current_event)
# Future event beyond 30 minutes # Future event beyond 30 minutes
future_event = CalendarEvent( future_event = CalendarEvent(
@@ -153,10 +150,10 @@ async def test_calendar_event_get_upcoming(db_session):
start_time=now + timedelta(hours=2), start_time=now + timedelta(hours=2),
end_time=now + timedelta(hours=3), end_time=now + timedelta(hours=3),
) )
await calendar_events_controller.upsert(db_session, future_event) await calendar_events_controller.upsert(future_event)
# Get upcoming events (default 120 minutes) - should include current, upcoming, and future # Get upcoming events (default 120 minutes) - should include current, upcoming, and future
upcoming = await calendar_events_controller.get_upcoming(db_session, room.id) upcoming = await calendar_events_controller.get_upcoming(room.id)
assert len(upcoming) == 3 assert len(upcoming) == 3
# Events should be sorted by start_time (current event first, then upcoming, then future) # Events should be sorted by start_time (current event first, then upcoming, then future)
@@ -166,7 +163,7 @@ async def test_calendar_event_get_upcoming(db_session):
# Get upcoming with custom window # Get upcoming with custom window
upcoming_extended = await calendar_events_controller.get_upcoming( upcoming_extended = await calendar_events_controller.get_upcoming(
db_session, room.id, minutes_ahead=180 room.id, minutes_ahead=180
) )
assert len(upcoming_extended) == 3 assert len(upcoming_extended) == 3
@@ -177,11 +174,10 @@ async def test_calendar_event_get_upcoming(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calendar_event_get_upcoming_includes_currently_happening(db_session): async def test_calendar_event_get_upcoming_includes_currently_happening():
"""Test that get_upcoming includes currently happening events but excludes ended events.""" """Test that get_upcoming includes currently happening events but excludes ended events."""
# Create room # Create room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="current-happening-room", name="current-happening-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -204,7 +200,7 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
start_time=now - timedelta(hours=2), start_time=now - timedelta(hours=2),
end_time=now - timedelta(minutes=30), end_time=now - timedelta(minutes=30),
) )
await calendar_events_controller.upsert(db_session, past_ended_event) await calendar_events_controller.upsert(past_ended_event)
# Event currently happening (started 10 minutes ago, ends in 20 minutes) - SHOULD be included # Event currently happening (started 10 minutes ago, ends in 20 minutes) - SHOULD be included
currently_happening_event = CalendarEvent( currently_happening_event = CalendarEvent(
@@ -214,7 +210,7 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
start_time=now - timedelta(minutes=10), start_time=now - timedelta(minutes=10),
end_time=now + timedelta(minutes=20), end_time=now + timedelta(minutes=20),
) )
await calendar_events_controller.upsert(db_session, currently_happening_event) await calendar_events_controller.upsert(currently_happening_event)
# Event starting soon (in 5 minutes) - SHOULD be included # Event starting soon (in 5 minutes) - SHOULD be included
upcoming_soon_event = CalendarEvent( upcoming_soon_event = CalendarEvent(
@@ -224,12 +220,10 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
start_time=now + timedelta(minutes=5), start_time=now + timedelta(minutes=5),
end_time=now + timedelta(minutes=35), end_time=now + timedelta(minutes=35),
) )
await calendar_events_controller.upsert(db_session, upcoming_soon_event) await calendar_events_controller.upsert(upcoming_soon_event)
# Get upcoming events # Get upcoming events
upcoming = await calendar_events_controller.get_upcoming( upcoming = await calendar_events_controller.get_upcoming(room.id, minutes_ahead=30)
db_session, room.id, minutes_ahead=30
)
# Should only include currently happening and upcoming soon events # Should only include currently happening and upcoming soon events
assert len(upcoming) == 2 assert len(upcoming) == 2
@@ -238,11 +232,10 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calendar_event_upsert(db_session): async def test_calendar_event_upsert():
"""Test upserting (create/update) calendar events.""" """Test upserting (create/update) calendar events."""
# Create room # Create room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="upsert-room", name="upsert-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -266,30 +259,29 @@ async def test_calendar_event_upsert(db_session):
end_time=now + timedelta(hours=1), end_time=now + timedelta(hours=1),
) )
created = await calendar_events_controller.upsert(db_session, event) created = await calendar_events_controller.upsert(event)
assert created.title == "Original Title" assert created.title == "Original Title"
# Update existing event # Update existing event
event.title = "Updated Title" event.title = "Updated Title"
event.description = "Added description" event.description = "Added description"
updated = await calendar_events_controller.upsert(db_session, event) updated = await calendar_events_controller.upsert(event)
assert updated.title == "Updated Title" assert updated.title == "Updated Title"
assert updated.description == "Added description" assert updated.description == "Added description"
assert updated.ics_uid == "upsert-test" assert updated.ics_uid == "upsert-test"
# Verify only one event exists # Verify only one event exists
events = await calendar_events_controller.get_by_room(db_session, room.id) events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1 assert len(events) == 1
assert events[0].title == "Updated Title" assert events[0].title == "Updated Title"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calendar_event_soft_delete(db_session): async def test_calendar_event_soft_delete():
"""Test soft deleting events no longer in calendar.""" """Test soft deleting events no longer in calendar."""
# Create room # Create room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="delete-room", name="delete-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -313,36 +305,35 @@ async def test_calendar_event_soft_delete(db_session):
start_time=now + timedelta(hours=i), start_time=now + timedelta(hours=i),
end_time=now + timedelta(hours=i + 1), end_time=now + timedelta(hours=i + 1),
) )
await calendar_events_controller.upsert(db_session, event) await calendar_events_controller.upsert(event)
# Soft delete events not in current list # Soft delete events not in current list
current_ids = ["event-0", "event-2"] # Keep events 0 and 2 current_ids = ["event-0", "event-2"] # Keep events 0 and 2
deleted_count = await calendar_events_controller.soft_delete_missing( deleted_count = await calendar_events_controller.soft_delete_missing(
db_session, room.id, current_ids room.id, current_ids
) )
assert deleted_count == 2 # Should delete events 1 and 3 assert deleted_count == 2 # Should delete events 1 and 3
# Get non-deleted events # Get non-deleted events
events = await calendar_events_controller.get_by_room( events = await calendar_events_controller.get_by_room(
db_session, room.id, include_deleted=False room.id, include_deleted=False
) )
assert len(events) == 2 assert len(events) == 2
assert {e.ics_uid for e in events} == {"event-0", "event-2"} assert {e.ics_uid for e in events} == {"event-0", "event-2"}
# Get all events including deleted # Get all events including deleted
all_events = await calendar_events_controller.get_by_room( all_events = await calendar_events_controller.get_by_room(
db_session, room.id, include_deleted=True room.id, include_deleted=True
) )
assert len(all_events) == 4 assert len(all_events) == 4
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calendar_event_past_events_not_deleted(db_session): async def test_calendar_event_past_events_not_deleted():
"""Test that past events are not soft deleted.""" """Test that past events are not soft deleted."""
# Create room # Create room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="past-events-room", name="past-events-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -365,7 +356,7 @@ async def test_calendar_event_past_events_not_deleted(db_session):
start_time=now - timedelta(hours=2), start_time=now - timedelta(hours=2),
end_time=now - timedelta(hours=1), end_time=now - timedelta(hours=1),
) )
await calendar_events_controller.upsert(db_session, past_event) await calendar_events_controller.upsert(past_event)
# Create future event # Create future event
future_event = CalendarEvent( future_event = CalendarEvent(
@@ -375,29 +366,26 @@ async def test_calendar_event_past_events_not_deleted(db_session):
start_time=now + timedelta(hours=1), start_time=now + timedelta(hours=1),
end_time=now + timedelta(hours=2), end_time=now + timedelta(hours=2),
) )
await calendar_events_controller.upsert(db_session, future_event) await calendar_events_controller.upsert(future_event)
# Try to soft delete all events (only future should be deleted) # Try to soft delete all events (only future should be deleted)
deleted_count = await calendar_events_controller.soft_delete_missing( deleted_count = await calendar_events_controller.soft_delete_missing(room.id, [])
db_session, room.id, []
)
assert deleted_count == 1 # Only future event deleted assert deleted_count == 1 # Only future event deleted
# Verify past event still exists # Verify past event still exists
events = await calendar_events_controller.get_by_room( events = await calendar_events_controller.get_by_room(
db_session, room.id, include_deleted=False room.id, include_deleted=False
) )
assert len(events) == 1 assert len(events) == 1
assert events[0].ics_uid == "past-event" assert events[0].ics_uid == "past-event"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calendar_event_with_raw_ics_data(db_session): async def test_calendar_event_with_raw_ics_data():
"""Test storing raw ICS data with calendar event.""" """Test storing raw ICS data with calendar event."""
# Create room # Create room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="raw-ics-room", name="raw-ics-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -426,13 +414,11 @@ END:VEVENT"""
ics_raw_data=raw_ics, ics_raw_data=raw_ics,
) )
saved = await calendar_events_controller.upsert(db_session, event) saved = await calendar_events_controller.upsert(event)
assert saved.ics_raw_data == raw_ics assert saved.ics_raw_data == raw_ics
# Retrieve and verify # Retrieve and verify
retrieved = await calendar_events_controller.get_by_ics_uid( retrieved = await calendar_events_controller.get_by_ics_uid(room.id, "test-raw-123")
db_session, room.id, "test-raw-123"
)
assert retrieved is not None assert retrieved is not None
assert retrieved.ics_raw_data == raw_ics assert retrieved.ics_raw_data == raw_ics

View File

@@ -2,32 +2,26 @@ from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
from sqlalchemy import delete, insert, select, update
from reflector.db.base import ( from reflector.db.recordings import Recording, recordings_controller
MeetingConsentModel,
MeetingModel,
RecordingModel,
TranscriptModel,
)
from reflector.db.transcripts import SourceKind, transcripts_controller from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.worker.cleanup import cleanup_old_public_data from reflector.worker.cleanup import cleanup_old_public_data
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_old_public_data_skips_when_not_public(db_session): async def test_cleanup_old_public_data_skips_when_not_public():
"""Test that cleanup is skipped when PUBLIC_MODE is False.""" """Test that cleanup is skipped when PUBLIC_MODE is False."""
with patch("reflector.worker.cleanup.settings") as mock_settings: with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = False mock_settings.PUBLIC_MODE = False
result = await cleanup_old_public_data(db_session) result = await cleanup_old_public_data()
# Should return early without doing anything # Should return early without doing anything
assert result is None assert result is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(db_session): async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts():
"""Test that old anonymous transcripts are deleted.""" """Test that old anonymous transcripts are deleted."""
# Create old and new anonymous transcripts # Create old and new anonymous transcripts
old_date = datetime.now(timezone.utc) - timedelta(days=8) old_date = datetime.now(timezone.utc) - timedelta(days=8)
@@ -35,23 +29,22 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(db_sess
# Create old anonymous transcript (should be deleted) # Create old anonymous transcript (should be deleted)
old_transcript = await transcripts_controller.add( old_transcript = await transcripts_controller.add(
db_session,
name="Old Anonymous Transcript", name="Old Anonymous Transcript",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
user_id=None, # Anonymous user_id=None, # Anonymous
) )
# Manually update created_at to be old # Manually update created_at to be old
await db_session.execute( from reflector.db import get_database
update(TranscriptModel) from reflector.db.transcripts import transcripts
.where(TranscriptModel.id == old_transcript.id)
await get_database().execute(
transcripts.update()
.where(transcripts.c.id == old_transcript.id)
.values(created_at=old_date) .values(created_at=old_date)
) )
await db_session.commit()
# Create new anonymous transcript (should NOT be deleted) # Create new anonymous transcript (should NOT be deleted)
new_transcript = await transcripts_controller.add( new_transcript = await transcripts_controller.add(
db_session,
name="New Anonymous Transcript", name="New Anonymous Transcript",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
user_id=None, # Anonymous user_id=None, # Anonymous
@@ -59,265 +52,230 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(db_sess
# Create old transcript with user (should NOT be deleted) # Create old transcript with user (should NOT be deleted)
old_user_transcript = await transcripts_controller.add( old_user_transcript = await transcripts_controller.add(
db_session,
name="Old User Transcript", name="Old User Transcript",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
user_id="user-123", user_id="user123",
) )
await db_session.execute( await get_database().execute(
update(TranscriptModel) transcripts.update()
.where(TranscriptModel.id == old_user_transcript.id) .where(transcripts.c.id == old_user_transcript.id)
.values(created_at=old_date) .values(created_at=old_date)
) )
await db_session.commit()
# Mock settings for public mode
with patch("reflector.worker.cleanup.settings") as mock_settings: with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = True mock_settings.PUBLIC_MODE = True
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7 mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
# Mock delete_single_transcript to track what gets deleted # Mock the storage deletion
with patch("reflector.worker.cleanup.delete_single_transcript") as mock_delete: with patch("reflector.db.transcripts.get_transcripts_storage") as mock_storage:
mock_delete.return_value = None mock_storage.return_value.delete_file = AsyncMock()
# Run cleanup with test session result = await cleanup_old_public_data()
await cleanup_old_public_data(db_session)
# Verify only old anonymous transcript was deleted # Check results
assert mock_delete.call_count == 1 assert result["transcripts_deleted"] == 1
# The function is called with session_factory, transcript_data dict, and stats dict assert result["errors"] == []
call_args = mock_delete.call_args[0]
transcript_data = call_args[1] # Verify old anonymous transcript was deleted
assert transcript_data["id"] == old_transcript.id assert await transcripts_controller.get_by_id(old_transcript.id) is None
# Verify new anonymous transcript still exists
assert await transcripts_controller.get_by_id(new_transcript.id) is not None
# Verify user transcript still exists
assert await transcripts_controller.get_by_id(old_user_transcript.id) is not None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_deletes_associated_meeting_and_recording(db_session): async def test_cleanup_deletes_associated_meeting_and_recording():
"""Test that cleanup deletes associated meetings and recordings.""" """Test that meetings and recordings associated with old transcripts are deleted."""
from reflector.db import get_database
from reflector.db.meetings import meetings
from reflector.db.transcripts import transcripts
old_date = datetime.now(timezone.utc) - timedelta(days=8) old_date = datetime.now(timezone.utc) - timedelta(days=8)
# Create a meeting
meeting_id = "test-meeting-for-transcript"
await get_database().execute(
meetings.insert().values(
id=meeting_id,
room_name="Meeting with Transcript",
room_url="https://example.com/meeting",
host_room_url="https://example.com/meeting-host",
start_date=old_date,
end_date=old_date + timedelta(hours=1),
room_id=None,
)
)
# Create a recording
recording = await recordings_controller.create(
Recording(
bucket_name="test-bucket",
object_key="test-recording.mp4",
recorded_at=old_date,
)
)
# Create an old transcript with both meeting and recording # Create an old transcript with both meeting and recording
old_transcript = await transcripts_controller.add( old_transcript = await transcripts_controller.add(
db_session,
name="Old Transcript with Meeting and Recording", name="Old Transcript with Meeting and Recording",
source_kind=SourceKind.FILE, source_kind=SourceKind.ROOM,
user_id=None, user_id=None,
meeting_id=meeting_id,
recording_id=recording.id,
) )
await db_session.execute(
update(TranscriptModel) # Update created_at to be old
.where(TranscriptModel.id == old_transcript.id) await get_database().execute(
transcripts.update()
.where(transcripts.c.id == old_transcript.id)
.values(created_at=old_date) .values(created_at=old_date)
) )
await db_session.commit()
# Create associated meeting directly
meeting_id = "test-meeting-id"
await db_session.execute(
insert(MeetingModel).values(
id=meeting_id,
room_id=None,
room_name="test-room",
room_url="https://example.com/room",
host_room_url="https://example.com/room-host",
start_date=old_date,
end_date=old_date + timedelta(hours=1),
is_active=False,
num_clients=0,
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic",
)
)
# Create associated recording directly
recording_id = "test-recording-id"
await db_session.execute(
insert(RecordingModel).values(
id=recording_id,
meeting_id=meeting_id,
url="https://example.com/recording.mp4",
object_key="recordings/test.mp4",
duration=3600.0,
created_at=old_date,
)
)
await db_session.commit()
# Update transcript with meeting_id and recording_id
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == old_transcript.id)
.values(meeting_id=meeting_id, recording_id=recording_id)
)
await db_session.commit()
# Mock settings
with patch("reflector.worker.cleanup.settings") as mock_settings: with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = True mock_settings.PUBLIC_MODE = True
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7 mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
# Mock storage deletion # Mock storage deletion
with patch("reflector.worker.cleanup.get_recordings_storage") as mock_storage: with patch("reflector.worker.cleanup.get_transcripts_storage") as mock_storage:
mock_storage.return_value.delete_file = AsyncMock() mock_storage.return_value.delete_file = AsyncMock()
# Run cleanup with test session result = await cleanup_old_public_data()
await cleanup_old_public_data(db_session)
# Verify transcript was deleted # Check results
result = await db_session.execute( assert result["transcripts_deleted"] == 1
select(TranscriptModel).where(TranscriptModel.id == old_transcript.id) assert result["meetings_deleted"] == 1
) assert result["recordings_deleted"] == 1
transcript = result.scalar_one_or_none() assert result["errors"] == []
assert transcript is None
# Verify meeting was deleted # Verify transcript was deleted
result = await db_session.execute( assert await transcripts_controller.get_by_id(old_transcript.id) is None
select(MeetingModel).where(MeetingModel.id == meeting_id)
)
meeting = result.scalar_one_or_none()
assert meeting is None
# Verify recording was deleted # Verify meeting was deleted
result = await db_session.execute( query = meetings.select().where(meetings.c.id == meeting_id)
select(RecordingModel).where(RecordingModel.id == recording_id) meeting_result = await get_database().fetch_one(query)
) assert meeting_result is None
recording = result.scalar_one_or_none()
assert recording is None # Verify recording was deleted
assert await recordings_controller.get_by_id(recording.id) is None
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_cleanup_handles_errors_gracefully(db_session): async def test_cleanup_handles_errors_gracefully():
"""Test that cleanup continues even if individual deletions fail.""" """Test that cleanup continues even when individual deletions fail."""
old_date = datetime.now(timezone.utc) - timedelta(days=8) old_date = datetime.now(timezone.utc) - timedelta(days=8)
# Create multiple old transcripts # Create multiple old transcripts
transcript1 = await transcripts_controller.add( transcript1 = await transcripts_controller.add(
db_session,
name="Transcript 1", name="Transcript 1",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
user_id=None, user_id=None,
) )
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == transcript1.id)
.values(created_at=old_date)
)
transcript2 = await transcripts_controller.add( transcript2 = await transcripts_controller.add(
db_session,
name="Transcript 2", name="Transcript 2",
source_kind=SourceKind.FILE, source_kind=SourceKind.FILE,
user_id=None, user_id=None,
) )
await db_session.execute(
update(TranscriptModel) # Update created_at to be old
.where(TranscriptModel.id == transcript2.id) from reflector.db import get_database
.values(created_at=old_date) from reflector.db.transcripts import transcripts
)
await db_session.commit() for t_id in [transcript1.id, transcript2.id]:
await get_database().execute(
transcripts.update()
.where(transcripts.c.id == t_id)
.values(created_at=old_date)
)
with patch("reflector.worker.cleanup.settings") as mock_settings: with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = True mock_settings.PUBLIC_MODE = True
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7 mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
# Mock delete_single_transcript to fail on first call but succeed on second # Mock remove_by_id to fail for the first transcript
with patch("reflector.worker.cleanup.delete_single_transcript") as mock_delete: original_remove = transcripts_controller.remove_by_id
mock_delete.side_effect = [Exception("Delete failed"), None] call_count = 0
# Run cleanup with test session - should not raise exception async def mock_remove_by_id(transcript_id, user_id=None):
await cleanup_old_public_data(db_session) nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("Simulated deletion error")
return await original_remove(transcript_id, user_id)
# Both transcripts should have been attempted to delete with patch.object(
assert mock_delete.call_count == 2 transcripts_controller, "remove_by_id", side_effect=mock_remove_by_id
):
result = await cleanup_old_public_data()
# Should have one successful deletion and one error
assert result["transcripts_deleted"] == 1
assert len(result["errors"]) == 1
assert "Failed to delete transcript" in result["errors"][0]
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_meeting_consent_cascade_delete(db_session): async def test_meeting_consent_cascade_delete():
"""Test that meeting_consent entries are cascade deleted with meetings.""" """Test that meeting_consent records are automatically deleted when meeting is deleted."""
old_date = datetime.now(timezone.utc) - timedelta(days=8) from reflector.db import get_database
from reflector.db.meetings import (
# Create an old transcript meeting_consent,
transcript = await transcripts_controller.add( meeting_consent_controller,
db_session, meetings,
name="Transcript with Meeting",
source_kind=SourceKind.FILE,
user_id=None,
) )
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == transcript.id)
.values(created_at=old_date)
)
await db_session.commit()
# Create a meeting directly # Create a meeting
meeting_id = "test-meeting-consent" meeting_id = "test-cascade-meeting"
await db_session.execute( await get_database().execute(
insert(MeetingModel).values( meetings.insert().values(
id=meeting_id, id=meeting_id,
room_name="Test Meeting for CASCADE",
room_url="https://example.com/cascade-test",
host_room_url="https://example.com/cascade-test-host",
start_date=datetime.now(timezone.utc),
end_date=datetime.now(timezone.utc) + timedelta(hours=1),
room_id=None, room_id=None,
room_name="test-room",
room_url="https://example.com/room",
host_room_url="https://example.com/room-host",
start_date=old_date,
end_date=old_date + timedelta(hours=1),
is_active=False,
num_clients=0,
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic",
) )
) )
await db_session.commit()
# Update transcript with meeting_id # Create consent records for this meeting
await db_session.execute( consent1_id = "consent-1"
update(TranscriptModel) consent2_id = "consent-2"
.where(TranscriptModel.id == transcript.id)
.values(meeting_id=meeting_id)
)
await db_session.commit()
# Create meeting_consent entries await get_database().execute(
await db_session.execute( meeting_consent.insert().values(
insert(MeetingConsentModel).values( id=consent1_id,
id="consent-1",
meeting_id=meeting_id, meeting_id=meeting_id,
user_id="user-1", user_id="user1",
consent_given=True, consent_given=True,
consent_timestamp=old_date, consent_timestamp=datetime.now(timezone.utc),
) )
) )
await db_session.execute(
insert(MeetingConsentModel).values(
id="consent-2",
meeting_id=meeting_id,
user_id="user-2",
consent_given=True,
consent_timestamp=old_date,
)
)
await db_session.commit()
# Verify consent entries exist await get_database().execute(
result = await db_session.execute( meeting_consent.insert().values(
select(MeetingConsentModel).where(MeetingConsentModel.meeting_id == meeting_id) id=consent2_id,
meeting_id=meeting_id,
user_id="user2",
consent_given=False,
consent_timestamp=datetime.now(timezone.utc),
)
) )
consents = result.scalars().all()
# Verify consent records exist
consents = await meeting_consent_controller.get_by_meeting_id(meeting_id)
assert len(consents) == 2 assert len(consents) == 2
# Delete the transcript and meeting # Delete the meeting
await db_session.execute( await get_database().execute(meetings.delete().where(meetings.c.id == meeting_id))
delete(TranscriptModel).where(TranscriptModel.id == transcript.id)
)
await db_session.execute(delete(MeetingModel).where(MeetingModel.id == meeting_id))
await db_session.commit()
# Verify consent entries were cascade deleted # Verify meeting is deleted
result = await db_session.execute( query = meetings.select().where(meetings.c.id == meeting_id)
select(MeetingConsentModel).where(MeetingConsentModel.meeting_id == meeting_id) result = await get_database().fetch_one(query)
) assert result is None
consents = result.scalars().all()
assert len(consents) == 0 # Verify consent records are automatically deleted (CASCADE DELETE)
consents_after = await meeting_consent_controller.get_by_meeting_id(meeting_id)
assert len(consents_after) == 0

View File

@@ -0,0 +1,330 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from reflector.db.meetings import (
MeetingConsent,
meeting_consent_controller,
meetings_controller,
)
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.pipelines.main_live_pipeline import cleanup_consent
@pytest.mark.asyncio
async def test_consent_cleanup_deletes_multitrack_files():
room = await rooms_controller.add(
name="Test Room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic",
is_shared=False,
platform="daily",
)
# Create meeting
meeting = await meetings_controller.create(
id="test-multitrack-meeting",
room_name="test-room-20250101120000",
room_url="https://test.daily.co/test-room",
host_room_url="https://test.daily.co/test-room",
start_date=datetime.now(timezone.utc),
end_date=datetime.now(timezone.utc),
room=room,
)
track_keys = [
"recordings/test-room-20250101120000/track-0.webm",
"recordings/test-room-20250101120000/track-1.webm",
"recordings/test-room-20250101120000/track-2.webm",
]
recording = await recordings_controller.create(
Recording(
bucket_name="test-bucket",
object_key="recordings/test-room-20250101120000", # Folder path
recorded_at=datetime.now(timezone.utc),
meeting_id=meeting.id,
track_keys=track_keys,
)
)
# Create transcript
transcript = await transcripts_controller.add(
name="Test Multitrack Transcript",
source_kind=SourceKind.ROOM,
recording_id=recording.id,
meeting_id=meeting.id,
)
# Add consent denial
await meeting_consent_controller.upsert(
MeetingConsent(
meeting_id=meeting.id,
user_id="test-user",
consent_given=False,
consent_timestamp=datetime.now(timezone.utc),
)
)
# Mock get_transcripts_storage (master credentials with bucket override)
with patch(
"reflector.pipelines.main_live_pipeline.get_transcripts_storage"
) as mock_get_transcripts_storage:
mock_master_storage = MagicMock()
mock_master_storage.delete_file = AsyncMock()
mock_get_transcripts_storage.return_value = mock_master_storage
await cleanup_consent(transcript_id=transcript.id)
# Verify master storage was used with bucket override for all track keys
assert mock_master_storage.delete_file.call_count == 3
deleted_keys = []
for call_args in mock_master_storage.delete_file.call_args_list:
key = call_args[0][0]
bucket_kwarg = call_args[1].get("bucket")
deleted_keys.append(key)
assert bucket_kwarg == "test-bucket" # Verify bucket override!
assert set(deleted_keys) == set(track_keys)
updated_transcript = await transcripts_controller.get_by_id(transcript.id)
assert updated_transcript.audio_deleted is True
@pytest.mark.asyncio
async def test_consent_cleanup_handles_missing_track_keys():
room = await rooms_controller.add(
name="Test Room 2",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic",
is_shared=False,
platform="daily",
)
# Create meeting
meeting = await meetings_controller.create(
id="test-multitrack-meeting-2",
room_name="test-room-20250101120001",
room_url="https://test.daily.co/test-room-2",
host_room_url="https://test.daily.co/test-room-2",
start_date=datetime.now(timezone.utc),
end_date=datetime.now(timezone.utc),
room=room,
)
recording = await recordings_controller.create(
Recording(
bucket_name="test-bucket",
object_key="recordings/old-style-recording.mp4",
recorded_at=datetime.now(timezone.utc),
meeting_id=meeting.id,
track_keys=None,
)
)
transcript = await transcripts_controller.add(
name="Test Old-Style Transcript",
source_kind=SourceKind.ROOM,
recording_id=recording.id,
meeting_id=meeting.id,
)
# Add consent denial
await meeting_consent_controller.upsert(
MeetingConsent(
meeting_id=meeting.id,
user_id="test-user-2",
consent_given=False,
consent_timestamp=datetime.now(timezone.utc),
)
)
# Mock get_transcripts_storage (master credentials with bucket override)
with patch(
"reflector.pipelines.main_live_pipeline.get_transcripts_storage"
) as mock_get_transcripts_storage:
mock_master_storage = MagicMock()
mock_master_storage.delete_file = AsyncMock()
mock_get_transcripts_storage.return_value = mock_master_storage
await cleanup_consent(transcript_id=transcript.id)
# Verify master storage was used with bucket override
assert mock_master_storage.delete_file.call_count == 1
call_args = mock_master_storage.delete_file.call_args
assert call_args[0][0] == recording.object_key
assert call_args[1].get("bucket") == "test-bucket" # Verify bucket override!
@pytest.mark.asyncio
async def test_consent_cleanup_empty_track_keys_falls_back():
room = await rooms_controller.add(
name="Test Room 3",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic",
is_shared=False,
platform="daily",
)
# Create meeting
meeting = await meetings_controller.create(
id="test-multitrack-meeting-3",
room_name="test-room-20250101120002",
room_url="https://test.daily.co/test-room-3",
host_room_url="https://test.daily.co/test-room-3",
start_date=datetime.now(timezone.utc),
end_date=datetime.now(timezone.utc),
room=room,
)
recording = await recordings_controller.create(
Recording(
bucket_name="test-bucket",
object_key="recordings/fallback-recording.mp4",
recorded_at=datetime.now(timezone.utc),
meeting_id=meeting.id,
track_keys=[],
)
)
transcript = await transcripts_controller.add(
name="Test Empty Track Keys Transcript",
source_kind=SourceKind.ROOM,
recording_id=recording.id,
meeting_id=meeting.id,
)
# Add consent denial
await meeting_consent_controller.upsert(
MeetingConsent(
meeting_id=meeting.id,
user_id="test-user-3",
consent_given=False,
consent_timestamp=datetime.now(timezone.utc),
)
)
# Mock get_transcripts_storage (master credentials with bucket override)
with patch(
"reflector.pipelines.main_live_pipeline.get_transcripts_storage"
) as mock_get_transcripts_storage:
mock_master_storage = MagicMock()
mock_master_storage.delete_file = AsyncMock()
mock_get_transcripts_storage.return_value = mock_master_storage
# Run cleanup
await cleanup_consent(transcript_id=transcript.id)
# Verify master storage was used with bucket override
assert mock_master_storage.delete_file.call_count == 1
call_args = mock_master_storage.delete_file.call_args
assert call_args[0][0] == recording.object_key
assert call_args[1].get("bucket") == "test-bucket" # Verify bucket override!
@pytest.mark.asyncio
async def test_consent_cleanup_partial_failure_doesnt_mark_deleted():
room = await rooms_controller.add(
name="Test Room 4",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic",
is_shared=False,
platform="daily",
)
# Create meeting
meeting = await meetings_controller.create(
id="test-multitrack-meeting-4",
room_name="test-room-20250101120003",
room_url="https://test.daily.co/test-room-4",
host_room_url="https://test.daily.co/test-room-4",
start_date=datetime.now(timezone.utc),
end_date=datetime.now(timezone.utc),
room=room,
)
track_keys = [
"recordings/test-room-20250101120003/track-0.webm",
"recordings/test-room-20250101120003/track-1.webm",
"recordings/test-room-20250101120003/track-2.webm",
]
recording = await recordings_controller.create(
Recording(
bucket_name="test-bucket",
object_key="recordings/test-room-20250101120003",
recorded_at=datetime.now(timezone.utc),
meeting_id=meeting.id,
track_keys=track_keys,
)
)
# Create transcript
transcript = await transcripts_controller.add(
name="Test Partial Failure Transcript",
source_kind=SourceKind.ROOM,
recording_id=recording.id,
meeting_id=meeting.id,
)
# Add consent denial
await meeting_consent_controller.upsert(
MeetingConsent(
meeting_id=meeting.id,
user_id="test-user-4",
consent_given=False,
consent_timestamp=datetime.now(timezone.utc),
)
)
# Mock get_transcripts_storage (master credentials with bucket override) with partial failure
with patch(
"reflector.pipelines.main_live_pipeline.get_transcripts_storage"
) as mock_get_transcripts_storage:
mock_master_storage = MagicMock()
call_count = 0
async def delete_side_effect(key, bucket=None):
nonlocal call_count
call_count += 1
if call_count == 2:
raise Exception("S3 deletion failed")
mock_master_storage.delete_file = AsyncMock(side_effect=delete_side_effect)
mock_get_transcripts_storage.return_value = mock_master_storage
await cleanup_consent(transcript_id=transcript.id)
# Verify master storage was called with bucket override
assert mock_master_storage.delete_file.call_count == 3
updated_transcript = await transcripts_controller.get_by_id(transcript.id)
assert (
updated_transcript.audio_deleted is None
or updated_transcript.audio_deleted is False
)

View File

@@ -4,8 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from icalendar import Calendar, Event from icalendar import Calendar, Event
from reflector.db import get_database
from reflector.db.calendar_events import calendar_events_controller from reflector.db.calendar_events import calendar_events_controller
from reflector.db.rooms import rooms_controller from reflector.db.rooms import rooms, rooms_controller
from reflector.services.ics_sync import ics_sync_service from reflector.services.ics_sync import ics_sync_service
from reflector.worker.ics_sync import ( from reflector.worker.ics_sync import (
_should_sync, _should_sync,
@@ -14,9 +15,8 @@ from reflector.worker.ics_sync import (
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_room_ics_task(db_session): async def test_sync_room_ics_task():
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="task-test-room", name="task-test-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -30,7 +30,6 @@ async def test_sync_room_ics_task(db_session):
ics_url="https://calendar.example.com/task.ics", ics_url="https://calendar.example.com/task.ics",
ics_enabled=True, ics_enabled=True,
) )
await db_session.flush()
cal = Calendar() cal = Calendar()
event = Event() event = Event()
@@ -46,22 +45,21 @@ async def test_sync_room_ics_task(db_session):
ics_content = cal.to_ical().decode("utf-8") ics_content = cal.to_ical().decode("utf-8")
with patch( with patch(
"reflector.services.ics_sync.ICSFetchService.fetch_ics", "reflector.services.ics_sync.ICSFetchService.fetch_ics", new_callable=AsyncMock
new_callable=AsyncMock,
) as mock_fetch: ) as mock_fetch:
mock_fetch.return_value = ics_content mock_fetch.return_value = ics_content
await ics_sync_service.sync_room_calendar(db_session, room) # Call the service directly instead of the Celery task to avoid event loop issues
await ics_sync_service.sync_room_calendar(room)
events = await calendar_events_controller.get_by_room(db_session, room.id) events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1 assert len(events) == 1
assert events[0].ics_uid == "task-event-1" assert events[0].ics_uid == "task-event-1"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_room_ics_disabled(db_session): async def test_sync_room_ics_disabled():
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="disabled-room", name="disabled-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -75,16 +73,16 @@ async def test_sync_room_ics_disabled(db_session):
ics_enabled=False, ics_enabled=False,
) )
result = await ics_sync_service.sync_room_calendar(db_session, room) # Test that disabled rooms are skipped by the service
result = await ics_sync_service.sync_room_calendar(room)
events = await calendar_events_controller.get_by_room(db_session, room.id) events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 0 assert len(events) == 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_all_ics_calendars(db_session): async def test_sync_all_ics_calendars():
room1 = await rooms_controller.add( room1 = await rooms_controller.add(
db_session,
name="sync-all-1", name="sync-all-1",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -100,7 +98,6 @@ async def test_sync_all_ics_calendars(db_session):
) )
room2 = await rooms_controller.add( room2 = await rooms_controller.add(
db_session,
name="sync-all-2", name="sync-all-2",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -116,7 +113,6 @@ async def test_sync_all_ics_calendars(db_session):
) )
room3 = await rooms_controller.add( room3 = await rooms_controller.add(
db_session,
name="sync-all-3", name="sync-all-3",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -131,11 +127,17 @@ async def test_sync_all_ics_calendars(db_session):
) )
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay: with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(db_session) # Directly call the sync_all logic without the Celery wrapper
query = rooms.select().where(
rooms.c.ics_enabled == True, rooms.c.ics_url != None
)
all_rooms = await get_database().fetch_all(query)
for room in ics_enabled_rooms: for room_data in all_rooms:
room_id = room_data["id"]
room = await rooms_controller.get_by_id(room_id)
if room and _should_sync(room): if room and _should_sync(room):
sync_room_ics.delay(room.id) sync_room_ics.delay(room_id)
assert mock_delay.call_count == 2 assert mock_delay.call_count == 2
called_room_ids = [call.args[0] for call in mock_delay.call_args_list] called_room_ids = [call.args[0] for call in mock_delay.call_args_list]
@@ -161,11 +163,10 @@ async def test_should_sync_logic():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_respects_fetch_interval(db_session): async def test_sync_respects_fetch_interval():
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
room1 = await rooms_controller.add( room1 = await rooms_controller.add(
db_session,
name="interval-test-1", name="interval-test-1",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -182,13 +183,11 @@ async def test_sync_respects_fetch_interval(db_session):
) )
await rooms_controller.update( await rooms_controller.update(
db_session,
room1, room1,
{"ics_last_sync": now - timedelta(seconds=100)}, {"ics_last_sync": now - timedelta(seconds=100)},
) )
room2 = await rooms_controller.add( room2 = await rooms_controller.add(
db_session,
name="interval-test-2", name="interval-test-2",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -205,26 +204,30 @@ async def test_sync_respects_fetch_interval(db_session):
) )
await rooms_controller.update( await rooms_controller.update(
db_session,
room2, room2,
{"ics_last_sync": now - timedelta(seconds=100)}, {"ics_last_sync": now - timedelta(seconds=100)},
) )
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay: with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(db_session) # Test the sync logic without the Celery wrapper
query = rooms.select().where(
rooms.c.ics_enabled == True, rooms.c.ics_url != None
)
all_rooms = await get_database().fetch_all(query)
for room in ics_enabled_rooms: for room_data in all_rooms:
room_id = room_data["id"]
room = await rooms_controller.get_by_id(room_id)
if room and _should_sync(room): if room and _should_sync(room):
sync_room_ics.delay(room.id) sync_room_ics.delay(room_id)
assert mock_delay.call_count == 1 assert mock_delay.call_count == 1
assert mock_delay.call_args[0][0] == room2.id assert mock_delay.call_args[0][0] == room2.id
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_sync_handles_errors_gracefully(db_session): async def test_sync_handles_errors_gracefully():
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="error-task-room", name="error-task-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -244,8 +247,9 @@ async def test_sync_handles_errors_gracefully(db_session):
) as mock_fetch: ) as mock_fetch:
mock_fetch.side_effect = Exception("Network error") mock_fetch.side_effect = Exception("Network error")
result = await ics_sync_service.sync_room_calendar(db_session, room) # Call the service directly to test error handling
result = await ics_sync_service.sync_room_calendar(room)
assert result["status"] == "error" assert result["status"] == "error"
events = await calendar_events_controller.get_by_room(db_session, room.id) events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 0 assert len(events) == 0

View File

@@ -134,10 +134,9 @@ async def test_ics_fetch_service_extract_room_events():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ics_sync_service_sync_room_calendar(db_session): async def test_ics_sync_service_sync_room_calendar():
# Create room # Create room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="sync-test", name="sync-test",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -151,7 +150,6 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
ics_url="https://calendar.example.com/test.ics", ics_url="https://calendar.example.com/test.ics",
ics_enabled=True, ics_enabled=True,
) )
await db_session.flush()
# Mock ICS content # Mock ICS content
cal = Calendar() cal = Calendar()
@@ -177,7 +175,7 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
mock_fetch.return_value = ics_content mock_fetch.return_value = ics_content
# First sync # First sync
result = await sync_service.sync_room_calendar(db_session, room) result = await sync_service.sync_room_calendar(room)
assert result["status"] == "success" assert result["status"] == "success"
assert result["events_found"] == 1 assert result["events_found"] == 1
@@ -186,20 +184,18 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
assert result["events_deleted"] == 0 assert result["events_deleted"] == 0
# Verify event was created # Verify event was created
events = await calendar_events_controller.get_by_room(db_session, room.id) events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1 assert len(events) == 1
assert events[0].ics_uid == "sync-event-1" assert events[0].ics_uid == "sync-event-1"
assert events[0].title == "Sync Test Meeting" assert events[0].title == "Sync Test Meeting"
# Second sync with same content (should be unchanged) # Second sync with same content (should be unchanged)
# Refresh room to get updated etag and force sync by setting old sync time # Refresh room to get updated etag and force sync by setting old sync time
room = await rooms_controller.get_by_id(db_session, room.id) room = await rooms_controller.get_by_id(room.id)
await rooms_controller.update( await rooms_controller.update(
db_session, room, {"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)}
room,
{"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)},
) )
result = await sync_service.sync_room_calendar(db_session, room) result = await sync_service.sync_room_calendar(room)
assert result["status"] == "unchanged" assert result["status"] == "unchanged"
# Third sync with updated event # Third sync with updated event
@@ -210,15 +206,15 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
mock_fetch.return_value = ics_content mock_fetch.return_value = ics_content
# Force sync by clearing etag # Force sync by clearing etag
await rooms_controller.update(db_session, room, {"ics_last_etag": None}) await rooms_controller.update(room, {"ics_last_etag": None})
result = await sync_service.sync_room_calendar(db_session, room) result = await sync_service.sync_room_calendar(room)
assert result["status"] == "success" assert result["status"] == "success"
assert result["events_created"] == 0 assert result["events_created"] == 0
assert result["events_updated"] == 1 assert result["events_updated"] == 1
# Verify event was updated # Verify event was updated
events = await calendar_events_controller.get_by_room(db_session, room.id) events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1 assert len(events) == 1
assert events[0].title == "Updated Meeting Title" assert events[0].title == "Updated Meeting Title"
@@ -251,7 +247,7 @@ async def test_ics_sync_service_skip_disabled():
room.ics_enabled = False room.ics_enabled = False
room.ics_url = "https://calendar.example.com/test.ics" room.ics_url = "https://calendar.example.com/test.ics"
result = await service.sync_room_calendar(MagicMock(), room) result = await service.sync_room_calendar(room)
assert result["status"] == "skipped" assert result["status"] == "skipped"
assert result["reason"] == "ICS not configured" assert result["reason"] == "ICS not configured"
@@ -259,16 +255,15 @@ async def test_ics_sync_service_skip_disabled():
room.ics_enabled = True room.ics_enabled = True
room.ics_url = None room.ics_url = None
result = await service.sync_room_calendar(MagicMock(), room) result = await service.sync_room_calendar(room)
assert result["status"] == "skipped" assert result["status"] == "skipped"
assert result["reason"] == "ICS not configured" assert result["reason"] == "ICS not configured"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_ics_sync_service_error_handling(db_session): async def test_ics_sync_service_error_handling():
# Create room # Create room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="error-test", name="error-test",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -282,7 +277,6 @@ async def test_ics_sync_service_error_handling(db_session):
ics_url="https://calendar.example.com/error.ics", ics_url="https://calendar.example.com/error.ics",
ics_enabled=True, ics_enabled=True,
) )
await db_session.flush()
sync_service = ICSSyncService() sync_service = ICSSyncService()
@@ -291,6 +285,6 @@ async def test_ics_sync_service_error_handling(db_session):
) as mock_fetch: ) as mock_fetch:
mock_fetch.side_effect = Exception("Network error") mock_fetch.side_effect = Exception("Network error")
result = await sync_service.sync_room_calendar(db_session, room) result = await sync_service.sync_room_calendar(room)
assert result["status"] == "error" assert result["status"] == "error"
assert "Network error" in result["error"] assert "Network error" in result["error"]

View File

@@ -10,11 +10,10 @@ from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_multiple_active_meetings_per_room(db_session): async def test_multiple_active_meetings_per_room():
"""Test that multiple active meetings can exist for the same room.""" """Test that multiple active meetings can exist for the same room."""
# Create a room # Create a room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="test-room", name="test-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -32,7 +31,6 @@ async def test_multiple_active_meetings_per_room(db_session):
# Create first meeting # Create first meeting
meeting1 = await meetings_controller.create( meeting1 = await meetings_controller.create(
db_session,
id="meeting-1", id="meeting-1",
room_name="test-meeting-1", room_name="test-meeting-1",
room_url="https://whereby.com/test-1", room_url="https://whereby.com/test-1",
@@ -44,7 +42,6 @@ async def test_multiple_active_meetings_per_room(db_session):
# Create second meeting for the same room (should succeed now) # Create second meeting for the same room (should succeed now)
meeting2 = await meetings_controller.create( meeting2 = await meetings_controller.create(
db_session,
id="meeting-2", id="meeting-2",
room_name="test-meeting-2", room_name="test-meeting-2",
room_url="https://whereby.com/test-2", room_url="https://whereby.com/test-2",
@@ -56,7 +53,7 @@ async def test_multiple_active_meetings_per_room(db_session):
# Both meetings should be active # Both meetings should be active
active_meetings = await meetings_controller.get_all_active_for_room( active_meetings = await meetings_controller.get_all_active_for_room(
db_session, room=room, current_time=current_time room=room, current_time=current_time
) )
assert len(active_meetings) == 2 assert len(active_meetings) == 2
@@ -65,11 +62,10 @@ async def test_multiple_active_meetings_per_room(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_active_by_calendar_event(db_session): async def test_get_active_by_calendar_event():
"""Test getting active meeting by calendar event ID.""" """Test getting active meeting by calendar event ID."""
# Create a room # Create a room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="test-room", name="test-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -90,14 +86,13 @@ async def test_get_active_by_calendar_event(db_session):
start_time=datetime.now(timezone.utc), start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc) + timedelta(hours=1), end_time=datetime.now(timezone.utc) + timedelta(hours=1),
) )
event = await calendar_events_controller.upsert(db_session, event) event = await calendar_events_controller.upsert(event)
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
end_time = current_time + timedelta(hours=2) end_time = current_time + timedelta(hours=2)
# Create meeting linked to calendar event # Create meeting linked to calendar event
meeting = await meetings_controller.create( meeting = await meetings_controller.create(
db_session,
id="meeting-cal-1", id="meeting-cal-1",
room_name="test-meeting-cal", room_name="test-meeting-cal",
room_url="https://whereby.com/test-cal", room_url="https://whereby.com/test-cal",
@@ -111,7 +106,7 @@ async def test_get_active_by_calendar_event(db_session):
# Should find the meeting by calendar event # Should find the meeting by calendar event
found_meeting = await meetings_controller.get_active_by_calendar_event( found_meeting = await meetings_controller.get_active_by_calendar_event(
db_session, room=room, calendar_event_id=event.id, current_time=current_time room=room, calendar_event_id=event.id, current_time=current_time
) )
assert found_meeting is not None assert found_meeting is not None
@@ -120,11 +115,10 @@ async def test_get_active_by_calendar_event(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_calendar_meeting_deactivates_after_scheduled_end(db_session): async def test_calendar_meeting_deactivates_after_scheduled_end():
"""Test that unused calendar meetings deactivate after scheduled end time.""" """Test that unused calendar meetings deactivate after scheduled end time."""
# Create a room # Create a room
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="test-room", name="test-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -145,13 +139,12 @@ async def test_calendar_meeting_deactivates_after_scheduled_end(db_session):
start_time=datetime.now(timezone.utc) - timedelta(hours=2), start_time=datetime.now(timezone.utc) - timedelta(hours=2),
end_time=datetime.now(timezone.utc) - timedelta(minutes=35), end_time=datetime.now(timezone.utc) - timedelta(minutes=35),
) )
event = await calendar_events_controller.upsert(db_session, event) event = await calendar_events_controller.upsert(event)
current_time = datetime.now(timezone.utc) current_time = datetime.now(timezone.utc)
# Create meeting linked to calendar event # Create meeting linked to calendar event
meeting = await meetings_controller.create( meeting = await meetings_controller.create(
db_session,
id="meeting-unused", id="meeting-unused",
room_name="test-meeting-unused", room_name="test-meeting-unused",
room_url="https://whereby.com/test-unused", room_url="https://whereby.com/test-unused",
@@ -168,9 +161,7 @@ async def test_calendar_meeting_deactivates_after_scheduled_end(db_session):
# Simulate process_meetings logic for unused calendar meeting past end time # Simulate process_meetings logic for unused calendar meeting past end time
if meeting.calendar_event_id and current_time > meeting.end_date: if meeting.calendar_event_id and current_time > meeting.end_date:
# In real code, we'd check has_had_sessions = False here # In real code, we'd check has_had_sessions = False here
await meetings_controller.update_meeting( await meetings_controller.update_meeting(meeting.id, is_active=False)
db_session, meeting.id, is_active=False
)
updated_meeting = await meetings_controller.get_by_id(db_session, meeting.id) updated_meeting = await meetings_controller.get_by_id(meeting.id)
assert updated_meeting.is_active is False # Deactivated after scheduled end assert updated_meeting.is_active is False # Deactivated after scheduled end

View File

@@ -101,37 +101,21 @@ async def mock_transcript_in_db(tmpdir):
target_language="en", target_language="en",
) )
# Mock all transcripts controller methods that are used in the pipeline # Mock the controller to return our transcript
try: try:
with patch( with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id" "reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
) as mock_get: ) as mock_get:
mock_get.return_value = transcript mock_get.return_value = transcript
with patch( with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.update" "reflector.pipelines.main_live_pipeline.transcripts_controller.get_by_id"
) as mock_update: ) as mock_get2:
mock_update.return_value = transcript mock_get2.return_value = transcript
with patch( with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.set_status" "reflector.pipelines.main_live_pipeline.transcripts_controller.update"
) as mock_set_status: ) as mock_update:
mock_set_status.return_value = None mock_update.return_value = None
with patch( yield transcript
"reflector.pipelines.main_file_pipeline.transcripts_controller.upsert_topic"
) as mock_upsert_topic:
mock_upsert_topic.return_value = None
with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.append_event"
) as mock_append_event:
mock_append_event.return_value = None
with patch(
"reflector.pipelines.main_live_pipeline.transcripts_controller.get_by_id"
) as mock_get2:
mock_get2.return_value = transcript
with patch(
"reflector.pipelines.main_live_pipeline.transcripts_controller.update"
) as mock_update2:
mock_update2.return_value = None
yield transcript
finally: finally:
# Restore original DATA_DIR # Restore original DATA_DIR
settings.DATA_DIR = original_data_dir settings.DATA_DIR = original_data_dir
@@ -143,18 +127,27 @@ async def mock_storage():
from reflector.storage.base import Storage from reflector.storage.base import Storage
class TestStorage(Storage): class TestStorage(Storage):
async def _put_file(self, path, data): async def _put_file(self, path, data, bucket=None):
return None return None
async def _get_file_url(self, path): async def _get_file_url(
self,
path,
operation: str = "get_object",
expires_in: int = 3600,
bucket=None,
):
return f"http://test-storage/{path}" return f"http://test-storage/{path}"
async def _get_file(self, path): async def _get_file(self, path, bucket=None):
return b"test_audio_data" return b"test_audio_data"
async def _delete_file(self, path): async def _delete_file(self, path, bucket=None):
return None return None
async def _stream_to_fileobj(self, path, fileobj, bucket=None):
fileobj.write(b"test_audio_data")
storage = TestStorage() storage = TestStorage()
# Add mock tracking for verification # Add mock tracking for verification
storage._put_file = AsyncMock(side_effect=storage._put_file) storage._put_file = AsyncMock(side_effect=storage._put_file)
@@ -197,7 +190,7 @@ async def mock_waveform_processor():
async def mock_topic_detector(): async def mock_topic_detector():
"""Mock TranscriptTopicDetectorProcessor""" """Mock TranscriptTopicDetectorProcessor"""
with patch( with patch(
"reflector.pipelines.main_file_pipeline.TranscriptTopicDetectorProcessor" "reflector.pipelines.topic_processing.TranscriptTopicDetectorProcessor"
) as mock_topic_class: ) as mock_topic_class:
mock_topic = AsyncMock() mock_topic = AsyncMock()
mock_topic.set_pipeline = MagicMock() mock_topic.set_pipeline = MagicMock()
@@ -234,7 +227,7 @@ async def mock_topic_detector():
async def mock_title_processor(): async def mock_title_processor():
"""Mock TranscriptFinalTitleProcessor""" """Mock TranscriptFinalTitleProcessor"""
with patch( with patch(
"reflector.pipelines.main_file_pipeline.TranscriptFinalTitleProcessor" "reflector.pipelines.topic_processing.TranscriptFinalTitleProcessor"
) as mock_title_class: ) as mock_title_class:
mock_title = AsyncMock() mock_title = AsyncMock()
mock_title.set_pipeline = MagicMock() mock_title.set_pipeline = MagicMock()
@@ -263,7 +256,7 @@ async def mock_title_processor():
async def mock_summary_processor(): async def mock_summary_processor():
"""Mock TranscriptFinalSummaryProcessor""" """Mock TranscriptFinalSummaryProcessor"""
with patch( with patch(
"reflector.pipelines.main_file_pipeline.TranscriptFinalSummaryProcessor" "reflector.pipelines.topic_processing.TranscriptFinalSummaryProcessor"
) as mock_summary_class: ) as mock_summary_class:
mock_summary = AsyncMock() mock_summary = AsyncMock()
mock_summary.set_pipeline = MagicMock() mock_summary.set_pipeline = MagicMock()
@@ -624,11 +617,7 @@ async def test_pipeline_file_process_no_transcript():
# Should raise an exception for missing transcript when get_transcript is called # Should raise an exception for missing transcript when get_transcript is called
with pytest.raises(Exception, match="Transcript not found"): with pytest.raises(Exception, match="Transcript not found"):
# Use a mock session - the controller is mocked to return None anyway await pipeline.get_transcript()
from unittest.mock import MagicMock
mock_session = MagicMock()
await pipeline.get_transcript(mock_session)
@pytest.mark.asyncio @pytest.mark.asyncio

View File

@@ -10,10 +10,9 @@ from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_room_create_with_ics_fields(db_session): async def test_room_create_with_ics_fields():
"""Test creating a room with ICS calendar fields.""" """Test creating a room with ICS calendar fields."""
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="test-room", name="test-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -41,11 +40,10 @@ async def test_room_create_with_ics_fields(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_room_update_ics_configuration(db_session): async def test_room_update_ics_configuration():
"""Test updating room ICS configuration.""" """Test updating room ICS configuration."""
# Create room without ICS # Create room without ICS
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="update-test", name="update-test",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -63,7 +61,6 @@ async def test_room_update_ics_configuration(db_session):
# Update with ICS configuration # Update with ICS configuration
await rooms_controller.update( await rooms_controller.update(
db_session,
room, room,
{ {
"ics_url": "https://outlook.office365.com/owa/calendar/test/calendar.ics", "ics_url": "https://outlook.office365.com/owa/calendar/test/calendar.ics",
@@ -80,10 +77,9 @@ async def test_room_update_ics_configuration(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_room_ics_sync_metadata(db_session): async def test_room_ics_sync_metadata():
"""Test updating room ICS sync metadata.""" """Test updating room ICS sync metadata."""
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="sync-test", name="sync-test",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -101,7 +97,6 @@ async def test_room_ics_sync_metadata(db_session):
# Update sync metadata # Update sync metadata
sync_time = datetime.now(timezone.utc) sync_time = datetime.now(timezone.utc)
await rooms_controller.update( await rooms_controller.update(
db_session,
room, room,
{ {
"ics_last_sync": sync_time, "ics_last_sync": sync_time,
@@ -114,11 +109,10 @@ async def test_room_ics_sync_metadata(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_room_get_with_ics_fields(db_session): async def test_room_get_with_ics_fields():
"""Test retrieving room with ICS fields.""" """Test retrieving room with ICS fields."""
# Create room # Create room
created_room = await rooms_controller.add( created_room = await rooms_controller.add(
db_session,
name="get-test", name="get-test",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -135,14 +129,14 @@ async def test_room_get_with_ics_fields(db_session):
) )
# Get by ID # Get by ID
room = await rooms_controller.get_by_id(db_session, created_room.id) room = await rooms_controller.get_by_id(created_room.id)
assert room is not None assert room is not None
assert room.ics_url == "webcal://calendar.example.com/feed.ics" assert room.ics_url == "webcal://calendar.example.com/feed.ics"
assert room.ics_fetch_interval == 900 assert room.ics_fetch_interval == 900
assert room.ics_enabled is True assert room.ics_enabled is True
# Get by name # Get by name
room = await rooms_controller.get_by_name(db_session, "get-test") room = await rooms_controller.get_by_name("get-test")
assert room is not None assert room is not None
assert room.ics_url == "webcal://calendar.example.com/feed.ics" assert room.ics_url == "webcal://calendar.example.com/feed.ics"
assert room.ics_fetch_interval == 900 assert room.ics_fetch_interval == 900
@@ -150,11 +144,10 @@ async def test_room_get_with_ics_fields(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_room_list_with_ics_enabled_filter(db_session): async def test_room_list_with_ics_enabled_filter():
"""Test listing rooms filtered by ICS enabled status.""" """Test listing rooms filtered by ICS enabled status."""
# Create rooms with and without ICS # Create rooms with and without ICS
room1 = await rooms_controller.add( room1 = await rooms_controller.add(
db_session,
name="ics-enabled-1", name="ics-enabled-1",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -170,7 +163,6 @@ async def test_room_list_with_ics_enabled_filter(db_session):
) )
room2 = await rooms_controller.add( room2 = await rooms_controller.add(
db_session,
name="ics-disabled", name="ics-disabled",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -185,7 +177,6 @@ async def test_room_list_with_ics_enabled_filter(db_session):
) )
room3 = await rooms_controller.add( room3 = await rooms_controller.add(
db_session,
name="ics-enabled-2", name="ics-enabled-2",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -201,20 +192,19 @@ async def test_room_list_with_ics_enabled_filter(db_session):
) )
# Get all rooms # Get all rooms
all_rooms = await rooms_controller.get_all(db_session) all_rooms = await rooms_controller.get_all()
assert len(all_rooms) == 3 assert len(all_rooms) == 3
# Filter for ICS-enabled rooms (would need to implement this in controller) # Filter for ICS-enabled rooms (would need to implement this in controller)
ics_rooms = [r for r in all_rooms if r.ics_enabled] ics_rooms = [r for r in all_rooms if r["ics_enabled"]]
assert len(ics_rooms) == 2 assert len(ics_rooms) == 2
assert all(r.ics_enabled for r in ics_rooms) assert all(r["ics_enabled"] for r in ics_rooms)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_room_default_ics_values(db_session): async def test_room_default_ics_values():
"""Test that ICS fields have correct default values.""" """Test that ICS fields have correct default values."""
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="default-test", name="default-test",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,

View File

@@ -11,14 +11,21 @@ from reflector.db.rooms import rooms_controller
@pytest.fixture @pytest.fixture
async def authenticated_client(client): async def authenticated_client(client):
from reflector.app import app from reflector.app import app
from reflector.auth import current_user_optional from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {
"sub": "test-user",
"email": "test@example.com",
}
app.dependency_overrides[current_user_optional] = lambda: { app.dependency_overrides[current_user_optional] = lambda: {
"sub": "test-user", "sub": "test-user",
"email": "test@example.com", "email": "test@example.com",
} }
yield client try:
del app.dependency_overrides[current_user_optional] yield client
finally:
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -41,6 +48,7 @@ async def test_create_room_with_ics_fields(authenticated_client):
"ics_url": "https://calendar.example.com/test.ics", "ics_url": "https://calendar.example.com/test.ics",
"ics_fetch_interval": 600, "ics_fetch_interval": 600,
"ics_enabled": True, "ics_enabled": True,
"platform": "daily",
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -68,6 +76,7 @@ async def test_update_room_ics_configuration(authenticated_client):
"is_shared": False, "is_shared": False,
"webhook_url": "", "webhook_url": "",
"webhook_secret": "", "webhook_secret": "",
"platform": "daily",
}, },
) )
assert response.status_code == 200 assert response.status_code == 200
@@ -89,10 +98,9 @@ async def test_update_room_ics_configuration(authenticated_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_ics_sync(authenticated_client, db_session): async def test_trigger_ics_sync(authenticated_client):
client = authenticated_client client = authenticated_client
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="sync-api-room", name="sync-api-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -105,6 +113,7 @@ async def test_trigger_ics_sync(authenticated_client, db_session):
is_shared=False, is_shared=False,
ics_url="https://calendar.example.com/api.ics", ics_url="https://calendar.example.com/api.ics",
ics_enabled=True, ics_enabled=True,
platform="daily",
) )
cal = Calendar() cal = Calendar()
@@ -134,9 +143,8 @@ async def test_trigger_ics_sync(authenticated_client, db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_ics_sync_unauthorized(client, db_session): async def test_trigger_ics_sync_unauthorized(client):
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="sync-unauth-room", name="sync-unauth-room",
user_id="owner-123", user_id="owner-123",
zulip_auto_post=False, zulip_auto_post=False,
@@ -149,6 +157,7 @@ async def test_trigger_ics_sync_unauthorized(client, db_session):
is_shared=False, is_shared=False,
ics_url="https://calendar.example.com/api.ics", ics_url="https://calendar.example.com/api.ics",
ics_enabled=True, ics_enabled=True,
platform="daily",
) )
response = await client.post(f"/rooms/{room.name}/ics/sync") response = await client.post(f"/rooms/{room.name}/ics/sync")
@@ -157,10 +166,9 @@ async def test_trigger_ics_sync_unauthorized(client, db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_ics_sync_not_configured(authenticated_client, db_session): async def test_trigger_ics_sync_not_configured(authenticated_client):
client = authenticated_client client = authenticated_client
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="sync-not-configured", name="sync-not-configured",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -172,6 +180,7 @@ async def test_trigger_ics_sync_not_configured(authenticated_client, db_session)
recording_trigger="automatic-2nd-participant", recording_trigger="automatic-2nd-participant",
is_shared=False, is_shared=False,
ics_enabled=False, ics_enabled=False,
platform="daily",
) )
response = await client.post(f"/rooms/{room.name}/ics/sync") response = await client.post(f"/rooms/{room.name}/ics/sync")
@@ -180,10 +189,9 @@ async def test_trigger_ics_sync_not_configured(authenticated_client, db_session)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_ics_status(authenticated_client, db_session): async def test_get_ics_status(authenticated_client):
client = authenticated_client client = authenticated_client
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="status-room", name="status-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -197,11 +205,11 @@ async def test_get_ics_status(authenticated_client, db_session):
ics_url="https://calendar.example.com/status.ics", ics_url="https://calendar.example.com/status.ics",
ics_enabled=True, ics_enabled=True,
ics_fetch_interval=300, ics_fetch_interval=300,
platform="daily",
) )
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
await rooms_controller.update( await rooms_controller.update(
db_session,
room, room,
{"ics_last_sync": now, "ics_last_etag": "test-etag"}, {"ics_last_sync": now, "ics_last_etag": "test-etag"},
) )
@@ -215,9 +223,8 @@ async def test_get_ics_status(authenticated_client, db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_ics_status_unauthorized(client, db_session): async def test_get_ics_status_unauthorized(client):
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="status-unauth", name="status-unauth",
user_id="owner-456", user_id="owner-456",
zulip_auto_post=False, zulip_auto_post=False,
@@ -230,6 +237,7 @@ async def test_get_ics_status_unauthorized(client, db_session):
is_shared=False, is_shared=False,
ics_url="https://calendar.example.com/status.ics", ics_url="https://calendar.example.com/status.ics",
ics_enabled=True, ics_enabled=True,
platform="daily",
) )
response = await client.get(f"/rooms/{room.name}/ics/status") response = await client.get(f"/rooms/{room.name}/ics/status")
@@ -238,10 +246,9 @@ async def test_get_ics_status_unauthorized(client, db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_room_meetings(authenticated_client, db_session): async def test_list_room_meetings(authenticated_client):
client = authenticated_client client = authenticated_client
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="meetings-room", name="meetings-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -252,6 +259,7 @@ async def test_list_room_meetings(authenticated_client, db_session):
recording_type="cloud", recording_type="cloud",
recording_trigger="automatic-2nd-participant", recording_trigger="automatic-2nd-participant",
is_shared=False, is_shared=False,
platform="daily",
) )
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
@@ -262,7 +270,7 @@ async def test_list_room_meetings(authenticated_client, db_session):
start_time=now - timedelta(hours=2), start_time=now - timedelta(hours=2),
end_time=now - timedelta(hours=1), end_time=now - timedelta(hours=1),
) )
await calendar_events_controller.upsert(db_session, event1) await calendar_events_controller.upsert(event1)
event2 = CalendarEvent( event2 = CalendarEvent(
room_id=room.id, room_id=room.id,
@@ -273,7 +281,7 @@ async def test_list_room_meetings(authenticated_client, db_session):
end_time=now + timedelta(hours=2), end_time=now + timedelta(hours=2),
attendees=[{"email": "test@example.com"}], attendees=[{"email": "test@example.com"}],
) )
await calendar_events_controller.upsert(db_session, event2) await calendar_events_controller.upsert(event2)
response = await client.get(f"/rooms/{room.name}/meetings") response = await client.get(f"/rooms/{room.name}/meetings")
assert response.status_code == 200 assert response.status_code == 200
@@ -286,9 +294,8 @@ async def test_list_room_meetings(authenticated_client, db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_room_meetings_non_owner(client, db_session): async def test_list_room_meetings_non_owner(client):
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="meetings-privacy", name="meetings-privacy",
user_id="owner-789", user_id="owner-789",
zulip_auto_post=False, zulip_auto_post=False,
@@ -299,6 +306,7 @@ async def test_list_room_meetings_non_owner(client, db_session):
recording_type="cloud", recording_type="cloud",
recording_trigger="automatic-2nd-participant", recording_trigger="automatic-2nd-participant",
is_shared=False, is_shared=False,
platform="daily",
) )
event = CalendarEvent( event = CalendarEvent(
@@ -310,7 +318,7 @@ async def test_list_room_meetings_non_owner(client, db_session):
end_time=datetime.now(timezone.utc) + timedelta(hours=2), end_time=datetime.now(timezone.utc) + timedelta(hours=2),
attendees=[{"email": "private@example.com"}], attendees=[{"email": "private@example.com"}],
) )
await calendar_events_controller.upsert(db_session, event) await calendar_events_controller.upsert(event)
response = await client.get(f"/rooms/{room.name}/meetings") response = await client.get(f"/rooms/{room.name}/meetings")
assert response.status_code == 200 assert response.status_code == 200
@@ -322,10 +330,9 @@ async def test_list_room_meetings_non_owner(client, db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_list_upcoming_meetings(authenticated_client, db_session): async def test_list_upcoming_meetings(authenticated_client):
client = authenticated_client client = authenticated_client
room = await rooms_controller.add( room = await rooms_controller.add(
db_session,
name="upcoming-room", name="upcoming-room",
user_id="test-user", user_id="test-user",
zulip_auto_post=False, zulip_auto_post=False,
@@ -336,6 +343,7 @@ async def test_list_upcoming_meetings(authenticated_client, db_session):
recording_type="cloud", recording_type="cloud",
recording_trigger="automatic-2nd-participant", recording_trigger="automatic-2nd-participant",
is_shared=False, is_shared=False,
platform="daily",
) )
now = datetime.now(timezone.utc) now = datetime.now(timezone.utc)
@@ -347,7 +355,7 @@ async def test_list_upcoming_meetings(authenticated_client, db_session):
start_time=now - timedelta(hours=1), start_time=now - timedelta(hours=1),
end_time=now - timedelta(minutes=30), end_time=now - timedelta(minutes=30),
) )
await calendar_events_controller.upsert(db_session, past_event) await calendar_events_controller.upsert(past_event)
soon_event = CalendarEvent( soon_event = CalendarEvent(
room_id=room.id, room_id=room.id,
@@ -356,7 +364,7 @@ async def test_list_upcoming_meetings(authenticated_client, db_session):
start_time=now + timedelta(minutes=15), start_time=now + timedelta(minutes=15),
end_time=now + timedelta(minutes=45), end_time=now + timedelta(minutes=45),
) )
await calendar_events_controller.upsert(db_session, soon_event) await calendar_events_controller.upsert(soon_event)
later_event = CalendarEvent( later_event = CalendarEvent(
room_id=room.id, room_id=room.id,
@@ -365,7 +373,7 @@ async def test_list_upcoming_meetings(authenticated_client, db_session):
start_time=now + timedelta(hours=2), start_time=now + timedelta(hours=2),
end_time=now + timedelta(hours=3), end_time=now + timedelta(hours=3),
) )
await calendar_events_controller.upsert(db_session, later_event) await calendar_events_controller.upsert(later_event)
response = await client.get(f"/rooms/{room.name}/meetings/upcoming") response = await client.get(f"/rooms/{room.name}/meetings/upcoming")
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -2,40 +2,40 @@
import json import json
from datetime import datetime, timezone from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest import pytest
from sqlalchemy import delete, insert
from reflector.db.base import TranscriptModel from reflector.db import get_database
from reflector.db.search import ( from reflector.db.search import (
SearchController, SearchController,
SearchParameters, SearchParameters,
SearchResult, SearchResult,
search_controller, search_controller,
) )
from reflector.db.transcripts import SourceKind from reflector.db.transcripts import SourceKind, transcripts
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_postgresql_only(db_session): async def test_search_postgresql_only():
params = SearchParameters(query_text="any query here") params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(db_session, params) results, total = await search_controller.search_transcripts(params)
assert results == [] assert results == []
assert total == 0 assert total == 0
params_empty = SearchParameters(query_text=None) params_empty = SearchParameters(query_text=None)
results_empty, total_empty = await search_controller.search_transcripts( results_empty, total_empty = await search_controller.search_transcripts(
db_session, params_empty params_empty
) )
assert isinstance(results_empty, list) assert isinstance(results_empty, list)
assert isinstance(total_empty, int) assert isinstance(total_empty, int)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_with_empty_query(db_session): async def test_search_with_empty_query():
"""Test that empty query returns all transcripts.""" """Test that empty query returns all transcripts."""
params = SearchParameters(query_text=None) params = SearchParameters(query_text=None)
results, total = await search_controller.search_transcripts(db_session, params) results, total = await search_controller.search_transcripts(params)
assert isinstance(results, list) assert isinstance(results, list)
assert isinstance(total, int) assert isinstance(total, int)
@@ -45,13 +45,13 @@ async def test_search_with_empty_query(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_empty_transcript_title_only_match(db_session): async def test_empty_transcript_title_only_match():
"""Test that transcripts with title-only matches return empty snippets.""" """Test that transcripts with title-only matches return empty snippets."""
test_id = "test-empty-9b3f2a8d" test_id = "test-empty-9b3f2a8d"
try: try:
await db_session.execute( await get_database().execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id) transcripts.delete().where(transcripts.c.id == test_id)
) )
test_data = { test_data = {
@@ -77,11 +77,10 @@ async def test_empty_transcript_title_only_match(db_session):
"user_id": "test-user-1", "user_id": "test-user-1",
} }
await db_session.execute(insert(TranscriptModel).values(**test_data)) await get_database().execute(transcripts.insert().values(**test_data))
await db_session.commit()
params = SearchParameters(query_text="empty", user_id="test-user-1") params = SearchParameters(query_text="empty", user_id="test-user-1")
results, total = await search_controller.search_transcripts(db_session, params) results, total = await search_controller.search_transcripts(params)
assert total >= 1 assert total >= 1
found = next((r for r in results if r.id == test_id), None) found = next((r for r in results if r.id == test_id), None)
@@ -90,20 +89,20 @@ async def test_empty_transcript_title_only_match(db_session):
assert found.total_match_count == 0 assert found.total_match_count == 0
finally: finally:
await db_session.execute( await get_database().execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id) transcripts.delete().where(transcripts.c.id == test_id)
) )
await db_session.commit() await get_database().disconnect()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_with_long_summary(db_session): async def test_search_with_long_summary():
"""Test that long_summary content is searchable.""" """Test that long_summary content is searchable."""
test_id = "test-long-summary-8a9f3c2d" test_id = "test-long-summary-8a9f3c2d"
try: try:
await db_session.execute( await get_database().execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id) transcripts.delete().where(transcripts.c.id == test_id)
) )
test_data = { test_data = {
@@ -132,11 +131,10 @@ Basic meeting content without special keywords.""",
"user_id": "test-user-2", "user_id": "test-user-2",
} }
await db_session.execute(insert(TranscriptModel).values(**test_data)) await get_database().execute(transcripts.insert().values(**test_data))
await db_session.commit()
params = SearchParameters(query_text="quantum computing", user_id="test-user-2") params = SearchParameters(query_text="quantum computing", user_id="test-user-2")
results, total = await search_controller.search_transcripts(db_session, params) results, total = await search_controller.search_transcripts(params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
@@ -148,19 +146,19 @@ Basic meeting content without special keywords.""",
assert "quantum computing" in test_result.search_snippets[0].lower() assert "quantum computing" in test_result.search_snippets[0].lower()
finally: finally:
await db_session.execute( await get_database().execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id) transcripts.delete().where(transcripts.c.id == test_id)
) )
await db_session.commit() await get_database().disconnect()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_postgresql_search_with_data(db_session): async def test_postgresql_search_with_data():
test_id = "test-search-e2e-7f3a9b2c" test_id = "test-search-e2e-7f3a9b2c"
try: try:
await db_session.execute( await get_database().execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id) transcripts.delete().where(transcripts.c.id == test_id)
) )
test_data = { test_data = {
@@ -198,17 +196,16 @@ We need to implement PostgreSQL tsvector for better performance.""",
"user_id": "test-user-3", "user_id": "test-user-3",
} }
await db_session.execute(insert(TranscriptModel).values(**test_data)) await get_database().execute(transcripts.insert().values(**test_data))
await db_session.commit()
params = SearchParameters(query_text="planning", user_id="test-user-3") params = SearchParameters(query_text="planning", user_id="test-user-3")
results, total = await search_controller.search_transcripts(db_session, params) results, total = await search_controller.search_transcripts(params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by title word" assert found, "Should find test transcript by title word"
params = SearchParameters(query_text="tsvector", user_id="test-user-3") params = SearchParameters(query_text="tsvector", user_id="test-user-3")
results, total = await search_controller.search_transcripts(db_session, params) results, total = await search_controller.search_transcripts(params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by webvtt content" assert found, "Should find test transcript by webvtt content"
@@ -216,7 +213,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters( params = SearchParameters(
query_text="engineering planning", user_id="test-user-3" query_text="engineering planning", user_id="test-user-3"
) )
results, total = await search_controller.search_transcripts(db_session, params) results, total = await search_controller.search_transcripts(params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by multiple words" assert found, "Should find test transcript by multiple words"
@@ -231,7 +228,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters( params = SearchParameters(
query_text="tsvector OR nosuchword", user_id="test-user-3" query_text="tsvector OR nosuchword", user_id="test-user-3"
) )
results, total = await search_controller.search_transcripts(db_session, params) results, total = await search_controller.search_transcripts(params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find test transcript with OR query" assert found, "Should find test transcript with OR query"
@@ -239,16 +236,16 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters( params = SearchParameters(
query_text='"full-text search"', user_id="test-user-3" query_text='"full-text search"', user_id="test-user-3"
) )
results, total = await search_controller.search_transcripts(db_session, params) results, total = await search_controller.search_transcripts(params)
assert total >= 1 assert total >= 1
found = any(r.id == test_id for r in results) found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by exact phrase" assert found, "Should find test transcript by exact phrase"
finally: finally:
await db_session.execute( await get_database().execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id) transcripts.delete().where(transcripts.c.id == test_id)
) )
await db_session.commit() await get_database().disconnect()
@pytest.fixture @pytest.fixture
@@ -314,56 +311,87 @@ class TestSearchControllerFilters:
"""Test SearchController functionality with various filters.""" """Test SearchController functionality with various filters."""
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_with_source_kind_filter(self, db_session): async def test_search_with_source_kind_filter(self):
"""Test search filtering by source_kind.""" """Test search filtering by source_kind."""
controller = SearchController() controller = SearchController()
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE) with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
# This should not fail, even if no results are found params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
results, total = await controller.search_transcripts(db_session, params)
assert isinstance(results, list) results, total = await controller.search_transcripts(params)
assert isinstance(total, int)
assert total >= 0 assert results == []
assert total == 0
mock_db.return_value.fetch_all.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_with_single_room_id(self, db_session): async def test_search_with_single_room_id(self):
"""Test search filtering by single room ID (currently supported).""" """Test search filtering by single room ID (currently supported)."""
controller = SearchController() controller = SearchController()
params = SearchParameters( with (
query_text="test", patch("reflector.db.search.is_postgresql", return_value=True),
room_id="room1", patch("reflector.db.search.get_database") as mock_db,
) ):
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
# This should not fail, even if no results are found params = SearchParameters(
results, total = await controller.search_transcripts(db_session, params) query_text="test",
room_id="room1",
)
assert isinstance(results, list) results, total = await controller.search_transcripts(params)
assert isinstance(total, int)
assert total >= 0 assert results == []
assert total == 0
mock_db.return_value.fetch_all.assert_called_once()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_search_result_includes_available_fields( async def test_search_result_includes_available_fields(self, mock_db_result):
self, db_session, mock_db_result
):
"""Test that search results include available fields like source_kind.""" """Test that search results include available fields like source_kind."""
# Test that the search method works and returns SearchResult objects
controller = SearchController() controller = SearchController()
params = SearchParameters(query_text="test") with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
results, total = await controller.search_transcripts(db_session, params) class MockRow:
def __init__(self, data):
self._data = data
self._mapping = data
assert isinstance(results, list) def __iter__(self):
assert isinstance(total, int) return iter(self._data.items())
assert total >= 0
# If any results exist, verify they are SearchResult objects def __getitem__(self, key):
for result in results: return self._data[key]
def keys(self):
return self._data.keys()
mock_row = MockRow(mock_db_result)
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
mock_db.return_value.fetch_val = AsyncMock(return_value=1)
params = SearchParameters(query_text="test")
results, total = await controller.search_transcripts(params)
assert total == 1
assert len(results) == 1
result = results[0]
assert isinstance(result, SearchResult) assert isinstance(result, SearchResult)
assert hasattr(result, "id") assert result.id == "test-transcript-id"
assert hasattr(result, "title") assert result.title == "Test Transcript"
assert hasattr(result, "rank") assert result.rank == 0.95
assert hasattr(result, "source_kind")
class TestSearchEndpointParsing: class TestSearchEndpointParsing:

Some files were not shown because too many files have changed in this diff Show More