Compare commits

..

9 Commits

Author SHA1 Message Date
af86c47f1d chore(main): release 0.14.0 (#670) 2025-10-08 14:57:31 -06:00
5f6910e513 feat: Add calendar event data to transcript webhook payload (#689)
* feat: add calendar event data to transcript webhook payload and implement get_by_id method

* Update server/reflector/worker/webhook.py

Co-authored-by: pr-agent-monadical[bot] <198624643+pr-agent-monadical[bot]@users.noreply.github.com>

* Update server/reflector/worker/webhook.py

Co-authored-by: pr-agent-monadical[bot] <198624643+pr-agent-monadical[bot]@users.noreply.github.com>

* style: format conditional time fields with line breaks for better readability

* docs: add calendar event fields to transcript.completed webhook payload schema

---------

Co-authored-by: pr-agent-monadical[bot] <198624643+pr-agent-monadical[bot]@users.noreply.github.com>
2025-10-08 11:11:57 -05:00
9a71af145e fix: update transcript list on reprocess (#676)
* Update transcript list on reprocess

* Fix transcript create

* Fix multiple sockets issue

* Pass token in sec websocket protocol

* userEvent parse example

* transcript list invalidation non-abstraction

* Emit only relevant events to the user room

* Add ws close code const

* Refactor user websocket endpoint

* Refactor user events provider

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-10-07 19:11:30 +02:00
eef6dc3903 fix: upgrade nemo toolkit (#678) 2025-10-07 16:45:02 +02:00
Igor Monadical
1dee255fed parakeet endpoint doc (#679)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-10-07 10:41:01 -04:00
5d98754305 fix: security review (#656)
* Add security review doc

* Add tests to reproduce security issues

* Fix security issues

* Fix tests

* Set auth auth backend for tests

* Fix ics api tests

* Fix transcript mutate check

* Update frontent env var names

* Remove permissions doc
2025-09-29 23:07:49 +02:00
Igor Monadical
969bd84fcc feat: container build for www / github (#672)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-09-24 12:27:45 -04:00
Igor Monadical
36608849ec fix: restore feature boolean logic (#671)
Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-09-24 11:57:49 -04:00
Igor Monadical
5bf64b5a41 feat: docker-compose for production frontend (#664)
* docker-compose for production frontend

* fix: Remove external Redis port mapping for Coolify compatibility

Redis should only be accessible within the internal Docker network in Coolify deployments to avoid port conflicts with other applications.

* fix: Remove external port mapping for web service in Coolify

Coolify handles port exposure through its proxy (Traefik), so services should not expose ports directly in the docker-compose file.

* server side client envs

* missing vars

* nextjs experimental

* fix claude 'fix'

* remove build env vars compose

* docker

* remove ports for coolify

* review

* cleanup

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-09-24 11:15:27 -04:00
103 changed files with 5426 additions and 5126 deletions

57
.github/workflows/docker-frontend.yml vendored Normal file
View File

@@ -0,0 +1,57 @@
name: Build and Push Frontend Docker Image
on:
push:
branches:
- main
paths:
- 'www/**'
- '.github/workflows/docker-frontend.yml'
workflow_dispatch:
env:
REGISTRY: ghcr.io
IMAGE_NAME: ${{ github.repository }}-frontend
jobs:
build-and-push:
runs-on: ubuntu-latest
permissions:
contents: read
packages: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Log in to GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Extract metadata
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
tags: |
type=ref,event=branch
type=sha,prefix={{branch}}-
type=raw,value=latest,enable={{is_default_branch}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build and push Docker image
uses: docker/build-push-action@v5
with:
context: ./www
file: ./www/Dockerfile
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha
cache-to: type=gha,mode=max
platforms: linux/amd64,linux/arm64

View File

@@ -1,5 +1,22 @@
# Changelog
## [0.14.0](https://github.com/Monadical-SAS/reflector/compare/v0.13.1...v0.14.0) (2025-10-08)
### Features
* Add calendar event data to transcript webhook payload ([#689](https://github.com/Monadical-SAS/reflector/issues/689)) ([5f6910e](https://github.com/Monadical-SAS/reflector/commit/5f6910e5131b7f28f86c9ecdcc57fed8412ee3cd))
* container build for www / github ([#672](https://github.com/Monadical-SAS/reflector/issues/672)) ([969bd84](https://github.com/Monadical-SAS/reflector/commit/969bd84fcc14851d1a101412a0ba115f1b7cde82))
* docker-compose for production frontend ([#664](https://github.com/Monadical-SAS/reflector/issues/664)) ([5bf64b5](https://github.com/Monadical-SAS/reflector/commit/5bf64b5a41f64535e22849b4bb11734d4dbb4aae))
### Bug Fixes
* restore feature boolean logic ([#671](https://github.com/Monadical-SAS/reflector/issues/671)) ([3660884](https://github.com/Monadical-SAS/reflector/commit/36608849ec64e953e3be456172502762e3c33df9))
* security review ([#656](https://github.com/Monadical-SAS/reflector/issues/656)) ([5d98754](https://github.com/Monadical-SAS/reflector/commit/5d98754305c6c540dd194dda268544f6d88bfaf8))
* update transcript list on reprocess ([#676](https://github.com/Monadical-SAS/reflector/issues/676)) ([9a71af1](https://github.com/Monadical-SAS/reflector/commit/9a71af145ee9b833078c78d0c684590ab12e9f0e))
* upgrade nemo toolkit ([#678](https://github.com/Monadical-SAS/reflector/issues/678)) ([eef6dc3](https://github.com/Monadical-SAS/reflector/commit/eef6dc39037329b65804297786d852dddb0557f9))
## [0.13.1](https://github.com/Monadical-SAS/reflector/compare/v0.13.0...v0.13.1) (2025-09-22)

View File

@@ -151,7 +151,7 @@ All endpoints prefixed `/v1/`:
**Frontend** (`www/.env`):
- `NEXTAUTH_URL`, `NEXTAUTH_SECRET` - Authentication configuration
- `NEXT_PUBLIC_REFLECTOR_API_URL` - Backend API endpoint
- `REFLECTOR_API_URL` - Backend API endpoint
- `REFLECTOR_DOMAIN_CONFIG` - Feature flags and domain settings
## Testing Strategy

View File

@@ -168,6 +168,13 @@ You can manually process an audio file by calling the process tool:
uv run python -m reflector.tools.process path/to/audio.wav
```
## Build-time env variables
Next.js projects are more used to NEXT_PUBLIC_ prefixed buildtime vars. We don't have those for the reason we need to serve a ccustomizable prebuild docker container.
Instead, all the variables are runtime. Variables needed to the frontend are served to the frontend app at initial render.
It also means there's no static prebuild and no static files to serve for js/html.
## Feature Flags
@@ -177,24 +184,24 @@ Reflector uses environment variable-based feature flags to control application f
| Feature Flag | Environment Variable |
|-------------|---------------------|
| `requireLogin` | `NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN` |
| `privacy` | `NEXT_PUBLIC_FEATURE_PRIVACY` |
| `browse` | `NEXT_PUBLIC_FEATURE_BROWSE` |
| `sendToZulip` | `NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP` |
| `rooms` | `NEXT_PUBLIC_FEATURE_ROOMS` |
| `requireLogin` | `FEATURE_REQUIRE_LOGIN` |
| `privacy` | `FEATURE_PRIVACY` |
| `browse` | `FEATURE_BROWSE` |
| `sendToZulip` | `FEATURE_SEND_TO_ZULIP` |
| `rooms` | `FEATURE_ROOMS` |
### Setting Feature Flags
Feature flags are controlled via environment variables using the pattern `NEXT_PUBLIC_FEATURE_{FEATURE_NAME}` where `{FEATURE_NAME}` is the SCREAMING_SNAKE_CASE version of the feature name.
Feature flags are controlled via environment variables using the pattern `FEATURE_{FEATURE_NAME}` where `{FEATURE_NAME}` is the SCREAMING_SNAKE_CASE version of the feature name.
**Examples:**
```bash
# Enable user authentication requirement
NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN=true
FEATURE_REQUIRE_LOGIN=true
# Disable browse functionality
NEXT_PUBLIC_FEATURE_BROWSE=false
FEATURE_BROWSE=false
# Enable Zulip integration
NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP=true
FEATURE_SEND_TO_ZULIP=true
```

39
docker-compose.prod.yml Normal file
View File

@@ -0,0 +1,39 @@
# Production Docker Compose configuration for Frontend
# Usage: docker compose -f docker-compose.prod.yml up -d
services:
web:
build:
context: ./www
dockerfile: Dockerfile
image: reflector-frontend:latest
environment:
- KV_URL=${KV_URL:-redis://redis:6379}
- SITE_URL=${SITE_URL}
- API_URL=${API_URL}
- WEBSOCKET_URL=${WEBSOCKET_URL}
- NEXTAUTH_URL=${NEXTAUTH_URL:-http://localhost:3000}
- NEXTAUTH_SECRET=${NEXTAUTH_SECRET:-changeme-in-production}
- AUTHENTIK_ISSUER=${AUTHENTIK_ISSUER}
- AUTHENTIK_CLIENT_ID=${AUTHENTIK_CLIENT_ID}
- AUTHENTIK_CLIENT_SECRET=${AUTHENTIK_CLIENT_SECRET}
- AUTHENTIK_REFRESH_TOKEN_URL=${AUTHENTIK_REFRESH_TOKEN_URL}
- SENTRY_DSN=${SENTRY_DSN}
- SENTRY_IGNORE_API_RESOLUTION_ERROR=${SENTRY_IGNORE_API_RESOLUTION_ERROR:-1}
depends_on:
- redis
restart: unless-stopped
redis:
image: redis:7.2-alpine
restart: unless-stopped
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 30s
timeout: 3s
retries: 3
volumes:
- redis_data:/data
volumes:
redis_data:

View File

@@ -39,7 +39,7 @@ services:
ports:
- 6379:6379
web:
image: node:18
image: node:22-alpine
ports:
- "3000:3000"
command: sh -c "corepack enable && pnpm install && pnpm dev"
@@ -50,6 +50,8 @@ services:
- /app/node_modules
env_file:
- ./www/.env.local
environment:
- NODE_ENV=development
postgres:
image: postgres:17

View File

@@ -77,7 +77,7 @@ image = (
.pip_install(
"hf_transfer==0.1.9",
"huggingface_hub[hf-xet]==0.31.2",
"nemo_toolkit[asr]==2.3.0",
"nemo_toolkit[asr]==2.5.0",
"cuda-python==12.8.0",
"fastapi==0.115.12",
"numpy<2",

View File

@@ -1,118 +0,0 @@
# AsyncIO Event Loop Analysis for test_attendee_parsing_bug.py
## Problem Summary
The test passes but encounters an error during teardown where asyncpg tries to use a different/closed event loop, resulting in:
- `RuntimeError: Task got Future attached to a different loop`
- `RuntimeError: Event loop is closed`
## Root Cause Analysis
### 1. Multiple Event Loop Creation Points
The test environment creates event loops at different scopes:
1. **Session-scoped loop** (conftest.py:27-34):
- Created once per test session
- Used by session-scoped fixtures
- Closed after all tests complete
2. **Function-scoped loop** (pytest-asyncio default):
- Created for each async test function
- This is the loop that runs the actual test
- Closed immediately after test completes
3. **AsyncPG internal loop**:
- AsyncPG connections store a reference to the loop they were created with
- Used for connection lifecycle management
### 2. Event Loop Lifecycle Mismatch
The issue occurs because:
1. **Session fixture creates database connection** on session-scoped loop
2. **Test runs** on function-scoped loop (different from session loop)
3. **During teardown**, the session fixture tries to rollback/close using the original session loop
4. **AsyncPG connection** still references the function-scoped loop which is now closed
5. **Conflict**: SQLAlchemy tries to use session loop, but asyncpg Future is attached to the closed function loop
### 3. Configuration Issues
Current pytest configuration:
- `asyncio_mode = "auto"` in pyproject.toml
- `asyncio_default_fixture_loop_scope=session` (shown in test output)
- `asyncio_default_test_loop_scope=function` (shown in test output)
This mismatch between fixture loop scope (session) and test loop scope (function) causes the problem.
## Solutions
### Option 1: Align Loop Scopes (Recommended)
Change pytest-asyncio configuration to use consistent loop scopes:
```python
# pyproject.toml
[tool.pytest.ini_options]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "function" # Change from session to function
```
### Option 2: Use Function-Scoped Database Fixture
Change the `session` fixture scope from session to function:
```python
@pytest_asyncio.fixture # Remove scope="session"
async def session(setup_database):
# ... existing code ...
```
### Option 3: Explicit Loop Management
Ensure all async operations use the same loop:
```python
@pytest_asyncio.fixture
async def session(setup_database, event_loop):
# Force using the current event loop
engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
poolclass=NullPool,
connect_args={"loop": event_loop} # Pass explicit loop
)
# ... rest of fixture ...
```
### Option 4: Upgrade pytest-asyncio
The current version (1.1.0) has known issues with loop management. Consider upgrading to the latest version which has better loop scope handling.
## Immediate Workaround
For the test to run cleanly without the teardown error, you can:
1. Add explicit cleanup in the test:
```python
@pytest.mark.asyncio
async def test_attendee_parsing_bug(session):
# ... existing test code ...
# Explicit cleanup before fixture teardown
await session.commit() # or await session.close()
```
2. Or suppress the teardown error (not recommended for production):
```python
@pytest.fixture
async def session(setup_database):
# ... existing setup ...
try:
yield session
await session.rollback()
except RuntimeError as e:
if "Event loop is closed" not in str(e):
raise
finally:
await session.close()
```
## Recommendation
The cleanest solution is to align the loop scopes by setting both fixture and test loop scopes to "function" scope. This ensures each test gets its own clean event loop and avoids cross-contamination between tests.

View File

@@ -14,7 +14,7 @@ Webhooks are configured at the room level with two fields:
### `transcript.completed`
Triggered when a transcript has been fully processed, including transcription, diarization, summarization, and topic detection.
Triggered when a transcript has been fully processed, including transcription, diarization, summarization, topic detection and calendar event integration.
### `test`
@@ -128,6 +128,27 @@ This event includes a convenient URL for accessing the transcript:
"room": {
"id": "room-789",
"name": "Product Team Room"
},
"calendar_event": {
"id": "calendar-event-123",
"ics_uid": "event-123",
"title": "Q3 Product Planning Meeting",
"start_time": "2025-08-27T12:00:00Z",
"end_time": "2025-08-27T12:30:00Z",
"description": "Team discussed Q3 product roadmap, prioritizing mobile app features and API improvements.",
"location": "Conference Room 1",
"attendees": [
{
"id": "participant-1",
"name": "John Doe",
"speaker": "Speaker 1"
},
{
"id": "participant-2",
"name": "Jane Smith",
"speaker": "Speaker 2"
}
]
}
}
```

View File

@@ -27,7 +27,7 @@ AUTH_JWT_AUDIENCE=
#TRANSCRIPT_MODAL_API_KEY=xxxxx
TRANSCRIPT_BACKEND=modal
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web.modal.run
TRANSCRIPT_MODAL_API_KEY=
## =======================================================

View File

@@ -1,583 +0,0 @@
# Celery to TaskIQ Migration Guide
## Executive Summary
This document outlines the migration path from Celery to TaskIQ for the Reflector project. TaskIQ is a modern, async-first distributed task queue that provides similar functionality to Celery while being designed specifically for async Python applications.
## Current Celery Usage Analysis
### Key Patterns in Use
1. **Task Decorators**: `@shared_task`, `@asynctask`, `@with_session` decorators
2. **Task Invocation**: `.delay()`, `.si()` for signatures
3. **Workflow Patterns**: `chain()`, `group()`, `chord()` for complex pipelines
4. **Scheduled Tasks**: Celery Beat with crontab and periodic schedules
5. **Session Management**: Custom `@with_session` and `@with_session_and_transcript` decorators
6. **Retry Logic**: Auto-retry with exponential backoff
7. **Redis Backend**: Using Redis for broker and result backend
### Critical Files to Migrate
- `reflector/worker/app.py` - Celery app configuration and beat schedule
- `reflector/worker/session_decorator.py` - Session management decorators
- `reflector/pipelines/main_file_pipeline.py` - File processing pipeline
- `reflector/pipelines/main_live_pipeline.py` - Live streaming pipeline (10 tasks)
- `reflector/worker/process.py` - Background processing tasks
- `reflector/worker/ics_sync.py` - Calendar sync tasks
- `reflector/worker/cleanup.py` - Cleanup tasks
- `reflector/worker/webhook.py` - Webhook notifications
## TaskIQ Architecture Mapping
### 1. Installation
```bash
# Remove Celery dependencies
uv remove celery flower
# Install TaskIQ with Redis support
uv add taskiq taskiq-redis taskiq-pipelines
```
### 2. Broker Configuration
#### Current (Celery)
```python
# reflector/worker/app.py
from celery import Celery
app = Celery(
"reflector",
broker=settings.CELERY_BROKER_URL,
backend=settings.CELERY_RESULT_BACKEND,
include=[...],
)
```
#### New (TaskIQ)
```python
# reflector/worker/broker.py
from taskiq_redis import RedisAsyncResultBackend, RedisStreamBroker
from taskiq import PipelineMiddleware, SimpleRetryMiddleware
result_backend = RedisAsyncResultBackend(
redis_url=settings.REDIS_URL,
result_ex_time=86400, # 24 hours
)
broker = RedisStreamBroker(
url=settings.REDIS_URL,
max_connection_pool_size=10,
).with_result_backend(result_backend).with_middlewares(
PipelineMiddleware(), # For chain/group/chord support
SimpleRetryMiddleware(default_retry_count=3),
)
# For testing environment
if os.environ.get("ENVIRONMENT") == "pytest":
from taskiq import InMemoryBroker
broker = InMemoryBroker(await_inplace=True)
```
### 3. Task Definition Migration
#### Current (Celery)
```python
@shared_task
@asynctask
@with_session
async def task_pipeline_file_process(session: AsyncSession, transcript_id: str):
pipeline = PipelineMainFile(transcript_id=transcript_id)
await pipeline.process()
```
#### New (TaskIQ)
```python
from taskiq import TaskiqDepends
from reflector.worker.broker import broker
from reflector.worker.dependencies import get_db_session
@broker.task
async def task_pipeline_file_process(transcript_id: str):
# Use get_session for proper test mocking
async for session in get_session():
pipeline = PipelineMainFile(transcript_id=transcript_id)
await pipeline.process()
```
### 4. Session Management
#### Current Session Decorators (Keep Using These!)
```python
# reflector/worker/session_decorator.py
def with_session(func):
@functools.wraps(func)
async def wrapper(*args, **kwargs):
async with get_session_context() as session:
return await func(session, *args, **kwargs)
return wrapper
```
#### Session Management Strategy
**⚠️ CRITICAL**: The key insight is to maintain consistent session management patterns:
1. **For Worker Tasks**: Continue using `@with_session` decorator pattern
2. **For FastAPI endpoints**: Use `get_session` dependency injection
3. **Never use `get_session_factory()` directly** in application code
```python
# APPROACH 1: Simple migration keeping decorator pattern
from reflector.worker.session_decorator import with_session
@taskiq_broker.task
@with_session
async def task_pipeline_file_process(session, *, transcript_id: str):
# Session is provided by decorator, just like Celery version
transcript = await transcripts_controller.get_by_id(session, transcript_id)
pipeline = PipelineMainFile(transcript_id=transcript_id)
await pipeline.process()
# APPROACH 2: For test compatibility without decorator
from reflector.db import get_session
@taskiq_broker.task
async def task_pipeline_file_process(transcript_id: str):
# Use get_session which is mocked in tests
async for session in get_session():
transcript = await transcripts_controller.get_by_id(session, transcript_id)
pipeline = PipelineMainFile(transcript_id=transcript_id)
await pipeline.process()
# APPROACH 3: Future - TaskIQ dependency injection (after full migration)
from taskiq import TaskiqDepends
async def get_session_context():
"""Context manager version of get_session for consistency"""
async for session in get_session():
yield session
@taskiq_broker.task
async def task_pipeline_file_process(
transcript_id: str,
session: AsyncSession = TaskiqDepends(get_session_context)
):
transcript = await transcripts_controller.get_by_id(session, transcript_id)
pipeline = PipelineMainFile(transcript_id=transcript_id)
await pipeline.process()
```
**Key Points:**
- `@with_session` decorator works with TaskIQ tasks (remove `@asynctask`, keep `@with_session`)
- For testing: `get_session()` from `reflector.db` is properly mocked
- Never call `get_session_factory()` directly - always use the abstractions
### 5. Task Invocation
#### Current (Celery)
```python
# Simple async execution
task_pipeline_file_process.delay(transcript_id=transcript.id)
# With signature for chaining
task_cleanup_consent.si(transcript_id=transcript_id)
```
#### New (TaskIQ)
```python
# Simple async execution
await task_pipeline_file_process.kiq(transcript_id=transcript.id)
# With kicker for advanced configuration
await task_cleanup_consent.kicker().with_labels(
priority="high"
).kiq(transcript_id=transcript_id)
```
### 6. Workflow Patterns (Chain, Group, Chord)
#### Current (Celery)
```python
from celery import chain, group, chord
# Chain example
post_chain = chain(
task_cleanup_consent.si(transcript_id=transcript_id),
task_pipeline_post_to_zulip.si(transcript_id=transcript_id),
task_send_webhook_if_needed.si(transcript_id=transcript_id),
)
# Chord example (parallel + callback)
chain = chord(
group(chain_mp3_and_diarize, chain_title_preview),
chain_final_summaries,
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
```
#### New (TaskIQ with Pipelines)
```python
from taskiq_pipelines import Pipeline
from taskiq import gather
# Chain example using Pipeline
post_pipeline = (
Pipeline(broker, task_cleanup_consent)
.call_next(task_pipeline_post_to_zulip, transcript_id=transcript_id)
.call_next(task_send_webhook_if_needed, transcript_id=transcript_id)
)
await post_pipeline.kiq(transcript_id=transcript_id)
# Parallel execution with gather
results = await gather([
chain_mp3_and_diarize.kiq(transcript_id),
chain_title_preview.kiq(transcript_id),
])
# Then execute callback
await chain_final_summaries.kiq(transcript_id, results)
await task_pipeline_post_to_zulip.kiq(transcript_id)
```
### 7. Scheduled Tasks (Celery Beat → TaskIQ Scheduler)
#### Current (Celery Beat)
```python
# reflector/worker/app.py
app.conf.beat_schedule = {
"process_messages": {
"task": "reflector.worker.process.process_messages",
"schedule": float(settings.SQS_POLLING_TIMEOUT_SECONDS),
},
"reprocess_failed_recordings": {
"task": "reflector.worker.process.reprocess_failed_recordings",
"schedule": crontab(hour=5, minute=0),
},
}
```
#### New (TaskIQ Scheduler)
```python
# reflector/worker/scheduler.py
from taskiq import TaskiqScheduler
from taskiq_redis import ListRedisScheduleSource
schedule_source = ListRedisScheduleSource(settings.REDIS_URL)
# Define scheduled tasks with decorators
@broker.task(
schedule=[
{
"cron": f"*/{int(settings.SQS_POLLING_TIMEOUT_SECONDS)} * * * * *"
}
]
)
async def process_messages():
# Task implementation
pass
@broker.task(
schedule=[{"cron": "0 5 * * *"}] # Daily at 5 AM
)
async def reprocess_failed_recordings():
# Task implementation
pass
# Initialize scheduler
scheduler = TaskiqScheduler(broker, sources=[schedule_source])
# Run scheduler (separate process)
# taskiq scheduler reflector.worker.scheduler:scheduler
```
### 8. Retry Configuration
#### Current (Celery)
```python
@shared_task(
bind=True,
max_retries=30,
default_retry_delay=60,
retry_backoff=True,
retry_backoff_max=3600,
)
async def task_send_webhook_if_needed(self, ...):
try:
# Task logic
except Exception as exc:
raise self.retry(exc=exc)
```
#### New (TaskIQ)
```python
from taskiq.middlewares import SimpleRetryMiddleware
# Global middleware configuration (1:1 with Celery defaults)
broker = broker.with_middlewares(
SimpleRetryMiddleware(default_retry_count=3),
)
# For specific tasks with custom retry logic:
@broker.task(retry_on_error=True, max_retries=30)
async def task_send_webhook_if_needed(...):
# Task logic - exceptions auto-retry
pass
```
## Testing Migration
### Current Pytest Setup (Celery)
```python
# tests/conftest.py
@pytest.fixture(scope="session")
def celery_config():
return {
"broker_url": "memory://",
"result_backend": "cache+memory://",
}
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
async def test_task():
pass
```
### New Pytest Setup (TaskIQ)
```python
# tests/conftest.py
import pytest
from taskiq import InMemoryBroker
from reflector.worker.broker import broker
@pytest.fixture(scope="function", autouse=True)
async def setup_taskiq_broker():
"""Replace broker with InMemoryBroker for testing"""
original_broker = broker
test_broker = InMemoryBroker(await_inplace=True)
# Copy task registrations
for task_name, task in original_broker._tasks.items():
test_broker.register_task(task.original_function, task_name=task_name)
yield test_broker
await test_broker.shutdown()
@pytest.fixture
async def taskiq_with_db_session(db_session):
"""Setup TaskIQ with database session"""
from reflector.worker.broker import broker
broker.add_dependency_context({
AsyncSession: db_session
})
yield
broker.custom_dependency_context = {}
# Test example
@pytest.mark.anyio
async def test_task(taskiq_with_db_session):
result = await task_pipeline_file_process("transcript-id")
assert result is not None
```
## Migration Steps
### Phase 1: Setup (Week 1)
1. **Install TaskIQ packages**
```bash
uv add taskiq taskiq-redis taskiq-pipelines
```
2. **Create new broker configuration**
- Create `reflector/worker/broker.py` with TaskIQ broker setup
- Create `reflector/worker/dependencies.py` for dependency injection
3. **Update settings**
- Keep existing Redis configuration
- Add TaskIQ-specific settings if needed
### Phase 2: Parallel Running (Week 2-3)
1. **Migrate simple tasks first**
- Start with `cleanup.py` (1 task)
- Move to `webhook.py` (1 task)
- Test thoroughly in isolation
2. **Setup dual-mode operation**
- Keep Celery tasks running
- Add TaskIQ versions alongside
- Use feature flags to switch between them
### Phase 3: Complex Tasks (Week 3-4)
1. **Migrate pipeline tasks**
- Convert `main_file_pipeline.py`
- Convert `main_live_pipeline.py` (most complex with 10 tasks)
- Ensure chain/group/chord patterns work
2. **Migrate scheduled tasks**
- Setup TaskIQ scheduler
- Convert beat schedule to TaskIQ schedules
- Test cron patterns
### Phase 4: Testing & Validation (Week 4-5)
1. **Update test suite**
- Replace Celery fixtures with TaskIQ fixtures
- Update all test files
- Ensure coverage remains the same
2. **Performance testing**
- Compare task execution times
- Monitor Redis memory usage
- Test under load
### Phase 5: Cutover (Week 5-6)
1. **Final migration**
- Remove Celery dependencies
- Update deployment scripts
- Update documentation
2. **Monitoring**
- Setup TaskIQ monitoring (if available)
- Create health checks
- Document operational procedures
## Key Differences to Note
### Advantages of TaskIQ
1. **Native async support** - No need for `@asynctask` wrapper
2. **Dependency injection** - Cleaner than decorators for session management
3. **Type hints** - Better IDE support and autocompletion
4. **Modern Python** - Designed for Python 3.7+
5. **Simpler testing** - InMemoryBroker makes testing easier
### Potential Challenges
1. **Less mature ecosystem** - Fewer third-party integrations
2. **Documentation** - Less comprehensive than Celery
3. **Monitoring tools** - No Flower equivalent (may need custom solution)
4. **Community support** - Smaller community than Celery
## Command Line Changes
### Current (Celery)
```bash
# Start worker
celery -A reflector.worker.app worker --loglevel=info
# Start beat scheduler
celery -A reflector.worker.app beat
```
### New (TaskIQ)
```bash
# Start worker
taskiq worker reflector.worker.broker:broker
# Start scheduler
taskiq scheduler reflector.worker.scheduler:scheduler
# With custom settings
taskiq worker reflector.worker.broker:broker --workers 4 --log-level INFO
```
## Rollback Plan
If issues arise during migration:
1. **Keep Celery code in version control** - Tag the last Celery version
2. **Maintain dual broker setup** - Can switch back via environment variable
3. **Database compatibility** - No schema changes required
4. **Redis compatibility** - Both use Redis, easy to switch back
## Success Criteria
1. ✅ All tasks migrated and functioning
2. ✅ Test coverage maintained at current levels
3. ✅ Performance equal or better than Celery
4. ✅ Scheduled tasks running reliably
5. ✅ Error handling and retries working correctly
6. ✅ WebSocket notifications still functioning
7. ✅ Pipeline processing maintaining same behavior
## Monitoring & Operations
### Health Checks
```python
# reflector/worker/healthcheck.py
@broker.task
async def healthcheck_ping():
"""TaskIQ health check task"""
return {"status": "healthy", "timestamp": datetime.now()}
```
### Metrics Collection
- Task execution times
- Success/failure rates
- Queue depths
- Worker utilization
## Key Implementation Points - MUST READ
### Critical Changes Required
1. **Session Management in Tasks**
- ✅ **VERIFIED**: Tasks MUST use `get_session()` from `reflector.db` for test compatibility
- ❌ Do NOT use `get_session_factory()` directly in tasks - it bypasses test mocks
- ✅ The test database session IS properly shared when using `get_session()`
2. **Task Invocation Changes**
- Replace `.delay()` with `await .kiq()`
- All task invocations become async/await
- No need to commit sessions before task invocation (controllers handle this)
3. **Broker Configuration**
- TaskIQ broker must be initialized in `worker/app.py`
- Use `InMemoryBroker(await_inplace=True)` for testing
- Use `RedisStreamBroker` for production
4. **Test Setup Requirements**
- Set `os.environ["ENVIRONMENT"] = "pytest"` at top of test files
- Add TaskIQ broker fixture to test functions
- Keep Celery fixtures for now (dual-mode operation)
5. **Import Pattern Changes**
```python
# Each file needs both imports during migration
from reflector.pipelines.main_file_pipeline import (
task_pipeline_file_process, # Celery version
task_pipeline_file_process_taskiq, # TaskIQ version
)
```
6. **Decorator Changes**
- Remove `@asynctask` - TaskIQ is async-native
- **Keep `@with_session`** - it works with TaskIQ tasks!
- Remove `@shared_task` from TaskIQ version
- Keep `@shared_task` on Celery version for backward compatibility
## Verified POC Results
✅ **Database transactions work correctly** across test and TaskIQ tasks
✅ **Tasks execute immediately** in tests with `InMemoryBroker(await_inplace=True)`
✅ **Session mocking works** when using `get_session()` properly
✅ **"OK" output confirmed** - TaskIQ task executes and accesses test data
## Conclusion
The migration from Celery to TaskIQ is feasible and offers several advantages for an async-first codebase like Reflector. The key challenges will be:
1. Migrating complex pipeline patterns (chain/chord)
2. Ensuring scheduled task reliability
3. **SOLVED**: Maintaining session management patterns - use `get_session()`
4. Updating the test suite
The phased approach allows for gradual migration with minimal risk. The ability to run both systems in parallel provides a safety net during the transition period.
## Appendix: Quick Reference
| Celery | TaskIQ |
|--------|--------|
| `@shared_task` | `@broker.task` |
| `.delay()` | `.kiq()` |
| `.apply_async()` | `.kicker().kiq()` |
| `chain()` | `Pipeline()` |
| `group()` | `gather()` |
| `chord()` | `gather() + callback` |
| `@task.retry()` | `retry_on_error=True` |
| Celery Beat | TaskIQ Scheduler |
| `celery worker` | `taskiq worker` |
| Flower | Custom monitoring needed |

View File

@@ -3,7 +3,7 @@ from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
from reflector.db.base import metadata
from reflector.db import metadata
from reflector.settings import settings
# this is the Alembic Config object, which provides

View File

@@ -23,16 +23,14 @@ def upgrade() -> None:
op.drop_column("transcript", "search_vector_en")
# Recreate the search vector column with long_summary included
op.execute(
"""
op.execute("""
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
GENERATED ALWAYS AS (
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') ||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')
) STORED
"""
)
""")
# Recreate the GIN index for the search vector
op.create_index(
@@ -49,15 +47,13 @@ def downgrade() -> None:
op.drop_column("transcript", "search_vector_en")
# Recreate the original search vector column without long_summary
op.execute(
"""
op.execute("""
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
GENERATED ALWAYS AS (
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
) STORED
"""
)
""")
# Recreate the GIN index for the search vector
op.create_index(

View File

@@ -21,15 +21,13 @@ def upgrade() -> None:
if conn.dialect.name != "postgresql":
return
op.execute(
"""
op.execute("""
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
GENERATED ALWAYS AS (
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
) STORED
"""
)
""")
op.create_index(
"idx_transcript_search_vector_en",

View File

@@ -19,14 +19,12 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Set room_id to NULL for meetings that reference non-existent rooms
op.execute(
"""
op.execute("""
UPDATE meeting
SET room_id = NULL
WHERE room_id IS NOT NULL
AND room_id NOT IN (SELECT id FROM room WHERE id IS NOT NULL)
"""
)
""")
def downgrade() -> None:

View File

@@ -28,7 +28,7 @@ def upgrade() -> None:
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select(transcript.c.id, transcript.c.topics))
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]
@@ -58,7 +58,7 @@ def downgrade() -> None:
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select(transcript.c.id, transcript.c.topics))
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]

View File

@@ -36,7 +36,9 @@ def upgrade() -> None:
# select only the one with duration = 0
results = bind.execute(
select(transcript.c.id, transcript.c.duration).where(transcript.c.duration == 0)
select([transcript.c.id, transcript.c.duration]).where(
transcript.c.duration == 0
)
)
data_dir = Path(settings.DATA_DIR)

View File

@@ -28,7 +28,7 @@ def upgrade() -> None:
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select(transcript.c.id, transcript.c.topics))
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]
@@ -58,7 +58,7 @@ def downgrade() -> None:
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
# Select all rows from the transcript table
results = bind.execute(select(transcript.c.id, transcript.c.topics))
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
for row in results:
transcript_id = row["id"]

View File

@@ -27,8 +27,7 @@ def upgrade() -> None:
# Populate room_id for existing ROOM-type transcripts
# This joins through recording -> meeting -> room to get the room_id
op.execute(
"""
op.execute("""
UPDATE transcript AS t
SET room_id = r.id
FROM recording rec
@@ -37,13 +36,11 @@ def upgrade() -> None:
WHERE t.recording_id = rec.id
AND t.source_kind = 'room'
AND t.room_id IS NULL
"""
)
""")
# Fix missing meeting_id for ROOM-type transcripts
# The meeting_id field exists but was never populated
op.execute(
"""
op.execute("""
UPDATE transcript AS t
SET meeting_id = rec.meeting_id
FROM recording rec
@@ -51,8 +48,7 @@ def upgrade() -> None:
AND t.source_kind = 'room'
AND t.meeting_id IS NULL
AND rec.meeting_id IS NOT NULL
"""
)
""")
def downgrade() -> None:

View File

@@ -19,13 +19,14 @@ dependencies = [
"sentry-sdk[fastapi]>=1.29.2",
"httpx>=0.24.1",
"fastapi-pagination>=0.12.6",
"sqlalchemy>=2.0.0",
"asyncpg>=0.29.0",
"databases[aiosqlite, asyncpg]>=0.7.0",
"sqlalchemy<1.5",
"alembic>=1.11.3",
"nltk>=3.8.1",
"prometheus-fastapi-instrumentator>=6.1.0",
"sentencepiece>=0.1.99",
"protobuf>=4.24.3",
"celery>=5.3.4",
"redis>=5.0.1",
"python-jose[cryptography]>=3.3.0",
"python-multipart>=0.0.6",
@@ -38,8 +39,6 @@ dependencies = [
"pytest-env>=1.1.5",
"webvtt-py>=0.5.0",
"icalendar>=6.0.0",
"taskiq>=0.11.18",
"taskiq-redis>=1.1.0",
]
[dependency-groups]
@@ -47,7 +46,6 @@ dev = [
"black>=24.1.1",
"stamina>=23.1.0",
"pyinstrument>=4.6.1",
"pytest-async-sqlalchemy>=0.2.0",
]
tests = [
"pytest-cov>=4.1.0",
@@ -56,6 +54,7 @@ tests = [
"pytest>=7.4.0",
"httpx-ws>=0.4.1",
"pytest-httpx>=0.23.1",
"pytest-celery>=0.0.0",
"pytest-recording>=0.13.4",
"pytest-docker>=3.2.3",
"asgi-lifespan>=2.1.0",
@@ -112,15 +111,13 @@ source = ["reflector"]
[tool.pytest_env]
ENVIRONMENT = "pytest"
DATABASE_URL = "postgresql+asyncpg://test_user:test_password@localhost:15432/reflector_test"
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
AUTH_BACKEND = "jwt"
[tool.pytest.ini_options]
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
testpaths = ["tests"]
asyncio_mode = "auto"
asyncio_debug = true
asyncio_default_fixture_loop_scope = "session"
asyncio_default_test_loop_scope = "session"
markers = [
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
]

View File

@@ -26,6 +26,7 @@ from reflector.views.transcripts_upload import router as transcripts_upload_rout
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
from reflector.views.user import router as user_router
from reflector.views.user_websocket import router as user_ws_router
from reflector.views.whereby import router as whereby_router
from reflector.views.zulip import router as zulip_router
@@ -65,6 +66,12 @@ app.add_middleware(
allow_headers=["*"],
)
@app.get("/health")
async def health():
return {"status": "healthy"}
# metrics
instrumentator = Instrumentator(
excluded_handlers=["/docs", "/metrics"],
@@ -84,12 +91,13 @@ app.include_router(transcripts_websocket_router, prefix="/v1")
app.include_router(transcripts_webrtc_router, prefix="/v1")
app.include_router(transcripts_process_router, prefix="/v1")
app.include_router(user_router, prefix="/v1")
app.include_router(user_ws_router, prefix="/v1")
app.include_router(zulip_router, prefix="/v1")
app.include_router(whereby_router, prefix="/v1")
add_pagination(app)
# prepare taskiq
from reflector.worker import app as taskiq_app # noqa
# prepare celery
from reflector.worker import app as celery_app # noqa
# simpler openapi id

View File

@@ -0,0 +1,27 @@
import asyncio
import functools
from reflector.db import get_database
def asynctask(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
async def run_with_db():
database = get_database()
await database.connect()
try:
return await f(*args, **kwargs)
finally:
await database.disconnect()
coro = run_with_db()
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
return loop.run_until_complete(coro)
return asyncio.run(coro)
return wrapper

View File

@@ -1,82 +1,48 @@
from contextlib import asynccontextmanager
from typing import AsyncGenerator
import contextvars
from typing import Optional
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
import databases
import sqlalchemy
from reflector.db.base import Base as Base
from reflector.db.base import metadata as metadata
from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings
_engine: AsyncEngine | None = None
_session_factory: async_sessionmaker[AsyncSession] | None = None
metadata = sqlalchemy.MetaData()
def get_engine() -> AsyncEngine:
global _engine
if _engine is None:
_engine = create_async_engine(
settings.DATABASE_URL,
echo=False,
pool_pre_ping=True,
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
contextvars.ContextVar("database", default=None)
)
return _engine
def get_session_factory() -> async_sessionmaker[AsyncSession]:
global _session_factory
if _session_factory is None:
_session_factory = async_sessionmaker(
get_engine(),
class_=AsyncSession,
expire_on_commit=False,
)
return _session_factory
async def _get_session() -> AsyncGenerator[AsyncSession, None]:
# necessary implementation to ease mocking on pytest
async with get_session_factory()() as session:
yield session
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""
Get a database session, fastapi dependency injection style
"""
async for session in _get_session():
yield session
@asynccontextmanager
async def get_session_context():
"""
Get a database session as an async context manager
"""
async for session in _get_session():
yield session
def get_database() -> databases.Database:
"""Get database instance for current asyncio context"""
db = _database_context.get()
if db is None:
db = databases.Database(settings.DATABASE_URL)
_database_context.set(db)
return db
# import models
import reflector.db.calendar_events # noqa
import reflector.db.meetings # noqa
import reflector.db.recordings # noqa
import reflector.db.rooms # noqa
import reflector.db.transcripts # noqa
kwargs = {}
if "postgres" not in settings.DATABASE_URL:
raise Exception("Only postgres database is supported in reflector")
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
@subscribers_startup.append
async def database_connect(_):
get_engine()
database = get_database()
await database.connect()
@subscribers_shutdown.append
async def database_disconnect(_):
global _engine
if _engine:
await _engine.dispose()
_engine = None
database = get_database()
await database.disconnect()

View File

@@ -1,237 +0,0 @@
from datetime import datetime
from typing import Optional
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(AsyncAttrs, DeclarativeBase):
pass
class TranscriptModel(Base):
__tablename__ = "transcript"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
name: Mapped[Optional[str]] = mapped_column(sa.String)
status: Mapped[Optional[str]] = mapped_column(sa.String)
locked: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
created_at: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
title: Mapped[Optional[str]] = mapped_column(sa.String)
short_summary: Mapped[Optional[str]] = mapped_column(sa.String)
long_summary: Mapped[Optional[str]] = mapped_column(sa.String)
topics: Mapped[Optional[list]] = mapped_column(sa.JSON)
events: Mapped[Optional[list]] = mapped_column(sa.JSON)
participants: Mapped[Optional[list]] = mapped_column(sa.JSON)
source_language: Mapped[Optional[str]] = mapped_column(sa.String)
target_language: Mapped[Optional[str]] = mapped_column(sa.String)
reviewed: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
audio_location: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="local"
)
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
share_mode: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="private"
)
meeting_id: Mapped[Optional[str]] = mapped_column(sa.String)
recording_id: Mapped[Optional[str]] = mapped_column(sa.String)
zulip_message_id: Mapped[Optional[int]] = mapped_column(sa.Integer)
source_kind: Mapped[str] = mapped_column(
sa.String, nullable=False
) # Enum will be handled separately
audio_deleted: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
room_id: Mapped[Optional[str]] = mapped_column(sa.String)
webvtt: Mapped[Optional[str]] = mapped_column(sa.Text)
__table_args__ = (
sa.Index("idx_transcript_recording_id", "recording_id"),
sa.Index("idx_transcript_user_id", "user_id"),
sa.Index("idx_transcript_created_at", "created_at"),
sa.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
sa.Index("idx_transcript_room_id", "room_id"),
sa.Index("idx_transcript_source_kind", "source_kind"),
sa.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
)
TranscriptModel.search_vector_en = sa.Column(
"search_vector_en",
TSVECTOR,
sa.Computed(
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
persisted=True,
),
)
class RoomModel(Base):
__tablename__ = "room"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
name: Mapped[str] = mapped_column(sa.String, nullable=False, unique=True)
user_id: Mapped[str] = mapped_column(sa.String, nullable=False)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
zulip_auto_post: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
zulip_stream: Mapped[Optional[str]] = mapped_column(sa.String)
zulip_topic: Mapped[Optional[str]] = mapped_column(sa.String)
is_locked: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
room_mode: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="normal"
)
recording_type: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="cloud"
)
recording_trigger: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="automatic-2nd-participant"
)
is_shared: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
webhook_url: Mapped[Optional[str]] = mapped_column(sa.String)
webhook_secret: Mapped[Optional[str]] = mapped_column(sa.String)
ics_url: Mapped[Optional[str]] = mapped_column(sa.Text)
ics_fetch_interval: Mapped[Optional[int]] = mapped_column(
sa.Integer, server_default=sa.text("300")
)
ics_enabled: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
ics_last_sync: Mapped[Optional[datetime]] = mapped_column(
sa.DateTime(timezone=True)
)
ics_last_etag: Mapped[Optional[str]] = mapped_column(sa.Text)
__table_args__ = (
sa.Index("idx_room_is_shared", "is_shared"),
sa.Index("idx_room_ics_enabled", "ics_enabled"),
)
class MeetingModel(Base):
__tablename__ = "meeting"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
room_name: Mapped[Optional[str]] = mapped_column(sa.String)
room_url: Mapped[Optional[str]] = mapped_column(sa.String)
host_room_url: Mapped[Optional[str]] = mapped_column(sa.String)
start_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
end_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
room_id: Mapped[Optional[str]] = mapped_column(
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE")
)
is_locked: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
room_mode: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="normal"
)
recording_type: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="cloud"
)
recording_trigger: Mapped[str] = mapped_column(
sa.String, nullable=False, server_default="automatic-2nd-participant"
)
num_clients: Mapped[int] = mapped_column(
sa.Integer, nullable=False, server_default=sa.text("0")
)
is_active: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("true")
)
calendar_event_id: Mapped[Optional[str]] = mapped_column(
sa.String,
sa.ForeignKey(
"calendar_event.id",
ondelete="SET NULL",
name="fk_meeting_calendar_event_id",
),
)
calendar_metadata: Mapped[Optional[dict]] = mapped_column(JSONB)
__table_args__ = (
sa.Index("idx_meeting_room_id", "room_id"),
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
)
class MeetingConsentModel(Base):
__tablename__ = "meeting_consent"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
meeting_id: Mapped[str] = mapped_column(
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
)
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
consent_given: Mapped[bool] = mapped_column(sa.Boolean, nullable=False)
consent_timestamp: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
class RecordingModel(Base):
__tablename__ = "recording"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
meeting_id: Mapped[str] = mapped_column(
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
)
url: Mapped[str] = mapped_column(sa.String, nullable=False)
object_key: Mapped[str] = mapped_column(sa.String, nullable=False)
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
__table_args__ = (sa.Index("idx_recording_meeting_id", "meeting_id"),)
class CalendarEventModel(Base):
__tablename__ = "calendar_event"
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
room_id: Mapped[str] = mapped_column(
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE"), nullable=False
)
ics_uid: Mapped[str] = mapped_column(sa.Text, nullable=False)
title: Mapped[Optional[str]] = mapped_column(sa.Text)
description: Mapped[Optional[str]] = mapped_column(sa.Text)
start_time: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
end_time: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
attendees: Mapped[Optional[dict]] = mapped_column(JSONB)
location: Mapped[Optional[str]] = mapped_column(sa.Text)
ics_raw_data: Mapped[Optional[str]] = mapped_column(sa.Text)
last_synced: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
is_deleted: Mapped[bool] = mapped_column(
sa.Boolean, nullable=False, server_default=sa.text("false")
)
created_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
sa.DateTime(timezone=True), nullable=False
)
__table_args__ = (
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
)
metadata = Base.metadata

View File

@@ -2,17 +2,45 @@ from datetime import datetime, timedelta, timezone
from typing import Any
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel, Field
from sqlalchemy.dialects.postgresql import JSONB
from reflector.db.base import CalendarEventModel
from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4
calendar_events = sa.Table(
"calendar_event",
metadata,
sa.Column("id", sa.String, primary_key=True),
sa.Column(
"room_id",
sa.String,
sa.ForeignKey("room.id", ondelete="CASCADE", name="fk_calendar_event_room_id"),
nullable=False,
),
sa.Column("ics_uid", sa.Text, nullable=False),
sa.Column("title", sa.Text),
sa.Column("description", sa.Text),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("attendees", JSONB),
sa.Column("location", sa.Text),
sa.Column("ics_raw_data", sa.Text),
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False),
sa.Column("is_deleted", sa.Boolean, nullable=False, server_default=sa.false()),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"),
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
sa.Index(
"idx_calendar_event_deleted",
"is_deleted",
postgresql_where=sa.text("NOT is_deleted"),
),
)
class CalendarEvent(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4)
room_id: str
ics_uid: str
@@ -30,159 +58,129 @@ class CalendarEvent(BaseModel):
class CalendarEventController:
async def get_upcoming_events(
self,
session: AsyncSession,
room_id: str,
current_time: datetime,
buffer_minutes: int = 15,
) -> list[CalendarEvent]:
buffer_time = current_time + timedelta(minutes=buffer_minutes)
query = (
select(CalendarEventModel)
.where(
sa.and_(
CalendarEventModel.room_id == room_id,
CalendarEventModel.start_time <= buffer_time,
CalendarEventModel.end_time > current_time,
)
)
.order_by(CalendarEventModel.start_time)
)
result = await session.execute(query)
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
async def get_by_id(
self, session: AsyncSession, event_id: str
) -> CalendarEvent | None:
query = select(CalendarEventModel).where(CalendarEventModel.id == event_id)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
return None
return CalendarEvent.model_validate(row)
async def get_by_ics_uid(
self, session: AsyncSession, room_id: str, ics_uid: str
) -> CalendarEvent | None:
query = select(CalendarEventModel).where(
sa.and_(
CalendarEventModel.room_id == room_id,
CalendarEventModel.ics_uid == ics_uid,
)
)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
return None
return CalendarEvent.model_validate(row)
async def upsert(
self, session: AsyncSession, event: CalendarEvent
) -> CalendarEvent:
existing = await self.get_by_ics_uid(session, event.room_id, event.ics_uid)
if existing:
event.updated_at = datetime.now(timezone.utc)
query = (
update(CalendarEventModel)
.where(CalendarEventModel.id == existing.id)
.values(**event.model_dump(exclude={"id"}))
)
await session.execute(query)
await session.commit()
return event
else:
new_event = CalendarEventModel(**event.model_dump())
session.add(new_event)
await session.commit()
return event
async def delete_old_events(
self, session: AsyncSession, room_id: str, cutoff_date: datetime
) -> int:
query = delete(CalendarEventModel).where(
sa.and_(
CalendarEventModel.room_id == room_id,
CalendarEventModel.end_time < cutoff_date,
)
)
result = await session.execute(query)
await session.commit()
return result.rowcount
async def delete_events_not_in_list(
self, session: AsyncSession, room_id: str, keep_ics_uids: list[str]
) -> int:
if not keep_ics_uids:
query = delete(CalendarEventModel).where(
CalendarEventModel.room_id == room_id
)
else:
query = delete(CalendarEventModel).where(
sa.and_(
CalendarEventModel.room_id == room_id,
CalendarEventModel.ics_uid.notin_(keep_ics_uids),
)
)
result = await session.execute(query)
await session.commit()
return result.rowcount
async def get_by_room(
self, session: AsyncSession, room_id: str, include_deleted: bool = True
self,
room_id: str,
include_deleted: bool = False,
start_after: datetime | None = None,
end_before: datetime | None = None,
) -> list[CalendarEvent]:
query = select(CalendarEventModel).where(CalendarEventModel.room_id == room_id)
query = calendar_events.select().where(calendar_events.c.room_id == room_id)
if not include_deleted:
query = query.where(CalendarEventModel.is_deleted == False)
result = await session.execute(query)
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
query = query.where(calendar_events.c.is_deleted == False)
if start_after:
query = query.where(calendar_events.c.start_time >= start_after)
if end_before:
query = query.where(calendar_events.c.end_time <= end_before)
query = query.order_by(calendar_events.c.start_time.asc())
results = await get_database().fetch_all(query)
return [CalendarEvent(**result) for result in results]
async def get_upcoming(
self, session: AsyncSession, room_id: str, minutes_ahead: int = 120
self, room_id: str, minutes_ahead: int = 120
) -> list[CalendarEvent]:
"""Get upcoming events for a room within the specified minutes, including currently happening events."""
now = datetime.now(timezone.utc)
buffer_time = now + timedelta(minutes=minutes_ahead)
future_time = now + timedelta(minutes=minutes_ahead)
query = (
select(CalendarEventModel)
calendar_events.select()
.where(
sa.and_(
CalendarEventModel.room_id == room_id,
CalendarEventModel.start_time <= buffer_time,
CalendarEventModel.end_time > now,
CalendarEventModel.is_deleted == False,
calendar_events.c.room_id == room_id,
calendar_events.c.is_deleted == False,
calendar_events.c.start_time <= future_time,
calendar_events.c.end_time >= now,
)
)
.order_by(CalendarEventModel.start_time)
.order_by(calendar_events.c.start_time.asc())
)
result = await session.execute(query)
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
results = await get_database().fetch_all(query)
return [CalendarEvent(**result) for result in results]
async def get_by_id(self, event_id: str) -> CalendarEvent | None:
query = calendar_events.select().where(calendar_events.c.id == event_id)
result = await get_database().fetch_one(query)
return CalendarEvent(**result) if result else None
async def get_by_ics_uid(self, room_id: str, ics_uid: str) -> CalendarEvent | None:
query = calendar_events.select().where(
sa.and_(
calendar_events.c.room_id == room_id,
calendar_events.c.ics_uid == ics_uid,
)
)
result = await get_database().fetch_one(query)
return CalendarEvent(**result) if result else None
async def upsert(self, event: CalendarEvent) -> CalendarEvent:
existing = await self.get_by_ics_uid(event.room_id, event.ics_uid)
if existing:
event.id = existing.id
event.created_at = existing.created_at
event.updated_at = datetime.now(timezone.utc)
query = (
calendar_events.update()
.where(calendar_events.c.id == existing.id)
.values(**event.model_dump())
)
else:
query = calendar_events.insert().values(**event.model_dump())
await get_database().execute(query)
return event
async def soft_delete_missing(
self, session: AsyncSession, room_id: str, current_ics_uids: list[str]
self, room_id: str, current_ics_uids: list[str]
) -> int:
query = (
update(CalendarEventModel)
"""Soft delete future events that are no longer in the calendar."""
now = datetime.now(timezone.utc)
select_query = calendar_events.select().where(
sa.and_(
calendar_events.c.room_id == room_id,
calendar_events.c.start_time > now,
calendar_events.c.is_deleted == False,
calendar_events.c.ics_uid.notin_(current_ics_uids)
if current_ics_uids
else True,
)
)
to_delete = await get_database().fetch_all(select_query)
delete_count = len(to_delete)
if delete_count > 0:
update_query = (
calendar_events.update()
.where(
sa.and_(
CalendarEventModel.room_id == room_id,
(
CalendarEventModel.ics_uid.notin_(current_ics_uids)
calendar_events.c.room_id == room_id,
calendar_events.c.start_time > now,
calendar_events.c.is_deleted == False,
calendar_events.c.ics_uid.notin_(current_ics_uids)
if current_ics_uids
else True
),
CalendarEventModel.end_time > datetime.now(timezone.utc),
else True,
)
)
.values(is_deleted=True)
.values(is_deleted=True, updated_at=now)
)
result = await session.execute(query)
await session.commit()
await get_database().execute(update_query)
return delete_count
async def delete_by_room(self, room_id: str) -> int:
query = calendar_events.delete().where(calendar_events.c.room_id == room_id)
result = await get_database().execute(query)
return result.rowcount

View File

@@ -2,18 +2,80 @@ from datetime import datetime
from typing import Any, Literal
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel, Field
from sqlalchemy.dialects.postgresql import JSONB
from reflector.db.base import MeetingConsentModel, MeetingModel
from reflector.db import get_database, metadata
from reflector.db.rooms import Room
from reflector.utils import generate_uuid4
meetings = sa.Table(
"meeting",
metadata,
sa.Column("id", sa.String, primary_key=True),
sa.Column("room_name", sa.String),
sa.Column("room_url", sa.String),
sa.Column("host_room_url", sa.String),
sa.Column("start_date", sa.DateTime(timezone=True)),
sa.Column("end_date", sa.DateTime(timezone=True)),
sa.Column(
"room_id",
sa.String,
sa.ForeignKey("room.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
sa.Column("room_mode", sa.String, nullable=False, server_default="normal"),
sa.Column("recording_type", sa.String, nullable=False, server_default="cloud"),
sa.Column(
"recording_trigger",
sa.String,
nullable=False,
server_default="automatic-2nd-participant",
),
sa.Column(
"num_clients",
sa.Integer,
nullable=False,
server_default=sa.text("0"),
),
sa.Column(
"is_active",
sa.Boolean,
nullable=False,
server_default=sa.true(),
),
sa.Column(
"calendar_event_id",
sa.String,
sa.ForeignKey(
"calendar_event.id",
ondelete="SET NULL",
name="fk_meeting_calendar_event_id",
),
),
sa.Column("calendar_metadata", JSONB),
sa.Index("idx_meeting_room_id", "room_id"),
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
)
meeting_consent = sa.Table(
"meeting_consent",
metadata,
sa.Column("id", sa.String, primary_key=True),
sa.Column(
"meeting_id",
sa.String,
sa.ForeignKey("meeting.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("user_id", sa.String),
sa.Column("consent_given", sa.Boolean, nullable=False),
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
)
class MeetingConsent(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4)
meeting_id: str
user_id: str | None = None
@@ -22,8 +84,6 @@ class MeetingConsent(BaseModel):
class Meeting(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
room_name: str
room_url: str
@@ -46,7 +106,6 @@ class Meeting(BaseModel):
class MeetingController:
async def create(
self,
session: AsyncSession,
id: str,
room_name: str,
room_url: str,
@@ -72,198 +131,170 @@ class MeetingController:
calendar_event_id=calendar_event_id,
calendar_metadata=calendar_metadata,
)
new_meeting = MeetingModel(**meeting.model_dump())
session.add(new_meeting)
await session.commit()
query = meetings.insert().values(**meeting.model_dump())
await get_database().execute(query)
return meeting
async def get_all_active(self, session: AsyncSession) -> list[Meeting]:
query = select(MeetingModel).where(MeetingModel.is_active)
result = await session.execute(query)
return [Meeting.model_validate(row) for row in result.scalars().all()]
async def get_all_active(self) -> list[Meeting]:
query = meetings.select().where(meetings.c.is_active)
return await get_database().fetch_all(query)
async def get_by_room_name(
self,
session: AsyncSession,
room_name: str,
) -> Meeting | None:
"""
Get a meeting by room name.
For backward compatibility, returns the most recent meeting.
"""
end_date = getattr(meetings.c, "end_date")
query = (
select(MeetingModel)
.where(MeetingModel.room_name == room_name)
.order_by(MeetingModel.end_date.desc())
meetings.select()
.where(meetings.c.room_name == room_name)
.order_by(end_date.desc())
)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
result = await get_database().fetch_one(query)
if not result:
return None
return Meeting.model_validate(row)
async def get_active(
self, session: AsyncSession, room: Room, current_time: datetime
) -> Meeting | None:
return Meeting(**result)
async def get_active(self, room: Room, current_time: datetime) -> Meeting | None:
"""
Get latest active meeting for a room.
For backward compatibility, returns the most recent active meeting.
"""
end_date = getattr(meetings.c, "end_date")
query = (
select(MeetingModel)
meetings.select()
.where(
sa.and_(
MeetingModel.room_id == room.id,
MeetingModel.end_date > current_time,
MeetingModel.is_active,
meetings.c.room_id == room.id,
meetings.c.end_date > current_time,
meetings.c.is_active,
)
)
.order_by(MeetingModel.end_date.desc())
.order_by(end_date.desc())
)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
result = await get_database().fetch_one(query)
if not result:
return None
return Meeting.model_validate(row)
return Meeting(**result)
async def get_all_active_for_room(
self, session: AsyncSession, room: Room, current_time: datetime
self, room: Room, current_time: datetime
) -> list[Meeting]:
end_date = getattr(meetings.c, "end_date")
query = (
select(MeetingModel)
meetings.select()
.where(
sa.and_(
MeetingModel.room_id == room.id,
MeetingModel.end_date > current_time,
MeetingModel.is_active,
meetings.c.room_id == room.id,
meetings.c.end_date > current_time,
meetings.c.is_active,
)
)
.order_by(MeetingModel.end_date.desc())
.order_by(end_date.desc())
)
result = await session.execute(query)
return [Meeting.model_validate(row) for row in result.scalars().all()]
results = await get_database().fetch_all(query)
return [Meeting(**result) for result in results]
async def get_active_by_calendar_event(
self,
session: AsyncSession,
room: Room,
calendar_event_id: str,
current_time: datetime,
self, room: Room, calendar_event_id: str, current_time: datetime
) -> Meeting | None:
"""
Get active meeting for a specific calendar event.
"""
query = select(MeetingModel).where(
query = meetings.select().where(
sa.and_(
MeetingModel.room_id == room.id,
MeetingModel.calendar_event_id == calendar_event_id,
MeetingModel.end_date > current_time,
MeetingModel.is_active,
meetings.c.room_id == room.id,
meetings.c.calendar_event_id == calendar_event_id,
meetings.c.end_date > current_time,
meetings.c.is_active,
)
)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
result = await get_database().fetch_one(query)
if not result:
return None
return Meeting.model_validate(row)
return Meeting(**result)
async def get_by_id(
self, session: AsyncSession, meeting_id: str, **kwargs
) -> Meeting | None:
query = select(MeetingModel).where(MeetingModel.id == meeting_id)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
async def get_by_id(self, meeting_id: str, **kwargs) -> Meeting | None:
query = meetings.select().where(meetings.c.id == meeting_id)
result = await get_database().fetch_one(query)
if not result:
return None
return Meeting.model_validate(row)
return Meeting(**result)
async def get_by_calendar_event(
self, session: AsyncSession, calendar_event_id: str
) -> Meeting | None:
query = select(MeetingModel).where(
MeetingModel.calendar_event_id == calendar_event_id
async def get_by_calendar_event(self, calendar_event_id: str) -> Meeting | None:
query = meetings.select().where(
meetings.c.calendar_event_id == calendar_event_id
)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
result = await get_database().fetch_one(query)
if not result:
return None
return Meeting.model_validate(row)
return Meeting(**result)
async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs):
query = (
update(MeetingModel).where(MeetingModel.id == meeting_id).values(**kwargs)
)
await session.execute(query)
await session.commit()
async def update_meeting(self, meeting_id: str, **kwargs):
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
await get_database().execute(query)
class MeetingConsentController:
async def get_by_meeting_id(
self, session: AsyncSession, meeting_id: str
) -> list[MeetingConsent]:
query = select(MeetingConsentModel).where(
MeetingConsentModel.meeting_id == meeting_id
async def get_by_meeting_id(self, meeting_id: str) -> list[MeetingConsent]:
query = meeting_consent.select().where(
meeting_consent.c.meeting_id == meeting_id
)
result = await session.execute(query)
return [MeetingConsent.model_validate(row) for row in result.scalars().all()]
results = await get_database().fetch_all(query)
return [MeetingConsent(**result) for result in results]
async def get_by_meeting_and_user(
self, session: AsyncSession, meeting_id: str, user_id: str
self, meeting_id: str, user_id: str
) -> MeetingConsent | None:
"""Get existing consent for a specific user and meeting"""
query = select(MeetingConsentModel).where(
sa.and_(
MeetingConsentModel.meeting_id == meeting_id,
MeetingConsentModel.user_id == user_id,
query = meeting_consent.select().where(
meeting_consent.c.meeting_id == meeting_id,
meeting_consent.c.user_id == user_id,
)
)
result = await session.execute(query)
row = result.scalar_one_or_none()
if row is None:
result = await get_database().fetch_one(query)
if result is None:
return None
return MeetingConsent.model_validate(row)
return MeetingConsent(**result)
async def upsert(
self, session: AsyncSession, consent: MeetingConsent
) -> MeetingConsent:
async def upsert(self, consent: MeetingConsent) -> MeetingConsent:
if consent.user_id:
# For authenticated users, check if consent already exists
# not transactional but we're ok with that; the consents ain't deleted anyways
existing = await self.get_by_meeting_and_user(
session, consent.meeting_id, consent.user_id
consent.meeting_id, consent.user_id
)
if existing:
query = (
update(MeetingConsentModel)
.where(MeetingConsentModel.id == existing.id)
meeting_consent.update()
.where(meeting_consent.c.id == existing.id)
.values(
consent_given=consent.consent_given,
consent_timestamp=consent.consent_timestamp,
)
)
await session.execute(query)
await session.commit()
await get_database().execute(query)
existing.consent_given = consent.consent_given
existing.consent_timestamp = consent.consent_timestamp
return existing
new_consent = MeetingConsentModel(**consent.model_dump())
session.add(new_consent)
await session.commit()
query = meeting_consent.insert().values(**consent.model_dump())
await get_database().execute(query)
return consent
async def has_any_denial(self, session: AsyncSession, meeting_id: str) -> bool:
async def has_any_denial(self, meeting_id: str) -> bool:
"""Check if any participant denied consent for this meeting"""
query = select(MeetingConsentModel).where(
sa.and_(
MeetingConsentModel.meeting_id == meeting_id,
MeetingConsentModel.consent_given.is_(False),
query = meeting_consent.select().where(
meeting_consent.c.meeting_id == meeting_id,
meeting_consent.c.consent_given.is_(False),
)
)
result = await session.execute(query)
row = result.scalar_one_or_none()
return row is not None
result = await get_database().fetch_one(query)
return result is not None
meetings_controller = MeetingController()

View File

@@ -1,79 +1,61 @@
from datetime import datetime, timezone
from datetime import datetime
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
import sqlalchemy as sa
from pydantic import BaseModel, Field
from reflector.db.base import RecordingModel
from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4
recordings = sa.Table(
"recording",
metadata,
sa.Column("id", sa.String, primary_key=True),
sa.Column("bucket_name", sa.String, nullable=False),
sa.Column("object_key", sa.String, nullable=False),
sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
sa.Column(
"status",
sa.String,
nullable=False,
server_default="pending",
),
sa.Column("meeting_id", sa.String),
sa.Index("idx_recording_meeting_id", "meeting_id"),
)
class Recording(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4)
meeting_id: str
url: str
bucket_name: str
object_key: str
duration: float | None = None
created_at: datetime
recorded_at: datetime
status: Literal["pending", "processing", "completed", "failed"] = "pending"
meeting_id: str | None = None
class RecordingController:
async def create(
self,
session: AsyncSession,
meeting_id: str,
url: str,
object_key: str,
duration: float | None = None,
created_at: datetime | None = None,
):
if created_at is None:
created_at = datetime.now(timezone.utc)
recording = Recording(
meeting_id=meeting_id,
url=url,
object_key=object_key,
duration=duration,
created_at=created_at,
)
new_recording = RecordingModel(**recording.model_dump())
session.add(new_recording)
await session.commit()
async def create(self, recording: Recording):
query = recordings.insert().values(**recording.model_dump())
await get_database().execute(query)
return recording
async def get_by_id(
self, session: AsyncSession, recording_id: str
) -> Recording | None:
"""
Get a recording by id
"""
query = select(RecordingModel).where(RecordingModel.id == recording_id)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
return None
return Recording.model_validate(row)
async def get_by_id(self, id: str) -> Recording:
query = recordings.select().where(recordings.c.id == id)
result = await get_database().fetch_one(query)
return Recording(**result) if result else None
async def get_by_meeting_id(
self, session: AsyncSession, meeting_id: str
) -> list[Recording]:
"""
Get all recordings for a meeting
"""
query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id)
result = await session.execute(query)
return [Recording.model_validate(row) for row in result.scalars().all()]
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
query = recordings.select().where(
recordings.c.bucket_name == bucket_name,
recordings.c.object_key == object_key,
)
result = await get_database().fetch_one(query)
return Recording(**result) if result else None
async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None:
"""
Remove a recording by id
"""
query = delete(RecordingModel).where(RecordingModel.id == recording_id)
await session.execute(query)
await session.commit()
async def remove_by_id(self, id: str) -> None:
query = recordings.delete().where(recordings.c.id == id)
await get_database().execute(query)
recordings_controller = RecordingController()

View File

@@ -3,19 +3,59 @@ from datetime import datetime, timezone
from sqlite3 import IntegrityError
from typing import Literal
import sqlalchemy
from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import or_
from pydantic import BaseModel, Field
from sqlalchemy.sql import false, or_
from reflector.db.base import RoomModel
from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4
rooms = sqlalchemy.Table(
"room",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
sqlalchemy.Column(
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column("zulip_stream", sqlalchemy.String),
sqlalchemy.Column("zulip_topic", sqlalchemy.String),
sqlalchemy.Column(
"is_locked", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column(
"room_mode", sqlalchemy.String, nullable=False, server_default="normal"
),
sqlalchemy.Column(
"recording_type", sqlalchemy.String, nullable=False, server_default="cloud"
),
sqlalchemy.Column(
"recording_trigger",
sqlalchemy.String,
nullable=False,
server_default="automatic-2nd-participant",
),
sqlalchemy.Column(
"is_shared", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column("webhook_url", sqlalchemy.String, nullable=True),
sqlalchemy.Column("webhook_secret", sqlalchemy.String, nullable=True),
sqlalchemy.Column("ics_url", sqlalchemy.Text),
sqlalchemy.Column("ics_fetch_interval", sqlalchemy.Integer, server_default="300"),
sqlalchemy.Column(
"ics_enabled", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column("ics_last_sync", sqlalchemy.DateTime(timezone=True)),
sqlalchemy.Column("ics_last_etag", sqlalchemy.Text),
sqlalchemy.Index("idx_room_is_shared", "is_shared"),
sqlalchemy.Index("idx_room_ics_enabled", "ics_enabled"),
)
class Room(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4)
name: str
user_id: str
@@ -42,7 +82,6 @@ class Room(BaseModel):
class RoomController:
async def get_all(
self,
session: AsyncSession,
user_id: str | None = None,
order_by: str | None = None,
return_query: bool = False,
@@ -56,14 +95,14 @@ class RoomController:
Parameters:
- `order_by`: field to order by, e.g. "-created_at"
"""
query = select(RoomModel)
query = rooms.select()
if user_id is not None:
query = query.where(or_(RoomModel.user_id == user_id, RoomModel.is_shared))
query = query.where(or_(rooms.c.user_id == user_id, rooms.c.is_shared))
else:
query = query.where(RoomModel.is_shared)
query = query.where(rooms.c.is_shared)
if order_by is not None:
field = getattr(RoomModel, order_by[1:])
field = getattr(rooms.c, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
@@ -71,12 +110,11 @@ class RoomController:
if return_query:
return query
result = await session.execute(query)
return [Room.model_validate(row) for row in result.scalars().all()]
results = await get_database().fetch_all(query)
return results
async def add(
self,
session: AsyncSession,
name: str,
user_id: str,
zulip_auto_post: bool,
@@ -116,27 +154,23 @@ class RoomController:
ics_fetch_interval=ics_fetch_interval,
ics_enabled=ics_enabled,
)
new_room = RoomModel(**room.model_dump())
session.add(new_room)
query = rooms.insert().values(**room.model_dump())
try:
await session.flush()
await get_database().execute(query)
except IntegrityError:
raise HTTPException(status_code=400, detail="Room name is not unique")
return room
async def update(
self, session: AsyncSession, room: Room, values: dict, mutate=True
):
async def update(self, room: Room, values: dict, mutate=True):
"""
Update a room fields with key/values in values
"""
if values.get("webhook_url") and not values.get("webhook_secret"):
values["webhook_secret"] = secrets.token_urlsafe(32)
query = update(RoomModel).where(RoomModel.id == room.id).values(**values)
query = rooms.update().where(rooms.c.id == room.id).values(**values)
try:
await session.execute(query)
await session.flush()
await get_database().execute(query)
except IntegrityError:
raise HTTPException(status_code=400, detail="Room name is not unique")
@@ -144,79 +178,67 @@ class RoomController:
for key, value in values.items():
setattr(room, key, value)
async def get_by_id(
self, session: AsyncSession, room_id: str, **kwargs
) -> Room | None:
async def get_by_id(self, room_id: str, **kwargs) -> Room | None:
"""
Get a room by id
"""
query = select(RoomModel).where(RoomModel.id == room_id)
query = rooms.select().where(rooms.c.id == room_id)
if "user_id" in kwargs:
query = query.where(RoomModel.user_id == kwargs["user_id"])
result = await session.execute(query)
row = result.scalars().first()
if not row:
query = query.where(rooms.c.user_id == kwargs["user_id"])
result = await get_database().fetch_one(query)
if not result:
return None
return Room.model_validate(row)
return Room(**result)
async def get_by_name(
self, session: AsyncSession, room_name: str, **kwargs
) -> Room | None:
async def get_by_name(self, room_name: str, **kwargs) -> Room | None:
"""
Get a room by name
"""
query = select(RoomModel).where(RoomModel.name == room_name)
query = rooms.select().where(rooms.c.name == room_name)
if "user_id" in kwargs:
query = query.where(RoomModel.user_id == kwargs["user_id"])
result = await session.execute(query)
row = result.scalars().first()
if not row:
query = query.where(rooms.c.user_id == kwargs["user_id"])
result = await get_database().fetch_one(query)
if not result:
return None
return Room.model_validate(row)
return Room(**result)
async def get_by_id_for_http(
self, session: AsyncSession, meeting_id: str, user_id: str | None
) -> Room:
async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Room:
"""
Get a room by ID for HTTP request.
If not found, it will raise a 404 error.
"""
query = select(RoomModel).where(RoomModel.id == meeting_id)
result = await session.execute(query)
row = result.scalars().first()
if not row:
query = rooms.select().where(rooms.c.id == meeting_id)
result = await get_database().fetch_one(query)
if not result:
raise HTTPException(status_code=404, detail="Room not found")
room = Room.model_validate(row)
room = Room(**result)
return room
async def get_ics_enabled(self, session: AsyncSession) -> list[Room]:
query = select(RoomModel).where(
RoomModel.ics_enabled == True, RoomModel.ics_url != None
async def get_ics_enabled(self) -> list[Room]:
query = rooms.select().where(
rooms.c.ics_enabled == True, rooms.c.ics_url != None
)
result = await session.execute(query)
results = result.scalars().all()
return [Room(**row.__dict__) for row in results]
results = await get_database().fetch_all(query)
return [Room(**result) for result in results]
async def remove_by_id(
self,
session: AsyncSession,
room_id: str,
user_id: str | None = None,
) -> None:
"""
Remove a room by id
"""
room = await self.get_by_id(session, room_id, user_id=user_id)
room = await self.get_by_id(room_id, user_id=user_id)
if not room:
return
if user_id is not None and room.user_id != user_id:
return
query = delete(RoomModel).where(RoomModel.id == room_id)
await session.execute(query)
await session.flush()
query = rooms.delete().where(rooms.c.id == room_id)
await get_database().execute(query)
rooms_controller = RoomController()

View File

@@ -8,6 +8,7 @@ from typing import Annotated, Any, Dict, Iterator
import sqlalchemy
import webvtt
from databases.interfaces import Record as DbRecord
from fastapi import HTTPException
from pydantic import (
BaseModel,
@@ -19,10 +20,11 @@ from pydantic import (
constr,
field_serializer,
)
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.base import RoomModel, TranscriptModel
from reflector.db.transcripts import SourceKind, TranscriptStatus
from reflector.db import get_database
from reflector.db.rooms import rooms
from reflector.db.transcripts import SourceKind, TranscriptStatus, transcripts
from reflector.db.utils import is_postgresql
from reflector.logger import logger
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
@@ -329,30 +331,36 @@ class SearchController:
@classmethod
async def search_transcripts(
cls, session: AsyncSession, params: SearchParameters
cls, params: SearchParameters
) -> tuple[list[SearchResult], int]:
"""
Full-text search for transcripts using PostgreSQL tsvector.
Returns (results, total_count).
"""
if not is_postgresql():
logger.warning(
"Full-text search requires PostgreSQL. Returning empty results."
)
return [], 0
base_columns = [
TranscriptModel.id,
TranscriptModel.title,
TranscriptModel.created_at,
TranscriptModel.duration,
TranscriptModel.status,
TranscriptModel.user_id,
TranscriptModel.room_id,
TranscriptModel.source_kind,
TranscriptModel.webvtt,
TranscriptModel.long_summary,
transcripts.c.id,
transcripts.c.title,
transcripts.c.created_at,
transcripts.c.duration,
transcripts.c.status,
transcripts.c.user_id,
transcripts.c.room_id,
transcripts.c.source_kind,
transcripts.c.webvtt,
transcripts.c.long_summary,
sqlalchemy.case(
(
TranscriptModel.room_id.isnot(None) & RoomModel.id.is_(None),
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
"Deleted Room",
),
else_=RoomModel.name,
else_=rooms.c.name,
).label("room_name"),
]
search_query = None
@@ -361,7 +369,7 @@ class SearchController:
"english", params.query_text
)
rank_column = sqlalchemy.func.ts_rank(
TranscriptModel.search_vector_en,
transcripts.c.search_vector_en,
search_query,
32, # normalization flag: rank/(rank+1) for 0-1 range
).label("rank")
@@ -369,51 +377,47 @@ class SearchController:
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
columns = base_columns + [rank_column]
base_query = (
sqlalchemy.select(*columns)
.select_from(TranscriptModel)
.outerjoin(RoomModel, TranscriptModel.room_id == RoomModel.id)
base_query = sqlalchemy.select(columns).select_from(
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
)
if params.query_text is not None:
# because already initialized based on params.query_text presence above
assert search_query is not None
base_query = base_query.where(
TranscriptModel.search_vector_en.op("@@")(search_query)
transcripts.c.search_vector_en.op("@@")(search_query)
)
if params.user_id:
base_query = base_query.where(
sqlalchemy.or_(
TranscriptModel.user_id == params.user_id, RoomModel.is_shared
transcripts.c.user_id == params.user_id, rooms.c.is_shared
)
)
else:
base_query = base_query.where(RoomModel.is_shared)
base_query = base_query.where(rooms.c.is_shared)
if params.room_id:
base_query = base_query.where(TranscriptModel.room_id == params.room_id)
base_query = base_query.where(transcripts.c.room_id == params.room_id)
if params.source_kind:
base_query = base_query.where(
TranscriptModel.source_kind == params.source_kind
transcripts.c.source_kind == params.source_kind
)
if params.query_text is not None:
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
else:
order_by = sqlalchemy.desc(TranscriptModel.created_at)
order_by = sqlalchemy.desc(transcripts.c.created_at)
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
result = await session.execute(query)
rs = result.mappings().all()
rs = await get_database().fetch_all(query)
count_query = sqlalchemy.select(sqlalchemy.func.count()).select_from(
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
base_query.alias("search_results")
)
count_result = await session.execute(count_query)
total = count_result.scalar()
total = await get_database().fetch_val(count_query)
def _process_result(r: dict) -> SearchResult:
def _process_result(r: DbRecord) -> SearchResult:
r_dict: Dict[str, Any] = dict(r)
webvtt_raw: str | None = r_dict.pop("webvtt", None)

View File

@@ -2,18 +2,22 @@ import enum
import json
import os
import shutil
from contextlib import asynccontextmanager
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Literal
import sqlalchemy
from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import or_
from sqlalchemy import Enum
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.sql import false, or_
from reflector.db.base import RoomModel, TranscriptModel
from reflector.db import get_database, metadata
from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms
from reflector.db.utils import is_postgresql
from reflector.logger import logger
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
@@ -28,6 +32,91 @@ class SourceKind(enum.StrEnum):
FILE = enum.auto()
transcripts = sqlalchemy.Table(
"transcript",
metadata,
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String),
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Float),
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
sqlalchemy.Column("title", sqlalchemy.String),
sqlalchemy.Column("short_summary", sqlalchemy.String),
sqlalchemy.Column("long_summary", sqlalchemy.String),
sqlalchemy.Column("topics", sqlalchemy.JSON),
sqlalchemy.Column("events", sqlalchemy.JSON),
sqlalchemy.Column("participants", sqlalchemy.JSON),
sqlalchemy.Column("source_language", sqlalchemy.String),
sqlalchemy.Column("target_language", sqlalchemy.String),
sqlalchemy.Column(
"reviewed", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column(
"audio_location",
sqlalchemy.String,
nullable=False,
server_default="local",
),
# with user attached, optional
sqlalchemy.Column("user_id", sqlalchemy.String),
sqlalchemy.Column(
"share_mode",
sqlalchemy.String,
nullable=False,
server_default="private",
),
sqlalchemy.Column(
"meeting_id",
sqlalchemy.String,
),
sqlalchemy.Column("recording_id", sqlalchemy.String),
sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer),
sqlalchemy.Column(
"source_kind",
Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]),
nullable=False,
),
# indicative field: whether associated audio is deleted
# the main "audio deleted" is the presence of the audio itself / consents not-given
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
sqlalchemy.Column("room_id", sqlalchemy.String),
sqlalchemy.Column("webvtt", sqlalchemy.Text),
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
)
# Add PostgreSQL-specific full-text search column
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
if is_postgresql():
transcripts.append_column(
sqlalchemy.Column(
"search_vector_en",
TSVECTOR,
sqlalchemy.Computed(
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
persisted=True,
),
)
)
# Add GIN index for the search vector
transcripts.append_constraint(
sqlalchemy.Index(
"idx_transcript_search_vector_en",
"search_vector_en",
postgresql_using="gin",
)
)
def generate_transcript_name() -> str:
now = datetime.now(timezone.utc)
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
@@ -102,8 +191,6 @@ class TranscriptParticipant(BaseModel):
class Transcript(BaseModel):
"""Full transcript model with all fields."""
model_config = ConfigDict(from_attributes=True)
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
@@ -272,7 +359,6 @@ class Transcript(BaseModel):
class TranscriptController:
async def get_all(
self,
session: AsyncSession,
user_id: str | None = None,
order_by: str | None = None,
filter_empty: bool | None = False,
@@ -297,114 +383,102 @@ class TranscriptController:
- `search_term`: filter transcripts by search term
"""
query = select(TranscriptModel).join(
RoomModel, TranscriptModel.room_id == RoomModel.id, isouter=True
query = transcripts.select().join(
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
)
if user_id:
query = query.where(
or_(TranscriptModel.user_id == user_id, RoomModel.is_shared)
or_(transcripts.c.user_id == user_id, rooms.c.is_shared)
)
else:
query = query.where(RoomModel.is_shared)
query = query.where(rooms.c.is_shared)
if source_kind:
query = query.where(TranscriptModel.source_kind == source_kind)
query = query.where(transcripts.c.source_kind == source_kind)
if room_id:
query = query.where(TranscriptModel.room_id == room_id)
query = query.where(transcripts.c.room_id == room_id)
if search_term:
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
query = query.where(transcripts.c.title.ilike(f"%{search_term}%"))
# Exclude heavy JSON columns from list queries
# Get all ORM column attributes except excluded ones
transcript_columns = [
getattr(TranscriptModel, col.name)
for col in TranscriptModel.__table__.c
if col.name not in exclude_columns
col for col in transcripts.c if col.name not in exclude_columns
]
query = query.with_only_columns(
*transcript_columns,
RoomModel.name.label("room_name"),
transcript_columns
+ [
rooms.c.name.label("room_name"),
]
)
if order_by is not None:
field = getattr(TranscriptModel, order_by[1:])
field = getattr(transcripts.c, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
if filter_empty:
query = query.filter(TranscriptModel.status != "idle")
query = query.filter(transcripts.c.status != "idle")
if filter_recording:
query = query.filter(TranscriptModel.status != "recording")
query = query.filter(transcripts.c.status != "recording")
# print(query.compile(compile_kwargs={"literal_binds": True}))
if return_query:
return query
result = await session.execute(query)
return [dict(row) for row in result.mappings().all()]
results = await get_database().fetch_all(query)
return results
async def get_by_id(
self, session: AsyncSession, transcript_id: str, **kwargs
) -> Transcript | None:
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
"""
Get a transcript by id
"""
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs:
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await get_database().fetch_one(query)
if not result:
return None
return Transcript.model_validate(row)
return Transcript(**result)
async def get_by_recording_id(
self, session: AsyncSession, recording_id: str, **kwargs
self, recording_id: str, **kwargs
) -> Transcript | None:
"""
Get a transcript by recording_id
"""
query = select(TranscriptModel).where(
TranscriptModel.recording_id == recording_id
)
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
if "user_id" in kwargs:
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await get_database().fetch_one(query)
if not result:
return None
return Transcript.model_validate(row)
return Transcript(**result)
async def get_by_room_id(
self, session: AsyncSession, room_id: str, **kwargs
) -> list[Transcript]:
async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]:
"""
Get transcripts by room_id (direct access without joins)
"""
query = select(TranscriptModel).where(TranscriptModel.room_id == room_id)
query = transcripts.select().where(transcripts.c.room_id == room_id)
if "user_id" in kwargs:
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
query = query.where(transcripts.c.user_id == kwargs["user_id"])
if "order_by" in kwargs:
order_by = kwargs["order_by"]
field = getattr(TranscriptModel, order_by[1:])
field = getattr(transcripts.c, order_by[1:])
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
results = await session.execute(query)
return [
Transcript.model_validate(dict(row)) for row in results.mappings().all()
]
results = await get_database().fetch_all(query)
return [Transcript(**result) for result in results]
async def get_by_id_for_http(
self,
session: AsyncSession,
transcript_id: str,
user_id: str | None,
) -> Transcript:
@@ -417,14 +491,13 @@ class TranscriptController:
This method checks the share mode of the transcript and the user_id
to determine if the user can access the transcript.
"""
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
result = await session.execute(query)
row = result.scalar_one_or_none()
if not row:
query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await get_database().fetch_one(query)
if not result:
raise HTTPException(status_code=404, detail="Transcript not found")
# if the transcript is anonymous, share mode is not checked
transcript = Transcript.model_validate(row)
transcript = Transcript(**result)
if transcript.user_id is None:
return transcript
@@ -447,7 +520,6 @@ class TranscriptController:
async def add(
self,
session: AsyncSession,
name: str,
source_kind: SourceKind,
source_language: str = "en",
@@ -472,20 +544,14 @@ class TranscriptController:
meeting_id=meeting_id,
room_id=room_id,
)
query = insert(TranscriptModel).values(**transcript.model_dump())
await session.execute(query)
await session.commit()
query = transcripts.insert().values(**transcript.model_dump())
await get_database().execute(query)
return transcript
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
# using mutate=True is discouraged
async def update(
self,
session: AsyncSession,
transcript: Transcript,
values: dict,
commit=True,
mutate=False,
self, transcript: Transcript, values: dict, mutate=False
) -> Transcript:
"""
Update a transcript fields with key/values in values.
@@ -494,13 +560,11 @@ class TranscriptController:
values = TranscriptController._handle_topics_update(values)
query = (
update(TranscriptModel)
.where(TranscriptModel.id == transcript.id)
transcripts.update()
.where(transcripts.c.id == transcript.id)
.values(**values)
)
await session.execute(query)
if commit:
await session.commit()
await get_database().execute(query)
if mutate:
for key, value in values.items():
setattr(transcript, key, value)
@@ -529,14 +593,13 @@ class TranscriptController:
async def remove_by_id(
self,
session: AsyncSession,
transcript_id: str,
user_id: str | None = None,
) -> None:
"""
Remove a transcript by id
"""
transcript = await self.get_by_id(session, transcript_id)
transcript = await self.get_by_id(transcript_id)
if not transcript:
return
if user_id is not None and transcript.user_id != user_id:
@@ -556,7 +619,7 @@ class TranscriptController:
if transcript.recording_id:
try:
recording = await recordings_controller.get_by_id(
session, transcript.recording_id
transcript.recording_id
)
if recording:
try:
@@ -567,49 +630,59 @@ class TranscriptController:
exc_info=e,
recording_id=transcript.recording_id,
)
await recordings_controller.remove_by_id(
session, transcript.recording_id
)
await recordings_controller.remove_by_id(transcript.recording_id)
except Exception as e:
logger.warning(
"Failed to delete recording row",
exc_info=e,
recording_id=transcript.recording_id,
)
query = delete(TranscriptModel).where(TranscriptModel.id == transcript_id)
await session.execute(query)
await session.commit()
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await get_database().execute(query)
async def remove_by_recording_id(self, session: AsyncSession, recording_id: str):
async def remove_by_recording_id(self, recording_id: str):
"""
Remove a transcript by recording_id
"""
query = delete(TranscriptModel).where(
TranscriptModel.recording_id == recording_id
)
await session.execute(query)
await session.commit()
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
await get_database().execute(query)
@staticmethod
def user_can_mutate(transcript: Transcript, user_id: str | None) -> bool:
"""
Returns True if the given user is allowed to modify the transcript.
Policy:
- Anonymous transcripts (user_id is None) cannot be modified via API
- Only the owner (matching user_id) can modify their transcript
"""
if transcript.user_id is None:
return False
return user_id and transcript.user_id == user_id
@asynccontextmanager
async def transaction(self):
"""
A context manager for database transaction
"""
async with get_database().transaction(isolation="serializable"):
yield
async def append_event(
self,
session: AsyncSession,
transcript: Transcript,
event: str,
data: Any,
commit=True,
) -> TranscriptEvent:
"""
Append an event to a transcript
"""
resp = transcript.add_event(event=event, data=data)
await self.update(
session, transcript, {"events": transcript.events_dump()}, commit=commit
)
await self.update(transcript, {"events": transcript.events_dump()})
return resp
async def upsert_topic(
self,
session: AsyncSession,
transcript: Transcript,
topic: TranscriptTopic,
) -> TranscriptEvent:
@@ -617,9 +690,9 @@ class TranscriptController:
Upsert topics to a transcript
"""
transcript.upsert_topic(topic)
await self.update(session, transcript, {"topics": transcript.topics_dump()})
await self.update(transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, session: AsyncSession, transcript: Transcript):
async def move_mp3_to_storage(self, transcript: Transcript):
"""
Move mp3 file to storage
"""
@@ -643,16 +716,12 @@ class TranscriptController:
# indicate on the transcript that the audio is now on storage
# mutates transcript argument
await self.update(
session, transcript, {"audio_location": "storage"}, mutate=True
)
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
async def download_mp3_from_storage(
self, session: AsyncSession, transcript: Transcript
):
async def download_mp3_from_storage(self, transcript: Transcript):
"""
Download audio from storage
"""
@@ -664,7 +733,6 @@ class TranscriptController:
async def upsert_participant(
self,
session: AsyncSession,
transcript: Transcript,
participant: TranscriptParticipant,
) -> TranscriptParticipant:
@@ -672,14 +740,11 @@ class TranscriptController:
Add/update a participant to a transcript
"""
result = transcript.upsert_participant(participant)
await self.update(
session, transcript, {"participants": transcript.participants_dump()}
)
await self.update(transcript, {"participants": transcript.participants_dump()})
return result
async def delete_participant(
self,
session: AsyncSession,
transcript: Transcript,
participant_id: str,
):
@@ -687,37 +752,28 @@ class TranscriptController:
Delete a participant from a transcript
"""
transcript.delete_participant(participant_id)
await self.update(
session, transcript, {"participants": transcript.participants_dump()}
)
await self.update(transcript, {"participants": transcript.participants_dump()})
async def set_status(
self, session: AsyncSession, transcript_id: str, status: TranscriptStatus
self, transcript_id: str, status: TranscriptStatus
) -> TranscriptEvent | None:
"""
Update the status of a transcript
Will add an event STATUS + update the status field of transcript
"""
transcript = await self.get_by_id(session, transcript_id)
async with self.transaction():
transcript = await self.get_by_id(transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
if transcript.status == status:
return
resp = await self.append_event(
session,
transcript=transcript,
event="STATUS",
data=StrValue(value=status),
commit=False,
)
await self.update(
session,
transcript,
{"status": status},
commit=False,
)
await session.commit()
await self.update(transcript, {"status": status})
return resp

View File

@@ -0,0 +1,9 @@
"""Database utility functions."""
from reflector.db import get_database
def is_postgresql() -> bool:
return get_database().url.scheme and get_database().url.scheme.startswith(
"postgresql"
)

View File

@@ -12,8 +12,9 @@ from pathlib import Path
import av
import structlog
from sqlalchemy.ext.asyncio import AsyncSession
from celery import chain, shared_task
from reflector.asynctask import asynctask
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import (
SourceKind,
@@ -25,8 +26,8 @@ from reflector.logger import logger
from reflector.pipelines.main_live_pipeline import (
PipelineMainBase,
broadcast_to_sockets,
task_cleanup_consent_taskiq,
task_pipeline_post_to_zulip_taskiq,
task_cleanup_consent,
task_pipeline_post_to_zulip,
)
from reflector.processors import (
AudioFileWriterProcessor,
@@ -52,9 +53,7 @@ from reflector.processors.types import (
)
from reflector.settings import settings
from reflector.storage import get_transcripts_storage
from reflector.worker.app import taskiq_broker
from reflector.worker.session_decorator import catch_exception, with_session
from reflector.worker.webhook import send_transcript_webhook_taskiq
from reflector.worker.webhook import send_transcript_webhook
class EmptyPipeline:
@@ -96,23 +95,19 @@ class PipelineMainFile(PipelineMainBase):
)
@broadcast_to_sockets
async def set_status(
self,
session: AsyncSession,
transcript_id: str,
status: TranscriptStatus,
):
return await transcripts_controller.set_status(session, transcript_id, status)
async def set_status(self, transcript_id: str, status: TranscriptStatus):
async with self.lock_transaction():
return await transcripts_controller.set_status(transcript_id, status)
async def process(self, session: AsyncSession, file_path: Path):
async def process(self, file_path: Path):
"""Main entry point for file processing"""
self.logger.info(f"Starting file pipeline for {file_path}")
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
transcript = await self.get_transcript()
# Clear transcript as we're going to regenerate everything
async with self.transaction():
await transcripts_controller.update(
session,
transcript,
{
"events": [],
@@ -128,7 +123,6 @@ class PipelineMainFile(PipelineMainBase):
# Run parallel processing
await self.run_parallel_processing(
session,
audio_path,
audio_url,
transcript.source_language,
@@ -137,7 +131,7 @@ class PipelineMainFile(PipelineMainBase):
self.logger.info("File pipeline complete")
await transcripts_controller.set_status(session, transcript.id, "ended")
await self.set_status(transcript.id, "ended")
async def extract_and_write_audio(
self, file_path: Path, transcript: Transcript
@@ -199,7 +193,6 @@ class PipelineMainFile(PipelineMainBase):
async def run_parallel_processing(
self,
session,
audio_path: Path,
audio_url: str,
source_language: str,
@@ -213,7 +206,7 @@ class PipelineMainFile(PipelineMainBase):
# Phase 1: Parallel processing of independent tasks
transcription_task = self.transcribe_file(audio_url, source_language)
diarization_task = self.diarize_file(audio_url)
waveform_task = self.generate_waveform(session, audio_path)
waveform_task = self.generate_waveform(audio_path)
results = await asyncio.gather(
transcription_task, diarization_task, waveform_task, return_exceptions=True
@@ -261,7 +254,7 @@ class PipelineMainFile(PipelineMainBase):
)
results = await asyncio.gather(
self.generate_title(topics),
self.generate_summaries(session, topics),
self.generate_summaries(topics),
return_exceptions=True,
)
@@ -313,9 +306,9 @@ class PipelineMainFile(PipelineMainBase):
self.logger.error(f"Diarization failed: {e}")
return None
async def generate_waveform(self, session: AsyncSession, audio_path: Path):
async def generate_waveform(self, audio_path: Path):
"""Generate and save waveform"""
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
transcript = await self.get_transcript()
processor = AudioWaveformProcessor(
audio_path=audio_path,
@@ -368,13 +361,13 @@ class PipelineMainFile(PipelineMainBase):
await processor.flush()
async def generate_summaries(self, session, topics: list[TitleSummary]):
async def generate_summaries(self, topics: list[TitleSummary]):
"""Generate long and short summaries from topics"""
if not topics:
self.logger.warning("No topics for summary generation")
return
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
transcript = await self.get_transcript()
processor = TranscriptFinalSummaryProcessor(
transcript=transcript,
callback=self.on_long_summary,
@@ -388,15 +381,16 @@ class PipelineMainFile(PipelineMainBase):
await processor.flush()
@taskiq_broker.task
@with_session
async def task_send_webhook_if_needed(session, *, transcript_id: str):
transcript = await transcripts_controller.get_by_id(session, transcript_id)
@shared_task
@asynctask
async def task_send_webhook_if_needed(*, transcript_id: str):
"""Send webhook if this is a room recording with webhook configured"""
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
return
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
room = await rooms_controller.get_by_id(session, transcript.room_id)
room = await rooms_controller.get_by_id(transcript.room_id)
if room and room.webhook_url:
logger.info(
"Dispatching webhook",
@@ -404,23 +398,25 @@ async def task_send_webhook_if_needed(session, *, transcript_id: str):
room_id=room.id,
webhook_url=room.webhook_url,
)
await send_transcript_webhook_taskiq.kiq(
send_transcript_webhook.delay(
transcript_id, room.id, event_id=uuid.uuid4().hex
)
@taskiq_broker.task
@catch_exception
@with_session
async def task_pipeline_file_process(session: AsyncSession, *, transcript_id: str):
transcript = await transcripts_controller.get_by_id(session, transcript_id)
@shared_task
@asynctask
async def task_pipeline_file_process(*, transcript_id: str):
"""Celery task for file pipeline processing"""
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
pipeline = PipelineMainFile(transcript_id=transcript_id)
try:
await pipeline.set_status(session, transcript_id, "processing")
await pipeline.set_status(transcript_id, "processing")
# Find the file to process
audio_file = next(transcript.data_path.glob("upload.*"), None)
if not audio_file:
audio_file = next(transcript.data_path.glob("audio.*"), None)
@@ -428,18 +424,16 @@ async def task_pipeline_file_process(session: AsyncSession, *, transcript_id: st
if not audio_file:
raise Exception("No audio file found to process")
await pipeline.process(session, audio_file)
await pipeline.process(audio_file)
except Exception:
logger.error("Error while processing the file", exc_info=True)
try:
await pipeline.set_status(session, transcript_id, "error")
except:
logger.error(
"Error setting status in task_pipeline_file_process during exception, ignoring it"
)
await pipeline.set_status(transcript_id, "error")
raise
await task_cleanup_consent_taskiq.kiq(transcript_id=transcript_id)
await task_pipeline_post_to_zulip_taskiq.kiq(transcript_id=transcript_id)
await task_send_webhook_if_needed.kiq(transcript_id=transcript_id)
# Run post-processing chain: consent cleanup -> zulip -> webhook
post_chain = chain(
task_cleanup_consent.si(transcript_id=transcript_id),
task_pipeline_post_to_zulip.si(transcript_id=transcript_id),
task_send_webhook_if_needed.si(transcript_id=transcript_id),
)
post_chain.delay()

View File

@@ -12,16 +12,17 @@ It is directly linked to our data model.
"""
import asyncio
import functools
from contextlib import asynccontextmanager
from typing import Generic
import av
import boto3
from celery import chord, current_task, group, shared_task
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from structlog import BoundLogger as Logger
from reflector.db import get_session_context
from reflector.asynctask import asynctask
from reflector.db.meetings import meeting_consent_controller, meetings_controller
from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller
@@ -61,8 +62,6 @@ from reflector.processors.types import (
from reflector.processors.types import Transcript as TranscriptProcessorType
from reflector.settings import settings
from reflector.storage import get_transcripts_storage
from reflector.worker.app import taskiq_broker
from reflector.worker.session_decorator import with_session_and_transcript
from reflector.ws_manager import WebsocketManager, get_ws_manager
from reflector.zulip import (
get_zulip_message,
@@ -86,6 +85,53 @@ def broadcast_to_sockets(func):
message=resp.model_dump(mode="json"),
)
transcript = await transcripts_controller.get_by_id(self.transcript_id)
if transcript and transcript.user_id:
# Emit only relevant events to the user room to avoid noisy updates.
# Allowed: STATUS, FINAL_TITLE, DURATION. All are prefixed with TRANSCRIPT_
allowed_user_events = {"STATUS", "FINAL_TITLE", "DURATION"}
if resp.event in allowed_user_events:
await self.ws_manager.send_json(
room_id=f"user:{transcript.user_id}",
message={
"event": f"TRANSCRIPT_{resp.event}",
"data": {"id": self.transcript_id, **resp.data},
},
)
return wrapper
def get_transcript(func):
"""
Decorator to fetch the transcript from the database from the first argument
"""
@functools.wraps(func)
async def wrapper(**kwargs):
transcript_id = kwargs.pop("transcript_id")
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
if not transcript:
raise Exception("Transcript {transcript_id} not found")
# Enhanced logger with Celery task context
tlogger = logger.bind(transcript_id=transcript.id)
if current_task:
tlogger = tlogger.bind(
task_id=current_task.request.id,
task_name=current_task.name,
worker_hostname=current_task.request.hostname,
task_retries=current_task.request.retries,
transcript_id=transcript_id,
)
try:
result = await func(transcript=transcript, logger=tlogger, **kwargs)
return result
except Exception as exc:
tlogger.error("Pipeline error", function_name=func.__name__, exc_info=exc)
raise
return wrapper
@@ -107,9 +153,11 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
self._ws_manager = get_ws_manager()
return self._ws_manager
async def get_transcript(self, session: AsyncSession) -> Transcript:
async def get_transcript(self) -> Transcript:
# fetch the transcript
result = await transcripts_controller.get_by_id(session, self.transcript_id)
result = await transcripts_controller.get_by_id(
transcript_id=self.transcript_id
)
if not result:
raise Exception("Transcript not found")
return result
@@ -139,10 +187,10 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
yield
@asynccontextmanager
async def locked_session(self):
async def transaction(self):
async with self.lock_transaction():
async with get_session_context() as session:
yield session
async with transcripts_controller.transaction():
yield
@broadcast_to_sockets
async def on_status(self, status):
@@ -173,17 +221,13 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
# when the status of the pipeline changes, update the transcript
async with self._lock:
async with get_session_context() as session:
return await transcripts_controller.set_status(
session, self.transcript_id, status
)
return await transcripts_controller.set_status(self.transcript_id, status)
@broadcast_to_sockets
async def on_transcript(self, data):
async with self.locked_session() as session:
transcript = await self.get_transcript(session)
async with self.transaction():
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
session,
transcript=transcript,
event="TRANSCRIPT",
data=TranscriptText(text=data.text, translation=data.translation),
@@ -200,11 +244,10 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
)
if isinstance(data, TitleSummaryWithIdProcessorType):
topic.id = data.id
async with self.locked_session() as session:
transcript = await self.get_transcript(session)
await transcripts_controller.upsert_topic(session, transcript, topic)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.upsert_topic(transcript, topic)
return await transcripts_controller.append_event(
session,
transcript=transcript,
event="TOPIC",
data=topic,
@@ -213,18 +256,16 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_title(self, data):
final_title = TranscriptFinalTitle(title=data.title)
async with self.locked_session() as session:
transcript = await self.get_transcript(session)
async with self.transaction():
transcript = await self.get_transcript()
if not transcript.title:
await transcripts_controller.update(
session,
transcript,
{
"title": final_title.title,
},
)
return await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_TITLE",
data=final_title,
@@ -233,17 +274,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_long_summary(self, data):
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
async with self.locked_session() as session:
transcript = await self.get_transcript(session)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
session,
transcript,
{
"long_summary": final_long_summary.long_summary,
},
)
return await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_LONG_SUMMARY",
data=final_long_summary,
@@ -254,17 +293,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
final_short_summary = TranscriptFinalShortSummary(
short_summary=data.short_summary
)
async with self.locked_session() as session:
transcript = await self.get_transcript(session)
async with self.transaction():
transcript = await self.get_transcript()
await transcripts_controller.update(
session,
transcript,
{
"short_summary": final_short_summary.short_summary,
},
)
return await transcripts_controller.append_event(
session,
transcript=transcript,
event="FINAL_SHORT_SUMMARY",
data=final_short_summary,
@@ -272,30 +309,29 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
@broadcast_to_sockets
async def on_duration(self, data):
async with self.locked_session() as session:
async with self.transaction():
duration = TranscriptDuration(duration=data)
transcript = await self.get_transcript(session)
transcript = await self.get_transcript()
await transcripts_controller.update(
session,
transcript,
{
"duration": duration.duration,
},
)
return await transcripts_controller.append_event(
session, transcript=transcript, event="DURATION", data=duration
transcript=transcript, event="DURATION", data=duration
)
@broadcast_to_sockets
async def on_waveform(self, data):
async with self.locked_session() as session:
async with self.transaction():
waveform = TranscriptWaveform(waveform=data)
transcript = await self.get_transcript(session)
transcript = await self.get_transcript()
return await transcripts_controller.append_event(
session, transcript=transcript, event="WAVEFORM", data=waveform
transcript=transcript, event="WAVEFORM", data=waveform
)
@@ -308,8 +344,7 @@ class PipelineMainLive(PipelineMainBase):
async def create(self) -> Pipeline:
# create a context for the whole rtc transaction
# add a customised logger to the context
async with get_session_context() as session:
transcript = await self.get_transcript(session)
transcript = await self.get_transcript()
processors = [
AudioFileWriterProcessor(
@@ -357,8 +392,7 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
# now let's start the pipeline by pushing information to the
# first processor diarization processor
# XXX translation is lost when converting our data model to the processor model
async with get_session_context() as session:
transcript = await self.get_transcript(session)
transcript = await self.get_transcript()
# diarization works only if the file is uploaded to an external storage
if transcript.audio_location == "local":
@@ -391,8 +425,7 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
async def create(self) -> Pipeline:
# get transcript
async with get_session_context() as session:
self._transcript = transcript = await self.get_transcript(session)
self._transcript = transcript = await self.get_transcript()
# create pipeline
processors = self.get_processors()
@@ -452,7 +485,8 @@ class PipelineMainWaveform(PipelineMainFromTopics):
]
async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger):
@get_transcript
async def pipeline_remove_upload(transcript: Transcript, logger: Logger):
# for future changes: note that there's also a consent process happens, beforehand and users may not consent with keeping files. currently, we delete regardless, so it's no need for that
logger.info("Starting remove upload")
uploads = transcript.data_path.glob("upload.*")
@@ -461,14 +495,16 @@ async def pipeline_remove_upload(session, transcript: Transcript, logger: Logger
logger.info("Remove upload done")
async def pipeline_waveform(session, transcript: Transcript, logger: Logger):
@get_transcript
async def pipeline_waveform(transcript: Transcript, logger: Logger):
logger.info("Starting waveform")
runner = PipelineMainWaveform(transcript_id=transcript.id)
await runner.run()
logger.info("Waveform done")
async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logger):
@get_transcript
async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
logger.info("Starting convert to mp3")
# If the audio wav is not available, just skip
@@ -494,7 +530,8 @@ async def pipeline_convert_to_mp3(session, transcript: Transcript, logger: Logge
logger.info("Convert to mp3 done")
async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
@get_transcript
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
if not settings.TRANSCRIPT_STORAGE_BACKEND:
logger.info("No storage backend configured, skipping mp3 upload")
return
@@ -512,49 +549,49 @@ async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
return
# Upload to external storage and delete the file
await transcripts_controller.move_mp3_to_storage(session, transcript)
await transcripts_controller.move_mp3_to_storage(transcript)
logger.info("Upload mp3 done")
async def pipeline_diarization(session, transcript: Transcript, logger: Logger):
@get_transcript
async def pipeline_diarization(transcript: Transcript, logger: Logger):
logger.info("Starting diarization")
runner = PipelineMainDiarization(transcript_id=transcript.id)
await runner.run()
logger.info("Diarization done")
async def pipeline_title(session, transcript: Transcript, logger: Logger):
@get_transcript
async def pipeline_title(transcript: Transcript, logger: Logger):
logger.info("Starting title")
runner = PipelineMainTitle(transcript_id=transcript.id)
await runner.run()
logger.info("Title done")
async def pipeline_summaries(session, transcript: Transcript, logger: Logger):
@get_transcript
async def pipeline_summaries(transcript: Transcript, logger: Logger):
logger.info("Starting summaries")
runner = PipelineMainFinalSummaries(transcript_id=transcript.id)
await runner.run()
logger.info("Summaries done")
async def cleanup_consent(session, transcript: Transcript, logger: Logger):
@get_transcript
async def cleanup_consent(transcript: Transcript, logger: Logger):
logger.info("Starting consent cleanup")
consent_denied = False
recording = None
try:
if transcript.recording_id:
recording = await recordings_controller.get_by_id(
session, transcript.recording_id
)
recording = await recordings_controller.get_by_id(transcript.recording_id)
if recording and recording.meeting_id:
meeting = await meetings_controller.get_by_id(
session, recording.meeting_id
)
meeting = await meetings_controller.get_by_id(recording.meeting_id)
if meeting:
consent_denied = await meeting_consent_controller.has_any_denial(
session, meeting.id
meeting.id
)
except Exception as e:
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
@@ -583,7 +620,7 @@ async def cleanup_consent(session, transcript: Transcript, logger: Logger):
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
# non-transactional, files marked for deletion not actually deleted is possible
await transcripts_controller.update(session, transcript, {"audio_deleted": True})
await transcripts_controller.update(transcript, {"audio_deleted": True})
# 2. Delete processed audio from transcript storage S3 bucket
if transcript.audio_location == "storage":
storage = get_transcripts_storage()
@@ -607,14 +644,15 @@ async def cleanup_consent(session, transcript: Transcript, logger: Logger):
logger.info("Consent cleanup done")
async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger):
@get_transcript
async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
logger.info("Starting post to zulip")
if not transcript.recording_id:
logger.info("Transcript has no recording")
return
recording = await recordings_controller.get_by_id(session, transcript.recording_id)
recording = await recordings_controller.get_by_id(transcript.recording_id)
if not recording:
logger.info("Recording not found")
return
@@ -623,12 +661,12 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
logger.info("Recording has no meeting")
return
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
meeting = await meetings_controller.get_by_id(recording.meeting_id)
if not meeting:
logger.info("No meeting found for this recording")
return
room = await rooms_controller.get_by_id(session, meeting.room_id)
room = await rooms_controller.get_by_id(meeting.room_id)
if not room:
logger.error(f"Missing room for a meeting {meeting.id}")
return
@@ -654,7 +692,7 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
room.zulip_stream, room.zulip_topic, message
)
await transcripts_controller.update(
session, transcript, {"zulip_message_id": response["id"]}
transcript, {"zulip_message_id": response["id"]}
)
logger.info("Posted to zulip")
@@ -665,120 +703,92 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
# ===================================================================
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_remove_upload(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_remove_upload(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_pipeline_remove_upload(*, transcript_id: str):
await pipeline_remove_upload(transcript_id=transcript_id)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_waveform(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_waveform(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_pipeline_waveform(*, transcript_id: str):
await pipeline_waveform(transcript_id=transcript_id)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_convert_to_mp3(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_convert_to_mp3(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_pipeline_convert_to_mp3(*, transcript_id: str):
await pipeline_convert_to_mp3(transcript_id=transcript_id)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_upload_mp3(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_upload_mp3(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_pipeline_upload_mp3(*, transcript_id: str):
await pipeline_upload_mp3(transcript_id=transcript_id)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_diarization(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_diarization(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_pipeline_diarization(*, transcript_id: str):
await pipeline_diarization(transcript_id=transcript_id)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_title(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_title(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_pipeline_title(*, transcript_id: str):
await pipeline_title(transcript_id=transcript_id)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_final_summaries(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
await pipeline_summaries(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_pipeline_final_summaries(*, transcript_id: str):
await pipeline_summaries(transcript_id=transcript_id)
@taskiq_broker.task
@with_session_and_transcript
async def task_cleanup_consent(session, *, transcript: Transcript, logger: Logger):
await cleanup_consent(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_cleanup_consent(*, transcript_id: str):
await cleanup_consent(transcript_id=transcript_id)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_post_to_zulip(
session, *, transcript: Transcript, logger: Logger
):
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_pipeline_post_to_zulip(*, transcript_id: str):
await pipeline_post_to_zulip(transcript_id=transcript_id)
@taskiq_broker.task
@with_session_and_transcript
async def task_cleanup_consent_taskiq(
session, *, transcript: Transcript, logger: Logger
):
await cleanup_consent(session, transcript=transcript, logger=logger)
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_post_to_zulip_taskiq(
session, *, transcript: Transcript, logger: Logger
):
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
async def pipeline_post(*, transcript_id: str):
await task_pipeline_post_sequential.kiq(transcript_id=transcript_id)
@taskiq_broker.task
async def task_pipeline_post_sequential(*, transcript_id: str):
await task_pipeline_waveform.kiq(transcript_id=transcript_id)
await task_pipeline_convert_to_mp3.kiq(transcript_id=transcript_id)
await task_pipeline_upload_mp3.kiq(transcript_id=transcript_id)
await task_pipeline_remove_upload.kiq(transcript_id=transcript_id)
await task_pipeline_diarization.kiq(transcript_id=transcript_id)
await task_cleanup_consent.kiq(transcript_id=transcript_id)
await asyncio.gather(
task_pipeline_title.kiq(transcript_id=transcript_id),
task_pipeline_final_summaries.kiq(transcript_id=transcript_id),
def pipeline_post(*, transcript_id: str):
"""
Run the post pipeline
"""
chain_mp3_and_diarize = (
task_pipeline_waveform.si(transcript_id=transcript_id)
| task_pipeline_convert_to_mp3.si(transcript_id=transcript_id)
| task_pipeline_upload_mp3.si(transcript_id=transcript_id)
| task_pipeline_remove_upload.si(transcript_id=transcript_id)
| task_pipeline_diarization.si(transcript_id=transcript_id)
| task_cleanup_consent.si(transcript_id=transcript_id)
)
chain_title_preview = task_pipeline_title.si(transcript_id=transcript_id)
chain_final_summaries = task_pipeline_final_summaries.si(
transcript_id=transcript_id
)
await task_pipeline_post_to_zulip.kiq(transcript_id=transcript_id)
chain = chord(
group(chain_mp3_and_diarize, chain_title_preview),
chain_final_summaries,
) | task_pipeline_post_to_zulip.si(transcript_id=transcript_id)
return chain.delay()
async def pipeline_process(session, transcript: Transcript, logger: Logger):
@get_transcript
async def pipeline_process(transcript: Transcript, logger: Logger):
try:
if transcript.audio_location == "storage":
await transcripts_controller.download_mp3_from_storage(transcript)
transcript.audio_waveform_filename.unlink(missing_ok=True)
await transcripts_controller.update(
session,
transcript,
{
"topics": [],
@@ -816,7 +826,6 @@ async def pipeline_process(session, transcript: Transcript, logger: Logger):
except Exception as exc:
logger.error("Pipeline error", exc_info=exc)
await transcripts_controller.update(
session,
transcript,
{
"status": "error",
@@ -827,9 +836,7 @@ async def pipeline_process(session, transcript: Transcript, logger: Logger):
logger.info("Pipeline ended")
@taskiq_broker.task
@with_session_and_transcript
async def task_pipeline_process(
session, *, transcript: Transcript, logger: Logger, transcript_id: str
):
return await pipeline_process(session, transcript=transcript, logger=logger)
@shared_task
@asynctask
async def task_pipeline_process(*, transcript_id: str):
return await pipeline_process(transcript_id=transcript_id)

View File

@@ -55,7 +55,6 @@ import httpx
import pytz
import structlog
from icalendar import Calendar, Event
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
from reflector.db.rooms import Room, rooms_controller
@@ -248,21 +247,15 @@ class ICSFetchService:
)
att_data: AttendeeData = {
"email": clean_email,
"name": (
att.params.get("CN")
"name": att.params.get("CN")
if hasattr(att, "params") and email == email_parts[0]
else None
),
"status": (
att.params.get("PARTSTAT")
else None,
"status": att.params.get("PARTSTAT")
if hasattr(att, "params") and email == email_parts[0]
else None
),
"role": (
att.params.get("ROLE")
else None,
"role": att.params.get("ROLE")
if hasattr(att, "params") and email == email_parts[0]
else None
),
else None,
}
final_attendees.append(att_data)
else:
@@ -270,9 +263,9 @@ class ICSFetchService:
att_data: AttendeeData = {
"email": email_str,
"name": att.params.get("CN") if hasattr(att, "params") else None,
"status": (
att.params.get("PARTSTAT") if hasattr(att, "params") else None
),
"status": att.params.get("PARTSTAT")
if hasattr(att, "params")
else None,
"role": att.params.get("ROLE") if hasattr(att, "params") else None,
}
final_attendees.append(att_data)
@@ -287,9 +280,9 @@ class ICSFetchService:
)
org_data: AttendeeData = {
"email": org_email,
"name": (
organizer.params.get("CN") if hasattr(organizer, "params") else None
),
"name": organizer.params.get("CN")
if hasattr(organizer, "params")
else None,
"role": "ORGANIZER",
}
final_attendees.append(org_data)
@@ -301,7 +294,7 @@ class ICSSyncService:
def __init__(self):
self.fetch_service = ICSFetchService()
async def sync_room_calendar(self, session: AsyncSession, room: Room) -> SyncResult:
async def sync_room_calendar(self, room: Room) -> SyncResult:
async with RedisAsyncLock(
f"ics_sync_room:{room.id}", skip_if_locked=True
) as lock:
@@ -312,11 +305,9 @@ class ICSSyncService:
"reason": "Sync already in progress",
}
return await self._sync_room_calendar(session, room)
return await self._sync_room_calendar(room)
async def _sync_room_calendar(
self, session: AsyncSession, room: Room
) -> SyncResult:
async def _sync_room_calendar(self, room: Room) -> SyncResult:
if not room.ics_enabled or not room.ics_url:
return {"status": SyncStatus.SKIPPED, "reason": "ICS not configured"}
@@ -349,11 +340,10 @@ class ICSSyncService:
events, total_events = self.fetch_service.extract_room_events(
calendar, room.name, room_url
)
sync_result = await self._sync_events_to_database(session, room.id, events)
sync_result = await self._sync_events_to_database(room.id, events)
# Update room sync metadata
await rooms_controller.update(
session,
room,
{
"ics_last_sync": datetime.now(timezone.utc),
@@ -382,7 +372,7 @@ class ICSSyncService:
return time_since_sync.total_seconds() >= room.ics_fetch_interval
async def _sync_events_to_database(
self, session: AsyncSession, room_id: str, events: list[EventData]
self, room_id: str, events: list[EventData]
) -> SyncStats:
created = 0
updated = 0
@@ -392,7 +382,7 @@ class ICSSyncService:
for event_data in events:
calendar_event = CalendarEvent(room_id=room_id, **event_data)
existing = await calendar_events_controller.get_by_ics_uid(
session, room_id, event_data["ics_uid"]
room_id, event_data["ics_uid"]
)
if existing:
@@ -400,12 +390,12 @@ class ICSSyncService:
else:
created += 1
await calendar_events_controller.upsert(session, calendar_event)
await calendar_events_controller.upsert(calendar_event)
current_ics_uids.append(event_data["ics_uid"])
# Soft delete events that are no longer in calendar
deleted = await calendar_events_controller.soft_delete_missing(
session, room_id, current_ics_uids
room_id, current_ics_uids
)
return {

View File

@@ -9,11 +9,12 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import get_session_context
from reflector.db.transcripts import transcripts_controller
from reflector.db import get_database, transcripts
async with get_session_context() as session:
transcripts = await transcripts_controller.get_all(session)
database = get_database()
await database.connect()
transcripts = await database.fetch_all(transcripts.select())
await database.disconnect()
def export_transcript(transcript, output_dir):
for topic in transcript.topics:

View File

@@ -8,11 +8,12 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import get_session_context
from reflector.db.transcripts import transcripts_controller
from reflector.db import get_database, transcripts
async with get_session_context() as session:
transcripts = await transcripts_controller.get_all(session)
database = get_database()
await database.connect()
transcripts = await database.fetch_all(transcripts.select())
await database.disconnect()
def export_transcript(transcript):
tid = transcript.id

View File

@@ -7,12 +7,10 @@ import asyncio
import json
import shutil
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db import get_session_context
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
from reflector.logger import logger
from reflector.pipelines.main_file_pipeline import (
@@ -52,7 +50,6 @@ TranscriptId = str
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
# ideally we want to get rid of it at some point
async def prepare_entry(
session: AsyncSession,
source_path: str,
source_language: str,
target_language: str,
@@ -60,7 +57,6 @@ async def prepare_entry(
file_path = Path(source_path)
transcript = await transcripts_controller.add(
session,
file_path.name,
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
source_kind=SourceKind.FILE,
@@ -82,20 +78,16 @@ async def prepare_entry(
logger.info(f"Copied {source_path} to {upload_path}")
# pipelines expect entity status "uploaded"
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
await transcripts_controller.update(transcript, {"status": "uploaded"})
return transcript.id
# same reason as prepare_entry
async def extract_result_from_entry(
session: AsyncSession,
transcript_id: TranscriptId,
output_path: str,
transcript_id: TranscriptId, output_path: str
) -> None:
post_final_transcript = await transcripts_controller.get_by_id(
session, transcript_id
)
post_final_transcript = await transcripts_controller.get_by_id(transcript_id)
# assert post_final_transcript.status == "ended"
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
@@ -123,7 +115,6 @@ async def extract_result_from_entry(
async def process_live_pipeline(
session: AsyncSession,
transcript_id: TranscriptId,
):
"""Process transcript_id with transcription and diarization"""
@@ -132,14 +123,18 @@ async def process_live_pipeline(
await live_pipeline_process(transcript_id=transcript_id)
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
pre_final_transcript = await transcripts_controller.get_by_id(
session, transcript_id
)
pre_final_transcript = await transcripts_controller.get_by_id(transcript_id)
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
assert pre_final_transcript.status != "ended"
await live_pipeline_post(transcript_id=transcript_id)
# at this point, diarization is running but we have no access to it. run diarization in parallel - one will hopefully win after polling
result = live_pipeline_post(transcript_id=transcript_id)
# result.ready() blocks even without await; it mutates result also
while not result.ready():
print(f"Status: {result.state}")
time.sleep(2)
async def process_file_pipeline(
@@ -147,7 +142,13 @@ async def process_file_pipeline(
):
"""Process audio/video file using the optimized file pipeline"""
await task_pipeline_file_process.kiq(transcript_id=transcript_id)
# task_pipeline_file_process is a Celery task, need to use .delay() for async execution
result = task_pipeline_file_process.delay(transcript_id=transcript_id)
# Wait for the Celery task to complete
while not result.ready():
print(f"File pipeline status: {result.state}", file=sys.stderr)
time.sleep(2)
logger.info("File pipeline processing complete")
@@ -159,16 +160,21 @@ async def process(
pipeline: Literal["live", "file"],
output_path: str = None,
):
async with get_session_context() as session:
from reflector.db import get_database
database = get_database()
# db connect is a part of ceremony
await database.connect()
try:
transcript_id = await prepare_entry(
session,
source_path,
source_language,
target_language,
)
pipeline_handlers = {
"live": lambda tid: process_live_pipeline(session, tid),
"live": process_live_pipeline,
"file": process_file_pipeline,
}
@@ -178,7 +184,9 @@ async def process(
await handler(transcript_id)
await extract_result_from_entry(session, transcript_id, output_path)
await extract_result_from_entry(transcript_id, output_path)
finally:
await database.disconnect()
if __name__ == "__main__":

View File

@@ -1,10 +1,14 @@
import argparse
import asyncio
from reflector.pipelines.main_live_pipeline import pipeline_post
from reflector.app import celery_app # noqa
from reflector.pipelines.main_live_pipeline import task_pipeline_main_post
parser = argparse.ArgumentParser()
parser.add_argument("transcript_id", type=str)
parser.add_argument("--delay", action="store_true")
args = parser.parse_args()
asyncio.run(pipeline_post(transcript_id=args.transcript_id))
if args.delay:
task_pipeline_main_post.delay(args.transcript_id)
else:
task_pipeline_main_post(args.transcript_id)

View File

@@ -5,13 +5,12 @@ from typing import Annotated, Any, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate
from fastapi_pagination.ext.databases import apaginate
from pydantic import BaseModel
from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db import get_database
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
@@ -177,29 +176,31 @@ def parse_datetime_with_timezone(iso_string: str) -> datetime:
@router.get("/rooms", response_model=Page[RoomDetails])
async def rooms_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> list[RoomDetails]:
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
query = await rooms_controller.get_all(
session, user_id=user_id, order_by="-created_at", return_query=True
return await apaginate(
get_database(),
await rooms_controller.get_all(
user_id=user_id, order_by="-created_at", return_query=True
),
)
return await paginate(session, query)
@router.get("/rooms/{room_id}", response_model=RoomDetails)
async def rooms_get(
room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
if not room.is_shared and (user_id is None or room.user_id != user_id):
raise HTTPException(status_code=403, detail="Room access denied")
return room
@@ -207,10 +208,9 @@ async def rooms_get(
async def rooms_get_by_name(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name)
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -231,13 +231,11 @@ async def rooms_get_by_name(
@router.post("/rooms", response_model=Room)
async def rooms_create(
room: CreateRoom,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
):
user_id = user["sub"] if user else None
user_id = user["sub"]
return await rooms_controller.add(
session,
name=room.name,
user_id=user_id,
zulip_auto_post=room.zulip_auto_post,
@@ -260,29 +258,31 @@ async def rooms_create(
async def rooms_update(
room_id: str,
info: UpdateRoom,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
user_id = user["sub"]
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
if room.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
values = info.dict(exclude_unset=True)
await rooms_controller.update(session, room, values)
await rooms_controller.update(room, values)
return room
@router.delete("/rooms/{room_id}", response_model=DeletionStatus)
async def rooms_delete(
room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_id(session, room_id, user_id=user_id)
user_id = user["sub"]
room = await rooms_controller.get_by_id(room_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
await rooms_controller.remove_by_id(session, room.id, user_id=user_id)
if room.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
await rooms_controller.remove_by_id(room.id, user_id=user_id)
return DeletionStatus(status="ok")
@@ -291,10 +291,9 @@ async def rooms_create_meeting(
room_name: str,
info: CreateRoomMeeting,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name)
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -310,7 +309,7 @@ async def rooms_create_meeting(
meeting = None
if not info.allow_duplicated:
meeting = await meetings_controller.get_active(
session, room=room, current_time=current_time
room=room, current_time=current_time
)
if meeting is None:
@@ -321,7 +320,6 @@ async def rooms_create_meeting(
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
meeting = await meetings_controller.create(
session,
id=whereby_meeting["meetingId"],
room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"],
@@ -347,17 +345,16 @@ async def rooms_create_meeting(
@router.post("/rooms/{room_id}/webhook/test", response_model=WebhookTestResult)
async def rooms_test_webhook(
room_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
):
"""Test webhook configuration by sending a sample payload."""
user_id = user["sub"] if user else None
user_id = user["sub"]
room = await rooms_controller.get_by_id(session, room_id)
room = await rooms_controller.get_by_id(room_id)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
if user_id and room.user_id != user_id:
if room.user_id != user_id:
raise HTTPException(
status_code=403, detail="Not authorized to test this room's webhook"
)
@@ -370,10 +367,9 @@ async def rooms_test_webhook(
async def rooms_sync_ics(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name)
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -386,7 +382,7 @@ async def rooms_sync_ics(
if not room.ics_enabled or not room.ics_url:
raise HTTPException(status_code=400, detail="ICS not configured for this room")
result = await ics_sync_service.sync_room_calendar(session, room)
result = await ics_sync_service.sync_room_calendar(room)
if result["status"] == "error":
raise HTTPException(
@@ -400,10 +396,9 @@ async def rooms_sync_ics(
async def rooms_ics_status(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name)
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
@@ -418,7 +413,7 @@ async def rooms_ics_status(
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
events = await calendar_events_controller.get_by_room(
session, room.id, include_deleted=False
room.id, include_deleted=False
)
return ICSStatus(
@@ -434,16 +429,15 @@ async def rooms_ics_status(
async def rooms_list_meetings(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name)
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
events = await calendar_events_controller.get_by_room(
session, room.id, include_deleted=False
room.id, include_deleted=False
)
if user_id != room.user_id:
@@ -461,16 +455,15 @@ async def rooms_list_upcoming_meetings(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
minutes_ahead: int = 120,
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name)
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
events = await calendar_events_controller.get_upcoming(
session, room.id, minutes_ahead=minutes_ahead
room.id, minutes_ahead=minutes_ahead
)
if user_id != room.user_id:
@@ -485,17 +478,16 @@ async def rooms_list_upcoming_meetings(
async def rooms_list_active_meetings(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name)
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
current_time = datetime.now(timezone.utc)
meetings = await meetings_controller.get_all_active_for_room(
session, room=room, current_time=current_time
room=room, current_time=current_time
)
# Hide host URLs from non-owners
@@ -511,16 +503,15 @@ async def rooms_get_meeting(
room_name: str,
meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
"""Get a single meeting by ID within a specific room."""
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name)
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(session, meeting_id)
meeting = await meetings_controller.get_by_id(meeting_id)
if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found")
@@ -540,15 +531,14 @@ async def rooms_join_meeting(
room_name: str,
meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(session, room_name)
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(session, meeting_id)
meeting = await meetings_controller.get_by_id(meeting_id)
if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found")

View File

@@ -3,15 +3,12 @@ from typing import Annotated, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page
from fastapi_pagination.ext.sqlalchemy import paginate
from fastapi_pagination.ext.databases import apaginate
from jose import jwt
from pydantic import BaseModel, Field, constr, field_serializer
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.db import get_database
from reflector.db.search import (
DEFAULT_SEARCH_LIMIT,
SearchLimit,
@@ -35,6 +32,7 @@ from reflector.db.transcripts import (
from reflector.processors.types import Transcript as ProcessorTranscript
from reflector.processors.types import Word
from reflector.settings import settings
from reflector.ws_manager import get_ws_manager
from reflector.zulip import (
InvalidMessageError,
get_zulip_message,
@@ -150,25 +148,24 @@ async def transcripts_list(
source_kind: SourceKind | None = None,
room_id: str | None = None,
search_term: str | None = None,
session: AsyncSession = Depends(get_session),
):
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
query = await transcripts_controller.get_all(
session,
return await apaginate(
get_database(),
await transcripts_controller.get_all(
user_id=user_id,
source_kind=SourceKind(source_kind) if source_kind else None,
room_id=room_id,
search_term=search_term,
order_by="-created_at",
return_query=True,
),
)
return await paginate(session, query)
@router.get("/transcripts/search", response_model=SearchResponse)
async def transcripts_search(
@@ -180,7 +177,6 @@ async def transcripts_search(
user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional)
] = None,
session: AsyncSession = Depends(get_session),
):
"""
Full-text search across transcript titles and content.
@@ -199,7 +195,7 @@ async def transcripts_search(
source_kind=source_kind,
)
results, total = await search_controller.search_transcripts(session, search_params)
results, total = await search_controller.search_transcripts(search_params)
return SearchResponse(
results=results,
@@ -214,11 +210,9 @@ async def transcripts_search(
async def transcripts_create(
info: CreateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
return await transcripts_controller.add(
session,
transcript = await transcripts_controller.add(
info.name,
source_kind=info.source_kind or SourceKind.LIVE,
source_language=info.source_language,
@@ -226,6 +220,14 @@ async def transcripts_create(
user_id=user_id,
)
if user_id:
await get_ws_manager().send_json(
room_id=f"user:{user_id}",
message={"event": "TRANSCRIPT_CREATED", "data": {"id": transcript.id}},
)
return transcript
# ==============================================================
# Single transcript
@@ -338,11 +340,10 @@ class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
async def transcript_get(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
return await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
@@ -350,38 +351,36 @@ async def transcript_get(
async def transcript_update(
transcript_id: str,
info: UpdateTranscript,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
):
user_id = user["sub"] if user else None
user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if not transcripts_controller.user_can_mutate(transcript, user_id):
raise HTTPException(status_code=403, detail="Not authorized")
values = info.dict(exclude_unset=True)
updated_transcript = await transcripts_controller.update(
session, transcript, values
)
updated_transcript = await transcripts_controller.update(transcript, values)
return updated_transcript
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
async def transcript_delete(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(session, transcript_id)
user_id = user["sub"]
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcripts_controller.user_can_mutate(transcript, user_id):
raise HTTPException(status_code=403, detail="Not authorized")
if transcript.meeting_id:
meeting = await meetings_controller.get_by_id(session, transcript.meeting_id)
room = await rooms_controller.get_by_id(session, meeting.room_id)
if room.is_shared:
user_id = None
await transcripts_controller.remove_by_id(session, transcript.id, user_id=user_id)
await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
await get_ws_manager().send_json(
room_id=f"user:{user_id}",
message={"event": "TRANSCRIPT_DELETED", "data": {"id": transcript.id}},
)
return DeletionStatus(status="ok")
@@ -392,11 +391,10 @@ async def transcript_delete(
async def transcript_get_topics(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
# convert to GetTranscriptTopic
@@ -412,11 +410,10 @@ async def transcript_get_topics(
async def transcript_get_topics_with_words(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
# convert to GetTranscriptTopicWithWords
@@ -434,11 +431,10 @@ async def transcript_get_topics_with_words_per_speaker(
transcript_id: str,
topic_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
# get the topic from the transcript
@@ -456,16 +452,16 @@ async def transcript_post_to_zulip(
stream: str,
topic: str,
include_topics: bool,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
):
user_id = user["sub"] if user else None
user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
if not transcripts_controller.user_can_mutate(transcript, user_id):
raise HTTPException(status_code=403, detail="Not authorized")
content = get_zulip_message(transcript, include_topics)
message_updated = False
@@ -481,5 +477,5 @@ async def transcript_post_to_zulip(
if not message_updated:
response = await send_message_to_zulip(stream, topic, content)
await transcripts_controller.update(
session, transcript, {"zulip_message_id": response["id"]}
transcript, {"zulip_message_id": response["id"]}
)

View File

@@ -9,10 +9,8 @@ from typing import Annotated, Optional
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from jose import jwt
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import AudioWaveform, transcripts_controller
from reflector.settings import settings
from reflector.views.transcripts import ALGORITHM
@@ -34,7 +32,6 @@ async def transcript_get_audio_mp3(
request: Request,
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
token: str | None = None,
):
user_id = user["sub"] if user else None
@@ -51,7 +48,7 @@ async def transcript_get_audio_mp3(
raise unauthorized_exception
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if transcript.audio_location == "storage":
@@ -89,7 +86,7 @@ async def transcript_get_audio_mp3(
return range_requests_response(
request,
transcript.audio_mp3_filename.as_posix(),
transcript.audio_mp3_filename,
content_type="audio/mpeg",
content_disposition=f"attachment; filename={filename}",
)
@@ -99,18 +96,13 @@ async def transcript_get_audio_mp3(
async def transcript_get_audio_waveform(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> AudioWaveform:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if not transcript.audio_waveform_filename.exists():
raise HTTPException(status_code=404, detail="Audio not found")
audio_waveform = transcript.audio_waveform
if not audio_waveform:
raise HTTPException(status_code=404, detail="Audio waveform not found")
return audio_waveform
return transcript.audio_waveform

View File

@@ -8,10 +8,8 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
from reflector.views.types import DeletionStatus
@@ -39,11 +37,10 @@ class UpdateParticipant(BaseModel):
async def transcript_get_participants(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> list[Participant]:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if transcript.participants is None:
@@ -59,13 +56,14 @@ async def transcript_get_participants(
async def transcript_add_participant(
transcript_id: str,
participant: CreateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
) -> Participant:
user_id = user["sub"] if user else None
user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
# ensure the speaker is unique
if participant.speaker is not None and transcript.participants is not None:
@@ -77,7 +75,7 @@ async def transcript_add_participant(
)
obj = await transcripts_controller.upsert_participant(
session, transcript, TranscriptParticipant(**participant.dict())
transcript, TranscriptParticipant(**participant.dict())
)
return Participant.model_validate(obj)
@@ -87,11 +85,10 @@ async def transcript_get_participant(
transcript_id: str,
participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
) -> Participant:
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
for p in transcript.participants:
@@ -106,13 +103,14 @@ async def transcript_update_participant(
transcript_id: str,
participant_id: str,
participant: UpdateParticipant,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
) -> Participant:
user_id = user["sub"] if user else None
user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
# ensure the speaker is unique
for p in transcript.participants:
@@ -136,7 +134,7 @@ async def transcript_update_participant(
fields = participant.dict(exclude_unset=True)
obj = obj.copy(update=fields)
await transcripts_controller.upsert_participant(session, transcript, obj)
await transcripts_controller.upsert_participant(transcript, obj)
return Participant.model_validate(obj)
@@ -144,12 +142,13 @@ async def transcript_update_participant(
async def transcript_delete_participant(
transcript_id: str,
participant_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
) -> DeletionStatus:
user_id = user["sub"] if user else None
user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
await transcripts_controller.delete_participant(session, transcript, participant_id)
if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
await transcripts_controller.delete_participant(transcript, participant_id)
return DeletionStatus(status="ok")

View File

@@ -1,11 +1,10 @@
from typing import Annotated, Optional
import celery
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
@@ -20,11 +19,10 @@ class ProcessStatus(BaseModel):
async def transcript_process(
transcript_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if transcript.locked:
@@ -35,6 +33,24 @@ async def transcript_process(
status_code=400, detail="Recording is not ready for processing"
)
await task_pipeline_file_process.kiq(transcript_id=transcript_id)
if task_is_scheduled_or_active(
"reflector.pipelines.main_file_pipeline.task_pipeline_file_process",
transcript_id=transcript_id,
):
return ProcessStatus(status="already running")
# schedule a background task process the file
task_pipeline_file_process.delay(transcript_id=transcript_id)
return ProcessStatus(status="ok")
def task_is_scheduled_or_active(task_name: str, **kwargs):
inspect = celery.current_app.control.inspect()
for worker, tasks in (inspect.scheduled() | inspect.active()).items():
for task in tasks:
if task["name"] == task_name and task["kwargs"] == kwargs:
return True
return False

View File

@@ -8,10 +8,8 @@ from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller
router = APIRouter()
@@ -37,13 +35,14 @@ class SpeakerMerge(BaseModel):
async def transcript_assign_speaker(
transcript_id: str,
assignment: SpeakerAssignment,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None
user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
@@ -82,9 +81,7 @@ async def transcript_assign_speaker(
# if the participant does not have a speaker, create one
if participant.speaker is None:
participant.speaker = transcript.find_empty_speaker()
await transcripts_controller.upsert_participant(
session, transcript, participant
)
await transcripts_controller.upsert_participant(transcript, participant)
speaker = participant.speaker
@@ -105,7 +102,6 @@ async def transcript_assign_speaker(
for topic in changed_topics:
transcript.upsert_topic(topic)
await transcripts_controller.update(
session,
transcript,
{
"topics": transcript.topics_dump(),
@@ -119,13 +115,14 @@ async def transcript_assign_speaker(
async def transcript_merge_speaker(
transcript_id: str,
merge: SpeakerMerge,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
) -> SpeakerAssignmentStatus:
user_id = user["sub"] if user else None
user_id = user["sub"]
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if transcript.user_id is not None and transcript.user_id != user_id:
raise HTTPException(status_code=403, detail="Not authorized")
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
@@ -170,7 +167,6 @@ async def transcript_merge_speaker(
for topic in changed_topics:
transcript.upsert_topic(topic)
await transcripts_controller.update(
session,
transcript,
{
"topics": transcript.topics_dump(),

View File

@@ -3,10 +3,8 @@ from typing import Annotated, Optional
import av
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
@@ -24,11 +22,10 @@ async def transcript_record_upload(
total_chunks: int,
chunk: UploadFile,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if transcript.locked:
@@ -92,8 +89,9 @@ async def transcript_record_upload(
container.close()
# set the status to "uploaded"
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
await transcripts_controller.update(transcript, {"status": "uploaded"})
await task_pipeline_file_process.kiq(transcript_id=transcript_id)
# launch a background task to process the file
task_pipeline_file_process.delay(transcript_id=transcript_id)
return UploadStatus(status="ok")

View File

@@ -1,10 +1,8 @@
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.ext.asyncio import AsyncSession
import reflector.auth as auth
from reflector.db import get_session
from reflector.db.transcripts import transcripts_controller
from .rtc_offer import RtcOffer, rtc_offer_base
@@ -18,11 +16,10 @@ async def transcript_record_webrtc(
params: RtcOffer,
request: Request,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
session: AsyncSession = Depends(get_session),
):
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
session, transcript_id, user_id=user_id
transcript_id, user_id=user_id
)
if transcript.locked:

View File

@@ -4,10 +4,11 @@ Transcripts websocket API
"""
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Optional
from reflector.db import get_session
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
import reflector.auth as auth
from reflector.db.transcripts import transcripts_controller
from reflector.ws_manager import get_ws_manager
@@ -23,11 +24,12 @@ async def transcript_get_websocket_events(transcript_id: str):
async def transcript_events_websocket(
transcript_id: str,
websocket: WebSocket,
session: AsyncSession = Depends(get_session),
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
user: Optional[auth.UserInfo] = Depends(auth.current_user_optional),
):
# user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id(session, transcript_id)
user_id = user["sub"] if user else None
transcript = await transcripts_controller.get_by_id_for_http(
transcript_id, user_id=user_id
)
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")

View File

@@ -0,0 +1,53 @@
from typing import Optional
from fastapi import APIRouter, WebSocket
from reflector.auth.auth_jwt import JWTAuth # type: ignore
from reflector.ws_manager import get_ws_manager
router = APIRouter()
# Close code for unauthorized WebSocket connections
UNAUTHORISED = 4401
@router.websocket("/events")
async def user_events_websocket(websocket: WebSocket):
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
token: Optional[str] = None
negotiated_subprotocol: Optional[str] = None
if len(parts) >= 2 and parts[0].lower() == "bearer":
negotiated_subprotocol = "bearer"
token = parts[1]
user_id: Optional[str] = None
if not token:
await websocket.close(code=UNAUTHORISED)
return
try:
payload = JWTAuth().verify_token(token)
user_id = payload.get("sub")
except Exception:
await websocket.close(code=UNAUTHORISED)
return
if not user_id:
await websocket.close(code=UNAUTHORISED)
return
room_id = f"user:{user_id}"
ws_manager = get_ws_manager()
await ws_manager.add_user_to_room(
room_id, websocket, subprotocol=negotiated_subprotocol
)
try:
while True:
await websocket.receive()
finally:
if room_id:
await ws_manager.remove_user_from_room(room_id, websocket)

View File

@@ -1,21 +1,68 @@
import os
import celery
import structlog
from taskiq import InMemoryBroker
from taskiq_redis import RedisAsyncResultBackend, RedisStreamBroker
from celery import Celery
from celery.schedules import crontab
from reflector.settings import settings
logger = structlog.get_logger(__name__)
env = os.environ.get("ENVIRONMENT")
if env and env == "pytest":
taskiq_broker = InMemoryBroker(await_inplace=True)
if celery.current_app.main != "default":
logger.info(f"Celery already configured ({celery.current_app})")
app = celery.current_app
else:
result_backend = RedisAsyncResultBackend(
redis_url=settings.CELERY_BROKER_URL,
result_ex_time=86400,
app = Celery(__name__)
app.conf.broker_url = settings.CELERY_BROKER_URL
app.conf.result_backend = settings.CELERY_RESULT_BACKEND
app.conf.broker_connection_retry_on_startup = True
app.autodiscover_tasks(
[
"reflector.pipelines.main_live_pipeline",
"reflector.worker.healthcheck",
"reflector.worker.process",
"reflector.worker.cleanup",
"reflector.worker.ics_sync",
]
)
taskiq_broker = RedisStreamBroker(
url=settings.CELERY_BROKER_URL,
).with_result_backend(result_backend)
# crontab
app.conf.beat_schedule = {
"process_messages": {
"task": "reflector.worker.process.process_messages",
"schedule": float(settings.SQS_POLLING_TIMEOUT_SECONDS),
},
"process_meetings": {
"task": "reflector.worker.process.process_meetings",
"schedule": float(settings.SQS_POLLING_TIMEOUT_SECONDS),
},
"reprocess_failed_recordings": {
"task": "reflector.worker.process.reprocess_failed_recordings",
"schedule": crontab(hour=5, minute=0), # Midnight EST
},
"sync_all_ics_calendars": {
"task": "reflector.worker.ics_sync.sync_all_ics_calendars",
"schedule": 60.0, # Run every minute to check which rooms need sync
},
"create_upcoming_meetings": {
"task": "reflector.worker.ics_sync.create_upcoming_meetings",
"schedule": 30.0, # Run every 30 seconds to create upcoming meetings
},
}
if settings.PUBLIC_MODE:
app.conf.beat_schedule["cleanup_old_public_data"] = {
"task": "reflector.worker.cleanup.cleanup_old_public_data_task",
"schedule": crontab(hour=3, minute=0),
}
logger.info(
"Public mode cleanup enabled",
retention_days=settings.PUBLIC_DATA_RETENTION_DAYS,
)
if settings.HEALTHCHECK_URL:
app.conf.beat_schedule["healthcheck_ping"] = {
"task": "reflector.worker.healthcheck.healthcheck_ping",
"schedule": 60.0 * 10,
}
logger.info("Healthcheck enabled", url=settings.HEALTHCHECK_URL)
else:
logger.warning("Healthcheck disabled, no url configured")

View File

@@ -9,16 +9,17 @@ from datetime import datetime, timedelta, timezone
from typing import TypedDict
import structlog
from celery import shared_task
from databases import Database
from pydantic.types import PositiveInt
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel
from reflector.db.transcripts import transcripts_controller
from reflector.asynctask import asynctask
from reflector.db import get_database
from reflector.db.meetings import meetings
from reflector.db.recordings import recordings
from reflector.db.transcripts import transcripts, transcripts_controller
from reflector.settings import settings
from reflector.storage import get_recordings_storage
from reflector.worker.app import taskiq_broker
from reflector.worker.session_decorator import with_session
logger = structlog.get_logger(__name__)
@@ -33,28 +34,28 @@ class CleanupStats(TypedDict):
async def delete_single_transcript(
session: AsyncSession, transcript_data: dict, stats: CleanupStats
db: Database, transcript_data: dict, stats: CleanupStats
):
transcript_id = transcript_data["id"]
meeting_id = transcript_data["meeting_id"]
recording_id = transcript_data["recording_id"]
try:
async with db.transaction(isolation="serializable"):
if meeting_id:
await session.execute(
delete(MeetingModel).where(MeetingModel.id == meeting_id)
)
await db.execute(meetings.delete().where(meetings.c.id == meeting_id))
stats["meetings_deleted"] += 1
logger.info("Deleted associated meeting", meeting_id=meeting_id)
if recording_id:
result = await session.execute(
select(RecordingModel).where(RecordingModel.id == recording_id)
recording = await db.fetch_one(
recordings.select().where(recordings.c.id == recording_id)
)
recording = result.mappings().first()
if recording:
try:
await get_recordings_storage().delete_file(recording["object_key"])
await get_recordings_storage().delete_file(
recording["object_key"]
)
except Exception as storage_error:
logger.warning(
"Failed to delete recording from storage",
@@ -63,13 +64,15 @@ async def delete_single_transcript(
error=str(storage_error),
)
await session.execute(
delete(RecordingModel).where(RecordingModel.id == recording_id)
await db.execute(
recordings.delete().where(recordings.c.id == recording_id)
)
stats["recordings_deleted"] += 1
logger.info("Deleted associated recording", recording_id=recording_id)
logger.info(
"Deleted associated recording", recording_id=recording_id
)
await transcripts_controller.remove_by_id(session, transcript_id)
await transcripts_controller.remove_by_id(transcript_id)
stats["transcripts_deleted"] += 1
logger.info(
"Deleted transcript",
@@ -83,30 +86,18 @@ async def delete_single_transcript(
async def cleanup_old_transcripts(
session: AsyncSession, cutoff_date: datetime, stats: CleanupStats
db: Database, cutoff_date: datetime, stats: CleanupStats
):
"""Delete old anonymous transcripts and their associated recordings/meetings."""
query = select(
TranscriptModel.id,
TranscriptModel.meeting_id,
TranscriptModel.recording_id,
TranscriptModel.created_at,
).where(
(TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None))
query = transcripts.select().where(
(transcripts.c.created_at < cutoff_date) & (transcripts.c.user_id.is_(None))
)
result = await session.execute(query)
old_transcripts = result.mappings().all()
old_transcripts = await db.fetch_all(query)
logger.info(f"Found {len(old_transcripts)} old transcripts to delete")
for transcript_data in old_transcripts:
try:
await delete_single_transcript(session, transcript_data, stats)
except Exception as e:
error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}"
logger.error(error_msg, exc_info=e)
stats["errors"].append(error_msg)
await delete_single_transcript(db, transcript_data, stats)
def log_cleanup_results(stats: CleanupStats):
@@ -126,7 +117,6 @@ def log_cleanup_results(stats: CleanupStats):
async def cleanup_old_public_data(
session: AsyncSession,
days: PositiveInt | None = None,
) -> CleanupStats | None:
if days is None:
@@ -149,13 +139,17 @@ async def cleanup_old_public_data(
"errors": [],
}
await cleanup_old_transcripts(session, cutoff_date, stats)
db = get_database()
await cleanup_old_transcripts(db, cutoff_date, stats)
log_cleanup_results(stats)
return stats
@taskiq_broker.task
@with_session
async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None):
await cleanup_old_public_data(session, days=days)
@shared_task(
autoretry_for=(Exception,),
retry_kwargs={"max_retries": 3, "countdown": 300},
)
@asynctask
async def cleanup_old_public_data_task(days: int | None = None):
await cleanup_old_public_data(days=days)

View File

@@ -1,13 +1,13 @@
import httpx
import structlog
from celery import shared_task
from reflector.settings import settings
from reflector.worker.app import taskiq_broker
logger = structlog.get_logger(__name__)
@taskiq_broker.task
@shared_task
def healthcheck_ping():
url = settings.HEALTHCHECK_URL
if not url:

View File

@@ -1,25 +1,25 @@
from datetime import datetime, timedelta, timezone
import structlog
from sqlalchemy.ext.asyncio import AsyncSession
from celery import shared_task
from celery.utils.log import get_task_logger
from reflector.asynctask import asynctask
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.redis_cache import RedisAsyncLock
from reflector.services.ics_sync import SyncStatus, ics_sync_service
from reflector.whereby import create_meeting, upload_logo
from reflector.worker.app import taskiq_broker
from reflector.worker.session_decorator import with_session
logger = structlog.get_logger(__name__)
logger = structlog.wrap_logger(get_task_logger(__name__))
@taskiq_broker.task
@with_session
async def sync_room_ics(session: AsyncSession, room_id: str):
@shared_task
@asynctask
async def sync_room_ics(room_id: str):
try:
room = await rooms_controller.get_by_id(session, room_id)
room = await rooms_controller.get_by_id(room_id)
if not room:
logger.warning("Room not found for ICS sync", room_id=room_id)
return
@@ -29,7 +29,7 @@ async def sync_room_ics(session: AsyncSession, room_id: str):
return
logger.info("Starting ICS sync for room", room_id=room_id, room_name=room.name)
result = await ics_sync_service.sync_room_calendar(session, room)
result = await ics_sync_service.sync_room_calendar(room)
if result["status"] == SyncStatus.SUCCESS:
logger.info(
@@ -53,13 +53,13 @@ async def sync_room_ics(session: AsyncSession, room_id: str):
logger.error("Unexpected error during ICS sync", room_id=room_id, error=str(e))
@taskiq_broker.task
@with_session
async def sync_all_ics_calendars(session: AsyncSession):
@shared_task
@asynctask
async def sync_all_ics_calendars():
try:
logger.info("Starting sync for all ICS-enabled rooms")
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
ics_enabled_rooms = await rooms_controller.get_ics_enabled()
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
for room in ics_enabled_rooms:
@@ -67,7 +67,7 @@ async def sync_all_ics_calendars(session: AsyncSession):
logger.debug("Skipping room, not time to sync yet", room_id=room.id)
continue
await sync_room_ics.kiq(room.id)
sync_room_ics.delay(room.id)
logger.info("Queued sync tasks for all eligible rooms")
@@ -86,14 +86,10 @@ def _should_sync(room) -> bool:
MEETING_DEFAULT_DURATION = timedelta(hours=1)
async def create_upcoming_meetings_for_event(
session: AsyncSession, event, create_window, room_id, room
):
async def create_upcoming_meetings_for_event(event, create_window, room_id, room):
if event.start_time <= create_window:
return
existing_meeting = await meetings_controller.get_by_calendar_event(
session, event.id
)
existing_meeting = await meetings_controller.get_by_calendar_event(event.id)
if existing_meeting:
return
@@ -116,7 +112,6 @@ async def create_upcoming_meetings_for_event(
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
meeting = await meetings_controller.create(
session,
id=whereby_meeting["meetingId"],
room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"],
@@ -147,9 +142,9 @@ async def create_upcoming_meetings_for_event(
)
@taskiq_broker.task
@with_session
async def create_upcoming_meetings(session: AsyncSession):
@shared_task
@asynctask
async def create_upcoming_meetings():
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
if not lock.acquired:
logger.warning(
@@ -160,20 +155,19 @@ async def create_upcoming_meetings(session: AsyncSession):
try:
logger.info("Starting creation of upcoming meetings")
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
ics_enabled_rooms = await rooms_controller.get_ics_enabled()
now = datetime.now(timezone.utc)
create_window = now - timedelta(minutes=6)
for room in ics_enabled_rooms:
events = await calendar_events_controller.get_upcoming(
session,
room.id,
minutes_ahead=7,
)
for event in events:
await create_upcoming_meetings_for_event(
session, event, create_window, room.id, room
event, create_window, room.id, room
)
logger.info("Completed pre-creation check for upcoming meetings")

View File

@@ -6,22 +6,22 @@ from urllib.parse import unquote
import av
import boto3
import structlog
from celery import shared_task
from celery.utils.log import get_task_logger
from pydantic import ValidationError
from redis.exceptions import LockError
from sqlalchemy.ext.asyncio import AsyncSession
from reflector.db.meetings import meetings_controller
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
from reflector.pipelines.main_live_pipeline import asynctask
from reflector.redis_cache import get_redis_client
from reflector.settings import settings
from reflector.whereby import get_room_sessions
from reflector.worker.app import taskiq_broker
from reflector.worker.session_decorator import with_session
logger = structlog.get_logger(__name__)
logger = structlog.wrap_logger(get_task_logger(__name__))
def parse_datetime_with_timezone(iso_string: str) -> datetime:
@@ -32,8 +32,8 @@ def parse_datetime_with_timezone(iso_string: str) -> datetime:
return dt
@taskiq_broker.task
async def process_messages():
@shared_task
def process_messages():
queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL
if not queue_url:
logger.warning("No process recording queue url")
@@ -64,7 +64,7 @@ async def process_messages():
if record["eventName"].startswith("ObjectCreated"):
bucket = record["s3"]["bucket"]["name"]
key = unquote(record["s3"]["object"]["key"])
await process_recording.kiq(bucket, key)
process_recording.delay(bucket, key)
sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=receipt_handle)
logger.info("Processed and deleted message: %s", message)
@@ -73,40 +73,32 @@ async def process_messages():
logger.error("process_messages", error=str(e))
@taskiq_broker.task
@with_session
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
@shared_task
@asynctask
async def process_recording(bucket_name: str, object_key: str):
logger.info("Processing recording: %s/%s", bucket_name, object_key)
# extract a guid and a datetime from the object key
room_name = f"/{object_key[:36]}"
recorded_at = parse_datetime_with_timezone(object_key[37:57])
meeting = await meetings_controller.get_by_room_name(session, room_name)
if not meeting:
logger.warning("Room not found, may be deleted ?", room_name=room_name)
return
meeting = await meetings_controller.get_by_room_name(room_name)
room = await rooms_controller.get_by_id(meeting.room_id)
room = await rooms_controller.get_by_id(session, meeting.room_id)
recording = await recordings_controller.get_by_object_key(
session, bucket_name, object_key
)
recording = await recordings_controller.get_by_object_key(bucket_name, object_key)
if not recording:
recording = await recordings_controller.create(
session,
Recording(
bucket_name=bucket_name,
object_key=object_key,
recorded_at=recorded_at,
meeting_id=meeting.id,
),
)
)
transcript = await transcripts_controller.get_by_recording_id(session, recording.id)
transcript = await transcripts_controller.get_by_recording_id(recording.id)
if transcript:
await transcripts_controller.update(
session,
transcript,
{
"topics": [],
@@ -114,7 +106,6 @@ async def process_recording(session: AsyncSession, bucket_name: str, object_key:
)
else:
transcript = await transcripts_controller.add(
session,
"",
source_kind=SourceKind.ROOM,
source_language="en",
@@ -150,14 +141,14 @@ async def process_recording(session: AsyncSession, bucket_name: str, object_key:
finally:
container.close()
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
await transcripts_controller.update(transcript, {"status": "uploaded"})
await task_pipeline_file_process.kiq(transcript_id=transcript.id)
task_pipeline_file_process.delay(transcript_id=transcript.id)
@taskiq_broker.task
@with_session
async def process_meetings(session: AsyncSession):
@shared_task
@asynctask
async def process_meetings():
"""
Checks which meetings are still active and deactivates those that have ended.
@@ -174,7 +165,7 @@ async def process_meetings(session: AsyncSession):
process the same meeting simultaneously.
"""
logger.info("Processing meetings")
meetings = await meetings_controller.get_all_active(session)
meetings = await meetings_controller.get_all_active()
current_time = datetime.now(timezone.utc)
redis_client = get_redis_client()
processed_count = 0
@@ -227,9 +218,7 @@ async def process_meetings(session: AsyncSession):
logger_.debug("Meeting not yet started, keep it")
if should_deactivate:
await meetings_controller.update_meeting(
session, meeting.id, is_active=False
)
await meetings_controller.update_meeting(meeting.id, is_active=False)
logger_.info("Meeting is deactivated")
processed_count += 1
@@ -249,9 +238,9 @@ async def process_meetings(session: AsyncSession):
)
@taskiq_broker.task
@with_session
async def reprocess_failed_recordings(session: AsyncSession):
@shared_task
@asynctask
async def reprocess_failed_recordings():
"""
Find recordings in the S3 bucket and check if they have proper transcriptions.
If not, requeue them for processing.
@@ -282,30 +271,28 @@ async def reprocess_failed_recordings(session: AsyncSession):
continue
recording = await recordings_controller.get_by_object_key(
session, bucket_name, object_key
bucket_name, object_key
)
if not recording:
logger.info(f"Queueing recording for processing: {object_key}")
await process_recording.kiq(bucket_name, object_key)
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
continue
transcript = None
try:
transcript = await transcripts_controller.get_by_recording_id(
session, recording.id
recording.id
)
except ValidationError:
await transcripts_controller.remove_by_recording_id(
session, recording.id
)
await transcripts_controller.remove_by_recording_id(recording.id)
logger.warning(
f"Removed invalid transcript for recording: {recording.id}"
)
if transcript is None or transcript.status == "error":
logger.info(f"Queueing recording for processing: {object_key}")
await process_recording.kiq(bucket_name, object_key)
process_recording.delay(bucket_name, object_key)
reprocessed_count += 1
except Exception as e:

View File

@@ -1,109 +0,0 @@
"""
Session management decorator for async worker tasks.
This decorator ensures that all worker tasks have a properly managed database session
that stays open for the entire duration of the task execution.
"""
import functools
from typing import Any, Callable, TypeVar
from reflector.db import get_session_context
from reflector.db.transcripts import transcripts_controller
from reflector.logger import logger
F = TypeVar("F", bound=Callable[..., Any])
def with_session(func: F) -> F:
"""
Decorator that provides an AsyncSession as the first argument to the decorated function.
This should be used with TaskIQ tasks to ensure proper session management
throughout the task execution.
Example:
@taskiq_broker.task
@with_session
async def my_task(session: AsyncSession, arg1: str, arg2: int):
# session is automatically provided and managed
result = await some_controller.get_by_id(session, arg1)
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
async with get_session_context() as session:
# Pass session as first argument to the decorated function
return await func(session, *args, **kwargs)
return wrapper
def with_session_and_transcript(func: F) -> F:
"""
Decorator that provides both an AsyncSession and a Transcript to the decorated function.
This decorator:
1. Extracts transcript_id from kwargs
2. Creates and manages a database session
3. Fetches the transcript using the session
4. Creates an enhanced logger with Celery task context
5. Passes session, transcript, and logger to the decorated function
This should be used with TaskIQ tasks.
Example:
@taskiq_broker.task
@with_session_and_transcript
async def my_task(session: AsyncSession, transcript: Transcript, logger: Logger, arg1: str):
# session, transcript, and logger are automatically provided
room = await rooms_controller.get_by_id(session, transcript.room_id)
...
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
transcript_id = kwargs.pop("transcript_id", None)
if not transcript_id:
raise ValueError(
"transcript_id is required for @with_session_and_transcript"
)
async with get_session_context() as session:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
if not transcript:
raise Exception(f"Transcript {transcript_id} not found")
tlogger = logger.bind(transcript_id=transcript.id)
try:
return await func(
session, transcript=transcript, logger=tlogger, *args, **kwargs
)
except Exception:
tlogger.exception("Error in task execution")
raise
return wrapper
def catch_exception(func: F) -> F:
"""
Decorator that catches exceptions and logs them using structlog.
"""
@functools.wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception:
logger.exception(
"Exception caught in function execution",
func=func.__name__,
args=args,
kwargs=kwargs,
)
raise
return wrapper

View File

@@ -1,76 +0,0 @@
"""
TaskIQ broker configuration for Reflector.
This module provides a production-ready TaskIQ broker configuration that handles
both test and production environments correctly. It includes retry middleware
for 1:1 parity with Celery and proper logging setup.
"""
import os
import structlog
from taskiq import InMemoryBroker
from taskiq.middlewares import SimpleRetryMiddleware
from taskiq_redis import RedisAsyncResultBackend, RedisStreamBroker
from reflector.settings import settings
logger = structlog.get_logger(__name__)
def create_taskiq_broker():
"""
Create and configure the TaskIQ broker based on environment.
Returns:
Configured TaskIQ broker instance with appropriate backend and middleware.
"""
env = os.environ.get("ENVIRONMENT")
if env == "pytest":
# Test environment: Use InMemoryBroker with immediate execution
logger.info("Configuring TaskIQ InMemoryBroker for test environment")
broker = InMemoryBroker(await_inplace=True)
else:
# Production environment: Use Redis broker with result backend
logger.info(
"Configuring TaskIQ RedisStreamBroker for production environment",
redis_url=settings.CELERY_BROKER_URL,
)
# Configure Redis result backend
result_backend = RedisAsyncResultBackend(
redis_url=settings.CELERY_BROKER_URL,
result_ex_time=86400, # Results expire after 24 hours
)
# Configure Redis stream broker
broker = RedisStreamBroker(
url=settings.CELERY_BROKER_URL,
stream_name="taskiq:stream", # Custom stream name for clarity
consumer_group="taskiq:workers", # Consumer group for load balancing
).with_result_backend(result_backend)
# Add retry middleware for production parity with Celery
# This provides automatic retries on task failures
retry_middleware = SimpleRetryMiddleware(
default_retry_count=3, # Match Celery's default retry behavior
)
broker.add_middlewares(retry_middleware)
logger.info(
"TaskIQ broker configured successfully",
broker_type=type(broker).__name__,
has_result_backend=hasattr(broker, "_result_backend"),
middleware_count=len(broker.middlewares),
)
return broker
# Create the global broker instance
taskiq_broker = create_taskiq_broker()
# Export the broker for use in task definitions
__all__ = ["taskiq_broker"]

View File

@@ -8,16 +8,18 @@ from datetime import datetime, timezone
import httpx
import structlog
from sqlalchemy.ext.asyncio import AsyncSession
from celery import shared_task
from celery.utils.log import get_task_logger
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.transcripts import transcripts_controller
from reflector.pipelines.main_live_pipeline import asynctask
from reflector.settings import settings
from reflector.utils.webvtt import topics_to_webvtt
from reflector.worker.app import taskiq_broker
from reflector.worker.session_decorator import with_session
logger = structlog.get_logger(__name__)
logger = structlog.wrap_logger(get_task_logger(__name__))
def generate_webhook_signature(payload: bytes, secret: str, timestamp: str) -> str:
@@ -31,29 +33,34 @@ def generate_webhook_signature(payload: bytes, secret: str, timestamp: str) -> s
return hmac_obj.hexdigest()
@taskiq_broker.task
@with_session
async def send_transcript_webhook_taskiq(
@shared_task(
bind=True,
max_retries=30,
default_retry_delay=60,
retry_backoff=True,
retry_backoff_max=3600, # Max 1 hour between retries
)
@asynctask
async def send_transcript_webhook(
self,
transcript_id: str,
room_id: str,
event_id: str,
session: AsyncSession,
):
retry_count = 0
log = logger.bind(
transcript_id=transcript_id,
room_id=room_id,
retry_count=retry_count,
retry_count=self.request.retries,
)
try:
transcript = await transcripts_controller.get_by_id(session, transcript_id)
# Fetch transcript and room
transcript = await transcripts_controller.get_by_id(transcript_id)
if not transcript:
log.error("Transcript not found, skipping webhook")
return
room = await rooms_controller.get_by_id(session, room_id)
room = await rooms_controller.get_by_id(room_id)
if not room:
log.error("Room not found, skipping webhook")
return
@@ -62,9 +69,11 @@ async def send_transcript_webhook_taskiq(
log.info("No webhook URL configured for room, skipping")
return
# Generate WebVTT content from topics
topics_data = []
if transcript.topics:
# Build topics data with diarized content per topic
for topic in transcript.topics:
topic_webvtt = topics_to_webvtt([topic]) if topic.words else ""
topics_data.append(
@@ -77,6 +86,19 @@ async def send_transcript_webhook_taskiq(
}
)
# Fetch meeting and calendar event if they exist
calendar_event = None
try:
if transcript.meeting_id:
meeting = await meetings_controller.get_by_id(transcript.meeting_id)
if meeting and meeting.calendar_event_id:
calendar_event = await calendar_events_controller.get_by_id(
meeting.calendar_event_id
)
except Exception as e:
logger.error("Error fetching meeting or calendar event", error=str(e))
# Build webhook payload
frontend_url = f"{settings.UI_BASE_URL}/transcripts/{transcript.id}"
participants = [
{"id": p.id, "name": p.name, "speaker": p.speaker}
@@ -108,14 +130,43 @@ async def send_transcript_webhook_taskiq(
},
}
# Always include calendar_event field, even if no event is present
payload_data["calendar_event"] = {}
# Add calendar event data if present
if calendar_event:
calendar_data = {
"id": calendar_event.id,
"ics_uid": calendar_event.ics_uid,
"title": calendar_event.title,
"start_time": calendar_event.start_time.isoformat()
if calendar_event.start_time
else None,
"end_time": calendar_event.end_time.isoformat()
if calendar_event.end_time
else None,
}
# Add optional fields only if they exist
if calendar_event.description:
calendar_data["description"] = calendar_event.description
if calendar_event.location:
calendar_data["location"] = calendar_event.location
if calendar_event.attendees:
calendar_data["attendees"] = calendar_event.attendees
payload_data["calendar_event"] = calendar_data
# Convert to JSON
payload_json = json.dumps(payload_data, separators=(",", ":"))
payload_bytes = payload_json.encode("utf-8")
# Generate signature if secret is configured
headers = {
"Content-Type": "application/json",
"User-Agent": "Reflector-Webhook/1.0",
"X-Webhook-Event": "transcript.completed",
"X-Webhook-Retry": str(retry_count),
"X-Webhook-Retry": str(self.request.retries),
}
if room.webhook_secret:
@@ -125,6 +176,7 @@ async def send_transcript_webhook_taskiq(
)
headers["X-Webhook-Signature"] = f"t={timestamp},v1={signature}"
# Send webhook with timeout
async with httpx.AsyncClient(timeout=30.0) as client:
log.info(
"Sending webhook",
@@ -150,22 +202,26 @@ async def send_transcript_webhook_taskiq(
log.error(
"Webhook failed with HTTP error",
status_code=e.response.status_code,
response_text=e.response.text[:500],
response_text=e.response.text[:500], # First 500 chars
)
# Don't retry on client errors (4xx)
if 400 <= e.response.status_code < 500:
log.error("Client error, not retrying")
return
raise
# Retry on server errors (5xx)
raise self.retry(exc=e)
except (httpx.ConnectError, httpx.TimeoutException) as e:
# Retry on network errors
log.error("Webhook failed with connection error", error=str(e))
raise
raise self.retry(exc=e)
except Exception as e:
# Retry on unexpected errors
log.exception("Unexpected error in webhook task", error=str(e))
raise
raise self.retry(exc=e)
async def test_webhook(room_id: str) -> dict:

View File

@@ -65,7 +65,12 @@ class WebsocketManager:
self.tasks: dict = {}
self.pubsub_client = pubsub_client
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
async def add_user_to_room(
self, room_id: str, websocket: WebSocket, subprotocol: str | None = None
) -> None:
if subprotocol:
await websocket.accept(subprotocol=subprotocol)
else:
await websocket.accept()
if room_id in self.rooms:

View File

@@ -1,86 +0,0 @@
# TaskIQ Migration Implementation Plan
## Phase 1: Core Infrastructure Setup
### 1.1 Create TaskIQ Broker Configuration
- [ ] Create `reflector/worker/taskiq_broker.py` with broker setup
- [ ] Configure Redis broker with proper connection pooling
- [ ] Add retry middleware for 1:1 parity with Celery
- [ ] Setup test/production environment detection
### 1.2 Session Management Utilities
- [ ] Create `get_session_context()` function in `reflector/db.py`
- [ ] Ensure `@with_session` decorator works with TaskIQ
- [ ] Verify test mocking works with new session approach
## Phase 2: Simple Task Migration (Start Small)
### 2.1 Migrate Single Tasks First
- [ ] `reflector/worker/cleanup.py` - 1 task, simple logic
- [ ] `reflector/worker/webhook.py` - 1 task with retry logic
- [ ] Test each migrated task individually
### 2.2 Create Dual-Mode Tasks
- [ ] Keep Celery version with `@shared_task`
- [ ] Add TaskIQ version without `@asynctask`
- [ ] Use feature flag to switch between versions
## Phase 3: Complex Pipeline Migration
### 3.1 File Processing Pipeline
- [ ] Migrate `task_pipeline_file_process` completely
- [ ] Handle all sub-tasks in the pipeline
- [ ] Migrate chain/group/chord patterns to TaskIQ
### 3.2 Live Processing Pipeline
- [ ] Migrate all 10 tasks in `main_live_pipeline.py`
- [ ] Convert complex chord patterns
- [ ] Ensure WebSocket notifications still work
## Phase 4: Scheduled Tasks Migration
### 4.1 Convert Celery Beat to TaskIQ Scheduler
- [ ] Create `reflector/worker/scheduler.py`
- [ ] Migrate all scheduled tasks
- [ ] Setup TaskIQ scheduler service
## Phase 5: Testing Infrastructure
### 5.1 Update Test Fixtures
- [ ] Create TaskIQ test fixtures in `conftest.py`
- [ ] Ensure dual-mode testing (both Celery and TaskIQ)
- [ ] Verify all existing tests pass
### 5.2 Migration-Specific Tests
- [ ] Test session management across tasks
- [ ] Test retry logic parity
- [ ] Test scheduled task execution
## Phase 6: Deployment & Monitoring
### 6.1 Update Deployment Scripts
- [ ] Update Docker configurations
- [ ] Create TaskIQ worker startup scripts
- [ ] Setup health checks for TaskIQ
### 6.2 Monitoring Setup
- [ ] Create TaskIQ metrics collection
- [ ] Setup alerting for failed tasks
- [ ] Create migration rollback plan
## Execution Order
1. **Week 1**: Phase 1 + Phase 2.1
2. **Week 2**: Phase 2.2 + Phase 3.1
3. **Week 3**: Phase 3.2 + Phase 4
4. **Week 4**: Phase 5
5. **Week 5**: Phase 6 + Testing
6. **Week 6**: Cutover + Monitoring
## Success Metrics
- All tests passing with TaskIQ
- No performance degradation
- Successful parallel running for 1 week
- Zero data loss during migration
- Rollback tested and documented

View File

@@ -1,21 +1,11 @@
import asyncio
import os
import sys
from contextlib import asynccontextmanager
from tempfile import NamedTemporaryFile
from unittest.mock import patch
import pytest
@pytest.fixture(scope="session")
def event_loop():
if sys.platform.startswith("win") and sys.version_info[:2] >= (3, 8):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session", autouse=True)
def settings_configuration():
# theses settings are linked to monadical for pytest-recording
@@ -46,6 +36,7 @@ def docker_compose_file(pytestconfig):
@pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive."""
port = docker_services.port_for("postgres_test", 5432)
def is_responsive():
@@ -66,6 +57,7 @@ def postgres_service(docker_ip, docker_services):
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
# Return connection parameters
return {
"host": docker_ip,
"port": port,
@@ -75,27 +67,20 @@ def postgres_service(docker_ip, docker_services):
}
@pytest.fixture(scope="session")
def _database_url(postgres_service):
db_config = postgres_service
DATABASE_URL = (
f"postgresql+asyncpg://{db_config['user']}:{db_config['password']}"
f"@{db_config['host']}:{db_config['port']}/{db_config['dbname']}"
)
@pytest.fixture(scope="function", autouse=True)
@pytest.mark.asyncio
async def setup_database(postgres_service):
from reflector.db import engine, metadata, get_database # noqa
# Override settings
from reflector.settings import settings
metadata.drop_all(bind=engine)
metadata.create_all(bind=engine)
database = get_database()
settings.DATABASE_URL = DATABASE_URL
return DATABASE_URL
@pytest.fixture(scope="session")
def init_database():
from reflector.db import Base
return Base.metadata.create_all
try:
await database.connect()
yield
finally:
await database.disconnect()
@pytest.fixture
@@ -321,96 +306,30 @@ async def dummy_storage():
yield
# from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
# from sqlalchemy.orm import sessionmaker
@pytest.fixture(scope="session")
def celery_enable_logging():
return True
# @pytest.fixture()
# async def db_connection(sqla_engine):
# connection = await sqla_engine.connect()
# try:
# yield connection
# finally:
# await connection.close()
@pytest.fixture(scope="session")
def celery_config():
with NamedTemporaryFile() as f:
yield {
"broker_url": "memory://",
"result_backend": f"db+sqlite:///{f.name}",
}
# @pytest.fixture()
# async def db_session_maker(db_connection):
# Session = async_sessionmaker(
# db_connection,
# expire_on_commit=False,
# class_=AsyncSession,
# )
# yield Session
# @pytest.fixture()
# async def db_session(db_session_maker, db_connection):
# """
# Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
# after the test completes.
# """
# session = db_session_maker(
# bind=db_connection,
# join_transaction_mode="create_savepoint",
# )
# try:
# yield session
# finally:
# await session.close()
# @pytest.fixture(autouse=True)
# async def ensure_db_session_in_app(db_connection, db_session_maker):
# async def mock_get_session():
# session = db_session_maker(
# bind=db_connection, join_transaction_mode="create_savepoint"
# )
# try:
# yield session
# finally:
# await session.close()
# with patch("reflector.db._get_session", side_effect=mock_get_session):
# yield
# @pytest.fixture()
# async def db_session(sqla_engine):
# """
# Fixture that returns a SQLAlchemy session with a SAVEPOINT, and the rollback to it
# after the test completes.
# """
# from sqlalchemy.ext.asyncio import AsyncSession
# from sqlalchemy.orm import sessionmaker
# connection = await sqla_engine.connect()
# trans = await connection.begin()
# Session = sessionmaker(connection, expire_on_commit=False, class_=AsyncSession)
# session = Session()
# try:
# yield session
# finally:
# await session.close()
# await trans.rollback()
# await connection.close()
@pytest.fixture(autouse=True)
async def ensure_db_session_in_app(db_session):
async def mock_get_session():
yield db_session
with patch("reflector.db._get_session", side_effect=mock_get_session):
yield
@pytest.fixture(scope="session")
def celery_includes():
return [
"reflector.pipelines.main_live_pipeline",
"reflector.pipelines.main_file_pipeline",
]
@pytest.fixture
async def client(db_session):
async def client():
from httpx import AsyncClient
from reflector.app import app
@@ -419,6 +338,166 @@ async def client(db_session):
yield ac
@pytest.fixture(autouse=True)
async def ws_manager_in_memory(monkeypatch):
"""Replace Redis-based WS manager with an in-memory implementation for tests."""
import asyncio
import json
from reflector.ws_manager import WebsocketManager
class _InMemorySubscriber:
def __init__(self, queue: asyncio.Queue):
self.queue = queue
async def get_message(self, ignore_subscribe_messages: bool = True):
try:
return await asyncio.wait_for(self.queue.get(), timeout=0.05)
except Exception:
return None
class InMemoryPubSubManager:
def __init__(self):
self.queues: dict[str, asyncio.Queue] = {}
self.connected = False
async def connect(self) -> None:
self.connected = True
async def disconnect(self) -> None:
self.connected = False
async def send_json(self, room_id: str, message: dict) -> None:
if room_id not in self.queues:
self.queues[room_id] = asyncio.Queue()
payload = json.dumps(message).encode("utf-8")
await self.queues[room_id].put(
{"channel": room_id.encode("utf-8"), "data": payload}
)
async def subscribe(self, room_id: str):
if room_id not in self.queues:
self.queues[room_id] = asyncio.Queue()
return _InMemorySubscriber(self.queues[room_id])
async def unsubscribe(self, room_id: str) -> None:
# keep queue for potential later resubscribe within same test
pass
pubsub = InMemoryPubSubManager()
ws_manager = WebsocketManager(pubsub_client=pubsub)
def _get_ws_manager():
return ws_manager
# Patch all places that imported get_ws_manager at import time
monkeypatch.setattr("reflector.ws_manager.get_ws_manager", _get_ws_manager)
monkeypatch.setattr(
"reflector.pipelines.main_live_pipeline.get_ws_manager", _get_ws_manager
)
monkeypatch.setattr(
"reflector.views.transcripts_websocket.get_ws_manager", _get_ws_manager
)
monkeypatch.setattr(
"reflector.views.user_websocket.get_ws_manager", _get_ws_manager
)
monkeypatch.setattr("reflector.views.transcripts.get_ws_manager", _get_ws_manager)
# Websocket auth: avoid OAuth2 on websocket dependencies; allow anonymous
import reflector.auth as auth
# Ensure FastAPI uses our override for routes that captured the original callable
from reflector.app import app as fastapi_app
try:
fastapi_app.dependency_overrides[auth.current_user_optional] = lambda: None
except Exception:
pass
# Stub Redis cache used by profanity filter to avoid external Redis
from reflector import redis_cache as rc
class _FakeRedis:
def __init__(self):
self._data = {}
def get(self, key):
value = self._data.get(key)
if value is None:
return None
if isinstance(value, bytes):
return value
return str(value).encode("utf-8")
def setex(self, key, duration, value):
# ignore duration for tests
if isinstance(value, bytes):
self._data[key] = value
else:
self._data[key] = str(value).encode("utf-8")
fake_redises: dict[int, _FakeRedis] = {}
def _get_redis_client(db=0):
if db not in fake_redises:
fake_redises[db] = _FakeRedis()
return fake_redises[db]
monkeypatch.setattr(rc, "get_redis_client", _get_redis_client)
yield
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client():
async with authenticated_client_ctx():
yield
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client2():
async with authenticated_client2_ctx():
yield
@asynccontextmanager
async def authenticated_client_ctx():
from reflector.app import app
from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {
"sub": "randomuserid",
"email": "test@mail.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "randomuserid",
"email": "test@mail.com",
}
yield
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@asynccontextmanager
async def authenticated_client2_ctx():
from reflector.app import app
from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {
"sub": "randomuserid2",
"email": "test@mail.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "randomuserid2",
"email": "test@mail.com",
}
yield
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@pytest.fixture(scope="session")
def fake_mp3_upload():
with patch(
@@ -429,19 +508,7 @@ def fake_mp3_upload():
@pytest.fixture
async def taskiq_broker():
from reflector.worker.app import taskiq_broker
await taskiq_broker.startup()
try:
yield taskiq_broker
finally:
await taskiq_broker.shutdown()
@pytest.fixture
async def fake_transcript_with_topics(tmpdir, client, db_session):
async def fake_transcript_with_topics(tmpdir, client):
import shutil
from pathlib import Path
@@ -457,10 +524,10 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
assert response.status_code == 200
tid = response.json()["id"]
transcript = await transcripts_controller.get_by_id(db_session, tid)
transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None
await transcripts_controller.update(db_session, transcript, {"status": "ended"})
await transcripts_controller.update(transcript, {"status": "ended"})
# manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename
@@ -470,7 +537,6 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
# create some topics
await transcripts_controller.upsert_topic(
db_session,
transcript,
TranscriptTopic(
title="Topic 1",
@@ -484,7 +550,6 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
),
)
await transcripts_controller.upsert_topic(
db_session,
transcript,
TranscriptTopic(
title="Topic 2",

View File

@@ -1,5 +1,5 @@
import os
from unittest.mock import patch
from unittest.mock import AsyncMock, patch
import pytest
@@ -8,7 +8,7 @@ from reflector.services.ics_sync import ICSSyncService
@pytest.mark.asyncio
async def test_attendee_parsing_bug(db_session):
async def test_attendee_parsing_bug():
"""
Test that reproduces the attendee parsing bug where a string with comma-separated
emails gets parsed as individual characters instead of separate email addresses.
@@ -16,8 +16,8 @@ async def test_attendee_parsing_bug(db_session):
The bug manifests as getting 29 attendees with emails like "M", "A", "I", etc.
instead of properly parsed email addresses.
"""
# Create a test room
room = await rooms_controller.add(
db_session,
name="test-room",
user_id="test-user",
zulip_auto_post=False,
@@ -31,8 +31,8 @@ async def test_attendee_parsing_bug(db_session):
ics_url="http://test.com/test.ics",
ics_enabled=True,
)
await db_session.flush()
# Read the test ICS file that reproduces the bug and update it with current time
from datetime import datetime, timedelta, timezone
test_ics_path = os.path.join(
@@ -41,26 +41,30 @@ async def test_attendee_parsing_bug(db_session):
with open(test_ics_path, "r") as f:
ics_content = f.read()
# Replace the dates with current time + 1 hour to ensure it's within the 24h window
now = datetime.now(timezone.utc)
future_time = now + timedelta(hours=1)
end_time = future_time + timedelta(hours=1)
# Format dates for ICS format
dtstart = future_time.strftime("%Y%m%dT%H%M%SZ")
dtend = end_time.strftime("%Y%m%dT%H%M%SZ")
dtstamp = now.strftime("%Y%m%dT%H%M%SZ")
# Update the ICS content with current dates
ics_content = ics_content.replace("20250910T180000Z", dtstart)
ics_content = ics_content.replace("20250910T190000Z", dtend)
ics_content = ics_content.replace("20250910T174000Z", dtstamp)
# Create sync service and mock the fetch
sync_service = ICSSyncService()
from unittest.mock import AsyncMock
with patch.object(
sync_service.fetch_service, "fetch_ics", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.return_value = ics_content
# Debug: Parse the ICS content directly to examine attendee parsing
calendar = sync_service.fetch_service.parse_ics(ics_content)
from reflector.settings import settings
@@ -76,23 +80,113 @@ async def test_attendee_parsing_bug(db_session):
print(f"Total events in calendar: {total_events}")
print(f"Events matching room: {len(events)}")
result = await sync_service.sync_room_calendar(db_session, room)
# Perform the sync
result = await sync_service.sync_room_calendar(room)
# Check that the sync succeeded
assert result.get("status") == "success"
assert result.get("events_found", 0) >= 0
assert result.get("events_found", 0) >= 0 # Allow for debugging
# We already have the matching events from the debug code above
assert len(events) == 1
event = events[0]
# This is where the bug manifests - check the attendees
attendees = event["attendees"]
print(f"Number of attendees: {len(attendees)}")
# Print attendee info for debugging
print(f"Number of attendees found: {len(attendees)}")
for i, attendee in enumerate(attendees):
print(f"Attendee {i}: {attendee}")
print(
f"Attendee {i}: email='{attendee.get('email')}', name='{attendee.get('name')}'"
)
assert len(attendees) == 30, f"Expected 30 attendees, got {len(attendees)}"
# With the fix, we should now get properly parsed email addresses
# Check that no single characters are parsed as emails
single_char_emails = [
att for att in attendees if att.get("email") and len(att["email"]) == 1
]
assert attendees[0]["email"] == "alice@example.com"
assert attendees[1]["email"] == "bob@example.com"
assert attendees[2]["email"] == "charlie@example.com"
assert any(att["email"] == "organizer@example.com" for att in attendees)
if single_char_emails:
print(
f"BUG DETECTED: Found {len(single_char_emails)} single-character emails:"
)
for att in single_char_emails:
print(f" - '{att['email']}'")
# Should have attendees but not single-character emails
assert len(attendees) > 0
assert (
len(single_char_emails) == 0
), f"Found {len(single_char_emails)} single-character emails, parsing is still buggy"
# Check that all emails are valid (contain @ symbol)
valid_emails = [
att for att in attendees if att.get("email") and "@" in att["email"]
]
assert len(valid_emails) == len(
attendees
), "Some attendees don't have valid email addresses"
# We expect around 29 attendees (28 from the comma-separated list + 1 organizer)
assert (
len(attendees) >= 25
), f"Expected around 29 attendees, got {len(attendees)}"
@pytest.mark.asyncio
async def test_correct_attendee_parsing():
"""
Test what correct attendee parsing should look like.
"""
from datetime import datetime, timezone
from icalendar import Event
from reflector.services.ics_sync import ICSFetchService
service = ICSFetchService()
# Create a properly formatted event with multiple attendees
event = Event()
event.add("uid", "test-correct-attendees")
event.add("summary", "Test Meeting")
event.add("location", "http://test.com/test")
event.add("dtstart", datetime.now(timezone.utc))
event.add("dtend", datetime.now(timezone.utc))
# Add attendees the correct way (separate ATTENDEE lines)
event.add("attendee", "mailto:alice@example.com", parameters={"CN": "Alice"})
event.add("attendee", "mailto:bob@example.com", parameters={"CN": "Bob"})
event.add("attendee", "mailto:charlie@example.com", parameters={"CN": "Charlie"})
event.add(
"organizer", "mailto:organizer@example.com", parameters={"CN": "Organizer"}
)
# Parse the event
result = service._parse_event(event)
assert result is not None
attendees = result["attendees"]
# Should have 4 attendees (3 attendees + 1 organizer)
assert len(attendees) == 4
# Check that all emails are valid email addresses
emails = [att["email"] for att in attendees if att.get("email")]
expected_emails = [
"alice@example.com",
"bob@example.com",
"charlie@example.com",
"organizer@example.com",
]
for email in emails:
assert "@" in email, f"Invalid email format: {email}"
assert len(email) > 5, f"Email too short: {email}"
# Check that we have the expected emails
assert "alice@example.com" in emails
assert "bob@example.com" in emails
assert "charlie@example.com" in emails
assert "organizer@example.com" in emails

View File

@@ -11,11 +11,10 @@ from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio
async def test_calendar_event_create(db_session):
async def test_calendar_event_create():
"""Test creating a calendar event."""
# Create a room first
room = await rooms_controller.add(
db_session,
name="test-room",
user_id="test-user",
zulip_auto_post=False,
@@ -45,7 +44,7 @@ async def test_calendar_event_create(db_session):
)
# Save event
saved_event = await calendar_events_controller.upsert(db_session, event)
saved_event = await calendar_events_controller.upsert(event)
assert saved_event.ics_uid == "test-event-123"
assert saved_event.title == "Team Meeting"
@@ -54,11 +53,10 @@ async def test_calendar_event_create(db_session):
@pytest.mark.asyncio
async def test_calendar_event_get_by_room(db_session):
async def test_calendar_event_get_by_room():
"""Test getting calendar events for a room."""
# Create room
room = await rooms_controller.add(
db_session,
name="events-room",
user_id="test-user",
zulip_auto_post=False,
@@ -82,10 +80,10 @@ async def test_calendar_event_get_by_room(db_session):
start_time=now + timedelta(hours=i),
end_time=now + timedelta(hours=i + 1),
)
await calendar_events_controller.upsert(db_session, event)
await calendar_events_controller.upsert(event)
# Get events for room
events = await calendar_events_controller.get_by_room(db_session, room.id)
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 3
assert all(e.room_id == room.id for e in events)
@@ -95,11 +93,10 @@ async def test_calendar_event_get_by_room(db_session):
@pytest.mark.asyncio
async def test_calendar_event_get_upcoming(db_session):
async def test_calendar_event_get_upcoming():
"""Test getting upcoming events within time window."""
# Create room
room = await rooms_controller.add(
db_session,
name="upcoming-room",
user_id="test-user",
zulip_auto_post=False,
@@ -123,7 +120,7 @@ async def test_calendar_event_get_upcoming(db_session):
start_time=now - timedelta(hours=2),
end_time=now - timedelta(hours=1),
)
await calendar_events_controller.upsert(db_session, past_event)
await calendar_events_controller.upsert(past_event)
# Upcoming event within 30 minutes
upcoming_event = CalendarEvent(
@@ -133,7 +130,7 @@ async def test_calendar_event_get_upcoming(db_session):
start_time=now + timedelta(minutes=15),
end_time=now + timedelta(minutes=45),
)
await calendar_events_controller.upsert(db_session, upcoming_event)
await calendar_events_controller.upsert(upcoming_event)
# Currently happening event (started 10 minutes ago, ends in 20 minutes)
current_event = CalendarEvent(
@@ -143,7 +140,7 @@ async def test_calendar_event_get_upcoming(db_session):
start_time=now - timedelta(minutes=10),
end_time=now + timedelta(minutes=20),
)
await calendar_events_controller.upsert(db_session, current_event)
await calendar_events_controller.upsert(current_event)
# Future event beyond 30 minutes
future_event = CalendarEvent(
@@ -153,10 +150,10 @@ async def test_calendar_event_get_upcoming(db_session):
start_time=now + timedelta(hours=2),
end_time=now + timedelta(hours=3),
)
await calendar_events_controller.upsert(db_session, future_event)
await calendar_events_controller.upsert(future_event)
# Get upcoming events (default 120 minutes) - should include current, upcoming, and future
upcoming = await calendar_events_controller.get_upcoming(db_session, room.id)
upcoming = await calendar_events_controller.get_upcoming(room.id)
assert len(upcoming) == 3
# Events should be sorted by start_time (current event first, then upcoming, then future)
@@ -166,7 +163,7 @@ async def test_calendar_event_get_upcoming(db_session):
# Get upcoming with custom window
upcoming_extended = await calendar_events_controller.get_upcoming(
db_session, room.id, minutes_ahead=180
room.id, minutes_ahead=180
)
assert len(upcoming_extended) == 3
@@ -177,11 +174,10 @@ async def test_calendar_event_get_upcoming(db_session):
@pytest.mark.asyncio
async def test_calendar_event_get_upcoming_includes_currently_happening(db_session):
async def test_calendar_event_get_upcoming_includes_currently_happening():
"""Test that get_upcoming includes currently happening events but excludes ended events."""
# Create room
room = await rooms_controller.add(
db_session,
name="current-happening-room",
user_id="test-user",
zulip_auto_post=False,
@@ -204,7 +200,7 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
start_time=now - timedelta(hours=2),
end_time=now - timedelta(minutes=30),
)
await calendar_events_controller.upsert(db_session, past_ended_event)
await calendar_events_controller.upsert(past_ended_event)
# Event currently happening (started 10 minutes ago, ends in 20 minutes) - SHOULD be included
currently_happening_event = CalendarEvent(
@@ -214,7 +210,7 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
start_time=now - timedelta(minutes=10),
end_time=now + timedelta(minutes=20),
)
await calendar_events_controller.upsert(db_session, currently_happening_event)
await calendar_events_controller.upsert(currently_happening_event)
# Event starting soon (in 5 minutes) - SHOULD be included
upcoming_soon_event = CalendarEvent(
@@ -224,12 +220,10 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
start_time=now + timedelta(minutes=5),
end_time=now + timedelta(minutes=35),
)
await calendar_events_controller.upsert(db_session, upcoming_soon_event)
await calendar_events_controller.upsert(upcoming_soon_event)
# Get upcoming events
upcoming = await calendar_events_controller.get_upcoming(
db_session, room.id, minutes_ahead=30
)
upcoming = await calendar_events_controller.get_upcoming(room.id, minutes_ahead=30)
# Should only include currently happening and upcoming soon events
assert len(upcoming) == 2
@@ -238,11 +232,10 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
@pytest.mark.asyncio
async def test_calendar_event_upsert(db_session):
async def test_calendar_event_upsert():
"""Test upserting (create/update) calendar events."""
# Create room
room = await rooms_controller.add(
db_session,
name="upsert-room",
user_id="test-user",
zulip_auto_post=False,
@@ -266,30 +259,29 @@ async def test_calendar_event_upsert(db_session):
end_time=now + timedelta(hours=1),
)
created = await calendar_events_controller.upsert(db_session, event)
created = await calendar_events_controller.upsert(event)
assert created.title == "Original Title"
# Update existing event
event.title = "Updated Title"
event.description = "Added description"
updated = await calendar_events_controller.upsert(db_session, event)
updated = await calendar_events_controller.upsert(event)
assert updated.title == "Updated Title"
assert updated.description == "Added description"
assert updated.ics_uid == "upsert-test"
# Verify only one event exists
events = await calendar_events_controller.get_by_room(db_session, room.id)
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1
assert events[0].title == "Updated Title"
@pytest.mark.asyncio
async def test_calendar_event_soft_delete(db_session):
async def test_calendar_event_soft_delete():
"""Test soft deleting events no longer in calendar."""
# Create room
room = await rooms_controller.add(
db_session,
name="delete-room",
user_id="test-user",
zulip_auto_post=False,
@@ -313,36 +305,35 @@ async def test_calendar_event_soft_delete(db_session):
start_time=now + timedelta(hours=i),
end_time=now + timedelta(hours=i + 1),
)
await calendar_events_controller.upsert(db_session, event)
await calendar_events_controller.upsert(event)
# Soft delete events not in current list
current_ids = ["event-0", "event-2"] # Keep events 0 and 2
deleted_count = await calendar_events_controller.soft_delete_missing(
db_session, room.id, current_ids
room.id, current_ids
)
assert deleted_count == 2 # Should delete events 1 and 3
# Get non-deleted events
events = await calendar_events_controller.get_by_room(
db_session, room.id, include_deleted=False
room.id, include_deleted=False
)
assert len(events) == 2
assert {e.ics_uid for e in events} == {"event-0", "event-2"}
# Get all events including deleted
all_events = await calendar_events_controller.get_by_room(
db_session, room.id, include_deleted=True
room.id, include_deleted=True
)
assert len(all_events) == 4
@pytest.mark.asyncio
async def test_calendar_event_past_events_not_deleted(db_session):
async def test_calendar_event_past_events_not_deleted():
"""Test that past events are not soft deleted."""
# Create room
room = await rooms_controller.add(
db_session,
name="past-events-room",
user_id="test-user",
zulip_auto_post=False,
@@ -365,7 +356,7 @@ async def test_calendar_event_past_events_not_deleted(db_session):
start_time=now - timedelta(hours=2),
end_time=now - timedelta(hours=1),
)
await calendar_events_controller.upsert(db_session, past_event)
await calendar_events_controller.upsert(past_event)
# Create future event
future_event = CalendarEvent(
@@ -375,29 +366,26 @@ async def test_calendar_event_past_events_not_deleted(db_session):
start_time=now + timedelta(hours=1),
end_time=now + timedelta(hours=2),
)
await calendar_events_controller.upsert(db_session, future_event)
await calendar_events_controller.upsert(future_event)
# Try to soft delete all events (only future should be deleted)
deleted_count = await calendar_events_controller.soft_delete_missing(
db_session, room.id, []
)
deleted_count = await calendar_events_controller.soft_delete_missing(room.id, [])
assert deleted_count == 1 # Only future event deleted
# Verify past event still exists
events = await calendar_events_controller.get_by_room(
db_session, room.id, include_deleted=False
room.id, include_deleted=False
)
assert len(events) == 1
assert events[0].ics_uid == "past-event"
@pytest.mark.asyncio
async def test_calendar_event_with_raw_ics_data(db_session):
async def test_calendar_event_with_raw_ics_data():
"""Test storing raw ICS data with calendar event."""
# Create room
room = await rooms_controller.add(
db_session,
name="raw-ics-room",
user_id="test-user",
zulip_auto_post=False,
@@ -426,13 +414,11 @@ END:VEVENT"""
ics_raw_data=raw_ics,
)
saved = await calendar_events_controller.upsert(db_session, event)
saved = await calendar_events_controller.upsert(event)
assert saved.ics_raw_data == raw_ics
# Retrieve and verify
retrieved = await calendar_events_controller.get_by_ics_uid(
db_session, room.id, "test-raw-123"
)
retrieved = await calendar_events_controller.get_by_ics_uid(room.id, "test-raw-123")
assert retrieved is not None
assert retrieved.ics_raw_data == raw_ics

View File

@@ -2,32 +2,26 @@ from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import delete, insert, select, update
from reflector.db.base import (
MeetingConsentModel,
MeetingModel,
RecordingModel,
TranscriptModel,
)
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.transcripts import SourceKind, transcripts_controller
from reflector.worker.cleanup import cleanup_old_public_data
@pytest.mark.asyncio
async def test_cleanup_old_public_data_skips_when_not_public(db_session):
async def test_cleanup_old_public_data_skips_when_not_public():
"""Test that cleanup is skipped when PUBLIC_MODE is False."""
with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = False
result = await cleanup_old_public_data(db_session)
result = await cleanup_old_public_data()
# Should return early without doing anything
assert result is None
@pytest.mark.asyncio
async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(db_session):
async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts():
"""Test that old anonymous transcripts are deleted."""
# Create old and new anonymous transcripts
old_date = datetime.now(timezone.utc) - timedelta(days=8)
@@ -35,23 +29,22 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(db_sess
# Create old anonymous transcript (should be deleted)
old_transcript = await transcripts_controller.add(
db_session,
name="Old Anonymous Transcript",
source_kind=SourceKind.FILE,
user_id=None, # Anonymous
)
# Manually update created_at to be old
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == old_transcript.id)
from reflector.db import get_database
from reflector.db.transcripts import transcripts
await get_database().execute(
transcripts.update()
.where(transcripts.c.id == old_transcript.id)
.values(created_at=old_date)
)
await db_session.commit()
# Create new anonymous transcript (should NOT be deleted)
new_transcript = await transcripts_controller.add(
db_session,
name="New Anonymous Transcript",
source_kind=SourceKind.FILE,
user_id=None, # Anonymous
@@ -59,265 +52,234 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(db_sess
# Create old transcript with user (should NOT be deleted)
old_user_transcript = await transcripts_controller.add(
db_session,
name="Old User Transcript",
source_kind=SourceKind.FILE,
user_id="user-123",
user_id="user123",
)
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == old_user_transcript.id)
await get_database().execute(
transcripts.update()
.where(transcripts.c.id == old_user_transcript.id)
.values(created_at=old_date)
)
await db_session.commit()
# Mock settings for public mode
with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = True
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
# Mock delete_single_transcript to track what gets deleted
with patch("reflector.worker.cleanup.delete_single_transcript") as mock_delete:
mock_delete.return_value = None
# Mock the storage deletion
with patch("reflector.db.transcripts.get_transcripts_storage") as mock_storage:
mock_storage.return_value.delete_file = AsyncMock()
# Run cleanup with test session
await cleanup_old_public_data(db_session)
result = await cleanup_old_public_data()
# Verify only old anonymous transcript was deleted
assert mock_delete.call_count == 1
# The function is called with session_factory, transcript_data dict, and stats dict
call_args = mock_delete.call_args[0]
transcript_data = call_args[1]
assert transcript_data["id"] == old_transcript.id
# Check results
assert result["transcripts_deleted"] == 1
assert result["errors"] == []
# Verify old anonymous transcript was deleted
assert await transcripts_controller.get_by_id(old_transcript.id) is None
# Verify new anonymous transcript still exists
assert await transcripts_controller.get_by_id(new_transcript.id) is not None
# Verify user transcript still exists
assert await transcripts_controller.get_by_id(old_user_transcript.id) is not None
@pytest.mark.asyncio
async def test_cleanup_deletes_associated_meeting_and_recording(db_session):
"""Test that cleanup deletes associated meetings and recordings."""
async def test_cleanup_deletes_associated_meeting_and_recording():
"""Test that meetings and recordings associated with old transcripts are deleted."""
from reflector.db import get_database
from reflector.db.meetings import meetings
from reflector.db.transcripts import transcripts
old_date = datetime.now(timezone.utc) - timedelta(days=8)
# Create a meeting
meeting_id = "test-meeting-for-transcript"
await get_database().execute(
meetings.insert().values(
id=meeting_id,
room_name="Meeting with Transcript",
room_url="https://example.com/meeting",
host_room_url="https://example.com/meeting-host",
start_date=old_date,
end_date=old_date + timedelta(hours=1),
room_id=None,
)
)
# Create a recording
recording = await recordings_controller.create(
Recording(
bucket_name="test-bucket",
object_key="test-recording.mp4",
recorded_at=old_date,
)
)
# Create an old transcript with both meeting and recording
old_transcript = await transcripts_controller.add(
db_session,
name="Old Transcript with Meeting and Recording",
source_kind=SourceKind.FILE,
source_kind=SourceKind.ROOM,
user_id=None,
meeting_id=meeting_id,
recording_id=recording.id,
)
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == old_transcript.id)
# Update created_at to be old
await get_database().execute(
transcripts.update()
.where(transcripts.c.id == old_transcript.id)
.values(created_at=old_date)
)
await db_session.commit()
# Create associated meeting directly
meeting_id = "test-meeting-id"
await db_session.execute(
insert(MeetingModel).values(
id=meeting_id,
room_id=None,
room_name="test-room",
room_url="https://example.com/room",
host_room_url="https://example.com/room-host",
start_date=old_date,
end_date=old_date + timedelta(hours=1),
is_active=False,
num_clients=0,
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic",
)
)
# Create associated recording directly
recording_id = "test-recording-id"
await db_session.execute(
insert(RecordingModel).values(
id=recording_id,
meeting_id=meeting_id,
url="https://example.com/recording.mp4",
object_key="recordings/test.mp4",
duration=3600.0,
created_at=old_date,
)
)
await db_session.commit()
# Update transcript with meeting_id and recording_id
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == old_transcript.id)
.values(meeting_id=meeting_id, recording_id=recording_id)
)
await db_session.commit()
# Mock settings
with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = True
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
# Mock storage deletion
with patch("reflector.worker.cleanup.get_recordings_storage") as mock_storage:
with patch("reflector.db.transcripts.get_transcripts_storage") as mock_storage:
mock_storage.return_value.delete_file = AsyncMock()
with patch(
"reflector.worker.cleanup.get_recordings_storage"
) as mock_rec_storage:
mock_rec_storage.return_value.delete_file = AsyncMock()
# Run cleanup with test session
await cleanup_old_public_data(db_session)
result = await cleanup_old_public_data()
# Check results
assert result["transcripts_deleted"] == 1
assert result["meetings_deleted"] == 1
assert result["recordings_deleted"] == 1
assert result["errors"] == []
# Verify transcript was deleted
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == old_transcript.id)
)
transcript = result.scalar_one_or_none()
assert transcript is None
assert await transcripts_controller.get_by_id(old_transcript.id) is None
# Verify meeting was deleted
result = await db_session.execute(
select(MeetingModel).where(MeetingModel.id == meeting_id)
)
meeting = result.scalar_one_or_none()
assert meeting is None
query = meetings.select().where(meetings.c.id == meeting_id)
meeting_result = await get_database().fetch_one(query)
assert meeting_result is None
# Verify recording was deleted
result = await db_session.execute(
select(RecordingModel).where(RecordingModel.id == recording_id)
)
recording = result.scalar_one_or_none()
assert recording is None
assert await recordings_controller.get_by_id(recording.id) is None
@pytest.mark.asyncio
async def test_cleanup_handles_errors_gracefully(db_session):
"""Test that cleanup continues even if individual deletions fail."""
async def test_cleanup_handles_errors_gracefully():
"""Test that cleanup continues even when individual deletions fail."""
old_date = datetime.now(timezone.utc) - timedelta(days=8)
# Create multiple old transcripts
transcript1 = await transcripts_controller.add(
db_session,
name="Transcript 1",
source_kind=SourceKind.FILE,
user_id=None,
)
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == transcript1.id)
.values(created_at=old_date)
)
transcript2 = await transcripts_controller.add(
db_session,
name="Transcript 2",
source_kind=SourceKind.FILE,
user_id=None,
)
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == transcript2.id)
# Update created_at to be old
from reflector.db import get_database
from reflector.db.transcripts import transcripts
for t_id in [transcript1.id, transcript2.id]:
await get_database().execute(
transcripts.update()
.where(transcripts.c.id == t_id)
.values(created_at=old_date)
)
await db_session.commit()
with patch("reflector.worker.cleanup.settings") as mock_settings:
mock_settings.PUBLIC_MODE = True
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
# Mock delete_single_transcript to fail on first call but succeed on second
with patch("reflector.worker.cleanup.delete_single_transcript") as mock_delete:
mock_delete.side_effect = [Exception("Delete failed"), None]
# Mock remove_by_id to fail for the first transcript
original_remove = transcripts_controller.remove_by_id
call_count = 0
# Run cleanup with test session - should not raise exception
await cleanup_old_public_data(db_session)
async def mock_remove_by_id(transcript_id, user_id=None):
nonlocal call_count
call_count += 1
if call_count == 1:
raise Exception("Simulated deletion error")
return await original_remove(transcript_id, user_id)
# Both transcripts should have been attempted to delete
assert mock_delete.call_count == 2
with patch.object(
transcripts_controller, "remove_by_id", side_effect=mock_remove_by_id
):
result = await cleanup_old_public_data()
# Should have one successful deletion and one error
assert result["transcripts_deleted"] == 1
assert len(result["errors"]) == 1
assert "Failed to delete transcript" in result["errors"][0]
@pytest.mark.asyncio
async def test_meeting_consent_cascade_delete(db_session):
"""Test that meeting_consent entries are cascade deleted with meetings."""
old_date = datetime.now(timezone.utc) - timedelta(days=8)
# Create an old transcript
transcript = await transcripts_controller.add(
db_session,
name="Transcript with Meeting",
source_kind=SourceKind.FILE,
user_id=None,
async def test_meeting_consent_cascade_delete():
"""Test that meeting_consent records are automatically deleted when meeting is deleted."""
from reflector.db import get_database
from reflector.db.meetings import (
meeting_consent,
meeting_consent_controller,
meetings,
)
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == transcript.id)
.values(created_at=old_date)
)
await db_session.commit()
# Create a meeting directly
meeting_id = "test-meeting-consent"
await db_session.execute(
insert(MeetingModel).values(
# Create a meeting
meeting_id = "test-cascade-meeting"
await get_database().execute(
meetings.insert().values(
id=meeting_id,
room_name="Test Meeting for CASCADE",
room_url="https://example.com/cascade-test",
host_room_url="https://example.com/cascade-test-host",
start_date=datetime.now(timezone.utc),
end_date=datetime.now(timezone.utc) + timedelta(hours=1),
room_id=None,
room_name="test-room",
room_url="https://example.com/room",
host_room_url="https://example.com/room-host",
start_date=old_date,
end_date=old_date + timedelta(hours=1),
is_active=False,
num_clients=0,
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic",
)
)
await db_session.commit()
# Update transcript with meeting_id
await db_session.execute(
update(TranscriptModel)
.where(TranscriptModel.id == transcript.id)
.values(meeting_id=meeting_id)
)
await db_session.commit()
# Create consent records for this meeting
consent1_id = "consent-1"
consent2_id = "consent-2"
# Create meeting_consent entries
await db_session.execute(
insert(MeetingConsentModel).values(
id="consent-1",
await get_database().execute(
meeting_consent.insert().values(
id=consent1_id,
meeting_id=meeting_id,
user_id="user-1",
user_id="user1",
consent_given=True,
consent_timestamp=old_date,
consent_timestamp=datetime.now(timezone.utc),
)
)
await db_session.execute(
insert(MeetingConsentModel).values(
id="consent-2",
meeting_id=meeting_id,
user_id="user-2",
consent_given=True,
consent_timestamp=old_date,
)
)
await db_session.commit()
# Verify consent entries exist
result = await db_session.execute(
select(MeetingConsentModel).where(MeetingConsentModel.meeting_id == meeting_id)
await get_database().execute(
meeting_consent.insert().values(
id=consent2_id,
meeting_id=meeting_id,
user_id="user2",
consent_given=False,
consent_timestamp=datetime.now(timezone.utc),
)
consents = result.scalars().all()
)
# Verify consent records exist
consents = await meeting_consent_controller.get_by_meeting_id(meeting_id)
assert len(consents) == 2
# Delete the transcript and meeting
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == transcript.id)
)
await db_session.execute(delete(MeetingModel).where(MeetingModel.id == meeting_id))
await db_session.commit()
# Delete the meeting
await get_database().execute(meetings.delete().where(meetings.c.id == meeting_id))
# Verify consent entries were cascade deleted
result = await db_session.execute(
select(MeetingConsentModel).where(MeetingConsentModel.meeting_id == meeting_id)
)
consents = result.scalars().all()
assert len(consents) == 0
# Verify meeting is deleted
query = meetings.select().where(meetings.c.id == meeting_id)
result = await get_database().fetch_one(query)
assert result is None
# Verify consent records are automatically deleted (CASCADE DELETE)
consents_after = await meeting_consent_controller.get_by_meeting_id(meeting_id)
assert len(consents_after) == 0

View File

@@ -4,8 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from icalendar import Calendar, Event
from reflector.db import get_database
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.rooms import rooms_controller
from reflector.db.rooms import rooms, rooms_controller
from reflector.services.ics_sync import ics_sync_service
from reflector.worker.ics_sync import (
_should_sync,
@@ -14,9 +15,8 @@ from reflector.worker.ics_sync import (
@pytest.mark.asyncio
async def test_sync_room_ics_task(db_session):
async def test_sync_room_ics_task():
room = await rooms_controller.add(
db_session,
name="task-test-room",
user_id="test-user",
zulip_auto_post=False,
@@ -30,7 +30,6 @@ async def test_sync_room_ics_task(db_session):
ics_url="https://calendar.example.com/task.ics",
ics_enabled=True,
)
await db_session.flush()
cal = Calendar()
event = Event()
@@ -46,22 +45,21 @@ async def test_sync_room_ics_task(db_session):
ics_content = cal.to_ical().decode("utf-8")
with patch(
"reflector.services.ics_sync.ICSFetchService.fetch_ics",
new_callable=AsyncMock,
"reflector.services.ics_sync.ICSFetchService.fetch_ics", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.return_value = ics_content
await ics_sync_service.sync_room_calendar(db_session, room)
# Call the service directly instead of the Celery task to avoid event loop issues
await ics_sync_service.sync_room_calendar(room)
events = await calendar_events_controller.get_by_room(db_session, room.id)
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1
assert events[0].ics_uid == "task-event-1"
@pytest.mark.asyncio
async def test_sync_room_ics_disabled(db_session):
async def test_sync_room_ics_disabled():
room = await rooms_controller.add(
db_session,
name="disabled-room",
user_id="test-user",
zulip_auto_post=False,
@@ -75,16 +73,16 @@ async def test_sync_room_ics_disabled(db_session):
ics_enabled=False,
)
result = await ics_sync_service.sync_room_calendar(db_session, room)
# Test that disabled rooms are skipped by the service
result = await ics_sync_service.sync_room_calendar(room)
events = await calendar_events_controller.get_by_room(db_session, room.id)
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 0
@pytest.mark.asyncio
async def test_sync_all_ics_calendars(db_session):
async def test_sync_all_ics_calendars():
room1 = await rooms_controller.add(
db_session,
name="sync-all-1",
user_id="test-user",
zulip_auto_post=False,
@@ -100,7 +98,6 @@ async def test_sync_all_ics_calendars(db_session):
)
room2 = await rooms_controller.add(
db_session,
name="sync-all-2",
user_id="test-user",
zulip_auto_post=False,
@@ -116,7 +113,6 @@ async def test_sync_all_ics_calendars(db_session):
)
room3 = await rooms_controller.add(
db_session,
name="sync-all-3",
user_id="test-user",
zulip_auto_post=False,
@@ -130,15 +126,21 @@ async def test_sync_all_ics_calendars(db_session):
ics_enabled=False,
)
with patch("reflector.worker.ics_sync.sync_room_ics.kiq") as mock_kiq:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(db_session)
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
# Directly call the sync_all logic without the Celery wrapper
query = rooms.select().where(
rooms.c.ics_enabled == True, rooms.c.ics_url != None
)
all_rooms = await get_database().fetch_all(query)
for room in ics_enabled_rooms:
for room_data in all_rooms:
room_id = room_data["id"]
room = await rooms_controller.get_by_id(room_id)
if room and _should_sync(room):
await sync_room_ics.kiq(room.id)
sync_room_ics.delay(room_id)
assert mock_kiq.call_count == 2
called_room_ids = [call.args[0] for call in mock_kiq.call_args_list]
assert mock_delay.call_count == 2
called_room_ids = [call.args[0] for call in mock_delay.call_args_list]
assert room1.id in called_room_ids
assert room2.id in called_room_ids
assert room3.id not in called_room_ids
@@ -161,11 +163,10 @@ async def test_should_sync_logic():
@pytest.mark.asyncio
async def test_sync_respects_fetch_interval(db_session):
async def test_sync_respects_fetch_interval():
now = datetime.now(timezone.utc)
room1 = await rooms_controller.add(
db_session,
name="interval-test-1",
user_id="test-user",
zulip_auto_post=False,
@@ -182,13 +183,11 @@ async def test_sync_respects_fetch_interval(db_session):
)
await rooms_controller.update(
db_session,
room1,
{"ics_last_sync": now - timedelta(seconds=100)},
)
room2 = await rooms_controller.add(
db_session,
name="interval-test-2",
user_id="test-user",
zulip_auto_post=False,
@@ -205,26 +204,30 @@ async def test_sync_respects_fetch_interval(db_session):
)
await rooms_controller.update(
db_session,
room2,
{"ics_last_sync": now - timedelta(seconds=100)},
)
with patch("reflector.worker.ics_sync.sync_room_ics.kiq") as mock_kiq:
ics_enabled_rooms = await rooms_controller.get_ics_enabled(db_session)
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
# Test the sync logic without the Celery wrapper
query = rooms.select().where(
rooms.c.ics_enabled == True, rooms.c.ics_url != None
)
all_rooms = await get_database().fetch_all(query)
for room in ics_enabled_rooms:
for room_data in all_rooms:
room_id = room_data["id"]
room = await rooms_controller.get_by_id(room_id)
if room and _should_sync(room):
await sync_room_ics.kiq(room.id)
sync_room_ics.delay(room_id)
assert mock_kiq.call_count == 1
assert mock_kiq.call_args[0][0] == room2.id
assert mock_delay.call_count == 1
assert mock_delay.call_args[0][0] == room2.id
@pytest.mark.asyncio
async def test_sync_handles_errors_gracefully(db_session):
async def test_sync_handles_errors_gracefully():
room = await rooms_controller.add(
db_session,
name="error-task-room",
user_id="test-user",
zulip_auto_post=False,
@@ -244,8 +247,9 @@ async def test_sync_handles_errors_gracefully(db_session):
) as mock_fetch:
mock_fetch.side_effect = Exception("Network error")
result = await ics_sync_service.sync_room_calendar(db_session, room)
# Call the service directly to test error handling
result = await ics_sync_service.sync_room_calendar(room)
assert result["status"] == "error"
events = await calendar_events_controller.get_by_room(db_session, room.id)
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 0

View File

@@ -134,10 +134,9 @@ async def test_ics_fetch_service_extract_room_events():
@pytest.mark.asyncio
async def test_ics_sync_service_sync_room_calendar(db_session):
async def test_ics_sync_service_sync_room_calendar():
# Create room
room = await rooms_controller.add(
db_session,
name="sync-test",
user_id="test-user",
zulip_auto_post=False,
@@ -151,7 +150,6 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
ics_url="https://calendar.example.com/test.ics",
ics_enabled=True,
)
await db_session.flush()
# Mock ICS content
cal = Calendar()
@@ -177,7 +175,7 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
mock_fetch.return_value = ics_content
# First sync
result = await sync_service.sync_room_calendar(db_session, room)
result = await sync_service.sync_room_calendar(room)
assert result["status"] == "success"
assert result["events_found"] == 1
@@ -186,20 +184,18 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
assert result["events_deleted"] == 0
# Verify event was created
events = await calendar_events_controller.get_by_room(db_session, room.id)
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1
assert events[0].ics_uid == "sync-event-1"
assert events[0].title == "Sync Test Meeting"
# Second sync with same content (should be unchanged)
# Refresh room to get updated etag and force sync by setting old sync time
room = await rooms_controller.get_by_id(db_session, room.id)
room = await rooms_controller.get_by_id(room.id)
await rooms_controller.update(
db_session,
room,
{"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)},
room, {"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)}
)
result = await sync_service.sync_room_calendar(db_session, room)
result = await sync_service.sync_room_calendar(room)
assert result["status"] == "unchanged"
# Third sync with updated event
@@ -210,15 +206,15 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
mock_fetch.return_value = ics_content
# Force sync by clearing etag
await rooms_controller.update(db_session, room, {"ics_last_etag": None})
await rooms_controller.update(room, {"ics_last_etag": None})
result = await sync_service.sync_room_calendar(db_session, room)
result = await sync_service.sync_room_calendar(room)
assert result["status"] == "success"
assert result["events_created"] == 0
assert result["events_updated"] == 1
# Verify event was updated
events = await calendar_events_controller.get_by_room(db_session, room.id)
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1
assert events[0].title == "Updated Meeting Title"
@@ -251,7 +247,7 @@ async def test_ics_sync_service_skip_disabled():
room.ics_enabled = False
room.ics_url = "https://calendar.example.com/test.ics"
result = await service.sync_room_calendar(MagicMock(), room)
result = await service.sync_room_calendar(room)
assert result["status"] == "skipped"
assert result["reason"] == "ICS not configured"
@@ -259,16 +255,15 @@ async def test_ics_sync_service_skip_disabled():
room.ics_enabled = True
room.ics_url = None
result = await service.sync_room_calendar(MagicMock(), room)
result = await service.sync_room_calendar(room)
assert result["status"] == "skipped"
assert result["reason"] == "ICS not configured"
@pytest.mark.asyncio
async def test_ics_sync_service_error_handling(db_session):
async def test_ics_sync_service_error_handling():
# Create room
room = await rooms_controller.add(
db_session,
name="error-test",
user_id="test-user",
zulip_auto_post=False,
@@ -282,7 +277,6 @@ async def test_ics_sync_service_error_handling(db_session):
ics_url="https://calendar.example.com/error.ics",
ics_enabled=True,
)
await db_session.flush()
sync_service = ICSSyncService()
@@ -291,6 +285,6 @@ async def test_ics_sync_service_error_handling(db_session):
) as mock_fetch:
mock_fetch.side_effect = Exception("Network error")
result = await sync_service.sync_room_calendar(db_session, room)
result = await sync_service.sync_room_calendar(room)
assert result["status"] == "error"
assert "Network error" in result["error"]

View File

@@ -10,11 +10,10 @@ from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio
async def test_multiple_active_meetings_per_room(db_session):
async def test_multiple_active_meetings_per_room():
"""Test that multiple active meetings can exist for the same room."""
# Create a room
room = await rooms_controller.add(
db_session,
name="test-room",
user_id="test-user",
zulip_auto_post=False,
@@ -32,7 +31,6 @@ async def test_multiple_active_meetings_per_room(db_session):
# Create first meeting
meeting1 = await meetings_controller.create(
db_session,
id="meeting-1",
room_name="test-meeting-1",
room_url="https://whereby.com/test-1",
@@ -44,7 +42,6 @@ async def test_multiple_active_meetings_per_room(db_session):
# Create second meeting for the same room (should succeed now)
meeting2 = await meetings_controller.create(
db_session,
id="meeting-2",
room_name="test-meeting-2",
room_url="https://whereby.com/test-2",
@@ -56,7 +53,7 @@ async def test_multiple_active_meetings_per_room(db_session):
# Both meetings should be active
active_meetings = await meetings_controller.get_all_active_for_room(
db_session, room=room, current_time=current_time
room=room, current_time=current_time
)
assert len(active_meetings) == 2
@@ -65,11 +62,10 @@ async def test_multiple_active_meetings_per_room(db_session):
@pytest.mark.asyncio
async def test_get_active_by_calendar_event(db_session):
async def test_get_active_by_calendar_event():
"""Test getting active meeting by calendar event ID."""
# Create a room
room = await rooms_controller.add(
db_session,
name="test-room",
user_id="test-user",
zulip_auto_post=False,
@@ -90,14 +86,13 @@ async def test_get_active_by_calendar_event(db_session):
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc) + timedelta(hours=1),
)
event = await calendar_events_controller.upsert(db_session, event)
event = await calendar_events_controller.upsert(event)
current_time = datetime.now(timezone.utc)
end_time = current_time + timedelta(hours=2)
# Create meeting linked to calendar event
meeting = await meetings_controller.create(
db_session,
id="meeting-cal-1",
room_name="test-meeting-cal",
room_url="https://whereby.com/test-cal",
@@ -111,7 +106,7 @@ async def test_get_active_by_calendar_event(db_session):
# Should find the meeting by calendar event
found_meeting = await meetings_controller.get_active_by_calendar_event(
db_session, room=room, calendar_event_id=event.id, current_time=current_time
room=room, calendar_event_id=event.id, current_time=current_time
)
assert found_meeting is not None
@@ -120,11 +115,10 @@ async def test_get_active_by_calendar_event(db_session):
@pytest.mark.asyncio
async def test_calendar_meeting_deactivates_after_scheduled_end(db_session):
async def test_calendar_meeting_deactivates_after_scheduled_end():
"""Test that unused calendar meetings deactivate after scheduled end time."""
# Create a room
room = await rooms_controller.add(
db_session,
name="test-room",
user_id="test-user",
zulip_auto_post=False,
@@ -145,13 +139,12 @@ async def test_calendar_meeting_deactivates_after_scheduled_end(db_session):
start_time=datetime.now(timezone.utc) - timedelta(hours=2),
end_time=datetime.now(timezone.utc) - timedelta(minutes=35),
)
event = await calendar_events_controller.upsert(db_session, event)
event = await calendar_events_controller.upsert(event)
current_time = datetime.now(timezone.utc)
# Create meeting linked to calendar event
meeting = await meetings_controller.create(
db_session,
id="meeting-unused",
room_name="test-meeting-unused",
room_url="https://whereby.com/test-unused",
@@ -168,9 +161,7 @@ async def test_calendar_meeting_deactivates_after_scheduled_end(db_session):
# Simulate process_meetings logic for unused calendar meeting past end time
if meeting.calendar_event_id and current_time > meeting.end_date:
# In real code, we'd check has_had_sessions = False here
await meetings_controller.update_meeting(
db_session, meeting.id, is_active=False
)
await meetings_controller.update_meeting(meeting.id, is_active=False)
updated_meeting = await meetings_controller.get_by_id(db_session, meeting.id)
updated_meeting = await meetings_controller.get_by_id(meeting.id)
assert updated_meeting.is_active is False # Deactivated after scheduled end

View File

@@ -101,36 +101,20 @@ async def mock_transcript_in_db(tmpdir):
target_language="en",
)
# Mock all transcripts controller methods that are used in the pipeline
# Mock the controller to return our transcript
try:
with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
) as mock_get:
mock_get.return_value = transcript
with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.update"
) as mock_update:
mock_update.return_value = transcript
with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.set_status"
) as mock_set_status:
mock_set_status.return_value = None
with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.upsert_topic"
) as mock_upsert_topic:
mock_upsert_topic.return_value = None
with patch(
"reflector.pipelines.main_file_pipeline.transcripts_controller.append_event"
) as mock_append_event:
mock_append_event.return_value = None
with patch(
"reflector.pipelines.main_live_pipeline.transcripts_controller.get_by_id"
) as mock_get2:
mock_get2.return_value = transcript
with patch(
"reflector.pipelines.main_live_pipeline.transcripts_controller.update"
) as mock_update2:
mock_update2.return_value = None
) as mock_update:
mock_update.return_value = None
yield transcript
finally:
# Restore original DATA_DIR
@@ -297,7 +281,6 @@ async def mock_summary_processor():
@pytest.mark.asyncio
async def test_pipeline_main_file_process(
db_session,
tmpdir,
mock_transcript_in_db,
dummy_file_transcript,
@@ -378,7 +361,7 @@ async def test_pipeline_main_file_process(
mock_av.side_effect = [mock_container, mock_decode_container]
# Run the pipeline
await pipeline.process(db_session, upload_path)
await pipeline.process(upload_path)
# Verify audio extraction and writing
assert mock_audio_file_writer.push.called
@@ -423,7 +406,6 @@ async def test_pipeline_main_file_process(
@pytest.mark.asyncio
async def test_pipeline_main_file_with_video(
db_session,
tmpdir,
mock_transcript_in_db,
dummy_file_transcript,
@@ -470,7 +452,7 @@ async def test_pipeline_main_file_with_video(
mock_av.side_effect = [mock_container, mock_decode_container]
# Run the pipeline
await pipeline.process(db_session, upload_path)
await pipeline.process(upload_path)
# Verify audio extraction from video
assert mock_audio_file_writer.push.called
@@ -488,7 +470,6 @@ async def test_pipeline_main_file_with_video(
@pytest.mark.asyncio
async def test_pipeline_main_file_no_diarization(
db_session,
tmpdir,
mock_transcript_in_db,
dummy_file_transcript,
@@ -536,7 +517,7 @@ async def test_pipeline_main_file_no_diarization(
mock_av.side_effect = [mock_container, mock_decode_container]
# Run the pipeline
await pipeline.process(db_session, upload_path)
await pipeline.process(upload_path)
# Verify the pipeline completed without diarization
assert mock_storage._put_file.called
@@ -550,7 +531,6 @@ async def test_pipeline_main_file_no_diarization(
@pytest.mark.asyncio
async def test_task_pipeline_file_process(
db_session,
tmpdir,
mock_transcript_in_db,
dummy_file_transcript,
@@ -597,7 +577,7 @@ async def test_task_pipeline_file_process(
from reflector.pipelines.main_file_pipeline import PipelineMainFile
pipeline = PipelineMainFile(transcript_id=mock_transcript_in_db.id)
await pipeline.process(db_session, upload_path)
await pipeline.process(upload_path)
# Verify the pipeline was executed through the task
assert mock_audio_file_writer.push.called
@@ -628,16 +608,11 @@ async def test_pipeline_file_process_no_transcript():
# Should raise an exception for missing transcript when get_transcript is called
with pytest.raises(Exception, match="Transcript not found"):
# Use a mock session - the controller is mocked to return None anyway
from unittest.mock import MagicMock
mock_session = MagicMock()
await pipeline.get_transcript(mock_session)
await pipeline.get_transcript()
@pytest.mark.asyncio
async def test_pipeline_file_process_no_audio_file(
db_session,
mock_transcript_in_db,
):
"""
@@ -655,4 +630,4 @@ async def test_pipeline_file_process_no_audio_file(
# This should fail when trying to open the file with av
with pytest.raises(Exception):
await pipeline.process(db_session, non_existent_path)
await pipeline.process(non_existent_path)

View File

@@ -10,10 +10,9 @@ from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio
async def test_room_create_with_ics_fields(db_session):
async def test_room_create_with_ics_fields():
"""Test creating a room with ICS calendar fields."""
room = await rooms_controller.add(
db_session,
name="test-room",
user_id="test-user",
zulip_auto_post=False,
@@ -41,11 +40,10 @@ async def test_room_create_with_ics_fields(db_session):
@pytest.mark.asyncio
async def test_room_update_ics_configuration(db_session):
async def test_room_update_ics_configuration():
"""Test updating room ICS configuration."""
# Create room without ICS
room = await rooms_controller.add(
db_session,
name="update-test",
user_id="test-user",
zulip_auto_post=False,
@@ -63,7 +61,6 @@ async def test_room_update_ics_configuration(db_session):
# Update with ICS configuration
await rooms_controller.update(
db_session,
room,
{
"ics_url": "https://outlook.office365.com/owa/calendar/test/calendar.ics",
@@ -80,10 +77,9 @@ async def test_room_update_ics_configuration(db_session):
@pytest.mark.asyncio
async def test_room_ics_sync_metadata(db_session):
async def test_room_ics_sync_metadata():
"""Test updating room ICS sync metadata."""
room = await rooms_controller.add(
db_session,
name="sync-test",
user_id="test-user",
zulip_auto_post=False,
@@ -101,7 +97,6 @@ async def test_room_ics_sync_metadata(db_session):
# Update sync metadata
sync_time = datetime.now(timezone.utc)
await rooms_controller.update(
db_session,
room,
{
"ics_last_sync": sync_time,
@@ -114,11 +109,10 @@ async def test_room_ics_sync_metadata(db_session):
@pytest.mark.asyncio
async def test_room_get_with_ics_fields(db_session):
async def test_room_get_with_ics_fields():
"""Test retrieving room with ICS fields."""
# Create room
created_room = await rooms_controller.add(
db_session,
name="get-test",
user_id="test-user",
zulip_auto_post=False,
@@ -135,14 +129,14 @@ async def test_room_get_with_ics_fields(db_session):
)
# Get by ID
room = await rooms_controller.get_by_id(db_session, created_room.id)
room = await rooms_controller.get_by_id(created_room.id)
assert room is not None
assert room.ics_url == "webcal://calendar.example.com/feed.ics"
assert room.ics_fetch_interval == 900
assert room.ics_enabled is True
# Get by name
room = await rooms_controller.get_by_name(db_session, "get-test")
room = await rooms_controller.get_by_name("get-test")
assert room is not None
assert room.ics_url == "webcal://calendar.example.com/feed.ics"
assert room.ics_fetch_interval == 900
@@ -150,11 +144,10 @@ async def test_room_get_with_ics_fields(db_session):
@pytest.mark.asyncio
async def test_room_list_with_ics_enabled_filter(db_session):
async def test_room_list_with_ics_enabled_filter():
"""Test listing rooms filtered by ICS enabled status."""
# Create rooms with and without ICS
room1 = await rooms_controller.add(
db_session,
name="ics-enabled-1",
user_id="test-user",
zulip_auto_post=False,
@@ -170,7 +163,6 @@ async def test_room_list_with_ics_enabled_filter(db_session):
)
room2 = await rooms_controller.add(
db_session,
name="ics-disabled",
user_id="test-user",
zulip_auto_post=False,
@@ -185,7 +177,6 @@ async def test_room_list_with_ics_enabled_filter(db_session):
)
room3 = await rooms_controller.add(
db_session,
name="ics-enabled-2",
user_id="test-user",
zulip_auto_post=False,
@@ -201,20 +192,19 @@ async def test_room_list_with_ics_enabled_filter(db_session):
)
# Get all rooms
all_rooms = await rooms_controller.get_all(db_session)
all_rooms = await rooms_controller.get_all()
assert len(all_rooms) == 3
# Filter for ICS-enabled rooms (would need to implement this in controller)
ics_rooms = [r for r in all_rooms if r.ics_enabled]
ics_rooms = [r for r in all_rooms if r["ics_enabled"]]
assert len(ics_rooms) == 2
assert all(r.ics_enabled for r in ics_rooms)
assert all(r["ics_enabled"] for r in ics_rooms)
@pytest.mark.asyncio
async def test_room_default_ics_values(db_session):
async def test_room_default_ics_values():
"""Test that ICS fields have correct default values."""
room = await rooms_controller.add(
db_session,
name="default-test",
user_id="test-user",
zulip_auto_post=False,

View File

@@ -11,13 +11,20 @@ from reflector.db.rooms import rooms_controller
@pytest.fixture
async def authenticated_client(client):
from reflector.app import app
from reflector.auth import current_user_optional
from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {
"sub": "test-user",
"email": "test@example.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "test-user",
"email": "test@example.com",
}
try:
yield client
finally:
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@@ -89,10 +96,9 @@ async def test_update_room_ics_configuration(authenticated_client):
@pytest.mark.asyncio
async def test_trigger_ics_sync(authenticated_client, db_session):
async def test_trigger_ics_sync(authenticated_client):
client = authenticated_client
room = await rooms_controller.add(
db_session,
name="sync-api-room",
user_id="test-user",
zulip_auto_post=False,
@@ -134,9 +140,8 @@ async def test_trigger_ics_sync(authenticated_client, db_session):
@pytest.mark.asyncio
async def test_trigger_ics_sync_unauthorized(client, db_session):
async def test_trigger_ics_sync_unauthorized(client):
room = await rooms_controller.add(
db_session,
name="sync-unauth-room",
user_id="owner-123",
zulip_auto_post=False,
@@ -157,10 +162,9 @@ async def test_trigger_ics_sync_unauthorized(client, db_session):
@pytest.mark.asyncio
async def test_trigger_ics_sync_not_configured(authenticated_client, db_session):
async def test_trigger_ics_sync_not_configured(authenticated_client):
client = authenticated_client
room = await rooms_controller.add(
db_session,
name="sync-not-configured",
user_id="test-user",
zulip_auto_post=False,
@@ -180,10 +184,9 @@ async def test_trigger_ics_sync_not_configured(authenticated_client, db_session)
@pytest.mark.asyncio
async def test_get_ics_status(authenticated_client, db_session):
async def test_get_ics_status(authenticated_client):
client = authenticated_client
room = await rooms_controller.add(
db_session,
name="status-room",
user_id="test-user",
zulip_auto_post=False,
@@ -201,7 +204,6 @@ async def test_get_ics_status(authenticated_client, db_session):
now = datetime.now(timezone.utc)
await rooms_controller.update(
db_session,
room,
{"ics_last_sync": now, "ics_last_etag": "test-etag"},
)
@@ -215,9 +217,8 @@ async def test_get_ics_status(authenticated_client, db_session):
@pytest.mark.asyncio
async def test_get_ics_status_unauthorized(client, db_session):
async def test_get_ics_status_unauthorized(client):
room = await rooms_controller.add(
db_session,
name="status-unauth",
user_id="owner-456",
zulip_auto_post=False,
@@ -238,10 +239,9 @@ async def test_get_ics_status_unauthorized(client, db_session):
@pytest.mark.asyncio
async def test_list_room_meetings(authenticated_client, db_session):
async def test_list_room_meetings(authenticated_client):
client = authenticated_client
room = await rooms_controller.add(
db_session,
name="meetings-room",
user_id="test-user",
zulip_auto_post=False,
@@ -262,7 +262,7 @@ async def test_list_room_meetings(authenticated_client, db_session):
start_time=now - timedelta(hours=2),
end_time=now - timedelta(hours=1),
)
await calendar_events_controller.upsert(db_session, event1)
await calendar_events_controller.upsert(event1)
event2 = CalendarEvent(
room_id=room.id,
@@ -273,7 +273,7 @@ async def test_list_room_meetings(authenticated_client, db_session):
end_time=now + timedelta(hours=2),
attendees=[{"email": "test@example.com"}],
)
await calendar_events_controller.upsert(db_session, event2)
await calendar_events_controller.upsert(event2)
response = await client.get(f"/rooms/{room.name}/meetings")
assert response.status_code == 200
@@ -286,9 +286,8 @@ async def test_list_room_meetings(authenticated_client, db_session):
@pytest.mark.asyncio
async def test_list_room_meetings_non_owner(client, db_session):
async def test_list_room_meetings_non_owner(client):
room = await rooms_controller.add(
db_session,
name="meetings-privacy",
user_id="owner-789",
zulip_auto_post=False,
@@ -310,7 +309,7 @@ async def test_list_room_meetings_non_owner(client, db_session):
end_time=datetime.now(timezone.utc) + timedelta(hours=2),
attendees=[{"email": "private@example.com"}],
)
await calendar_events_controller.upsert(db_session, event)
await calendar_events_controller.upsert(event)
response = await client.get(f"/rooms/{room.name}/meetings")
assert response.status_code == 200
@@ -322,10 +321,9 @@ async def test_list_room_meetings_non_owner(client, db_session):
@pytest.mark.asyncio
async def test_list_upcoming_meetings(authenticated_client, db_session):
async def test_list_upcoming_meetings(authenticated_client):
client = authenticated_client
room = await rooms_controller.add(
db_session,
name="upcoming-room",
user_id="test-user",
zulip_auto_post=False,
@@ -347,7 +345,7 @@ async def test_list_upcoming_meetings(authenticated_client, db_session):
start_time=now - timedelta(hours=1),
end_time=now - timedelta(minutes=30),
)
await calendar_events_controller.upsert(db_session, past_event)
await calendar_events_controller.upsert(past_event)
soon_event = CalendarEvent(
room_id=room.id,
@@ -356,7 +354,7 @@ async def test_list_upcoming_meetings(authenticated_client, db_session):
start_time=now + timedelta(minutes=15),
end_time=now + timedelta(minutes=45),
)
await calendar_events_controller.upsert(db_session, soon_event)
await calendar_events_controller.upsert(soon_event)
later_event = CalendarEvent(
room_id=room.id,
@@ -365,7 +363,7 @@ async def test_list_upcoming_meetings(authenticated_client, db_session):
start_time=now + timedelta(hours=2),
end_time=now + timedelta(hours=3),
)
await calendar_events_controller.upsert(db_session, later_event)
await calendar_events_controller.upsert(later_event)
response = await client.get(f"/rooms/{room.name}/meetings/upcoming")
assert response.status_code == 200

View File

@@ -2,40 +2,40 @@
import json
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import delete, insert
from reflector.db.base import TranscriptModel
from reflector.db import get_database
from reflector.db.search import (
SearchController,
SearchParameters,
SearchResult,
search_controller,
)
from reflector.db.transcripts import SourceKind
from reflector.db.transcripts import SourceKind, transcripts
@pytest.mark.asyncio
async def test_search_postgresql_only(db_session):
async def test_search_postgresql_only():
params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
assert results == []
assert total == 0
params_empty = SearchParameters(query_text=None)
results_empty, total_empty = await search_controller.search_transcripts(
db_session, params_empty
params_empty
)
assert isinstance(results_empty, list)
assert isinstance(total_empty, int)
@pytest.mark.asyncio
async def test_search_with_empty_query(db_session):
async def test_search_with_empty_query():
"""Test that empty query returns all transcripts."""
params = SearchParameters(query_text=None)
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
assert isinstance(results, list)
assert isinstance(total, int)
@@ -45,13 +45,13 @@ async def test_search_with_empty_query(db_session):
@pytest.mark.asyncio
async def test_empty_transcript_title_only_match(db_session):
async def test_empty_transcript_title_only_match():
"""Test that transcripts with title-only matches return empty snippets."""
test_id = "test-empty-9b3f2a8d"
try:
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
@@ -77,11 +77,10 @@ async def test_empty_transcript_title_only_match(db_session):
"user_id": "test-user-1",
}
await db_session.execute(insert(TranscriptModel).values(**test_data))
await db_session.commit()
await get_database().execute(transcripts.insert().values(**test_data))
params = SearchParameters(query_text="empty", user_id="test-user-1")
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = next((r for r in results if r.id == test_id), None)
@@ -90,20 +89,20 @@ async def test_empty_transcript_title_only_match(db_session):
assert found.total_match_count == 0
finally:
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await db_session.commit()
await get_database().disconnect()
@pytest.mark.asyncio
async def test_search_with_long_summary(db_session):
async def test_search_with_long_summary():
"""Test that long_summary content is searchable."""
test_id = "test-long-summary-8a9f3c2d"
try:
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
@@ -132,11 +131,10 @@ Basic meeting content without special keywords.""",
"user_id": "test-user-2",
}
await db_session.execute(insert(TranscriptModel).values(**test_data))
await db_session.commit()
await get_database().execute(transcripts.insert().values(**test_data))
params = SearchParameters(query_text="quantum computing", user_id="test-user-2")
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
@@ -148,19 +146,19 @@ Basic meeting content without special keywords.""",
assert "quantum computing" in test_result.search_snippets[0].lower()
finally:
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await db_session.commit()
await get_database().disconnect()
@pytest.mark.asyncio
async def test_postgresql_search_with_data(db_session):
async def test_postgresql_search_with_data():
test_id = "test-search-e2e-7f3a9b2c"
try:
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
@@ -198,17 +196,16 @@ We need to implement PostgreSQL tsvector for better performance.""",
"user_id": "test-user-3",
}
await db_session.execute(insert(TranscriptModel).values(**test_data))
await db_session.commit()
await get_database().execute(transcripts.insert().values(**test_data))
params = SearchParameters(query_text="planning", user_id="test-user-3")
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by title word"
params = SearchParameters(query_text="tsvector", user_id="test-user-3")
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by webvtt content"
@@ -216,7 +213,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters(
query_text="engineering planning", user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by multiple words"
@@ -231,7 +228,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters(
query_text="tsvector OR nosuchword", user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript with OR query"
@@ -239,16 +236,16 @@ We need to implement PostgreSQL tsvector for better performance.""",
params = SearchParameters(
query_text='"full-text search"', user_id="test-user-3"
)
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by exact phrase"
finally:
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await db_session.commit()
await get_database().disconnect()
@pytest.fixture
@@ -314,56 +311,87 @@ class TestSearchControllerFilters:
"""Test SearchController functionality with various filters."""
@pytest.mark.asyncio
async def test_search_with_source_kind_filter(self, db_session):
async def test_search_with_source_kind_filter(self):
"""Test search filtering by source_kind."""
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
# This should not fail, even if no results are found
results, total = await controller.search_transcripts(db_session, params)
results, total = await controller.search_transcripts(params)
assert isinstance(results, list)
assert isinstance(total, int)
assert total >= 0
assert results == []
assert total == 0
mock_db.return_value.fetch_all.assert_called_once()
@pytest.mark.asyncio
async def test_search_with_single_room_id(self, db_session):
async def test_search_with_single_room_id(self):
"""Test search filtering by single room ID (currently supported)."""
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
params = SearchParameters(
query_text="test",
room_id="room1",
)
# This should not fail, even if no results are found
results, total = await controller.search_transcripts(db_session, params)
results, total = await controller.search_transcripts(params)
assert isinstance(results, list)
assert isinstance(total, int)
assert total >= 0
assert results == []
assert total == 0
mock_db.return_value.fetch_all.assert_called_once()
@pytest.mark.asyncio
async def test_search_result_includes_available_fields(
self, db_session, mock_db_result
):
async def test_search_result_includes_available_fields(self, mock_db_result):
"""Test that search results include available fields like source_kind."""
# Test that the search method works and returns SearchResult objects
controller = SearchController()
with (
patch("reflector.db.search.is_postgresql", return_value=True),
patch("reflector.db.search.get_database") as mock_db,
):
class MockRow:
def __init__(self, data):
self._data = data
self._mapping = data
def __iter__(self):
return iter(self._data.items())
def __getitem__(self, key):
return self._data[key]
def keys(self):
return self._data.keys()
mock_row = MockRow(mock_db_result)
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
mock_db.return_value.fetch_val = AsyncMock(return_value=1)
params = SearchParameters(query_text="test")
results, total = await controller.search_transcripts(db_session, params)
results, total = await controller.search_transcripts(params)
assert isinstance(results, list)
assert isinstance(total, int)
assert total >= 0
assert total == 1
assert len(results) == 1
# If any results exist, verify they are SearchResult objects
for result in results:
result = results[0]
assert isinstance(result, SearchResult)
assert hasattr(result, "id")
assert hasattr(result, "title")
assert hasattr(result, "rank")
assert hasattr(result, "source_kind")
assert result.id == "test-transcript-id"
assert result.title == "Test Transcript"
assert result.rank == 0.95
class TestSearchEndpointParsing:

View File

@@ -4,21 +4,21 @@ import json
from datetime import datetime, timezone
import pytest
from sqlalchemy import delete, insert
from reflector.db.base import TranscriptModel
from reflector.db import get_database
from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts
@pytest.mark.asyncio
async def test_long_summary_snippet_prioritization(db_session):
async def test_long_summary_snippet_prioritization():
"""Test that snippets from long_summary are prioritized over webvtt content."""
test_id = "test-snippet-priority-3f9a2b8c"
try:
# Clean up any existing test data
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
@@ -57,11 +57,11 @@ We need to consider various implementation approaches.""",
"user_id": "test-user-priority",
}
await db_session.execute(insert(TranscriptModel).values(**test_data))
await get_database().execute(transcripts.insert().values(**test_data))
# Search for "robotics" which appears in both long_summary and webvtt
params = SearchParameters(query_text="robotics", user_id="test-user-priority")
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
assert total >= 1
test_result = next((r for r in results if r.id == test_id), None)
@@ -86,20 +86,20 @@ We need to consider various implementation approaches.""",
), f"Snippet should contain search term: {snippet}"
finally:
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await db_session.commit()
await get_database().disconnect()
@pytest.mark.asyncio
async def test_long_summary_only_search(db_session):
async def test_long_summary_only_search():
"""Test searching for content that only exists in long_summary."""
test_id = "test-long-only-8b3c9f2a"
try:
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
@@ -135,11 +135,11 @@ Discussion of timeline and deliverables.""",
"user_id": "test-user-long",
}
await db_session.execute(insert(TranscriptModel).values(**test_data))
await get_database().execute(transcripts.insert().values(**test_data))
# Search for terms only in long_summary
params = SearchParameters(query_text="cryptocurrency", user_id="test-user-long")
results, total = await search_controller.search_transcripts(db_session, params)
results, total = await search_controller.search_transcripts(params)
found = any(r.id == test_id for r in results)
assert found, "Should find transcript by long_summary-only content"
@@ -154,15 +154,13 @@ Discussion of timeline and deliverables.""",
# Search for "yield farming" - a more specific term
params2 = SearchParameters(query_text="yield farming", user_id="test-user-long")
results2, total2 = await search_controller.search_transcripts(
db_session, params2
)
results2, total2 = await search_controller.search_transcripts(params2)
found2 = any(r.id == test_id for r in results2)
assert found2, "Should find transcript by specific long_summary phrase"
finally:
await db_session.execute(
delete(TranscriptModel).where(TranscriptModel.id == test_id)
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await db_session.commit()
await get_database().disconnect()

View File

@@ -0,0 +1,384 @@
import asyncio
import shutil
import threading
import time
from pathlib import Path
import pytest
from httpx_ws import aconnect_ws
from uvicorn import Config, Server
from reflector import zulip as zulip_module
from reflector.app import app
from reflector.db import get_database
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import Room, rooms_controller
from reflector.db.transcripts import (
SourceKind,
TranscriptTopic,
transcripts_controller,
)
from reflector.processors.types import Word
from reflector.settings import settings
from reflector.views.transcripts import create_access_token
@pytest.mark.asyncio
async def test_anonymous_cannot_delete_transcript_in_shared_room(client):
# Create a shared room with a fake owner id so meeting has a room_id
room = await rooms_controller.add(
name="shared-room-test",
user_id="owner-1",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=True,
webhook_url="",
webhook_secret="",
)
# Create a meeting for that room (so transcript.meeting_id links to the shared room)
meeting = await meetings_controller.create(
id="meeting-sec-test",
room_name="room-sec-test",
room_url="room-url",
host_room_url="host-url",
start_date=Room.model_fields["created_at"].default_factory(),
end_date=Room.model_fields["created_at"].default_factory(),
room=room,
)
# Create a transcript owned by someone else and link it to meeting
t = await transcripts_controller.add(
name="to-delete",
source_kind=SourceKind.LIVE,
user_id="owner-2",
meeting_id=meeting.id,
room_id=room.id,
share_mode="private",
)
# Anonymous DELETE should be rejected
del_resp = await client.delete(f"/transcripts/{t.id}")
assert del_resp.status_code == 401, del_resp.text
@pytest.mark.asyncio
async def test_anonymous_cannot_mutate_participants_on_public_transcript(client):
# Create a public transcript with no owner
t = await transcripts_controller.add(
name="public-transcript",
source_kind=SourceKind.LIVE,
user_id=None,
share_mode="public",
)
# Anonymous POST participant must be rejected
resp = await client.post(
f"/transcripts/{t.id}/participants",
json={"name": "AnonUser", "speaker": 0},
)
assert resp.status_code == 401, resp.text
@pytest.mark.asyncio
async def test_anonymous_cannot_update_and_delete_room(client):
# Create room as owner id "owner-3" via controller
room = await rooms_controller.add(
name="room-anon-update-delete",
user_id="owner-3",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
webhook_url="",
webhook_secret="",
)
# Anonymous PATCH via API (no auth)
resp = await client.patch(
f"/rooms/{room.id}",
json={
"name": "room-anon-updated",
"zulip_auto_post": False,
"zulip_stream": "",
"zulip_topic": "",
"is_locked": False,
"room_mode": "normal",
"recording_type": "cloud",
"recording_trigger": "automatic-2nd-participant",
"is_shared": False,
"webhook_url": "",
"webhook_secret": "",
},
)
# Expect authentication required
assert resp.status_code == 401, resp.text
# Anonymous DELETE via API
del_resp = await client.delete(f"/rooms/{room.id}")
assert del_resp.status_code == 401, del_resp.text
@pytest.mark.asyncio
async def test_anonymous_cannot_post_transcript_to_zulip(client, monkeypatch):
# Create a public transcript with some content
t = await transcripts_controller.add(
name="zulip-public",
source_kind=SourceKind.LIVE,
user_id=None,
share_mode="public",
)
# Mock send/update calls
def _fake_send_message_to_zulip(stream, topic, content):
return {"id": 12345}
async def _fake_update_message(message_id, stream, topic, content):
return {"result": "success"}
monkeypatch.setattr(
zulip_module, "send_message_to_zulip", _fake_send_message_to_zulip
)
monkeypatch.setattr(zulip_module, "update_zulip_message", _fake_update_message)
# Anonymous POST to Zulip endpoint
resp = await client.post(
f"/transcripts/{t.id}/zulip",
params={"stream": "general", "topic": "Updates", "include_topics": False},
)
assert resp.status_code == 401, resp.text
@pytest.mark.asyncio
async def test_anonymous_cannot_assign_speaker_on_public_transcript(client):
# Create public transcript
t = await transcripts_controller.add(
name="public-assign",
source_kind=SourceKind.LIVE,
user_id=None,
share_mode="public",
)
# Add a topic with words to be reassigned
topic = TranscriptTopic(
title="T1",
summary="S1",
timestamp=0.0,
transcript="Hello",
words=[Word(start=0.0, end=1.0, text="Hello", speaker=0)],
)
transcript = await transcripts_controller.get_by_id(t.id)
await transcripts_controller.upsert_topic(transcript, topic)
# Anonymous assign speaker over time range covering the word
resp = await client.patch(
f"/transcripts/{t.id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0.0,
"timestamp_to": 1.0,
},
)
assert resp.status_code == 401, resp.text
# Minimal server fixture for websocket tests
@pytest.fixture
def appserver_ws_simple(setup_database):
host = "127.0.0.1"
port = 1256
server_started = threading.Event()
server_exception = None
server_instance = None
def run_server():
nonlocal server_exception, server_instance
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
config = Config(app=app, host=host, port=port, loop=loop)
server_instance = Server(config)
async def start_server():
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
server_started.set()
loop.run_until_complete(start_server())
except Exception as e:
server_exception = e
server_started.set()
finally:
loop.close()
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
server_started.wait(timeout=30)
if server_exception:
raise server_exception
time.sleep(0.5)
yield host, port
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=30)
@pytest.mark.asyncio
async def test_websocket_denies_anonymous_on_private_transcript(appserver_ws_simple):
host, port = appserver_ws_simple
# Create a private transcript owned by someone
t = await transcripts_controller.add(
name="private-ws",
source_kind=SourceKind.LIVE,
user_id="owner-x",
share_mode="private",
)
base_url = f"http://{host}:{port}/v1"
# Anonymous connect should be denied
with pytest.raises(Exception):
async with aconnect_ws(f"{base_url}/transcripts/{t.id}/events") as ws:
await ws.close()
@pytest.mark.asyncio
async def test_anonymous_cannot_update_public_transcript(client):
t = await transcripts_controller.add(
name="update-me",
source_kind=SourceKind.LIVE,
user_id=None,
share_mode="public",
)
resp = await client.patch(
f"/transcripts/{t.id}",
json={"title": "New Title From Anonymous"},
)
assert resp.status_code == 401, resp.text
@pytest.mark.asyncio
async def test_anonymous_cannot_get_nonshared_room_by_id(client):
room = await rooms_controller.add(
name="private-room-exposed",
user_id="owner-z",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
webhook_url="",
webhook_secret="",
)
resp = await client.get(f"/rooms/{room.id}")
assert resp.status_code == 403, resp.text
@pytest.mark.asyncio
async def test_anonymous_cannot_call_rooms_webhook_test(client):
room = await rooms_controller.add(
name="room-webhook-test",
user_id="owner-y",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
webhook_url="http://localhost.invalid/webhook",
webhook_secret="secret",
)
# Anonymous caller
resp = await client.post(f"/rooms/{room.id}/webhook/test")
assert resp.status_code == 401, resp.text
@pytest.mark.asyncio
async def test_anonymous_cannot_create_room(client):
payload = {
"name": "room-create-auth-required",
"zulip_auto_post": False,
"zulip_stream": "",
"zulip_topic": "",
"is_locked": False,
"room_mode": "normal",
"recording_type": "cloud",
"recording_trigger": "automatic-2nd-participant",
"is_shared": False,
"webhook_url": "",
"webhook_secret": "",
}
resp = await client.post("/rooms", json=payload)
assert resp.status_code == 401, resp.text
@pytest.mark.asyncio
async def test_list_search_401_when_public_mode_false(client, monkeypatch):
monkeypatch.setattr(settings, "PUBLIC_MODE", False)
resp = await client.get("/transcripts")
assert resp.status_code == 401
resp = await client.get("/transcripts/search", params={"q": "hello"})
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_audio_mp3_requires_token_for_owned_transcript(
client, tmpdir, monkeypatch
):
# Use temp data dir
monkeypatch.setattr(settings, "DATA_DIR", Path(tmpdir).as_posix())
# Create owner transcript and attach a local mp3
t = await transcripts_controller.add(
name="owned-audio",
source_kind=SourceKind.LIVE,
user_id="owner-a",
share_mode="private",
)
tr = await transcripts_controller.get_by_id(t.id)
await transcripts_controller.update(tr, {"status": "ended"})
# copy fixture audio to transcript path
audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
tr.audio_mp3_filename.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(audio_path, tr.audio_mp3_filename)
# Anonymous GET without token should be 403 or 404 depending on access; we call mp3
resp = await client.get(f"/transcripts/{t.id}/audio/mp3")
assert resp.status_code == 403
# With token should succeed
token = create_access_token(
{"sub": tr.user_id}, expires_delta=__import__("datetime").timedelta(minutes=15)
)
resp2 = await client.get(f"/transcripts/{t.id}/audio/mp3", params={"token": token})
assert resp2.status_code == 200

View File

@@ -1,5 +1,3 @@
from contextlib import asynccontextmanager
import pytest
@@ -19,7 +17,7 @@ async def test_transcript_create(client):
@pytest.mark.asyncio
async def test_transcript_get_update_name(client):
async def test_transcript_get_update_name(authenticated_client, client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
@@ -40,7 +38,7 @@ async def test_transcript_get_update_name(client):
@pytest.mark.asyncio
async def test_transcript_get_update_locked(client):
async def test_transcript_get_update_locked(authenticated_client, client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["locked"] is False
@@ -61,7 +59,7 @@ async def test_transcript_get_update_locked(client):
@pytest.mark.asyncio
async def test_transcript_get_update_summary(client):
async def test_transcript_get_update_summary(authenticated_client, client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["long_summary"] is None
@@ -89,7 +87,7 @@ async def test_transcript_get_update_summary(client):
@pytest.mark.asyncio
async def test_transcript_get_update_title(client):
async def test_transcript_get_update_title(authenticated_client, client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["title"] is None
@@ -127,56 +125,6 @@ async def test_transcripts_list_anonymous(client):
settings.PUBLIC_MODE = False
@asynccontextmanager
async def authenticated_client_ctx():
from reflector.app import app
from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {
"sub": "randomuserid",
"email": "test@mail.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "randomuserid",
"email": "test@mail.com",
}
yield
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@asynccontextmanager
async def authenticated_client2_ctx():
from reflector.app import app
from reflector.auth import current_user, current_user_optional
app.dependency_overrides[current_user] = lambda: {
"sub": "randomuserid2",
"email": "test@mail.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "randomuserid2",
"email": "test@mail.com",
}
yield
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client():
async with authenticated_client_ctx():
yield
@pytest.fixture
@pytest.mark.asyncio
async def authenticated_client2():
async with authenticated_client2_ctx():
yield
@pytest.mark.asyncio
async def test_transcripts_list_authenticated(authenticated_client, client):
# XXX this test is a bit fragile, as it depends on the storage which
@@ -199,7 +147,7 @@ async def test_transcripts_list_authenticated(authenticated_client, client):
@pytest.mark.asyncio
async def test_transcript_delete(client):
async def test_transcript_delete(authenticated_client, client):
response = await client.post("/transcripts", json={"name": "testdel1"})
assert response.status_code == 200
assert response.json()["name"] == "testdel1"
@@ -214,7 +162,7 @@ async def test_transcript_delete(client):
@pytest.mark.asyncio
async def test_transcript_mark_reviewed(client):
async def test_transcript_mark_reviewed(authenticated_client, client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"

View File

@@ -5,7 +5,7 @@ import pytest
@pytest.fixture
async def fake_transcript(tmpdir, client, db_session):
async def fake_transcript(tmpdir, client):
from reflector.settings import settings
from reflector.views.transcripts import transcripts_controller
@@ -16,10 +16,10 @@ async def fake_transcript(tmpdir, client, db_session):
assert response.status_code == 200
tid = response.json()["id"]
transcript = await transcripts_controller.get_by_id(db_session, tid)
transcript = await transcripts_controller.get_by_id(tid)
assert transcript is not None
await transcripts_controller.update(db_session, transcript, {"status": "ended"})
await transcripts_controller.update(transcript, {"status": "ended"})
# manually copy a file at the expected location
audio_filename = transcript.audio_mp3_filename
@@ -111,7 +111,9 @@ async def test_transcript_audio_download_range_with_seek(
@pytest.mark.asyncio
async def test_transcript_delete_with_audio(fake_transcript, client):
async def test_transcript_delete_with_audio(
authenticated_client, fake_transcript, client
):
response = await client.delete(f"/transcripts/{fake_transcript.id}")
assert response.status_code == 200
assert response.json()["status"] == "ok"

View File

@@ -2,7 +2,7 @@ import pytest
@pytest.mark.asyncio
async def test_transcript_participants(client):
async def test_transcript_participants(authenticated_client, client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
@@ -39,7 +39,7 @@ async def test_transcript_participants(client):
@pytest.mark.asyncio
async def test_transcript_participants_same_speaker(client):
async def test_transcript_participants_same_speaker(authenticated_client, client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
@@ -62,7 +62,7 @@ async def test_transcript_participants_same_speaker(client):
@pytest.mark.asyncio
async def test_transcript_participants_update_name(client):
async def test_transcript_participants_update_name(authenticated_client, client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
@@ -100,7 +100,7 @@ async def test_transcript_participants_update_name(client):
@pytest.mark.asyncio
async def test_transcript_participants_update_speaker(client):
async def test_transcript_participants_update_speaker(authenticated_client, client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []

View File

@@ -1,11 +1,9 @@
import os
import asyncio
import time
import pytest
from httpx import ASGITransport, AsyncClient
# Set environment for TaskIQ to use InMemoryBroker
os.environ["ENVIRONMENT"] = "pytest"
@pytest.fixture
async def app_lifespan():
@@ -25,16 +23,9 @@ async def client(app_lifespan):
)
@pytest.fixture
async def taskiq_broker():
from reflector.worker.app import taskiq_broker
# Broker is already initialized as InMemoryBroker due to ENVIRONMENT=pytest
await taskiq_broker.startup()
yield taskiq_broker
await taskiq_broker.shutdown()
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_process(
tmpdir,
@@ -44,10 +35,7 @@ async def test_transcript_process(
dummy_file_diarization,
dummy_storage,
client,
taskiq_broker,
db_session,
):
print("IN TEST", db_session)
# create a transcript
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
@@ -68,14 +56,18 @@ async def test_transcript_process(
assert response.status_code == 200
assert response.json()["status"] == "ok"
# Wait for all tasks to complete since we're using InMemoryBroker
await taskiq_broker.wait_all()
# Ensure it's finished ok
# wait for processing to finish (max 1 minute)
timeout_seconds = 60
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
print(resp.json())
assert resp.json()["status"] in ("ended", "error")
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
# restart the processing
response = await client.post(
@@ -83,15 +75,20 @@ async def test_transcript_process(
)
assert response.status_code == 200
assert response.json()["status"] == "ok"
await asyncio.sleep(2)
# Wait for all tasks to complete since we're using InMemoryBroker
await taskiq_broker.wait_all()
# Ensure it's finished ok
# wait for processing to finish (max 1 minute)
timeout_seconds = 60
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
print(resp.json())
assert resp.json()["status"] in ("ended", "error")
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
# check the transcript is ended
transcript = resp.json()

View File

@@ -2,84 +2,33 @@ from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from sqlalchemy import insert
from reflector.db.base import MeetingModel, RoomModel
from reflector.db.recordings import recordings_controller
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.transcripts import SourceKind, transcripts_controller
@pytest.mark.asyncio
async def test_recording_deleted_with_transcript(db_session):
"""Test that a recording is deleted when its associated transcript is deleted."""
# First create a room and meeting to satisfy foreign key constraints
room_id = "test-room"
await db_session.execute(
insert(RoomModel).values(
id=room_id,
name="test-room",
user_id="test-user",
created_at=datetime.now(timezone.utc),
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic",
is_shared=False,
)
)
meeting_id = "test-meeting"
await db_session.execute(
insert(MeetingModel).values(
id=meeting_id,
room_id=room_id,
room_name="test-room",
room_url="https://example.com/room",
host_room_url="https://example.com/room-host",
start_date=datetime.now(timezone.utc),
end_date=datetime.now(timezone.utc),
is_active=False,
num_clients=0,
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic",
)
)
await db_session.commit()
# Now create a recording
async def test_recording_deleted_with_transcript():
recording = await recordings_controller.create(
db_session,
meeting_id=meeting_id,
url="https://example.com/recording.mp4",
object_key="recordings/test.mp4",
duration=3600.0,
created_at=datetime.now(timezone.utc),
Recording(
bucket_name="test-bucket",
object_key="recording.mp4",
recorded_at=datetime.now(timezone.utc),
)
)
# Create a transcript associated with the recording
transcript = await transcripts_controller.add(
db_session,
name="Test Transcript",
source_kind=SourceKind.ROOM,
recording_id=recording.id,
)
# Mock the storage deletion
with patch("reflector.db.transcripts.get_recordings_storage") as mock_get_storage:
storage_instance = mock_get_storage.return_value
storage_instance.delete_file = AsyncMock()
# Delete the transcript
await transcripts_controller.remove_by_id(db_session, transcript.id)
await transcripts_controller.remove_by_id(transcript.id)
# Verify that the recording file was deleted from storage
storage_instance.delete_file.assert_awaited_once_with(recording.object_key)
# Verify both the recording and transcript are deleted
assert await recordings_controller.get_by_id(db_session, recording.id) is None
assert await transcripts_controller.get_by_id(db_session, transcript.id) is None
assert await recordings_controller.get_by_id(recording.id) is None
assert await transcripts_controller.get_by_id(transcript.id) is None

View File

@@ -49,12 +49,11 @@ class ThreadedUvicorn:
@pytest.fixture
def appserver(tmpdir, database):
def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
import threading
from reflector.app import app
# Database connection handled by SQLAlchemy engine
from reflector.db import get_database
from reflector.settings import settings
DATA_DIR = settings.DATA_DIR
@@ -78,8 +77,13 @@ def appserver(tmpdir, database):
server_instance = Server(config)
async def start_server():
# Database connections managed by SQLAlchemy engine
# Initialize database connection in this event loop
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
# Signal that server is starting
server_started.set()
@@ -111,6 +115,14 @@ def appserver(tmpdir, database):
settings.DATA_DIR = DATA_DIR
@pytest.fixture(scope="session")
def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket(
tmpdir,
@@ -156,7 +168,7 @@ async def test_transcript_rtc_and_websocket(
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
await ws.close()
ws.close()
print("Test websocket: DISCONNECTED")
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
@@ -273,6 +285,9 @@ async def test_transcript_rtc_and_websocket(
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_rtc_and_websocket_and_fr(
tmpdir,

View File

@@ -2,7 +2,9 @@ import pytest
@pytest.mark.asyncio
async def test_transcript_reassign_speaker(fake_transcript_with_topics, client):
async def test_transcript_reassign_speaker(
authenticated_client, fake_transcript_with_topics, client
):
transcript_id = fake_transcript_with_topics.id
# check the transcript exists
@@ -114,7 +116,9 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics, client):
@pytest.mark.asyncio
async def test_transcript_merge_speaker(fake_transcript_with_topics, client):
async def test_transcript_merge_speaker(
authenticated_client, fake_transcript_with_topics, client
):
transcript_id = fake_transcript_with_topics.id
# check the transcript exists
@@ -181,7 +185,7 @@ async def test_transcript_merge_speaker(fake_transcript_with_topics, client):
@pytest.mark.asyncio
async def test_transcript_reassign_with_participant(
fake_transcript_with_topics, client
authenticated_client, fake_transcript_with_topics, client
):
transcript_id = fake_transcript_with_topics.id
@@ -347,7 +351,9 @@ async def test_transcript_reassign_with_participant(
@pytest.mark.asyncio
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics, client):
async def test_transcript_reassign_edge_cases(
authenticated_client, fake_transcript_with_topics, client
):
transcript_id = fake_transcript_with_topics.id
# check the transcript exists

View File

@@ -4,6 +4,9 @@ import time
import pytest
@pytest.mark.usefixtures("setup_database")
@pytest.mark.usefixtures("celery_session_app")
@pytest.mark.usefixtures("celery_session_worker")
@pytest.mark.asyncio
async def test_transcript_upload_file(
tmpdir,

View File

@@ -0,0 +1,156 @@
import asyncio
import threading
import time
import pytest
from httpx_ws import aconnect_ws
from uvicorn import Config, Server
@pytest.fixture
def appserver_ws_user(setup_database):
from reflector.app import app
from reflector.db import get_database
host = "127.0.0.1"
port = 1257
server_started = threading.Event()
server_exception = None
server_instance = None
def run_server():
nonlocal server_exception, server_instance
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
config = Config(app=app, host=host, port=port, loop=loop)
server_instance = Server(config)
async def start_server():
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
server_started.set()
loop.run_until_complete(start_server())
except Exception as e:
server_exception = e
server_started.set()
finally:
loop.close()
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
server_started.wait(timeout=30)
if server_exception:
raise server_exception
time.sleep(0.5)
yield host, port
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=30)
@pytest.fixture(autouse=True)
def patch_jwt_verification(monkeypatch):
"""Patch JWT verification to accept HS256 tokens signed with SECRET_KEY for tests."""
from jose import jwt
from reflector.settings import settings
def _verify_token(self, token: str):
# Do not validate audience in tests
return jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) # type: ignore[arg-type]
monkeypatch.setattr(
"reflector.auth.auth_jwt.JWTAuth.verify_token", _verify_token, raising=True
)
def _make_dummy_jwt(sub: str = "user123") -> str:
# Create a short HS256 JWT using the app secret to pass verification in tests
from datetime import datetime, timedelta, timezone
from jose import jwt
from reflector.settings import settings
payload = {
"sub": sub,
"email": f"{sub}@example.com",
"exp": datetime.now(timezone.utc) + timedelta(minutes=5),
}
# Note: production uses RS256 public key verification; tests can sign with SECRET_KEY
return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")
@pytest.mark.asyncio
async def test_user_ws_rejects_missing_subprotocol(appserver_ws_user):
host, port = appserver_ws_user
base_ws = f"http://{host}:{port}/v1/events"
# No subprotocol/header with token
with pytest.raises(Exception):
async with aconnect_ws(base_ws) as ws: # type: ignore
# Should close during handshake; if not, close explicitly
await ws.close()
@pytest.mark.asyncio
async def test_user_ws_rejects_invalid_token(appserver_ws_user):
host, port = appserver_ws_user
base_ws = f"http://{host}:{port}/v1/events"
# Send wrong token via WebSocket subprotocols
protocols = ["bearer", "totally-invalid-token"]
with pytest.raises(Exception):
async with aconnect_ws(base_ws, subprotocols=protocols) as ws: # type: ignore
await ws.close()
@pytest.mark.asyncio
async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user):
host, port = appserver_ws_user
base_ws = f"http://{host}:{port}/v1/events"
token = _make_dummy_jwt("user-abc")
subprotocols = ["bearer", token]
# Connect and then trigger an event via HTTP create
async with aconnect_ws(base_ws, subprotocols=subprotocols) as ws:
# Emit an event to the user's room via a standard HTTP action
from httpx import AsyncClient
from reflector.app import app
from reflector.auth import current_user, current_user_optional
# Override auth dependencies so HTTP request is performed as the same user
app.dependency_overrides[current_user] = lambda: {
"sub": "user-abc",
"email": "user-abc@example.com",
}
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "user-abc",
"email": "user-abc@example.com",
}
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
# Create a transcript as this user so that the server publishes TRANSCRIPT_CREATED to user room
resp = await ac.post("/transcripts", json={"name": "WS Test"})
assert resp.status_code == 200
# Receive the published event
msg = await ws.receive_json()
assert msg["event"] == "TRANSCRIPT_CREATED"
assert "id" in msg["data"]
# Clean overrides
del app.dependency_overrides[current_user]
del app.dependency_overrides[current_user_optional]

View File

@@ -1,14 +1,13 @@
"""Integration tests for WebVTT auto-update functionality in Transcript model."""
import pytest
from sqlalchemy import select
from reflector.db.base import TranscriptModel
from reflector.db import get_database
from reflector.db.transcripts import (
SourceKind,
TranscriptController,
TranscriptTopic,
transcripts_controller,
transcripts,
)
from reflector.processors.types import Word
@@ -17,35 +16,30 @@ from reflector.processors.types import Word
class TestWebVTTAutoUpdate:
"""Test that WebVTT field auto-updates when Transcript is created or modified."""
async def test_webvtt_not_updated_on_transcript_creation_without_topics(
self, db_session
):
async def test_webvtt_not_updated_on_transcript_creation_without_topics(self):
"""WebVTT should be None when creating transcript without topics."""
# Using global transcripts_controller
controller = TranscriptController()
transcript = await transcripts_controller.add(
db_session,
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
try:
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
row = result.scalar_one_or_none()
assert row is not None
assert row.webvtt is None
assert result is not None
assert result["webvtt"] is None
finally:
await transcripts_controller.remove_by_id(db_session, transcript.id)
await controller.remove_by_id(transcript.id)
async def test_webvtt_updated_on_upsert_topic(self, db_session):
async def test_webvtt_updated_on_upsert_topic(self):
"""WebVTT should update when upserting topics via upsert_topic method."""
# Using global transcripts_controller
controller = TranscriptController()
transcript = await transcripts_controller.add(
db_session,
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
@@ -62,15 +56,14 @@ class TestWebVTTAutoUpdate:
],
)
await transcripts_controller.upsert_topic(db_session, transcript, topic)
await controller.upsert_topic(transcript, topic)
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
row = result.scalar_one_or_none()
assert row is not None
webvtt = row.webvtt
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "WEBVTT" in webvtt
@@ -78,14 +71,13 @@ class TestWebVTTAutoUpdate:
assert "<v Speaker0>" in webvtt
finally:
await transcripts_controller.remove_by_id(db_session, transcript.id)
await controller.remove_by_id(transcript.id)
async def test_webvtt_updated_on_direct_topics_update(self, db_session):
async def test_webvtt_updated_on_direct_topics_update(self):
"""WebVTT should update when updating topics field directly."""
# Using global transcripts_controller
controller = TranscriptController()
transcript = await transcripts_controller.add(
db_session,
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
@@ -104,32 +96,28 @@ class TestWebVTTAutoUpdate:
}
]
await transcripts_controller.update(
db_session, transcript, {"topics": topics_data}
)
await controller.update(transcript, {"topics": topics_data})
# Fetch from DB
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
row = result.scalar_one_or_none()
assert row is not None
webvtt = row.webvtt
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "WEBVTT" in webvtt
assert "First sentence" in webvtt
finally:
await transcripts_controller.remove_by_id(db_session, transcript.id)
await controller.remove_by_id(transcript.id)
async def test_webvtt_updated_manually_with_handle_topics_update(self, db_session):
async def test_webvtt_updated_manually_with_handle_topics_update(self):
"""Test that _handle_topics_update works when called manually."""
# Using global transcripts_controller
controller = TranscriptController()
transcript = await transcripts_controller.add(
db_session,
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
@@ -150,16 +138,15 @@ class TestWebVTTAutoUpdate:
values = {"topics": transcript.topics_dump()}
await transcripts_controller.update(db_session, transcript, values)
await controller.update(transcript, values)
# Fetch from DB
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
row = result.scalar_one_or_none()
assert row is not None
webvtt = row.webvtt
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "WEBVTT" in webvtt
@@ -167,14 +154,13 @@ class TestWebVTTAutoUpdate:
assert "<v Speaker0>" in webvtt
finally:
await transcripts_controller.remove_by_id(db_session, transcript.id)
await controller.remove_by_id(transcript.id)
async def test_webvtt_update_with_non_sequential_topics_fails(self, db_session):
async def test_webvtt_update_with_non_sequential_topics_fails(self):
"""Test that non-sequential topics raise assertion error."""
# Using global transcripts_controller
controller = TranscriptController()
transcript = await transcripts_controller.add(
db_session,
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
@@ -200,14 +186,13 @@ class TestWebVTTAutoUpdate:
assert "Words are not in sequence" in str(exc_info.value)
finally:
await transcripts_controller.remove_by_id(db_session, transcript.id)
await controller.remove_by_id(transcript.id)
async def test_multiple_speakers_in_webvtt(self, db_session):
async def test_multiple_speakers_in_webvtt(self):
"""Test WebVTT generation with multiple speakers."""
# Using global transcripts_controller
controller = TranscriptController()
transcript = await transcripts_controller.add(
db_session,
transcript = await controller.add(
name="Test Transcript",
source_kind=SourceKind.FILE,
)
@@ -228,16 +213,15 @@ class TestWebVTTAutoUpdate:
transcript.upsert_topic(topic)
values = {"topics": transcript.topics_dump()}
await transcripts_controller.update(db_session, transcript, values)
await controller.update(transcript, values)
# Fetch from DB
result = await db_session.execute(
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
result = await get_database().fetch_one(
transcripts.select().where(transcripts.c.id == transcript.id)
)
row = result.scalar_one_or_none()
assert row is not None
webvtt = row.webvtt
assert result is not None
webvtt = result["webvtt"]
assert webvtt is not None
assert "<v Speaker0>" in webvtt
@@ -247,4 +231,4 @@ class TestWebVTTAutoUpdate:
assert "Goodbye" in webvtt
finally:
await transcripts_controller.remove_by_id(db_session, transcript.id)
await controller.remove_by_id(transcript.id)

3394
server/uv.lock generated

File diff suppressed because it is too large Load Diff

14
www/.dockerignore Normal file
View File

@@ -0,0 +1,14 @@
.env
.env.*
.env.local
.env.development
.env.production
node_modules
.next
.git
.gitignore
*.md
.DS_Store
coverage
.pnpm-store
*.log

View File

@@ -1,9 +1,5 @@
# Environment
ENVIRONMENT=development
NEXT_PUBLIC_ENV=development
# Site Configuration
NEXT_PUBLIC_SITE_URL=http://localhost:3000
SITE_URL=http://localhost:3000
# Nextauth envs
# not used in app code but in lib code
@@ -18,16 +14,16 @@ AUTHENTIK_CLIENT_ID=your-client-id-here
AUTHENTIK_CLIENT_SECRET=your-client-secret-here
# Feature Flags
# NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN=true
# NEXT_PUBLIC_FEATURE_PRIVACY=false
# NEXT_PUBLIC_FEATURE_BROWSE=true
# NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP=true
# NEXT_PUBLIC_FEATURE_ROOMS=true
# FEATURE_REQUIRE_LOGIN=true
# FEATURE_PRIVACY=false
# FEATURE_BROWSE=true
# FEATURE_SEND_TO_ZULIP=true
# FEATURE_ROOMS=true
# API URLs
NEXT_PUBLIC_API_URL=http://127.0.0.1:1250
NEXT_PUBLIC_WEBSOCKET_URL=ws://127.0.0.1:1250
NEXT_PUBLIC_AUTH_CALLBACK_URL=http://localhost:3000/auth-callback
API_URL=http://127.0.0.1:1250
WEBSOCKET_URL=ws://127.0.0.1:1250
AUTH_CALLBACK_URL=http://localhost:3000/auth-callback
# Sentry
# SENTRY_DSN=https://your-dsn@sentry.io/project-id

81
www/DOCKER_README.md Normal file
View File

@@ -0,0 +1,81 @@
# Docker Production Build Guide
## Overview
The Docker image builds without any environment variables and requires all configuration to be provided at runtime.
## Environment Variables (ALL Runtime)
### Required Runtime Variables
```bash
API_URL # Backend API URL (e.g., https://api.example.com)
WEBSOCKET_URL # WebSocket URL (e.g., wss://api.example.com)
NEXTAUTH_URL # NextAuth base URL (e.g., https://app.example.com)
NEXTAUTH_SECRET # Random secret for NextAuth (generate with: openssl rand -base64 32)
KV_URL # Redis URL (e.g., redis://redis:6379)
```
### Optional Runtime Variables
```bash
SITE_URL # Frontend URL (defaults to NEXTAUTH_URL)
AUTHENTIK_ISSUER # OAuth issuer URL
AUTHENTIK_CLIENT_ID # OAuth client ID
AUTHENTIK_CLIENT_SECRET # OAuth client secret
AUTHENTIK_REFRESH_TOKEN_URL # OAuth token refresh URL
FEATURE_REQUIRE_LOGIN=false # Require authentication
FEATURE_PRIVACY=true # Enable privacy features
FEATURE_BROWSE=true # Enable browsing features
FEATURE_SEND_TO_ZULIP=false # Enable Zulip integration
FEATURE_ROOMS=true # Enable rooms feature
SENTRY_DSN # Sentry error tracking
AUTH_CALLBACK_URL # OAuth callback URL
```
## Building the Image
### Option 1: Using Docker Compose
1. Build the image (no environment variables needed):
```bash
docker compose -f docker-compose.prod.yml build
```
2. Create a `.env` file with runtime variables
3. Run with environment variables:
```bash
docker compose -f docker-compose.prod.yml --env-file .env up -d
```
### Option 2: Using Docker CLI
1. Build the image (no build args):
```bash
docker build -t reflector-frontend:latest ./www
```
2. Run with environment variables:
```bash
docker run -d \
-p 3000:3000 \
-e API_URL=https://api.example.com \
-e WEBSOCKET_URL=wss://api.example.com \
-e NEXTAUTH_URL=https://app.example.com \
-e NEXTAUTH_SECRET=your-secret \
-e KV_URL=redis://redis:6379 \
-e AUTHENTIK_ISSUER=https://auth.example.com/application/o/reflector \
-e AUTHENTIK_CLIENT_ID=your-client-id \
-e AUTHENTIK_CLIENT_SECRET=your-client-secret \
-e AUTHENTIK_REFRESH_TOKEN_URL=https://auth.example.com/application/o/token/ \
-e FEATURE_REQUIRE_LOGIN=true \
reflector-frontend:latest
```

View File

@@ -24,7 +24,8 @@ COPY --link . .
ENV NEXT_TELEMETRY_DISABLED 1
# If using npm comment out above and use below instead
RUN pnpm build
# next.js has the feature of excluding build step planned https://github.com/vercel/next.js/discussions/46544
RUN pnpm build-production
# RUN npm run build
# Production image, copy all the files and run next
@@ -51,6 +52,10 @@ USER nextjs
EXPOSE 3000
ENV PORT 3000
ENV HOSTNAME localhost
ENV HOSTNAME 0.0.0.0
HEALTHCHECK --interval=30s --timeout=3s --start-period=40s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://127.0.0.1:3000/api/health \
|| exit 1
CMD ["node", "server.js"]

View File

@@ -200,7 +200,13 @@ export default function ICSSettings({
<HStack gap={0} position="relative" width="100%">
<Input
ref={roomUrlInputRef}
value={roomAbsoluteUrl(parseNonEmptyString(roomName))}
value={roomAbsoluteUrl(
parseNonEmptyString(
roomName,
true,
"panic! roomName is required",
),
)}
readOnly
onClick={handleRoomUrlClick}
cursor="pointer"

View File

@@ -274,15 +274,31 @@ export function RoomTable({
<IconButton
aria-label="Force sync calendar"
onClick={() =>
handleForceSync(parseNonEmptyString(room.name))
handleForceSync(
parseNonEmptyString(
room.name,
true,
"panic! room.name is required",
),
)
}
size="sm"
variant="ghost"
disabled={syncingRooms.has(
parseNonEmptyString(room.name),
parseNonEmptyString(
room.name,
true,
"panic! room.name is required",
),
)}
>
{syncingRooms.has(parseNonEmptyString(room.name)) ? (
{syncingRooms.has(
parseNonEmptyString(
room.name,
true,
"panic! room.name is required",
),
) ? (
<Spinner size="sm" />
) : (
<CalendarSyncIcon />
@@ -297,7 +313,13 @@ export function RoomTable({
<IconButton
aria-label="Copy URL"
onClick={() =>
onCopyUrl(parseNonEmptyString(room.name))
onCopyUrl(
parseNonEmptyString(
room.name,
true,
"panic! room.name is required",
),
)
}
size="sm"
variant="ghost"

View File

@@ -833,7 +833,13 @@ export default function RoomsList() {
<Field.Root>
<ICSSettings
roomName={
room.name ? parseNonEmptyString(room.name) : null
room.name
? parseNonEmptyString(
room.name,
true,
"panic! room.name required",
)
: null
}
icsUrl={room.icsUrl}
icsEnabled={room.icsEnabled}

View File

@@ -62,7 +62,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
useEffect(() => {
document.onkeyup = (e) => {
if (e.key === "a" && process.env.NEXT_PUBLIC_ENV === "development") {
if (e.key === "a" && process.env.NODE_ENV === "development") {
const segments: GetTranscriptSegmentTopic[] = [
{
speaker: 1,
@@ -201,7 +201,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
setFinalSummary({ summary: "This is the final summary" });
}
if (e.key === "z" && process.env.NEXT_PUBLIC_ENV === "development") {
if (e.key === "z" && process.env.NODE_ENV === "development") {
setTranscriptTextLive(
"This text is in English, and it is a pretty long sentence to test the limits",
);

View File

@@ -261,7 +261,11 @@ export default function Room(details: RoomDetails) {
const params = use(details.params);
const wherebyLoaded = useWhereby();
const wherebyRef = useRef<HTMLElement>(null);
const roomName = parseNonEmptyString(params.roomName);
const roomName = parseNonEmptyString(
params.roomName,
true,
"panic! params.roomName is required",
);
const router = useRouter();
const auth = useAuth();
const status = auth.status;
@@ -308,7 +312,14 @@ export default function Room(details: RoomDetails) {
const handleMeetingSelect = (selectedMeeting: Meeting) => {
router.push(
roomMeetingUrl(roomName, parseNonEmptyString(selectedMeeting.id)),
roomMeetingUrl(
roomName,
parseNonEmptyString(
selectedMeeting.id,
true,
"panic! selectedMeeting.id is required",
),
),
);
};

View File

@@ -0,0 +1,38 @@
import { NextResponse } from "next/server";
export async function GET() {
const health = {
status: "healthy",
timestamp: new Date().toISOString(),
uptime: process.uptime(),
environment: process.env.NODE_ENV,
checks: {
redis: await checkRedis(),
},
};
const allHealthy = Object.values(health.checks).every((check) => check);
return NextResponse.json(health, {
status: allHealthy ? 200 : 503,
});
}
async function checkRedis(): Promise<boolean> {
try {
if (!process.env.KV_URL) {
return false;
}
const { tokenCacheRedis } = await import("../../lib/redisClient");
const testKey = `health:check:${Date.now()}`;
await tokenCacheRedis.setex(testKey, 10, "OK");
const value = await tokenCacheRedis.get(testKey);
await tokenCacheRedis.del(testKey);
return value === "OK";
} catch (error) {
console.error("Redis health check failed:", error);
return false;
}
}

View File

@@ -6,7 +6,10 @@ import ErrorMessage from "./(errors)/errorMessage";
import { RecordingConsentProvider } from "./recordingConsentContext";
import { ErrorBoundary } from "@sentry/nextjs";
import { Providers } from "./providers";
import { assertExistsAndNonEmptyString } from "./lib/utils";
import { getNextEnvVar } from "./lib/nextBuild";
import { getClientEnv } from "./lib/clientEnv";
export const dynamic = "force-dynamic";
const poppins = Poppins({
subsets: ["latin"],
@@ -21,13 +24,11 @@ export const viewport: Viewport = {
maximumScale: 1,
};
const NEXT_PUBLIC_SITE_URL = assertExistsAndNonEmptyString(
process.env.NEXT_PUBLIC_SITE_URL,
"NEXT_PUBLIC_SITE_URL required",
);
const SITE_URL = getNextEnvVar("SITE_URL");
const env = getClientEnv();
export const metadata: Metadata = {
metadataBase: new URL(NEXT_PUBLIC_SITE_URL),
metadataBase: new URL(SITE_URL),
title: {
template: "%s Reflector",
default: "Reflector - AI-Powered Meeting Transcriptions by Monadical",
@@ -74,15 +75,16 @@ export default async function RootLayout({
}) {
return (
<html lang="en" className={poppins.className} suppressHydrationWarning>
<body className={"h-[100svh] w-[100svw] overflow-x-hidden relative"}>
<RecordingConsentProvider>
<body
className={"h-[100svh] w-[100svw] overflow-x-hidden relative"}
data-env={JSON.stringify(env)}
>
<ErrorBoundary fallback={<p>"something went really wrong"</p>}>
<ErrorProvider>
<ErrorMessage />
<Providers>{children}</Providers>
</ErrorProvider>
</ErrorBoundary>
</RecordingConsentProvider>
</body>
</html>
);

View File

@@ -0,0 +1,180 @@
"use client";
import React, { useEffect, useRef } from "react";
import { useQueryClient } from "@tanstack/react-query";
import { WEBSOCKET_URL } from "./apiClient";
import { useAuth } from "./AuthProvider";
import { z } from "zod";
import { invalidateTranscriptLists, TRANSCRIPT_SEARCH_URL } from "./apiHooks";
const UserEvent = z.object({
event: z.string(),
});
type UserEvent = z.TypeOf<typeof UserEvent>;
class UserEventsStore {
private socket: WebSocket | null = null;
private listeners: Set<(event: MessageEvent) => void> = new Set();
private closeTimeoutId: number | null = null;
private isConnecting = false;
ensureConnection(url: string, subprotocols?: string[]) {
if (typeof window === "undefined") return;
if (this.closeTimeoutId !== null) {
clearTimeout(this.closeTimeoutId);
this.closeTimeoutId = null;
}
if (this.isConnecting) return;
if (
this.socket &&
(this.socket.readyState === WebSocket.OPEN ||
this.socket.readyState === WebSocket.CONNECTING)
) {
return;
}
this.isConnecting = true;
const ws = new WebSocket(url, subprotocols || []);
this.socket = ws;
ws.onmessage = (event: MessageEvent) => {
this.listeners.forEach((listener) => {
try {
listener(event);
} catch (err) {
console.error("UserEvents listener error", err);
}
});
};
ws.onopen = () => {
if (this.socket === ws) this.isConnecting = false;
};
ws.onclose = () => {
if (this.socket === ws) {
this.socket = null;
this.isConnecting = false;
}
};
ws.onerror = () => {
if (this.socket === ws) this.isConnecting = false;
};
}
subscribe(listener: (event: MessageEvent) => void): () => void {
this.listeners.add(listener);
if (this.closeTimeoutId !== null) {
clearTimeout(this.closeTimeoutId);
this.closeTimeoutId = null;
}
return () => {
this.listeners.delete(listener);
if (this.listeners.size === 0) {
this.closeTimeoutId = window.setTimeout(() => {
if (this.socket) {
try {
this.socket.close();
} catch (err) {
console.warn("Error closing user events socket", err);
}
}
this.socket = null;
this.closeTimeoutId = null;
}, 1000);
}
};
}
}
const sharedStore = new UserEventsStore();
export function UserEventsProvider({
children,
}: {
children: React.ReactNode;
}) {
const auth = useAuth();
const queryClient = useQueryClient();
const tokenRef = useRef<string | null>(null);
const detachRef = useRef<(() => void) | null>(null);
useEffect(() => {
// Only tear down when the user is truly unauthenticated
if (auth.status === "unauthenticated") {
if (detachRef.current) {
try {
detachRef.current();
} catch (err) {
console.warn("Error detaching UserEvents listener", err);
}
detachRef.current = null;
}
tokenRef.current = null;
return;
}
// During loading/refreshing, keep the existing connection intact
if (auth.status !== "authenticated") {
return;
}
// Authenticated: pin the initial token for the lifetime of this WS connection
if (!tokenRef.current && auth.accessToken) {
tokenRef.current = auth.accessToken;
}
const pinnedToken = tokenRef.current;
const url = `${WEBSOCKET_URL}/v1/events`;
// Ensure a single shared connection
sharedStore.ensureConnection(
url,
pinnedToken ? ["bearer", pinnedToken] : undefined,
);
// Subscribe once; avoid re-subscribing during transient status changes
if (!detachRef.current) {
const onMessage = (event: MessageEvent) => {
try {
const msg = UserEvent.parse(JSON.parse(event.data));
const eventName = msg.event;
const invalidateList = () => invalidateTranscriptLists(queryClient);
switch (eventName) {
case "TRANSCRIPT_CREATED":
case "TRANSCRIPT_DELETED":
case "TRANSCRIPT_STATUS":
case "TRANSCRIPT_FINAL_TITLE":
case "TRANSCRIPT_DURATION":
invalidateList().then(() => {});
break;
default:
// Ignore other content events for list updates
break;
}
} catch (err) {
console.warn("Invalid user event message", event.data);
}
};
const unsubscribe = sharedStore.subscribe(onMessage);
detachRef.current = unsubscribe;
}
}, [auth.status, queryClient]);
// On unmount, detach the listener and clear the pinned token
useEffect(() => {
return () => {
if (detachRef.current) {
try {
detachRef.current();
} catch (err) {
console.warn("Error detaching UserEvents listener on unmount", err);
}
detachRef.current = null;
}
tokenRef.current = null;
};
}, []);
return <>{children}</>;
}

View File

@@ -3,21 +3,19 @@
import createClient from "openapi-fetch";
import type { paths } from "../reflector-api";
import createFetchClient from "openapi-react-query";
import { assertExistsAndNonEmptyString, parseNonEmptyString } from "./utils";
import { parseNonEmptyString } from "./utils";
import { isBuildPhase } from "./next";
import { getSession } from "next-auth/react";
import { assertExtendedToken } from "./types";
import { getClientEnv } from "./clientEnv";
export const API_URL = !isBuildPhase
? assertExistsAndNonEmptyString(
process.env.NEXT_PUBLIC_API_URL,
"NEXT_PUBLIC_API_URL required",
)
? getClientEnv().API_URL
: "http://localhost";
// TODO decide strict validation or not
export const WEBSOCKET_URL =
process.env.NEXT_PUBLIC_WEBSOCKET_URL || "ws://127.0.0.1:1250";
export const WEBSOCKET_URL = !isBuildPhase
? getClientEnv().WEBSOCKET_URL || "ws://127.0.0.1:1250"
: "ws://localhost";
export const client = createClient<paths>({
baseUrl: API_URL,
@@ -44,7 +42,7 @@ client.use({
if (token !== null) {
request.headers.set(
"Authorization",
`Bearer ${parseNonEmptyString(token)}`,
`Bearer ${parseNonEmptyString(token, true, "panic! token is required")}`,
);
}
// XXX Only set Content-Type if not already set (FormData will set its own boundary)

View File

@@ -2,7 +2,7 @@
import { $api } from "./apiClient";
import { useError } from "../(errors)/errorContext";
import { useQueryClient } from "@tanstack/react-query";
import { QueryClient, useQueryClient } from "@tanstack/react-query";
import type { components } from "../reflector-api";
import { useAuth } from "./AuthProvider";
@@ -40,6 +40,13 @@ export function useRoomsList(page: number = 1) {
type SourceKind = components["schemas"]["SourceKind"];
export const TRANSCRIPT_SEARCH_URL = "/v1/transcripts/search" as const;
export const invalidateTranscriptLists = (queryClient: QueryClient) =>
queryClient.invalidateQueries({
queryKey: ["get", TRANSCRIPT_SEARCH_URL],
});
export function useTranscriptsSearch(
q: string = "",
options: {
@@ -51,7 +58,7 @@ export function useTranscriptsSearch(
) {
return $api.useQuery(
"get",
"/v1/transcripts/search",
TRANSCRIPT_SEARCH_URL,
{
params: {
query: {
@@ -76,7 +83,7 @@ export function useTranscriptDelete() {
return $api.useMutation("delete", "/v1/transcripts/{transcript_id}", {
onSuccess: () => {
return queryClient.invalidateQueries({
queryKey: ["get", "/v1/transcripts/search"],
queryKey: ["get", TRANSCRIPT_SEARCH_URL],
});
},
onError: (error) => {
@@ -613,7 +620,7 @@ export function useTranscriptCreate() {
return $api.useMutation("post", "/v1/transcripts", {
onSuccess: () => {
return queryClient.invalidateQueries({
queryKey: ["get", "/v1/transcripts/search"],
queryKey: ["get", TRANSCRIPT_SEARCH_URL],
});
},
onError: (error) => {

View File

@@ -18,26 +18,25 @@ import {
deleteTokenCache,
} from "./redisTokenCache";
import { tokenCacheRedis, redlock } from "./redisClient";
import { isBuildPhase } from "./next";
import { sequenceThrows } from "./errorUtils";
import { featureEnabled } from "./features";
import { getNextEnvVar } from "./nextBuild";
const TOKEN_CACHE_TTL = REFRESH_ACCESS_TOKEN_BEFORE;
const getAuthentikClientId = () =>
assertExistsAndNonEmptyString(
process.env.AUTHENTIK_CLIENT_ID,
"AUTHENTIK_CLIENT_ID required",
);
const getAuthentikClientSecret = () =>
assertExistsAndNonEmptyString(
process.env.AUTHENTIK_CLIENT_SECRET,
"AUTHENTIK_CLIENT_SECRET required",
);
const getAuthentikClientId = () => getNextEnvVar("AUTHENTIK_CLIENT_ID");
const getAuthentikClientSecret = () => getNextEnvVar("AUTHENTIK_CLIENT_SECRET");
const getAuthentikRefreshTokenUrl = () =>
assertExistsAndNonEmptyString(
process.env.AUTHENTIK_REFRESH_TOKEN_URL,
"AUTHENTIK_REFRESH_TOKEN_URL required",
);
getNextEnvVar("AUTHENTIK_REFRESH_TOKEN_URL");
const getAuthentikIssuer = () => {
const stringUrl = getNextEnvVar("AUTHENTIK_ISSUER");
try {
new URL(stringUrl);
} catch (e) {
throw new Error("AUTHENTIK_ISSUER is not a valid URL: " + stringUrl);
}
return stringUrl;
};
export const authOptions = (): AuthOptions =>
featureEnabled("requireLogin")
@@ -45,16 +44,17 @@ export const authOptions = (): AuthOptions =>
providers: [
AuthentikProvider({
...(() => {
const [clientId, clientSecret] = sequenceThrows(
const [clientId, clientSecret, issuer] = sequenceThrows(
getAuthentikClientId,
getAuthentikClientSecret,
getAuthentikIssuer,
);
return {
clientId,
clientSecret,
issuer,
};
})(),
issuer: process.env.AUTHENTIK_ISSUER,
authorization: {
params: {
scope: "openid email profile offline_access",

91
www/app/lib/clientEnv.ts Normal file
View File

@@ -0,0 +1,91 @@
import {
assertExists,
assertExistsAndNonEmptyString,
NonEmptyString,
parseNonEmptyString,
} from "./utils";
import { isBuildPhase } from "./next";
import { getNextEnvVar } from "./nextBuild";
export const FEATURE_REQUIRE_LOGIN_ENV_NAME = "FEATURE_REQUIRE_LOGIN" as const;
export const FEATURE_PRIVACY_ENV_NAME = "FEATURE_PRIVACY" as const;
export const FEATURE_BROWSE_ENV_NAME = "FEATURE_BROWSE" as const;
export const FEATURE_SEND_TO_ZULIP_ENV_NAME = "FEATURE_SEND_TO_ZULIP" as const;
export const FEATURE_ROOMS_ENV_NAME = "FEATURE_ROOMS" as const;
const FEATURE_ENV_NAMES = [
FEATURE_REQUIRE_LOGIN_ENV_NAME,
FEATURE_PRIVACY_ENV_NAME,
FEATURE_BROWSE_ENV_NAME,
FEATURE_SEND_TO_ZULIP_ENV_NAME,
FEATURE_ROOMS_ENV_NAME,
] as const;
export type FeatureEnvName = (typeof FEATURE_ENV_NAMES)[number];
export type EnvFeaturePartial = {
[key in FeatureEnvName]: boolean | null;
};
// CONTRACT: isomorphic with JSON.stringify
export type ClientEnvCommon = EnvFeaturePartial & {
API_URL: NonEmptyString;
WEBSOCKET_URL: NonEmptyString | null;
};
let clientEnv: ClientEnvCommon | null = null;
export const getClientEnvClient = (): ClientEnvCommon => {
if (typeof window === "undefined") {
throw new Error(
"getClientEnv() called during SSR - this should only be called in browser environment",
);
}
if (clientEnv) return clientEnv;
clientEnv = assertExists(
JSON.parse(
assertExistsAndNonEmptyString(
document.body.dataset.env,
"document.body.dataset.env is missing",
),
),
"document.body.dataset.env is parsed to nullish",
);
return clientEnv!;
};
const parseBooleanString = (str: string | undefined): boolean | null => {
if (str === undefined) return null;
return str === "true";
};
export const getClientEnvServer = (): ClientEnvCommon => {
if (typeof window !== "undefined") {
throw new Error(
"getClientEnv() not called during SSR - this should only be called in server environment",
);
}
if (clientEnv) return clientEnv;
const features = FEATURE_ENV_NAMES.reduce((acc, x) => {
acc[x] = parseBooleanString(process.env[x]);
return acc;
}, {} as EnvFeaturePartial);
if (isBuildPhase) {
return {
API_URL: getNextEnvVar("API_URL"),
WEBSOCKET_URL: getNextEnvVar("WEBSOCKET_URL"),
...features,
};
}
clientEnv = {
API_URL: getNextEnvVar("API_URL"),
WEBSOCKET_URL: getNextEnvVar("WEBSOCKET_URL"),
...features,
};
return clientEnv;
};
export const getClientEnv =
typeof window === "undefined" ? getClientEnvServer : getClientEnvClient;

View File

@@ -1,3 +1,13 @@
import {
FEATURE_BROWSE_ENV_NAME,
FEATURE_PRIVACY_ENV_NAME,
FEATURE_REQUIRE_LOGIN_ENV_NAME,
FEATURE_ROOMS_ENV_NAME,
FEATURE_SEND_TO_ZULIP_ENV_NAME,
FeatureEnvName,
getClientEnv,
} from "./clientEnv";
export const FEATURES = [
"requireLogin",
"privacy",
@@ -18,38 +28,30 @@ export const DEFAULT_FEATURES: Features = {
rooms: true,
} as const;
function parseBooleanEnv(
value: string | undefined,
defaultValue: boolean = false,
): boolean {
if (!value) return defaultValue;
return value.toLowerCase() === "true";
}
export const ENV_TO_FEATURE: {
[k in FeatureEnvName]: FeatureName;
} = {
FEATURE_REQUIRE_LOGIN: "requireLogin",
FEATURE_PRIVACY: "privacy",
FEATURE_BROWSE: "browse",
FEATURE_SEND_TO_ZULIP: "sendToZulip",
FEATURE_ROOMS: "rooms",
} as const;
// WARNING: keep process.env.* as-is, next.js won't see them if you generate dynamically
const features: Features = {
requireLogin: parseBooleanEnv(
process.env.NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN,
DEFAULT_FEATURES.requireLogin,
),
privacy: parseBooleanEnv(
process.env.NEXT_PUBLIC_FEATURE_PRIVACY,
DEFAULT_FEATURES.privacy,
),
browse: parseBooleanEnv(
process.env.NEXT_PUBLIC_FEATURE_BROWSE,
DEFAULT_FEATURES.browse,
),
sendToZulip: parseBooleanEnv(
process.env.NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP,
DEFAULT_FEATURES.sendToZulip,
),
rooms: parseBooleanEnv(
process.env.NEXT_PUBLIC_FEATURE_ROOMS,
DEFAULT_FEATURES.rooms,
),
export const FEATURE_TO_ENV: {
[k in FeatureName]: FeatureEnvName;
} = {
requireLogin: "FEATURE_REQUIRE_LOGIN",
privacy: "FEATURE_PRIVACY",
browse: "FEATURE_BROWSE",
sendToZulip: "FEATURE_SEND_TO_ZULIP",
rooms: "FEATURE_ROOMS",
};
const features = getClientEnv();
export const featureEnabled = (featureName: FeatureName): boolean => {
return features[featureName];
const isSet = features[FEATURE_TO_ENV[featureName]];
if (isSet === null) return DEFAULT_FEATURES[featureName];
return isSet;
};

17
www/app/lib/nextBuild.ts Normal file
View File

@@ -0,0 +1,17 @@
import { isBuildPhase } from "./next";
import { assertExistsAndNonEmptyString, NonEmptyString } from "./utils";
const _getNextEnvVar = (name: string, e?: string): NonEmptyString =>
isBuildPhase
? (() => {
throw new Error(
"panic! getNextEnvVar called during build phase; we don't support build envs",
);
})()
: assertExistsAndNonEmptyString(
process.env[name],
`${name} is required; ${e}`,
);
export const getNextEnvVar = (name: string, e?: string): NonEmptyString =>
_getNextEnvVar(name, e);

Some files were not shown because too many files have changed in this diff Show More