mirror of
https://github.com/Monadical-SAS/reflector.git
synced 2025-12-21 20:59:05 +00:00
Compare commits
16 Commits
mathieu/sq
...
v0.16.0
| Author | SHA1 | Date | |
|---|---|---|---|
| dc4b737daa | |||
|
|
0baff7abf7 | ||
|
|
962c40e2b6 | ||
|
|
3c4b9f2103 | ||
|
|
c6c035aacf | ||
| c086b91445 | |||
|
|
9a258abc02 | ||
| af86c47f1d | |||
| 5f6910e513 | |||
| 9a71af145e | |||
| eef6dc3903 | |||
|
|
1dee255fed | ||
| 5d98754305 | |||
|
|
969bd84fcc | ||
|
|
36608849ec | ||
|
|
5bf64b5a41 |
2
.github/workflows/deploy.yml
vendored
2
.github/workflows/deploy.yml
vendored
@@ -1,4 +1,4 @@
|
|||||||
name: Deploy to Amazon ECS
|
name: Build container/push to container registry
|
||||||
|
|
||||||
on: [workflow_dispatch]
|
on: [workflow_dispatch]
|
||||||
|
|
||||||
|
|||||||
57
.github/workflows/docker-frontend.yml
vendored
Normal file
57
.github/workflows/docker-frontend.yml
vendored
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
name: Build and Push Frontend Docker Image
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
paths:
|
||||||
|
- 'www/**'
|
||||||
|
- '.github/workflows/docker-frontend.yml'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
env:
|
||||||
|
REGISTRY: ghcr.io
|
||||||
|
IMAGE_NAME: ${{ github.repository }}-frontend
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-push:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
packages: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Log in to GitHub Container Registry
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ${{ env.REGISTRY }}
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
- name: Extract metadata
|
||||||
|
id: meta
|
||||||
|
uses: docker/metadata-action@v5
|
||||||
|
with:
|
||||||
|
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
|
||||||
|
tags: |
|
||||||
|
type=ref,event=branch
|
||||||
|
type=sha,prefix={{branch}}-
|
||||||
|
type=raw,value=latest,enable={{is_default_branch}}
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
|
|
||||||
|
- name: Build and push Docker image
|
||||||
|
uses: docker/build-push-action@v5
|
||||||
|
with:
|
||||||
|
context: ./www
|
||||||
|
file: ./www/Dockerfile
|
||||||
|
push: true
|
||||||
|
tags: ${{ steps.meta.outputs.tags }}
|
||||||
|
labels: ${{ steps.meta.outputs.labels }}
|
||||||
|
cache-from: type=gha
|
||||||
|
cache-to: type=gha,mode=max
|
||||||
|
platforms: linux/amd64,linux/arm64
|
||||||
31
CHANGELOG.md
31
CHANGELOG.md
@@ -1,5 +1,36 @@
|
|||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
## [0.16.0](https://github.com/Monadical-SAS/reflector/compare/v0.15.0...v0.16.0) (2025-10-24)
|
||||||
|
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* search date filter ([#710](https://github.com/Monadical-SAS/reflector/issues/710)) ([962c40e](https://github.com/Monadical-SAS/reflector/commit/962c40e2b6428ac42fd10aea926782d7a6f3f902))
|
||||||
|
|
||||||
|
## [0.15.0](https://github.com/Monadical-SAS/reflector/compare/v0.14.0...v0.15.0) (2025-10-20)
|
||||||
|
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* api tokens ([#705](https://github.com/Monadical-SAS/reflector/issues/705)) ([9a258ab](https://github.com/Monadical-SAS/reflector/commit/9a258abc0209b0ac3799532a507ea6a9125d703a))
|
||||||
|
|
||||||
|
## [0.14.0](https://github.com/Monadical-SAS/reflector/compare/v0.13.1...v0.14.0) (2025-10-08)
|
||||||
|
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* Add calendar event data to transcript webhook payload ([#689](https://github.com/Monadical-SAS/reflector/issues/689)) ([5f6910e](https://github.com/Monadical-SAS/reflector/commit/5f6910e5131b7f28f86c9ecdcc57fed8412ee3cd))
|
||||||
|
* container build for www / github ([#672](https://github.com/Monadical-SAS/reflector/issues/672)) ([969bd84](https://github.com/Monadical-SAS/reflector/commit/969bd84fcc14851d1a101412a0ba115f1b7cde82))
|
||||||
|
* docker-compose for production frontend ([#664](https://github.com/Monadical-SAS/reflector/issues/664)) ([5bf64b5](https://github.com/Monadical-SAS/reflector/commit/5bf64b5a41f64535e22849b4bb11734d4dbb4aae))
|
||||||
|
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* restore feature boolean logic ([#671](https://github.com/Monadical-SAS/reflector/issues/671)) ([3660884](https://github.com/Monadical-SAS/reflector/commit/36608849ec64e953e3be456172502762e3c33df9))
|
||||||
|
* security review ([#656](https://github.com/Monadical-SAS/reflector/issues/656)) ([5d98754](https://github.com/Monadical-SAS/reflector/commit/5d98754305c6c540dd194dda268544f6d88bfaf8))
|
||||||
|
* update transcript list on reprocess ([#676](https://github.com/Monadical-SAS/reflector/issues/676)) ([9a71af1](https://github.com/Monadical-SAS/reflector/commit/9a71af145ee9b833078c78d0c684590ab12e9f0e))
|
||||||
|
* upgrade nemo toolkit ([#678](https://github.com/Monadical-SAS/reflector/issues/678)) ([eef6dc3](https://github.com/Monadical-SAS/reflector/commit/eef6dc39037329b65804297786d852dddb0557f9))
|
||||||
|
|
||||||
## [0.13.1](https://github.com/Monadical-SAS/reflector/compare/v0.13.0...v0.13.1) (2025-09-22)
|
## [0.13.1](https://github.com/Monadical-SAS/reflector/compare/v0.13.0...v0.13.1) (2025-09-22)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ All endpoints prefixed `/v1/`:
|
|||||||
|
|
||||||
**Frontend** (`www/.env`):
|
**Frontend** (`www/.env`):
|
||||||
- `NEXTAUTH_URL`, `NEXTAUTH_SECRET` - Authentication configuration
|
- `NEXTAUTH_URL`, `NEXTAUTH_SECRET` - Authentication configuration
|
||||||
- `NEXT_PUBLIC_REFLECTOR_API_URL` - Backend API endpoint
|
- `REFLECTOR_API_URL` - Backend API endpoint
|
||||||
- `REFLECTOR_DOMAIN_CONFIG` - Feature flags and domain settings
|
- `REFLECTOR_DOMAIN_CONFIG` - Feature flags and domain settings
|
||||||
|
|
||||||
## Testing Strategy
|
## Testing Strategy
|
||||||
|
|||||||
25
README.md
25
README.md
@@ -168,6 +168,13 @@ You can manually process an audio file by calling the process tool:
|
|||||||
uv run python -m reflector.tools.process path/to/audio.wav
|
uv run python -m reflector.tools.process path/to/audio.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Build-time env variables
|
||||||
|
|
||||||
|
Next.js projects are more used to NEXT_PUBLIC_ prefixed buildtime vars. We don't have those for the reason we need to serve a ccustomizable prebuild docker container.
|
||||||
|
|
||||||
|
Instead, all the variables are runtime. Variables needed to the frontend are served to the frontend app at initial render.
|
||||||
|
|
||||||
|
It also means there's no static prebuild and no static files to serve for js/html.
|
||||||
|
|
||||||
## Feature Flags
|
## Feature Flags
|
||||||
|
|
||||||
@@ -177,24 +184,24 @@ Reflector uses environment variable-based feature flags to control application f
|
|||||||
|
|
||||||
| Feature Flag | Environment Variable |
|
| Feature Flag | Environment Variable |
|
||||||
|-------------|---------------------|
|
|-------------|---------------------|
|
||||||
| `requireLogin` | `NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN` |
|
| `requireLogin` | `FEATURE_REQUIRE_LOGIN` |
|
||||||
| `privacy` | `NEXT_PUBLIC_FEATURE_PRIVACY` |
|
| `privacy` | `FEATURE_PRIVACY` |
|
||||||
| `browse` | `NEXT_PUBLIC_FEATURE_BROWSE` |
|
| `browse` | `FEATURE_BROWSE` |
|
||||||
| `sendToZulip` | `NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP` |
|
| `sendToZulip` | `FEATURE_SEND_TO_ZULIP` |
|
||||||
| `rooms` | `NEXT_PUBLIC_FEATURE_ROOMS` |
|
| `rooms` | `FEATURE_ROOMS` |
|
||||||
|
|
||||||
### Setting Feature Flags
|
### Setting Feature Flags
|
||||||
|
|
||||||
Feature flags are controlled via environment variables using the pattern `NEXT_PUBLIC_FEATURE_{FEATURE_NAME}` where `{FEATURE_NAME}` is the SCREAMING_SNAKE_CASE version of the feature name.
|
Feature flags are controlled via environment variables using the pattern `FEATURE_{FEATURE_NAME}` where `{FEATURE_NAME}` is the SCREAMING_SNAKE_CASE version of the feature name.
|
||||||
|
|
||||||
**Examples:**
|
**Examples:**
|
||||||
```bash
|
```bash
|
||||||
# Enable user authentication requirement
|
# Enable user authentication requirement
|
||||||
NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN=true
|
FEATURE_REQUIRE_LOGIN=true
|
||||||
|
|
||||||
# Disable browse functionality
|
# Disable browse functionality
|
||||||
NEXT_PUBLIC_FEATURE_BROWSE=false
|
FEATURE_BROWSE=false
|
||||||
|
|
||||||
# Enable Zulip integration
|
# Enable Zulip integration
|
||||||
NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP=true
|
FEATURE_SEND_TO_ZULIP=true
|
||||||
```
|
```
|
||||||
|
|||||||
39
docker-compose.prod.yml
Normal file
39
docker-compose.prod.yml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# Production Docker Compose configuration for Frontend
|
||||||
|
# Usage: docker compose -f docker-compose.prod.yml up -d
|
||||||
|
|
||||||
|
services:
|
||||||
|
web:
|
||||||
|
build:
|
||||||
|
context: ./www
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
image: reflector-frontend:latest
|
||||||
|
environment:
|
||||||
|
- KV_URL=${KV_URL:-redis://redis:6379}
|
||||||
|
- SITE_URL=${SITE_URL}
|
||||||
|
- API_URL=${API_URL}
|
||||||
|
- WEBSOCKET_URL=${WEBSOCKET_URL}
|
||||||
|
- NEXTAUTH_URL=${NEXTAUTH_URL:-http://localhost:3000}
|
||||||
|
- NEXTAUTH_SECRET=${NEXTAUTH_SECRET:-changeme-in-production}
|
||||||
|
- AUTHENTIK_ISSUER=${AUTHENTIK_ISSUER}
|
||||||
|
- AUTHENTIK_CLIENT_ID=${AUTHENTIK_CLIENT_ID}
|
||||||
|
- AUTHENTIK_CLIENT_SECRET=${AUTHENTIK_CLIENT_SECRET}
|
||||||
|
- AUTHENTIK_REFRESH_TOKEN_URL=${AUTHENTIK_REFRESH_TOKEN_URL}
|
||||||
|
- SENTRY_DSN=${SENTRY_DSN}
|
||||||
|
- SENTRY_IGNORE_API_RESOLUTION_ERROR=${SENTRY_IGNORE_API_RESOLUTION_ERROR:-1}
|
||||||
|
depends_on:
|
||||||
|
- redis
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:7.2-alpine
|
||||||
|
restart: unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 3
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
redis_data:
|
||||||
@@ -39,7 +39,7 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- 6379:6379
|
- 6379:6379
|
||||||
web:
|
web:
|
||||||
image: node:18
|
image: node:22-alpine
|
||||||
ports:
|
ports:
|
||||||
- "3000:3000"
|
- "3000:3000"
|
||||||
command: sh -c "corepack enable && pnpm install && pnpm dev"
|
command: sh -c "corepack enable && pnpm install && pnpm dev"
|
||||||
@@ -50,6 +50,8 @@ services:
|
|||||||
- /app/node_modules
|
- /app/node_modules
|
||||||
env_file:
|
env_file:
|
||||||
- ./www/.env.local
|
- ./www/.env.local
|
||||||
|
environment:
|
||||||
|
- NODE_ENV=development
|
||||||
|
|
||||||
postgres:
|
postgres:
|
||||||
image: postgres:17
|
image: postgres:17
|
||||||
@@ -77,7 +77,7 @@ image = (
|
|||||||
.pip_install(
|
.pip_install(
|
||||||
"hf_transfer==0.1.9",
|
"hf_transfer==0.1.9",
|
||||||
"huggingface_hub[hf-xet]==0.31.2",
|
"huggingface_hub[hf-xet]==0.31.2",
|
||||||
"nemo_toolkit[asr]==2.3.0",
|
"nemo_toolkit[asr]==2.5.0",
|
||||||
"cuda-python==12.8.0",
|
"cuda-python==12.8.0",
|
||||||
"fastapi==0.115.12",
|
"fastapi==0.115.12",
|
||||||
"numpy<2",
|
"numpy<2",
|
||||||
|
|||||||
@@ -1,3 +1,29 @@
|
|||||||
|
## API Key Management
|
||||||
|
|
||||||
|
### Finding Your User ID
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Get your OAuth sub (user ID) - requires authentication
|
||||||
|
curl -H "Authorization: Bearer <your_jwt>" http://localhost:1250/v1/me
|
||||||
|
# Returns: {"sub": "your-oauth-sub-here", "email": "...", ...}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Creating API Keys
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:1250/v1/user/api-keys \
|
||||||
|
-H "Authorization: Bearer <your_jwt>" \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{"name": "My API Key"}'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Using API Keys
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use X-API-Key header instead of Authorization
|
||||||
|
curl -H "X-API-Key: <your_api_key>" http://localhost:1250/v1/transcripts
|
||||||
|
```
|
||||||
|
|
||||||
## AWS S3/SQS usage clarification
|
## AWS S3/SQS usage clarification
|
||||||
|
|
||||||
Whereby.com uploads recordings directly to our S3 bucket when meetings end.
|
Whereby.com uploads recordings directly to our S3 bucket when meetings end.
|
||||||
|
|||||||
@@ -1,118 +0,0 @@
|
|||||||
# AsyncIO Event Loop Analysis for test_attendee_parsing_bug.py
|
|
||||||
|
|
||||||
## Problem Summary
|
|
||||||
The test passes but encounters an error during teardown where asyncpg tries to use a different/closed event loop, resulting in:
|
|
||||||
- `RuntimeError: Task got Future attached to a different loop`
|
|
||||||
- `RuntimeError: Event loop is closed`
|
|
||||||
|
|
||||||
## Root Cause Analysis
|
|
||||||
|
|
||||||
### 1. Multiple Event Loop Creation Points
|
|
||||||
|
|
||||||
The test environment creates event loops at different scopes:
|
|
||||||
|
|
||||||
1. **Session-scoped loop** (conftest.py:27-34):
|
|
||||||
- Created once per test session
|
|
||||||
- Used by session-scoped fixtures
|
|
||||||
- Closed after all tests complete
|
|
||||||
|
|
||||||
2. **Function-scoped loop** (pytest-asyncio default):
|
|
||||||
- Created for each async test function
|
|
||||||
- This is the loop that runs the actual test
|
|
||||||
- Closed immediately after test completes
|
|
||||||
|
|
||||||
3. **AsyncPG internal loop**:
|
|
||||||
- AsyncPG connections store a reference to the loop they were created with
|
|
||||||
- Used for connection lifecycle management
|
|
||||||
|
|
||||||
### 2. Event Loop Lifecycle Mismatch
|
|
||||||
|
|
||||||
The issue occurs because:
|
|
||||||
|
|
||||||
1. **Session fixture creates database connection** on session-scoped loop
|
|
||||||
2. **Test runs** on function-scoped loop (different from session loop)
|
|
||||||
3. **During teardown**, the session fixture tries to rollback/close using the original session loop
|
|
||||||
4. **AsyncPG connection** still references the function-scoped loop which is now closed
|
|
||||||
5. **Conflict**: SQLAlchemy tries to use session loop, but asyncpg Future is attached to the closed function loop
|
|
||||||
|
|
||||||
### 3. Configuration Issues
|
|
||||||
|
|
||||||
Current pytest configuration:
|
|
||||||
- `asyncio_mode = "auto"` in pyproject.toml
|
|
||||||
- `asyncio_default_fixture_loop_scope=session` (shown in test output)
|
|
||||||
- `asyncio_default_test_loop_scope=function` (shown in test output)
|
|
||||||
|
|
||||||
This mismatch between fixture loop scope (session) and test loop scope (function) causes the problem.
|
|
||||||
|
|
||||||
## Solutions
|
|
||||||
|
|
||||||
### Option 1: Align Loop Scopes (Recommended)
|
|
||||||
Change pytest-asyncio configuration to use consistent loop scopes:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# pyproject.toml
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
asyncio_mode = "auto"
|
|
||||||
asyncio_default_fixture_loop_scope = "function" # Change from session to function
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 2: Use Function-Scoped Database Fixture
|
|
||||||
Change the `session` fixture scope from session to function:
|
|
||||||
|
|
||||||
```python
|
|
||||||
@pytest_asyncio.fixture # Remove scope="session"
|
|
||||||
async def session(setup_database):
|
|
||||||
# ... existing code ...
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 3: Explicit Loop Management
|
|
||||||
Ensure all async operations use the same loop:
|
|
||||||
|
|
||||||
```python
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def session(setup_database, event_loop):
|
|
||||||
# Force using the current event loop
|
|
||||||
engine = create_async_engine(
|
|
||||||
settings.DATABASE_URL,
|
|
||||||
echo=False,
|
|
||||||
poolclass=NullPool,
|
|
||||||
connect_args={"loop": event_loop} # Pass explicit loop
|
|
||||||
)
|
|
||||||
# ... rest of fixture ...
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 4: Upgrade pytest-asyncio
|
|
||||||
The current version (1.1.0) has known issues with loop management. Consider upgrading to the latest version which has better loop scope handling.
|
|
||||||
|
|
||||||
## Immediate Workaround
|
|
||||||
|
|
||||||
For the test to run cleanly without the teardown error, you can:
|
|
||||||
|
|
||||||
1. Add explicit cleanup in the test:
|
|
||||||
```python
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_attendee_parsing_bug(session):
|
|
||||||
# ... existing test code ...
|
|
||||||
|
|
||||||
# Explicit cleanup before fixture teardown
|
|
||||||
await session.commit() # or await session.close()
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Or suppress the teardown error (not recommended for production):
|
|
||||||
```python
|
|
||||||
@pytest.fixture
|
|
||||||
async def session(setup_database):
|
|
||||||
# ... existing setup ...
|
|
||||||
try:
|
|
||||||
yield session
|
|
||||||
await session.rollback()
|
|
||||||
except RuntimeError as e:
|
|
||||||
if "Event loop is closed" not in str(e):
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
await session.close()
|
|
||||||
```
|
|
||||||
|
|
||||||
## Recommendation
|
|
||||||
|
|
||||||
The cleanest solution is to align the loop scopes by setting both fixture and test loop scopes to "function" scope. This ensures each test gets its own clean event loop and avoids cross-contamination between tests.
|
|
||||||
@@ -14,7 +14,7 @@ Webhooks are configured at the room level with two fields:
|
|||||||
|
|
||||||
### `transcript.completed`
|
### `transcript.completed`
|
||||||
|
|
||||||
Triggered when a transcript has been fully processed, including transcription, diarization, summarization, and topic detection.
|
Triggered when a transcript has been fully processed, including transcription, diarization, summarization, topic detection and calendar event integration.
|
||||||
|
|
||||||
### `test`
|
### `test`
|
||||||
|
|
||||||
@@ -128,6 +128,27 @@ This event includes a convenient URL for accessing the transcript:
|
|||||||
"room": {
|
"room": {
|
||||||
"id": "room-789",
|
"id": "room-789",
|
||||||
"name": "Product Team Room"
|
"name": "Product Team Room"
|
||||||
|
},
|
||||||
|
"calendar_event": {
|
||||||
|
"id": "calendar-event-123",
|
||||||
|
"ics_uid": "event-123",
|
||||||
|
"title": "Q3 Product Planning Meeting",
|
||||||
|
"start_time": "2025-08-27T12:00:00Z",
|
||||||
|
"end_time": "2025-08-27T12:30:00Z",
|
||||||
|
"description": "Team discussed Q3 product roadmap, prioritizing mobile app features and API improvements.",
|
||||||
|
"location": "Conference Room 1",
|
||||||
|
"attendees": [
|
||||||
|
{
|
||||||
|
"id": "participant-1",
|
||||||
|
"name": "John Doe",
|
||||||
|
"speaker": "Speaker 1"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "participant-2",
|
||||||
|
"name": "Jane Smith",
|
||||||
|
"speaker": "Speaker 2"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ AUTH_JWT_AUDIENCE=
|
|||||||
#TRANSCRIPT_MODAL_API_KEY=xxxxx
|
#TRANSCRIPT_MODAL_API_KEY=xxxxx
|
||||||
|
|
||||||
TRANSCRIPT_BACKEND=modal
|
TRANSCRIPT_BACKEND=modal
|
||||||
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run
|
TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-parakeet-web.modal.run
|
||||||
TRANSCRIPT_MODAL_API_KEY=
|
TRANSCRIPT_MODAL_API_KEY=
|
||||||
|
|
||||||
## =======================================================
|
## =======================================================
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from logging.config import fileConfig
|
|||||||
from alembic import context
|
from alembic import context
|
||||||
from sqlalchemy import engine_from_config, pool
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
|
||||||
from reflector.db.base import metadata
|
from reflector.db import metadata
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def upgrade() -> None:
|
|||||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||||
|
|
||||||
# Select all rows from the transcript table
|
# Select all rows from the transcript table
|
||||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||||
|
|
||||||
for row in results:
|
for row in results:
|
||||||
transcript_id = row["id"]
|
transcript_id = row["id"]
|
||||||
@@ -58,7 +58,7 @@ def downgrade() -> None:
|
|||||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||||
|
|
||||||
# Select all rows from the transcript table
|
# Select all rows from the transcript table
|
||||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||||
|
|
||||||
for row in results:
|
for row in results:
|
||||||
transcript_id = row["id"]
|
transcript_id = row["id"]
|
||||||
|
|||||||
@@ -36,7 +36,9 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
# select only the one with duration = 0
|
# select only the one with duration = 0
|
||||||
results = bind.execute(
|
results = bind.execute(
|
||||||
select(transcript.c.id, transcript.c.duration).where(transcript.c.duration == 0)
|
select([transcript.c.id, transcript.c.duration]).where(
|
||||||
|
transcript.c.duration == 0
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
data_dir = Path(settings.DATA_DIR)
|
data_dir = Path(settings.DATA_DIR)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def upgrade() -> None:
|
|||||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||||
|
|
||||||
# Select all rows from the transcript table
|
# Select all rows from the transcript table
|
||||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||||
|
|
||||||
for row in results:
|
for row in results:
|
||||||
transcript_id = row["id"]
|
transcript_id = row["id"]
|
||||||
@@ -58,7 +58,7 @@ def downgrade() -> None:
|
|||||||
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
transcript = table("transcript", column("id", sa.String), column("topics", sa.JSON))
|
||||||
|
|
||||||
# Select all rows from the transcript table
|
# Select all rows from the transcript table
|
||||||
results = bind.execute(select(transcript.c.id, transcript.c.topics))
|
results = bind.execute(select([transcript.c.id, transcript.c.topics]))
|
||||||
|
|
||||||
for row in results:
|
for row in results:
|
||||||
transcript_id = row["id"]
|
transcript_id = row["id"]
|
||||||
|
|||||||
38
server/migrations/versions/9e3f7b2a4c8e_add_user_api_keys.py
Normal file
38
server/migrations/versions/9e3f7b2a4c8e_add_user_api_keys.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""add user api keys
|
||||||
|
|
||||||
|
Revision ID: 9e3f7b2a4c8e
|
||||||
|
Revises: dc035ff72fd5
|
||||||
|
Create Date: 2025-10-17 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "9e3f7b2a4c8e"
|
||||||
|
down_revision: Union[str, None] = "dc035ff72fd5"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"user_api_key",
|
||||||
|
sa.Column("id", sa.String(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.String(), nullable=False),
|
||||||
|
sa.Column("key_hash", sa.String(), nullable=False),
|
||||||
|
sa.Column("name", sa.String(), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("user_api_key", schema=None) as batch_op:
|
||||||
|
batch_op.create_index("idx_user_api_key_hash", ["key_hash"], unique=True)
|
||||||
|
batch_op.create_index("idx_user_api_key_user_id", ["user_id"], unique=False)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("user_api_key")
|
||||||
@@ -19,8 +19,8 @@ dependencies = [
|
|||||||
"sentry-sdk[fastapi]>=1.29.2",
|
"sentry-sdk[fastapi]>=1.29.2",
|
||||||
"httpx>=0.24.1",
|
"httpx>=0.24.1",
|
||||||
"fastapi-pagination>=0.12.6",
|
"fastapi-pagination>=0.12.6",
|
||||||
"sqlalchemy>=2.0.0",
|
"databases[aiosqlite, asyncpg]>=0.7.0",
|
||||||
"asyncpg>=0.29.0",
|
"sqlalchemy<1.5",
|
||||||
"alembic>=1.11.3",
|
"alembic>=1.11.3",
|
||||||
"nltk>=3.8.1",
|
"nltk>=3.8.1",
|
||||||
"prometheus-fastapi-instrumentator>=6.1.0",
|
"prometheus-fastapi-instrumentator>=6.1.0",
|
||||||
@@ -46,7 +46,6 @@ dev = [
|
|||||||
"black>=24.1.1",
|
"black>=24.1.1",
|
||||||
"stamina>=23.1.0",
|
"stamina>=23.1.0",
|
||||||
"pyinstrument>=4.6.1",
|
"pyinstrument>=4.6.1",
|
||||||
"pytest-async-sqlalchemy>=0.2.0",
|
|
||||||
]
|
]
|
||||||
tests = [
|
tests = [
|
||||||
"pytest-cov>=4.1.0",
|
"pytest-cov>=4.1.0",
|
||||||
@@ -112,15 +111,13 @@ source = ["reflector"]
|
|||||||
|
|
||||||
[tool.pytest_env]
|
[tool.pytest_env]
|
||||||
ENVIRONMENT = "pytest"
|
ENVIRONMENT = "pytest"
|
||||||
DATABASE_URL = "postgresql+asyncpg://test_user:test_password@localhost:15432/reflector_test"
|
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
|
||||||
|
AUTH_BACKEND = "jwt"
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
[tool.pytest.ini_options]
|
||||||
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
asyncio_debug = true
|
|
||||||
asyncio_default_fixture_loop_scope = "session"
|
|
||||||
asyncio_default_test_loop_scope = "session"
|
|
||||||
markers = [
|
markers = [
|
||||||
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
|
"model_api: tests for the unified model-serving HTTP API (backend- and hardware-agnostic)",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -26,6 +26,8 @@ from reflector.views.transcripts_upload import router as transcripts_upload_rout
|
|||||||
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
|
from reflector.views.transcripts_webrtc import router as transcripts_webrtc_router
|
||||||
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
|
from reflector.views.transcripts_websocket import router as transcripts_websocket_router
|
||||||
from reflector.views.user import router as user_router
|
from reflector.views.user import router as user_router
|
||||||
|
from reflector.views.user_api_keys import router as user_api_keys_router
|
||||||
|
from reflector.views.user_websocket import router as user_ws_router
|
||||||
from reflector.views.whereby import router as whereby_router
|
from reflector.views.whereby import router as whereby_router
|
||||||
from reflector.views.zulip import router as zulip_router
|
from reflector.views.zulip import router as zulip_router
|
||||||
|
|
||||||
@@ -65,6 +67,12 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
|
||||||
# metrics
|
# metrics
|
||||||
instrumentator = Instrumentator(
|
instrumentator = Instrumentator(
|
||||||
excluded_handlers=["/docs", "/metrics"],
|
excluded_handlers=["/docs", "/metrics"],
|
||||||
@@ -84,6 +92,8 @@ app.include_router(transcripts_websocket_router, prefix="/v1")
|
|||||||
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
app.include_router(transcripts_webrtc_router, prefix="/v1")
|
||||||
app.include_router(transcripts_process_router, prefix="/v1")
|
app.include_router(transcripts_process_router, prefix="/v1")
|
||||||
app.include_router(user_router, prefix="/v1")
|
app.include_router(user_router, prefix="/v1")
|
||||||
|
app.include_router(user_api_keys_router, prefix="/v1")
|
||||||
|
app.include_router(user_ws_router, prefix="/v1")
|
||||||
app.include_router(zulip_router, prefix="/v1")
|
app.include_router(zulip_router, prefix="/v1")
|
||||||
app.include_router(whereby_router, prefix="/v1")
|
app.include_router(whereby_router, prefix="/v1")
|
||||||
add_pagination(app)
|
add_pagination(app)
|
||||||
|
|||||||
@@ -1,14 +1,21 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
from reflector.db import get_database
|
||||||
|
|
||||||
|
|
||||||
def asynctask(f):
|
def asynctask(f):
|
||||||
@functools.wraps(f)
|
@functools.wraps(f)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
async def run_async():
|
async def run_with_db():
|
||||||
|
database = get_database()
|
||||||
|
await database.connect()
|
||||||
|
try:
|
||||||
return await f(*args, **kwargs)
|
return await f(*args, **kwargs)
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
coro = run_async()
|
coro = run_with_db()
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
|
|||||||
@@ -1,14 +1,16 @@
|
|||||||
from typing import Annotated, Optional
|
from typing import Annotated, List, Optional
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import APIKeyHeader, OAuth2PasswordBearer
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from reflector.db.user_api_keys import user_api_keys_controller
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False)
|
||||||
|
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||||
|
|
||||||
jwt_public_key = open(f"reflector/auth/jwt/keys/{settings.AUTH_JWT_PUBLIC_KEY}").read()
|
jwt_public_key = open(f"reflector/auth/jwt/keys/{settings.AUTH_JWT_PUBLIC_KEY}").read()
|
||||||
jwt_algorithm = settings.AUTH_JWT_ALGORITHM
|
jwt_algorithm = settings.AUTH_JWT_ALGORITHM
|
||||||
@@ -26,7 +28,7 @@ class JWTException(Exception):
|
|||||||
|
|
||||||
class UserInfo(BaseModel):
|
class UserInfo(BaseModel):
|
||||||
sub: str
|
sub: str
|
||||||
email: str
|
email: Optional[str] = None
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
return getattr(self, key)
|
return getattr(self, key)
|
||||||
@@ -58,34 +60,53 @@ def authenticated(token: Annotated[str, Depends(oauth2_scheme)]):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def current_user(
|
async def _authenticate_user(
|
||||||
token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
jwt_token: Optional[str],
|
||||||
jwtauth: JWTAuth = Depends(),
|
api_key: Optional[str],
|
||||||
):
|
jwtauth: JWTAuth,
|
||||||
if token is None:
|
) -> UserInfo | None:
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
user_infos: List[UserInfo] = []
|
||||||
|
if api_key:
|
||||||
|
user_api_key = await user_api_keys_controller.verify_key(api_key)
|
||||||
|
if user_api_key:
|
||||||
|
user_infos.append(UserInfo(sub=user_api_key.user_id, email=None))
|
||||||
|
|
||||||
|
if jwt_token:
|
||||||
try:
|
try:
|
||||||
payload = jwtauth.verify_token(token)
|
payload = jwtauth.verify_token(jwt_token)
|
||||||
sub = payload["sub"]
|
sub = payload["sub"]
|
||||||
email = payload["email"]
|
email = payload["email"]
|
||||||
return UserInfo(sub=sub, email=email)
|
user_infos.append(UserInfo(sub=sub, email=email))
|
||||||
except JWTError as e:
|
except JWTError as e:
|
||||||
logger.error(f"JWT error: {e}")
|
logger.error(f"JWT error: {e}")
|
||||||
raise HTTPException(status_code=401, detail="Invalid authentication")
|
raise HTTPException(status_code=401, detail="Invalid authentication")
|
||||||
|
|
||||||
|
if len(user_infos) == 0:
|
||||||
def current_user_optional(
|
|
||||||
token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
|
||||||
jwtauth: JWTAuth = Depends(),
|
|
||||||
):
|
|
||||||
# we accept no token, but if one is provided, it must be a valid one.
|
|
||||||
if token is None:
|
|
||||||
return None
|
return None
|
||||||
try:
|
|
||||||
payload = jwtauth.verify_token(token)
|
if len(set([x.sub for x in user_infos])) > 1:
|
||||||
sub = payload["sub"]
|
raise JWTException(
|
||||||
email = payload["email"]
|
status_code=401,
|
||||||
return UserInfo(sub=sub, email=email)
|
detail="Invalid authentication: more than one user provided",
|
||||||
except JWTError as e:
|
)
|
||||||
logger.error(f"JWT error: {e}")
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid authentication")
|
return user_infos[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def current_user(
|
||||||
|
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
||||||
|
api_key: Annotated[Optional[str], Depends(api_key_header)],
|
||||||
|
jwtauth: JWTAuth = Depends(),
|
||||||
|
):
|
||||||
|
user = await _authenticate_user(jwt_token, api_key, jwtauth)
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def current_user_optional(
|
||||||
|
jwt_token: Annotated[Optional[str], Depends(oauth2_scheme)],
|
||||||
|
api_key: Annotated[Optional[str], Depends(api_key_header)],
|
||||||
|
jwtauth: JWTAuth = Depends(),
|
||||||
|
):
|
||||||
|
return await _authenticate_user(jwt_token, api_key, jwtauth)
|
||||||
|
|||||||
@@ -1,69 +1,49 @@
|
|||||||
from typing import AsyncGenerator
|
import contextvars
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import (
|
import databases
|
||||||
AsyncEngine,
|
import sqlalchemy
|
||||||
AsyncSession,
|
|
||||||
async_sessionmaker,
|
|
||||||
create_async_engine,
|
|
||||||
)
|
|
||||||
|
|
||||||
from reflector.db.base import Base as Base
|
|
||||||
from reflector.db.base import metadata as metadata
|
|
||||||
from reflector.events import subscribers_shutdown, subscribers_startup
|
from reflector.events import subscribers_shutdown, subscribers_startup
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
_engine: AsyncEngine | None = None
|
metadata = sqlalchemy.MetaData()
|
||||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
|
||||||
|
|
||||||
|
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
|
||||||
def get_engine() -> AsyncEngine:
|
contextvars.ContextVar("database", default=None)
|
||||||
global _engine
|
|
||||||
if _engine is None:
|
|
||||||
_engine = create_async_engine(
|
|
||||||
settings.DATABASE_URL,
|
|
||||||
echo=False,
|
|
||||||
pool_pre_ping=True,
|
|
||||||
)
|
)
|
||||||
return _engine
|
|
||||||
|
|
||||||
|
|
||||||
def get_session_factory() -> async_sessionmaker[AsyncSession]:
|
def get_database() -> databases.Database:
|
||||||
global _session_factory
|
"""Get database instance for current asyncio context"""
|
||||||
if _session_factory is None:
|
db = _database_context.get()
|
||||||
_session_factory = async_sessionmaker(
|
if db is None:
|
||||||
get_engine(),
|
db = databases.Database(settings.DATABASE_URL)
|
||||||
class_=AsyncSession,
|
_database_context.set(db)
|
||||||
expire_on_commit=False,
|
return db
|
||||||
)
|
|
||||||
return _session_factory
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
# necessary implementation to ease mocking on pytest
|
|
||||||
async with get_session_factory()() as session:
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
async for session in _get_session():
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
|
# import models
|
||||||
import reflector.db.calendar_events # noqa
|
import reflector.db.calendar_events # noqa
|
||||||
import reflector.db.meetings # noqa
|
import reflector.db.meetings # noqa
|
||||||
import reflector.db.recordings # noqa
|
import reflector.db.recordings # noqa
|
||||||
import reflector.db.rooms # noqa
|
import reflector.db.rooms # noqa
|
||||||
import reflector.db.transcripts # noqa
|
import reflector.db.transcripts # noqa
|
||||||
|
import reflector.db.user_api_keys # noqa
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if "postgres" not in settings.DATABASE_URL:
|
||||||
|
raise Exception("Only postgres database is supported in reflector")
|
||||||
|
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@subscribers_startup.append
|
@subscribers_startup.append
|
||||||
async def database_connect(_):
|
async def database_connect(_):
|
||||||
get_engine()
|
database = get_database()
|
||||||
|
await database.connect()
|
||||||
|
|
||||||
|
|
||||||
@subscribers_shutdown.append
|
@subscribers_shutdown.append
|
||||||
async def database_disconnect(_):
|
async def database_disconnect(_):
|
||||||
global _engine
|
database = get_database()
|
||||||
if _engine:
|
await database.disconnect()
|
||||||
await _engine.dispose()
|
|
||||||
_engine = None
|
|
||||||
|
|||||||
@@ -1,237 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from sqlalchemy.dialects.postgresql import JSONB, TSVECTOR
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
||||||
|
|
||||||
|
|
||||||
class Base(AsyncAttrs, DeclarativeBase):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class TranscriptModel(Base):
|
|
||||||
__tablename__ = "transcript"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
name: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
status: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
locked: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
|
|
||||||
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
|
|
||||||
created_at: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
|
||||||
title: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
short_summary: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
long_summary: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
topics: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
|
||||||
events: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
|
||||||
participants: Mapped[Optional[list]] = mapped_column(sa.JSON)
|
|
||||||
source_language: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
target_language: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
reviewed: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
audio_location: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="local"
|
|
||||||
)
|
|
||||||
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
share_mode: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="private"
|
|
||||||
)
|
|
||||||
meeting_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
recording_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
zulip_message_id: Mapped[Optional[int]] = mapped_column(sa.Integer)
|
|
||||||
source_kind: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False
|
|
||||||
) # Enum will be handled separately
|
|
||||||
audio_deleted: Mapped[Optional[bool]] = mapped_column(sa.Boolean)
|
|
||||||
room_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
webvtt: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
sa.Index("idx_transcript_recording_id", "recording_id"),
|
|
||||||
sa.Index("idx_transcript_user_id", "user_id"),
|
|
||||||
sa.Index("idx_transcript_created_at", "created_at"),
|
|
||||||
sa.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
|
||||||
sa.Index("idx_transcript_room_id", "room_id"),
|
|
||||||
sa.Index("idx_transcript_source_kind", "source_kind"),
|
|
||||||
sa.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TranscriptModel.search_vector_en = sa.Column(
|
|
||||||
"search_vector_en",
|
|
||||||
TSVECTOR,
|
|
||||||
sa.Computed(
|
|
||||||
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
|
||||||
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
|
||||||
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
|
||||||
persisted=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RoomModel(Base):
|
|
||||||
__tablename__ = "room"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
name: Mapped[str] = mapped_column(sa.String, nullable=False, unique=True)
|
|
||||||
user_id: Mapped[str] = mapped_column(sa.String, nullable=False)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
zulip_auto_post: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
zulip_stream: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
zulip_topic: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
is_locked: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
room_mode: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="normal"
|
|
||||||
)
|
|
||||||
recording_type: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="cloud"
|
|
||||||
)
|
|
||||||
recording_trigger: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="automatic-2nd-participant"
|
|
||||||
)
|
|
||||||
is_shared: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
webhook_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
webhook_secret: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
ics_url: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
ics_fetch_interval: Mapped[Optional[int]] = mapped_column(
|
|
||||||
sa.Integer, server_default=sa.text("300")
|
|
||||||
)
|
|
||||||
ics_enabled: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
ics_last_sync: Mapped[Optional[datetime]] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True)
|
|
||||||
)
|
|
||||||
ics_last_etag: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
sa.Index("idx_room_is_shared", "is_shared"),
|
|
||||||
sa.Index("idx_room_ics_enabled", "ics_enabled"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MeetingModel(Base):
|
|
||||||
__tablename__ = "meeting"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
room_name: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
room_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
host_room_url: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
start_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
|
||||||
end_date: Mapped[Optional[datetime]] = mapped_column(sa.DateTime(timezone=True))
|
|
||||||
room_id: Mapped[Optional[str]] = mapped_column(
|
|
||||||
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE")
|
|
||||||
)
|
|
||||||
is_locked: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
room_mode: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="normal"
|
|
||||||
)
|
|
||||||
recording_type: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="cloud"
|
|
||||||
)
|
|
||||||
recording_trigger: Mapped[str] = mapped_column(
|
|
||||||
sa.String, nullable=False, server_default="automatic-2nd-participant"
|
|
||||||
)
|
|
||||||
num_clients: Mapped[int] = mapped_column(
|
|
||||||
sa.Integer, nullable=False, server_default=sa.text("0")
|
|
||||||
)
|
|
||||||
is_active: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("true")
|
|
||||||
)
|
|
||||||
calendar_event_id: Mapped[Optional[str]] = mapped_column(
|
|
||||||
sa.String,
|
|
||||||
sa.ForeignKey(
|
|
||||||
"calendar_event.id",
|
|
||||||
ondelete="SET NULL",
|
|
||||||
name="fk_meeting_calendar_event_id",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
calendar_metadata: Mapped[Optional[dict]] = mapped_column(JSONB)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
sa.Index("idx_meeting_room_id", "room_id"),
|
|
||||||
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MeetingConsentModel(Base):
|
|
||||||
__tablename__ = "meeting_consent"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
meeting_id: Mapped[str] = mapped_column(
|
|
||||||
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
user_id: Mapped[Optional[str]] = mapped_column(sa.String)
|
|
||||||
consent_given: Mapped[bool] = mapped_column(sa.Boolean, nullable=False)
|
|
||||||
consent_timestamp: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RecordingModel(Base):
|
|
||||||
__tablename__ = "recording"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
meeting_id: Mapped[str] = mapped_column(
|
|
||||||
sa.String, sa.ForeignKey("meeting.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
url: Mapped[str] = mapped_column(sa.String, nullable=False)
|
|
||||||
object_key: Mapped[str] = mapped_column(sa.String, nullable=False)
|
|
||||||
duration: Mapped[Optional[float]] = mapped_column(sa.Float)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (sa.Index("idx_recording_meeting_id", "meeting_id"),)
|
|
||||||
|
|
||||||
|
|
||||||
class CalendarEventModel(Base):
|
|
||||||
__tablename__ = "calendar_event"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(sa.String, primary_key=True)
|
|
||||||
room_id: Mapped[str] = mapped_column(
|
|
||||||
sa.String, sa.ForeignKey("room.id", ondelete="CASCADE"), nullable=False
|
|
||||||
)
|
|
||||||
ics_uid: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
|
||||||
title: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
description: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
start_time: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
end_time: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
attendees: Mapped[Optional[dict]] = mapped_column(JSONB)
|
|
||||||
location: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
ics_raw_data: Mapped[Optional[str]] = mapped_column(sa.Text)
|
|
||||||
last_synced: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
is_deleted: Mapped[bool] = mapped_column(
|
|
||||||
sa.Boolean, nullable=False, server_default=sa.text("false")
|
|
||||||
)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime(timezone=True), nullable=False
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
metadata = Base.metadata
|
|
||||||
@@ -2,17 +2,45 @@ from datetime import datetime, timedelta, timezone
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import delete, select, update
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db.base import CalendarEventModel
|
from reflector.db import get_database, metadata
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
|
calendar_events = sa.Table(
|
||||||
|
"calendar_event",
|
||||||
|
metadata,
|
||||||
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"room_id",
|
||||||
|
sa.String,
|
||||||
|
sa.ForeignKey("room.id", ondelete="CASCADE", name="fk_calendar_event_room_id"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("ics_uid", sa.Text, nullable=False),
|
||||||
|
sa.Column("title", sa.Text),
|
||||||
|
sa.Column("description", sa.Text),
|
||||||
|
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("attendees", JSONB),
|
||||||
|
sa.Column("location", sa.Text),
|
||||||
|
sa.Column("ics_raw_data", sa.Text),
|
||||||
|
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("is_deleted", sa.Boolean, nullable=False, server_default=sa.false()),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"),
|
||||||
|
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
|
||||||
|
sa.Index(
|
||||||
|
"idx_calendar_event_deleted",
|
||||||
|
"is_deleted",
|
||||||
|
postgresql_where=sa.text("NOT is_deleted"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class CalendarEvent(BaseModel):
|
class CalendarEvent(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
room_id: str
|
room_id: str
|
||||||
ics_uid: str
|
ics_uid: str
|
||||||
@@ -30,157 +58,129 @@ class CalendarEvent(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class CalendarEventController:
|
class CalendarEventController:
|
||||||
async def get_upcoming_events(
|
|
||||||
self,
|
|
||||||
session: AsyncSession,
|
|
||||||
room_id: str,
|
|
||||||
current_time: datetime,
|
|
||||||
buffer_minutes: int = 15,
|
|
||||||
) -> list[CalendarEvent]:
|
|
||||||
buffer_time = current_time + timedelta(minutes=buffer_minutes)
|
|
||||||
|
|
||||||
query = (
|
|
||||||
select(CalendarEventModel)
|
|
||||||
.where(
|
|
||||||
sa.and_(
|
|
||||||
CalendarEventModel.room_id == room_id,
|
|
||||||
CalendarEventModel.start_time <= buffer_time,
|
|
||||||
CalendarEventModel.end_time > current_time,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
.order_by(CalendarEventModel.start_time)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await session.execute(query)
|
|
||||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
|
||||||
|
|
||||||
async def get_by_id(
|
|
||||||
self, session: AsyncSession, event_id: str
|
|
||||||
) -> CalendarEvent | None:
|
|
||||||
query = select(CalendarEventModel).where(CalendarEventModel.id == event_id)
|
|
||||||
result = await session.execute(query)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if not row:
|
|
||||||
return None
|
|
||||||
return CalendarEvent.model_validate(row)
|
|
||||||
|
|
||||||
async def get_by_ics_uid(
|
|
||||||
self, session: AsyncSession, room_id: str, ics_uid: str
|
|
||||||
) -> CalendarEvent | None:
|
|
||||||
query = select(CalendarEventModel).where(
|
|
||||||
sa.and_(
|
|
||||||
CalendarEventModel.room_id == room_id,
|
|
||||||
CalendarEventModel.ics_uid == ics_uid,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await session.execute(query)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if not row:
|
|
||||||
return None
|
|
||||||
return CalendarEvent.model_validate(row)
|
|
||||||
|
|
||||||
async def upsert(
|
|
||||||
self, session: AsyncSession, event: CalendarEvent
|
|
||||||
) -> CalendarEvent:
|
|
||||||
existing = await self.get_by_ics_uid(session, event.room_id, event.ics_uid)
|
|
||||||
|
|
||||||
if existing:
|
|
||||||
event.updated_at = datetime.now(timezone.utc)
|
|
||||||
query = (
|
|
||||||
update(CalendarEventModel)
|
|
||||||
.where(CalendarEventModel.id == existing.id)
|
|
||||||
.values(**event.model_dump(exclude={"id"}))
|
|
||||||
)
|
|
||||||
await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
return event
|
|
||||||
else:
|
|
||||||
new_event = CalendarEventModel(**event.model_dump())
|
|
||||||
session.add(new_event)
|
|
||||||
await session.commit()
|
|
||||||
return event
|
|
||||||
|
|
||||||
async def delete_old_events(
|
|
||||||
self, session: AsyncSession, room_id: str, cutoff_date: datetime
|
|
||||||
) -> int:
|
|
||||||
query = delete(CalendarEventModel).where(
|
|
||||||
sa.and_(
|
|
||||||
CalendarEventModel.room_id == room_id,
|
|
||||||
CalendarEventModel.end_time < cutoff_date,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
result = await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
return result.rowcount
|
|
||||||
|
|
||||||
async def delete_events_not_in_list(
|
|
||||||
self, session: AsyncSession, room_id: str, keep_ics_uids: list[str]
|
|
||||||
) -> int:
|
|
||||||
if not keep_ics_uids:
|
|
||||||
query = delete(CalendarEventModel).where(
|
|
||||||
CalendarEventModel.room_id == room_id
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
query = delete(CalendarEventModel).where(
|
|
||||||
sa.and_(
|
|
||||||
CalendarEventModel.room_id == room_id,
|
|
||||||
CalendarEventModel.ics_uid.notin_(keep_ics_uids),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
return result.rowcount
|
|
||||||
|
|
||||||
async def get_by_room(
|
async def get_by_room(
|
||||||
self, session: AsyncSession, room_id: str, include_deleted: bool = True
|
self,
|
||||||
|
room_id: str,
|
||||||
|
include_deleted: bool = False,
|
||||||
|
start_after: datetime | None = None,
|
||||||
|
end_before: datetime | None = None,
|
||||||
) -> list[CalendarEvent]:
|
) -> list[CalendarEvent]:
|
||||||
query = select(CalendarEventModel).where(CalendarEventModel.room_id == room_id)
|
query = calendar_events.select().where(calendar_events.c.room_id == room_id)
|
||||||
|
|
||||||
if not include_deleted:
|
if not include_deleted:
|
||||||
query = query.where(CalendarEventModel.is_deleted == False)
|
query = query.where(calendar_events.c.is_deleted == False)
|
||||||
result = await session.execute(query)
|
|
||||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
if start_after:
|
||||||
|
query = query.where(calendar_events.c.start_time >= start_after)
|
||||||
|
|
||||||
|
if end_before:
|
||||||
|
query = query.where(calendar_events.c.end_time <= end_before)
|
||||||
|
|
||||||
|
query = query.order_by(calendar_events.c.start_time.asc())
|
||||||
|
|
||||||
|
results = await get_database().fetch_all(query)
|
||||||
|
return [CalendarEvent(**result) for result in results]
|
||||||
|
|
||||||
async def get_upcoming(
|
async def get_upcoming(
|
||||||
self, session: AsyncSession, room_id: str, minutes_ahead: int = 120
|
self, room_id: str, minutes_ahead: int = 120
|
||||||
) -> list[CalendarEvent]:
|
) -> list[CalendarEvent]:
|
||||||
|
"""Get upcoming events for a room within the specified minutes, including currently happening events."""
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
buffer_time = now + timedelta(minutes=minutes_ahead)
|
future_time = now + timedelta(minutes=minutes_ahead)
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
select(CalendarEventModel)
|
calendar_events.select()
|
||||||
.where(
|
.where(
|
||||||
sa.and_(
|
sa.and_(
|
||||||
CalendarEventModel.room_id == room_id,
|
calendar_events.c.room_id == room_id,
|
||||||
CalendarEventModel.start_time <= buffer_time,
|
calendar_events.c.is_deleted == False,
|
||||||
CalendarEventModel.end_time > now,
|
calendar_events.c.start_time <= future_time,
|
||||||
CalendarEventModel.is_deleted == False,
|
calendar_events.c.end_time >= now,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.order_by(CalendarEventModel.start_time)
|
.order_by(calendar_events.c.start_time.asc())
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [CalendarEvent.model_validate(row) for row in result.scalars().all()]
|
return [CalendarEvent(**result) for result in results]
|
||||||
|
|
||||||
|
async def get_by_id(self, event_id: str) -> CalendarEvent | None:
|
||||||
|
query = calendar_events.select().where(calendar_events.c.id == event_id)
|
||||||
|
result = await get_database().fetch_one(query)
|
||||||
|
return CalendarEvent(**result) if result else None
|
||||||
|
|
||||||
|
async def get_by_ics_uid(self, room_id: str, ics_uid: str) -> CalendarEvent | None:
|
||||||
|
query = calendar_events.select().where(
|
||||||
|
sa.and_(
|
||||||
|
calendar_events.c.room_id == room_id,
|
||||||
|
calendar_events.c.ics_uid == ics_uid,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
result = await get_database().fetch_one(query)
|
||||||
|
return CalendarEvent(**result) if result else None
|
||||||
|
|
||||||
|
async def upsert(self, event: CalendarEvent) -> CalendarEvent:
|
||||||
|
existing = await self.get_by_ics_uid(event.room_id, event.ics_uid)
|
||||||
|
|
||||||
|
if existing:
|
||||||
|
event.id = existing.id
|
||||||
|
event.created_at = existing.created_at
|
||||||
|
event.updated_at = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
query = (
|
||||||
|
calendar_events.update()
|
||||||
|
.where(calendar_events.c.id == existing.id)
|
||||||
|
.values(**event.model_dump())
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = calendar_events.insert().values(**event.model_dump())
|
||||||
|
|
||||||
|
await get_database().execute(query)
|
||||||
|
return event
|
||||||
|
|
||||||
async def soft_delete_missing(
|
async def soft_delete_missing(
|
||||||
self, session: AsyncSession, room_id: str, current_ics_uids: list[str]
|
self, room_id: str, current_ics_uids: list[str]
|
||||||
) -> int:
|
) -> int:
|
||||||
query = (
|
"""Soft delete future events that are no longer in the calendar."""
|
||||||
update(CalendarEventModel)
|
now = datetime.now(timezone.utc)
|
||||||
.where(
|
|
||||||
|
select_query = calendar_events.select().where(
|
||||||
sa.and_(
|
sa.and_(
|
||||||
CalendarEventModel.room_id == room_id,
|
calendar_events.c.room_id == room_id,
|
||||||
CalendarEventModel.ics_uid.notin_(current_ics_uids)
|
calendar_events.c.start_time > now,
|
||||||
|
calendar_events.c.is_deleted == False,
|
||||||
|
calendar_events.c.ics_uid.notin_(current_ics_uids)
|
||||||
if current_ics_uids
|
if current_ics_uids
|
||||||
else True,
|
else True,
|
||||||
CalendarEventModel.end_time > datetime.now(timezone.utc),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.values(is_deleted=True)
|
|
||||||
|
to_delete = await get_database().fetch_all(select_query)
|
||||||
|
delete_count = len(to_delete)
|
||||||
|
|
||||||
|
if delete_count > 0:
|
||||||
|
update_query = (
|
||||||
|
calendar_events.update()
|
||||||
|
.where(
|
||||||
|
sa.and_(
|
||||||
|
calendar_events.c.room_id == room_id,
|
||||||
|
calendar_events.c.start_time > now,
|
||||||
|
calendar_events.c.is_deleted == False,
|
||||||
|
calendar_events.c.ics_uid.notin_(current_ics_uids)
|
||||||
|
if current_ics_uids
|
||||||
|
else True,
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
)
|
||||||
await session.commit()
|
.values(is_deleted=True, updated_at=now)
|
||||||
|
)
|
||||||
|
|
||||||
|
await get_database().execute(update_query)
|
||||||
|
|
||||||
|
return delete_count
|
||||||
|
|
||||||
|
async def delete_by_room(self, room_id: str) -> int:
|
||||||
|
query = calendar_events.delete().where(calendar_events.c.room_id == room_id)
|
||||||
|
result = await get_database().execute(query)
|
||||||
return result.rowcount
|
return result.rowcount
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,18 +2,80 @@ from datetime import datetime
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select, update
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db.base import MeetingConsentModel, MeetingModel
|
from reflector.db import get_database, metadata
|
||||||
from reflector.db.rooms import Room
|
from reflector.db.rooms import Room
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
|
meetings = sa.Table(
|
||||||
|
"meeting",
|
||||||
|
metadata,
|
||||||
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
|
sa.Column("room_name", sa.String),
|
||||||
|
sa.Column("room_url", sa.String),
|
||||||
|
sa.Column("host_room_url", sa.String),
|
||||||
|
sa.Column("start_date", sa.DateTime(timezone=True)),
|
||||||
|
sa.Column("end_date", sa.DateTime(timezone=True)),
|
||||||
|
sa.Column(
|
||||||
|
"room_id",
|
||||||
|
sa.String,
|
||||||
|
sa.ForeignKey("room.id", ondelete="CASCADE"),
|
||||||
|
nullable=True,
|
||||||
|
),
|
||||||
|
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
|
||||||
|
sa.Column("room_mode", sa.String, nullable=False, server_default="normal"),
|
||||||
|
sa.Column("recording_type", sa.String, nullable=False, server_default="cloud"),
|
||||||
|
sa.Column(
|
||||||
|
"recording_trigger",
|
||||||
|
sa.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="automatic-2nd-participant",
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"num_clients",
|
||||||
|
sa.Integer,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("0"),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"is_active",
|
||||||
|
sa.Boolean,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.true(),
|
||||||
|
),
|
||||||
|
sa.Column(
|
||||||
|
"calendar_event_id",
|
||||||
|
sa.String,
|
||||||
|
sa.ForeignKey(
|
||||||
|
"calendar_event.id",
|
||||||
|
ondelete="SET NULL",
|
||||||
|
name="fk_meeting_calendar_event_id",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
sa.Column("calendar_metadata", JSONB),
|
||||||
|
sa.Index("idx_meeting_room_id", "room_id"),
|
||||||
|
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
meeting_consent = sa.Table(
|
||||||
|
"meeting_consent",
|
||||||
|
metadata,
|
||||||
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"meeting_id",
|
||||||
|
sa.String,
|
||||||
|
sa.ForeignKey("meeting.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("user_id", sa.String),
|
||||||
|
sa.Column("consent_given", sa.Boolean, nullable=False),
|
||||||
|
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MeetingConsent(BaseModel):
|
class MeetingConsent(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
meeting_id: str
|
meeting_id: str
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
@@ -22,8 +84,6 @@ class MeetingConsent(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Meeting(BaseModel):
|
class Meeting(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
room_name: str
|
room_name: str
|
||||||
room_url: str
|
room_url: str
|
||||||
@@ -46,7 +106,6 @@ class Meeting(BaseModel):
|
|||||||
class MeetingController:
|
class MeetingController:
|
||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
id: str,
|
id: str,
|
||||||
room_name: str,
|
room_name: str,
|
||||||
room_url: str,
|
room_url: str,
|
||||||
@@ -72,198 +131,170 @@ class MeetingController:
|
|||||||
calendar_event_id=calendar_event_id,
|
calendar_event_id=calendar_event_id,
|
||||||
calendar_metadata=calendar_metadata,
|
calendar_metadata=calendar_metadata,
|
||||||
)
|
)
|
||||||
new_meeting = MeetingModel(**meeting.model_dump())
|
query = meetings.insert().values(**meeting.model_dump())
|
||||||
session.add(new_meeting)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
return meeting
|
return meeting
|
||||||
|
|
||||||
async def get_all_active(self, session: AsyncSession) -> list[Meeting]:
|
async def get_all_active(self) -> list[Meeting]:
|
||||||
query = select(MeetingModel).where(MeetingModel.is_active)
|
query = meetings.select().where(meetings.c.is_active)
|
||||||
result = await session.execute(query)
|
return await get_database().fetch_all(query)
|
||||||
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
|
||||||
|
|
||||||
async def get_by_room_name(
|
async def get_by_room_name(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
room_name: str,
|
room_name: str,
|
||||||
) -> Meeting | None:
|
) -> Meeting | None:
|
||||||
"""
|
"""
|
||||||
Get a meeting by room name.
|
Get a meeting by room name.
|
||||||
For backward compatibility, returns the most recent meeting.
|
For backward compatibility, returns the most recent meeting.
|
||||||
"""
|
"""
|
||||||
|
end_date = getattr(meetings.c, "end_date")
|
||||||
query = (
|
query = (
|
||||||
select(MeetingModel)
|
meetings.select()
|
||||||
.where(MeetingModel.room_name == room_name)
|
.where(meetings.c.room_name == room_name)
|
||||||
.order_by(MeetingModel.end_date.desc())
|
.order_by(end_date.desc())
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Meeting.model_validate(row)
|
|
||||||
|
|
||||||
async def get_active(
|
return Meeting(**result)
|
||||||
self, session: AsyncSession, room: Room, current_time: datetime
|
|
||||||
) -> Meeting | None:
|
async def get_active(self, room: Room, current_time: datetime) -> Meeting | None:
|
||||||
"""
|
"""
|
||||||
Get latest active meeting for a room.
|
Get latest active meeting for a room.
|
||||||
For backward compatibility, returns the most recent active meeting.
|
For backward compatibility, returns the most recent active meeting.
|
||||||
"""
|
"""
|
||||||
|
end_date = getattr(meetings.c, "end_date")
|
||||||
query = (
|
query = (
|
||||||
select(MeetingModel)
|
meetings.select()
|
||||||
.where(
|
.where(
|
||||||
sa.and_(
|
sa.and_(
|
||||||
MeetingModel.room_id == room.id,
|
meetings.c.room_id == room.id,
|
||||||
MeetingModel.end_date > current_time,
|
meetings.c.end_date > current_time,
|
||||||
MeetingModel.is_active,
|
meetings.c.is_active,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.order_by(MeetingModel.end_date.desc())
|
.order_by(end_date.desc())
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Meeting.model_validate(row)
|
|
||||||
|
return Meeting(**result)
|
||||||
|
|
||||||
async def get_all_active_for_room(
|
async def get_all_active_for_room(
|
||||||
self, session: AsyncSession, room: Room, current_time: datetime
|
self, room: Room, current_time: datetime
|
||||||
) -> list[Meeting]:
|
) -> list[Meeting]:
|
||||||
|
end_date = getattr(meetings.c, "end_date")
|
||||||
query = (
|
query = (
|
||||||
select(MeetingModel)
|
meetings.select()
|
||||||
.where(
|
.where(
|
||||||
sa.and_(
|
sa.and_(
|
||||||
MeetingModel.room_id == room.id,
|
meetings.c.room_id == room.id,
|
||||||
MeetingModel.end_date > current_time,
|
meetings.c.end_date > current_time,
|
||||||
MeetingModel.is_active,
|
meetings.c.is_active,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.order_by(MeetingModel.end_date.desc())
|
.order_by(end_date.desc())
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [Meeting.model_validate(row) for row in result.scalars().all()]
|
return [Meeting(**result) for result in results]
|
||||||
|
|
||||||
async def get_active_by_calendar_event(
|
async def get_active_by_calendar_event(
|
||||||
self,
|
self, room: Room, calendar_event_id: str, current_time: datetime
|
||||||
session: AsyncSession,
|
|
||||||
room: Room,
|
|
||||||
calendar_event_id: str,
|
|
||||||
current_time: datetime,
|
|
||||||
) -> Meeting | None:
|
) -> Meeting | None:
|
||||||
"""
|
"""
|
||||||
Get active meeting for a specific calendar event.
|
Get active meeting for a specific calendar event.
|
||||||
"""
|
"""
|
||||||
query = select(MeetingModel).where(
|
query = meetings.select().where(
|
||||||
sa.and_(
|
sa.and_(
|
||||||
MeetingModel.room_id == room.id,
|
meetings.c.room_id == room.id,
|
||||||
MeetingModel.calendar_event_id == calendar_event_id,
|
meetings.c.calendar_event_id == calendar_event_id,
|
||||||
MeetingModel.end_date > current_time,
|
meetings.c.end_date > current_time,
|
||||||
MeetingModel.is_active,
|
meetings.c.is_active,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Meeting.model_validate(row)
|
return Meeting(**result)
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(self, meeting_id: str, **kwargs) -> Meeting | None:
|
||||||
self, session: AsyncSession, meeting_id: str, **kwargs
|
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||||
) -> Meeting | None:
|
result = await get_database().fetch_one(query)
|
||||||
query = select(MeetingModel).where(MeetingModel.id == meeting_id)
|
if not result:
|
||||||
result = await session.execute(query)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Meeting.model_validate(row)
|
return Meeting(**result)
|
||||||
|
|
||||||
async def get_by_calendar_event(
|
async def get_by_calendar_event(self, calendar_event_id: str) -> Meeting | None:
|
||||||
self, session: AsyncSession, calendar_event_id: str
|
query = meetings.select().where(
|
||||||
) -> Meeting | None:
|
meetings.c.calendar_event_id == calendar_event_id
|
||||||
query = select(MeetingModel).where(
|
|
||||||
MeetingModel.calendar_event_id == calendar_event_id
|
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Meeting.model_validate(row)
|
return Meeting(**result)
|
||||||
|
|
||||||
async def update_meeting(self, session: AsyncSession, meeting_id: str, **kwargs):
|
async def update_meeting(self, meeting_id: str, **kwargs):
|
||||||
query = (
|
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
|
||||||
update(MeetingModel).where(MeetingModel.id == meeting_id).values(**kwargs)
|
await get_database().execute(query)
|
||||||
)
|
|
||||||
await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
class MeetingConsentController:
|
class MeetingConsentController:
|
||||||
async def get_by_meeting_id(
|
async def get_by_meeting_id(self, meeting_id: str) -> list[MeetingConsent]:
|
||||||
self, session: AsyncSession, meeting_id: str
|
query = meeting_consent.select().where(
|
||||||
) -> list[MeetingConsent]:
|
meeting_consent.c.meeting_id == meeting_id
|
||||||
query = select(MeetingConsentModel).where(
|
|
||||||
MeetingConsentModel.meeting_id == meeting_id
|
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [MeetingConsent.model_validate(row) for row in result.scalars().all()]
|
return [MeetingConsent(**result) for result in results]
|
||||||
|
|
||||||
async def get_by_meeting_and_user(
|
async def get_by_meeting_and_user(
|
||||||
self, session: AsyncSession, meeting_id: str, user_id: str
|
self, meeting_id: str, user_id: str
|
||||||
) -> MeetingConsent | None:
|
) -> MeetingConsent | None:
|
||||||
"""Get existing consent for a specific user and meeting"""
|
"""Get existing consent for a specific user and meeting"""
|
||||||
query = select(MeetingConsentModel).where(
|
query = meeting_consent.select().where(
|
||||||
sa.and_(
|
meeting_consent.c.meeting_id == meeting_id,
|
||||||
MeetingConsentModel.meeting_id == meeting_id,
|
meeting_consent.c.user_id == user_id,
|
||||||
MeetingConsentModel.user_id == user_id,
|
|
||||||
)
|
)
|
||||||
)
|
result = await get_database().fetch_one(query)
|
||||||
result = await session.execute(query)
|
if result is None:
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
return None
|
return None
|
||||||
return MeetingConsent.model_validate(row)
|
return MeetingConsent(**result)
|
||||||
|
|
||||||
async def upsert(
|
async def upsert(self, consent: MeetingConsent) -> MeetingConsent:
|
||||||
self, session: AsyncSession, consent: MeetingConsent
|
|
||||||
) -> MeetingConsent:
|
|
||||||
if consent.user_id:
|
if consent.user_id:
|
||||||
# For authenticated users, check if consent already exists
|
# For authenticated users, check if consent already exists
|
||||||
# not transactional but we're ok with that; the consents ain't deleted anyways
|
# not transactional but we're ok with that; the consents ain't deleted anyways
|
||||||
existing = await self.get_by_meeting_and_user(
|
existing = await self.get_by_meeting_and_user(
|
||||||
session, consent.meeting_id, consent.user_id
|
consent.meeting_id, consent.user_id
|
||||||
)
|
)
|
||||||
if existing:
|
if existing:
|
||||||
query = (
|
query = (
|
||||||
update(MeetingConsentModel)
|
meeting_consent.update()
|
||||||
.where(MeetingConsentModel.id == existing.id)
|
.where(meeting_consent.c.id == existing.id)
|
||||||
.values(
|
.values(
|
||||||
consent_given=consent.consent_given,
|
consent_given=consent.consent_given,
|
||||||
consent_timestamp=consent.consent_timestamp,
|
consent_timestamp=consent.consent_timestamp,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
existing.consent_given = consent.consent_given
|
existing.consent_given = consent.consent_given
|
||||||
existing.consent_timestamp = consent.consent_timestamp
|
existing.consent_timestamp = consent.consent_timestamp
|
||||||
return existing
|
return existing
|
||||||
|
|
||||||
new_consent = MeetingConsentModel(**consent.model_dump())
|
query = meeting_consent.insert().values(**consent.model_dump())
|
||||||
session.add(new_consent)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
return consent
|
return consent
|
||||||
|
|
||||||
async def has_any_denial(self, session: AsyncSession, meeting_id: str) -> bool:
|
async def has_any_denial(self, meeting_id: str) -> bool:
|
||||||
"""Check if any participant denied consent for this meeting"""
|
"""Check if any participant denied consent for this meeting"""
|
||||||
query = select(MeetingConsentModel).where(
|
query = meeting_consent.select().where(
|
||||||
sa.and_(
|
meeting_consent.c.meeting_id == meeting_id,
|
||||||
MeetingConsentModel.meeting_id == meeting_id,
|
meeting_consent.c.consent_given.is_(False),
|
||||||
MeetingConsentModel.consent_given.is_(False),
|
|
||||||
)
|
)
|
||||||
)
|
result = await get_database().fetch_one(query)
|
||||||
result = await session.execute(query)
|
return result is not None
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
return row is not None
|
|
||||||
|
|
||||||
|
|
||||||
meetings_controller = MeetingController()
|
meetings_controller = MeetingController()
|
||||||
|
|||||||
@@ -1,79 +1,61 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import delete, select
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db.base import RecordingModel
|
from reflector.db import get_database, metadata
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
|
recordings = sa.Table(
|
||||||
|
"recording",
|
||||||
|
metadata,
|
||||||
|
sa.Column("id", sa.String, primary_key=True),
|
||||||
|
sa.Column("bucket_name", sa.String, nullable=False),
|
||||||
|
sa.Column("object_key", sa.String, nullable=False),
|
||||||
|
sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
sa.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="pending",
|
||||||
|
),
|
||||||
|
sa.Column("meeting_id", sa.String),
|
||||||
|
sa.Index("idx_recording_meeting_id", "meeting_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Recording(BaseModel):
|
class Recording(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
meeting_id: str
|
bucket_name: str
|
||||||
url: str
|
|
||||||
object_key: str
|
object_key: str
|
||||||
duration: float | None = None
|
recorded_at: datetime
|
||||||
created_at: datetime
|
status: Literal["pending", "processing", "completed", "failed"] = "pending"
|
||||||
|
meeting_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class RecordingController:
|
class RecordingController:
|
||||||
async def create(
|
async def create(self, recording: Recording):
|
||||||
self,
|
query = recordings.insert().values(**recording.model_dump())
|
||||||
session: AsyncSession,
|
await get_database().execute(query)
|
||||||
meeting_id: str,
|
|
||||||
url: str,
|
|
||||||
object_key: str,
|
|
||||||
duration: float | None = None,
|
|
||||||
created_at: datetime | None = None,
|
|
||||||
):
|
|
||||||
if created_at is None:
|
|
||||||
created_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
recording = Recording(
|
|
||||||
meeting_id=meeting_id,
|
|
||||||
url=url,
|
|
||||||
object_key=object_key,
|
|
||||||
duration=duration,
|
|
||||||
created_at=created_at,
|
|
||||||
)
|
|
||||||
new_recording = RecordingModel(**recording.model_dump())
|
|
||||||
session.add(new_recording)
|
|
||||||
await session.commit()
|
|
||||||
return recording
|
return recording
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(self, id: str) -> Recording:
|
||||||
self, session: AsyncSession, recording_id: str
|
query = recordings.select().where(recordings.c.id == id)
|
||||||
) -> Recording | None:
|
result = await get_database().fetch_one(query)
|
||||||
"""
|
return Recording(**result) if result else None
|
||||||
Get a recording by id
|
|
||||||
"""
|
|
||||||
query = select(RecordingModel).where(RecordingModel.id == recording_id)
|
|
||||||
result = await session.execute(query)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if not row:
|
|
||||||
return None
|
|
||||||
return Recording.model_validate(row)
|
|
||||||
|
|
||||||
async def get_by_meeting_id(
|
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
|
||||||
self, session: AsyncSession, meeting_id: str
|
query = recordings.select().where(
|
||||||
) -> list[Recording]:
|
recordings.c.bucket_name == bucket_name,
|
||||||
"""
|
recordings.c.object_key == object_key,
|
||||||
Get all recordings for a meeting
|
)
|
||||||
"""
|
result = await get_database().fetch_one(query)
|
||||||
query = select(RecordingModel).where(RecordingModel.meeting_id == meeting_id)
|
return Recording(**result) if result else None
|
||||||
result = await session.execute(query)
|
|
||||||
return [Recording.model_validate(row) for row in result.scalars().all()]
|
|
||||||
|
|
||||||
async def remove_by_id(self, session: AsyncSession, recording_id: str) -> None:
|
async def remove_by_id(self, id: str) -> None:
|
||||||
"""
|
query = recordings.delete().where(recordings.c.id == id)
|
||||||
Remove a recording by id
|
await get_database().execute(query)
|
||||||
"""
|
|
||||||
query = delete(RecordingModel).where(RecordingModel.id == recording_id)
|
|
||||||
await session.execute(query)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
|
|
||||||
recordings_controller = RecordingController()
|
recordings_controller = RecordingController()
|
||||||
|
|||||||
@@ -3,19 +3,59 @@ from datetime import datetime, timezone
|
|||||||
from sqlite3 import IntegrityError
|
from sqlite3 import IntegrityError
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import delete, select, update
|
from sqlalchemy.sql import false, or_
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from sqlalchemy.sql import or_
|
|
||||||
|
|
||||||
from reflector.db.base import RoomModel
|
from reflector.db import get_database, metadata
|
||||||
from reflector.utils import generate_uuid4
|
from reflector.utils import generate_uuid4
|
||||||
|
|
||||||
|
rooms = sqlalchemy.Table(
|
||||||
|
"room",
|
||||||
|
metadata,
|
||||||
|
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||||
|
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
|
||||||
|
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
|
||||||
|
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
|
),
|
||||||
|
sqlalchemy.Column("zulip_stream", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("zulip_topic", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"is_locked", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"room_mode", sqlalchemy.String, nullable=False, server_default="normal"
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"recording_type", sqlalchemy.String, nullable=False, server_default="cloud"
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"recording_trigger",
|
||||||
|
sqlalchemy.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="automatic-2nd-participant",
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"is_shared", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
|
),
|
||||||
|
sqlalchemy.Column("webhook_url", sqlalchemy.String, nullable=True),
|
||||||
|
sqlalchemy.Column("webhook_secret", sqlalchemy.String, nullable=True),
|
||||||
|
sqlalchemy.Column("ics_url", sqlalchemy.Text),
|
||||||
|
sqlalchemy.Column("ics_fetch_interval", sqlalchemy.Integer, server_default="300"),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"ics_enabled", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
|
),
|
||||||
|
sqlalchemy.Column("ics_last_sync", sqlalchemy.DateTime(timezone=True)),
|
||||||
|
sqlalchemy.Column("ics_last_etag", sqlalchemy.Text),
|
||||||
|
sqlalchemy.Index("idx_room_is_shared", "is_shared"),
|
||||||
|
sqlalchemy.Index("idx_room_ics_enabled", "ics_enabled"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Room(BaseModel):
|
class Room(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
name: str
|
name: str
|
||||||
user_id: str
|
user_id: str
|
||||||
@@ -42,7 +82,6 @@ class Room(BaseModel):
|
|||||||
class RoomController:
|
class RoomController:
|
||||||
async def get_all(
|
async def get_all(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
order_by: str | None = None,
|
order_by: str | None = None,
|
||||||
return_query: bool = False,
|
return_query: bool = False,
|
||||||
@@ -56,14 +95,14 @@ class RoomController:
|
|||||||
Parameters:
|
Parameters:
|
||||||
- `order_by`: field to order by, e.g. "-created_at"
|
- `order_by`: field to order by, e.g. "-created_at"
|
||||||
"""
|
"""
|
||||||
query = select(RoomModel)
|
query = rooms.select()
|
||||||
if user_id is not None:
|
if user_id is not None:
|
||||||
query = query.where(or_(RoomModel.user_id == user_id, RoomModel.is_shared))
|
query = query.where(or_(rooms.c.user_id == user_id, rooms.c.is_shared))
|
||||||
else:
|
else:
|
||||||
query = query.where(RoomModel.is_shared)
|
query = query.where(rooms.c.is_shared)
|
||||||
|
|
||||||
if order_by is not None:
|
if order_by is not None:
|
||||||
field = getattr(RoomModel, order_by[1:])
|
field = getattr(rooms.c, order_by[1:])
|
||||||
if order_by.startswith("-"):
|
if order_by.startswith("-"):
|
||||||
field = field.desc()
|
field = field.desc()
|
||||||
query = query.order_by(field)
|
query = query.order_by(field)
|
||||||
@@ -71,12 +110,11 @@ class RoomController:
|
|||||||
if return_query:
|
if return_query:
|
||||||
return query
|
return query
|
||||||
|
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [Room.model_validate(row) for row in result.scalars().all()]
|
return results
|
||||||
|
|
||||||
async def add(
|
async def add(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
name: str,
|
name: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
zulip_auto_post: bool,
|
zulip_auto_post: bool,
|
||||||
@@ -116,27 +154,23 @@ class RoomController:
|
|||||||
ics_fetch_interval=ics_fetch_interval,
|
ics_fetch_interval=ics_fetch_interval,
|
||||||
ics_enabled=ics_enabled,
|
ics_enabled=ics_enabled,
|
||||||
)
|
)
|
||||||
new_room = RoomModel(**room.model_dump())
|
query = rooms.insert().values(**room.model_dump())
|
||||||
session.add(new_room)
|
|
||||||
try:
|
try:
|
||||||
await session.flush()
|
await get_database().execute(query)
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||||
return room
|
return room
|
||||||
|
|
||||||
async def update(
|
async def update(self, room: Room, values: dict, mutate=True):
|
||||||
self, session: AsyncSession, room: Room, values: dict, mutate=True
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Update a room fields with key/values in values
|
Update a room fields with key/values in values
|
||||||
"""
|
"""
|
||||||
if values.get("webhook_url") and not values.get("webhook_secret"):
|
if values.get("webhook_url") and not values.get("webhook_secret"):
|
||||||
values["webhook_secret"] = secrets.token_urlsafe(32)
|
values["webhook_secret"] = secrets.token_urlsafe(32)
|
||||||
|
|
||||||
query = update(RoomModel).where(RoomModel.id == room.id).values(**values)
|
query = rooms.update().where(rooms.c.id == room.id).values(**values)
|
||||||
try:
|
try:
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.flush()
|
|
||||||
except IntegrityError:
|
except IntegrityError:
|
||||||
raise HTTPException(status_code=400, detail="Room name is not unique")
|
raise HTTPException(status_code=400, detail="Room name is not unique")
|
||||||
|
|
||||||
@@ -144,79 +178,67 @@ class RoomController:
|
|||||||
for key, value in values.items():
|
for key, value in values.items():
|
||||||
setattr(room, key, value)
|
setattr(room, key, value)
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(self, room_id: str, **kwargs) -> Room | None:
|
||||||
self, session: AsyncSession, room_id: str, **kwargs
|
|
||||||
) -> Room | None:
|
|
||||||
"""
|
"""
|
||||||
Get a room by id
|
Get a room by id
|
||||||
"""
|
"""
|
||||||
query = select(RoomModel).where(RoomModel.id == room_id)
|
query = rooms.select().where(rooms.c.id == room_id)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(RoomModel.user_id == kwargs["user_id"])
|
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalars().first()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Room.model_validate(row)
|
return Room(**result)
|
||||||
|
|
||||||
async def get_by_name(
|
async def get_by_name(self, room_name: str, **kwargs) -> Room | None:
|
||||||
self, session: AsyncSession, room_name: str, **kwargs
|
|
||||||
) -> Room | None:
|
|
||||||
"""
|
"""
|
||||||
Get a room by name
|
Get a room by name
|
||||||
"""
|
"""
|
||||||
query = select(RoomModel).where(RoomModel.name == room_name)
|
query = rooms.select().where(rooms.c.name == room_name)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(RoomModel.user_id == kwargs["user_id"])
|
query = query.where(rooms.c.user_id == kwargs["user_id"])
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalars().first()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Room.model_validate(row)
|
return Room(**result)
|
||||||
|
|
||||||
async def get_by_id_for_http(
|
async def get_by_id_for_http(self, meeting_id: str, user_id: str | None) -> Room:
|
||||||
self, session: AsyncSession, meeting_id: str, user_id: str | None
|
|
||||||
) -> Room:
|
|
||||||
"""
|
"""
|
||||||
Get a room by ID for HTTP request.
|
Get a room by ID for HTTP request.
|
||||||
|
|
||||||
If not found, it will raise a 404 error.
|
If not found, it will raise a 404 error.
|
||||||
"""
|
"""
|
||||||
query = select(RoomModel).where(RoomModel.id == meeting_id)
|
query = rooms.select().where(rooms.c.id == meeting_id)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalars().first()
|
if not result:
|
||||||
if not row:
|
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
room = Room.model_validate(row)
|
room = Room(**result)
|
||||||
|
|
||||||
return room
|
return room
|
||||||
|
|
||||||
async def get_ics_enabled(self, session: AsyncSession) -> list[Room]:
|
async def get_ics_enabled(self) -> list[Room]:
|
||||||
query = select(RoomModel).where(
|
query = rooms.select().where(
|
||||||
RoomModel.ics_enabled == True, RoomModel.ics_url != None
|
rooms.c.ics_enabled == True, rooms.c.ics_url != None
|
||||||
)
|
)
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
results = result.scalars().all()
|
return [Room(**result) for result in results]
|
||||||
return [Room(**row.__dict__) for row in results]
|
|
||||||
|
|
||||||
async def remove_by_id(
|
async def remove_by_id(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Remove a room by id
|
Remove a room by id
|
||||||
"""
|
"""
|
||||||
room = await self.get_by_id(session, room_id, user_id=user_id)
|
room = await self.get_by_id(room_id, user_id=user_id)
|
||||||
if not room:
|
if not room:
|
||||||
return
|
return
|
||||||
if user_id is not None and room.user_id != user_id:
|
if user_id is not None and room.user_id != user_id:
|
||||||
return
|
return
|
||||||
query = delete(RoomModel).where(RoomModel.id == room_id)
|
query = rooms.delete().where(rooms.c.id == room_id)
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.flush()
|
|
||||||
|
|
||||||
|
|
||||||
rooms_controller = RoomController()
|
rooms_controller = RoomController()
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from typing import Annotated, Any, Dict, Iterator
|
|||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
import webvtt
|
import webvtt
|
||||||
|
from databases.interfaces import Record as DbRecord
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -19,10 +20,11 @@ from pydantic import (
|
|||||||
constr,
|
constr,
|
||||||
field_serializer,
|
field_serializer,
|
||||||
)
|
)
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db.base import RoomModel, TranscriptModel
|
from reflector.db import get_database
|
||||||
from reflector.db.transcripts import SourceKind, TranscriptStatus
|
from reflector.db.rooms import rooms
|
||||||
|
from reflector.db.transcripts import SourceKind, TranscriptStatus, transcripts
|
||||||
|
from reflector.db.utils import is_postgresql
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
|
from reflector.utils.string import NonEmptyString, try_parse_non_empty_string
|
||||||
|
|
||||||
@@ -133,6 +135,8 @@ class SearchParameters(BaseModel):
|
|||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
room_id: str | None = None
|
room_id: str | None = None
|
||||||
source_kind: SourceKind | None = None
|
source_kind: SourceKind | None = None
|
||||||
|
from_datetime: datetime | None = None
|
||||||
|
to_datetime: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
class SearchResultDB(BaseModel):
|
class SearchResultDB(BaseModel):
|
||||||
@@ -329,30 +333,36 @@ class SearchController:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def search_transcripts(
|
async def search_transcripts(
|
||||||
cls, session: AsyncSession, params: SearchParameters
|
cls, params: SearchParameters
|
||||||
) -> tuple[list[SearchResult], int]:
|
) -> tuple[list[SearchResult], int]:
|
||||||
"""
|
"""
|
||||||
Full-text search for transcripts using PostgreSQL tsvector.
|
Full-text search for transcripts using PostgreSQL tsvector.
|
||||||
Returns (results, total_count).
|
Returns (results, total_count).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if not is_postgresql():
|
||||||
|
logger.warning(
|
||||||
|
"Full-text search requires PostgreSQL. Returning empty results."
|
||||||
|
)
|
||||||
|
return [], 0
|
||||||
|
|
||||||
base_columns = [
|
base_columns = [
|
||||||
TranscriptModel.id,
|
transcripts.c.id,
|
||||||
TranscriptModel.title,
|
transcripts.c.title,
|
||||||
TranscriptModel.created_at,
|
transcripts.c.created_at,
|
||||||
TranscriptModel.duration,
|
transcripts.c.duration,
|
||||||
TranscriptModel.status,
|
transcripts.c.status,
|
||||||
TranscriptModel.user_id,
|
transcripts.c.user_id,
|
||||||
TranscriptModel.room_id,
|
transcripts.c.room_id,
|
||||||
TranscriptModel.source_kind,
|
transcripts.c.source_kind,
|
||||||
TranscriptModel.webvtt,
|
transcripts.c.webvtt,
|
||||||
TranscriptModel.long_summary,
|
transcripts.c.long_summary,
|
||||||
sqlalchemy.case(
|
sqlalchemy.case(
|
||||||
(
|
(
|
||||||
TranscriptModel.room_id.isnot(None) & RoomModel.id.is_(None),
|
transcripts.c.room_id.isnot(None) & rooms.c.id.is_(None),
|
||||||
"Deleted Room",
|
"Deleted Room",
|
||||||
),
|
),
|
||||||
else_=RoomModel.name,
|
else_=rooms.c.name,
|
||||||
).label("room_name"),
|
).label("room_name"),
|
||||||
]
|
]
|
||||||
search_query = None
|
search_query = None
|
||||||
@@ -361,7 +371,7 @@ class SearchController:
|
|||||||
"english", params.query_text
|
"english", params.query_text
|
||||||
)
|
)
|
||||||
rank_column = sqlalchemy.func.ts_rank(
|
rank_column = sqlalchemy.func.ts_rank(
|
||||||
TranscriptModel.search_vector_en,
|
transcripts.c.search_vector_en,
|
||||||
search_query,
|
search_query,
|
||||||
32, # normalization flag: rank/(rank+1) for 0-1 range
|
32, # normalization flag: rank/(rank+1) for 0-1 range
|
||||||
).label("rank")
|
).label("rank")
|
||||||
@@ -369,51 +379,55 @@ class SearchController:
|
|||||||
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
rank_column = sqlalchemy.cast(1.0, sqlalchemy.Float).label("rank")
|
||||||
|
|
||||||
columns = base_columns + [rank_column]
|
columns = base_columns + [rank_column]
|
||||||
base_query = (
|
base_query = sqlalchemy.select(columns).select_from(
|
||||||
sqlalchemy.select(*columns)
|
transcripts.join(rooms, transcripts.c.room_id == rooms.c.id, isouter=True)
|
||||||
.select_from(TranscriptModel)
|
|
||||||
.outerjoin(RoomModel, TranscriptModel.room_id == RoomModel.id)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.query_text is not None:
|
if params.query_text is not None:
|
||||||
# because already initialized based on params.query_text presence above
|
# because already initialized based on params.query_text presence above
|
||||||
assert search_query is not None
|
assert search_query is not None
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
TranscriptModel.search_vector_en.op("@@")(search_query)
|
transcripts.c.search_vector_en.op("@@")(search_query)
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.user_id:
|
if params.user_id:
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
sqlalchemy.or_(
|
sqlalchemy.or_(
|
||||||
TranscriptModel.user_id == params.user_id, RoomModel.is_shared
|
transcripts.c.user_id == params.user_id, rooms.c.is_shared
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
base_query = base_query.where(RoomModel.is_shared)
|
base_query = base_query.where(rooms.c.is_shared)
|
||||||
if params.room_id:
|
if params.room_id:
|
||||||
base_query = base_query.where(TranscriptModel.room_id == params.room_id)
|
base_query = base_query.where(transcripts.c.room_id == params.room_id)
|
||||||
if params.source_kind:
|
if params.source_kind:
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
TranscriptModel.source_kind == params.source_kind
|
transcripts.c.source_kind == params.source_kind
|
||||||
|
)
|
||||||
|
if params.from_datetime:
|
||||||
|
base_query = base_query.where(
|
||||||
|
transcripts.c.created_at >= params.from_datetime
|
||||||
|
)
|
||||||
|
if params.to_datetime:
|
||||||
|
base_query = base_query.where(
|
||||||
|
transcripts.c.created_at <= params.to_datetime
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.query_text is not None:
|
if params.query_text is not None:
|
||||||
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
order_by = sqlalchemy.desc(sqlalchemy.text("rank"))
|
||||||
else:
|
else:
|
||||||
order_by = sqlalchemy.desc(TranscriptModel.created_at)
|
order_by = sqlalchemy.desc(transcripts.c.created_at)
|
||||||
|
|
||||||
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
|
query = base_query.order_by(order_by).limit(params.limit).offset(params.offset)
|
||||||
|
|
||||||
result = await session.execute(query)
|
rs = await get_database().fetch_all(query)
|
||||||
rs = result.mappings().all()
|
|
||||||
|
|
||||||
count_query = sqlalchemy.select(sqlalchemy.func.count()).select_from(
|
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
|
||||||
base_query.alias("search_results")
|
base_query.alias("search_results")
|
||||||
)
|
)
|
||||||
count_result = await session.execute(count_query)
|
total = await get_database().fetch_val(count_query)
|
||||||
total = count_result.scalar()
|
|
||||||
|
|
||||||
def _process_result(r: dict) -> SearchResult:
|
def _process_result(r: DbRecord) -> SearchResult:
|
||||||
r_dict: Dict[str, Any] = dict(r)
|
r_dict: Dict[str, Any] = dict(r)
|
||||||
|
|
||||||
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
webvtt_raw: str | None = r_dict.pop("webvtt", None)
|
||||||
|
|||||||
@@ -7,14 +7,17 @@ from datetime import datetime, timedelta, timezone
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||||
from sqlalchemy import delete, insert, select, update
|
from sqlalchemy import Enum
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||||
from sqlalchemy.sql import or_
|
from sqlalchemy.sql import false, or_
|
||||||
|
|
||||||
from reflector.db.base import RoomModel, TranscriptModel
|
from reflector.db import get_database, metadata
|
||||||
from reflector.db.recordings import recordings_controller
|
from reflector.db.recordings import recordings_controller
|
||||||
|
from reflector.db.rooms import rooms
|
||||||
|
from reflector.db.utils import is_postgresql
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.processors.types import Word as ProcessorWord
|
from reflector.processors.types import Word as ProcessorWord
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
@@ -29,6 +32,91 @@ class SourceKind(enum.StrEnum):
|
|||||||
FILE = enum.auto()
|
FILE = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
|
transcripts = sqlalchemy.Table(
|
||||||
|
"transcript",
|
||||||
|
metadata,
|
||||||
|
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||||
|
sqlalchemy.Column("name", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("status", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("locked", sqlalchemy.Boolean),
|
||||||
|
sqlalchemy.Column("duration", sqlalchemy.Float),
|
||||||
|
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
|
||||||
|
sqlalchemy.Column("title", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("short_summary", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("long_summary", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("topics", sqlalchemy.JSON),
|
||||||
|
sqlalchemy.Column("events", sqlalchemy.JSON),
|
||||||
|
sqlalchemy.Column("participants", sqlalchemy.JSON),
|
||||||
|
sqlalchemy.Column("source_language", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("target_language", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"reviewed", sqlalchemy.Boolean, nullable=False, server_default=false()
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"audio_location",
|
||||||
|
sqlalchemy.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="local",
|
||||||
|
),
|
||||||
|
# with user attached, optional
|
||||||
|
sqlalchemy.Column("user_id", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"share_mode",
|
||||||
|
sqlalchemy.String,
|
||||||
|
nullable=False,
|
||||||
|
server_default="private",
|
||||||
|
),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"meeting_id",
|
||||||
|
sqlalchemy.String,
|
||||||
|
),
|
||||||
|
sqlalchemy.Column("recording_id", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("zulip_message_id", sqlalchemy.Integer),
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"source_kind",
|
||||||
|
Enum(SourceKind, values_callable=lambda obj: [e.value for e in obj]),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
# indicative field: whether associated audio is deleted
|
||||||
|
# the main "audio deleted" is the presence of the audio itself / consents not-given
|
||||||
|
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
|
||||||
|
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
|
||||||
|
sqlalchemy.Column("room_id", sqlalchemy.String),
|
||||||
|
sqlalchemy.Column("webvtt", sqlalchemy.Text),
|
||||||
|
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
|
||||||
|
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
|
||||||
|
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
|
||||||
|
sqlalchemy.Index("idx_transcript_user_id_recording_id", "user_id", "recording_id"),
|
||||||
|
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
|
||||||
|
sqlalchemy.Index("idx_transcript_source_kind", "source_kind"),
|
||||||
|
sqlalchemy.Index("idx_transcript_room_id_created_at", "room_id", "created_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add PostgreSQL-specific full-text search column
|
||||||
|
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
|
||||||
|
if is_postgresql():
|
||||||
|
transcripts.append_column(
|
||||||
|
sqlalchemy.Column(
|
||||||
|
"search_vector_en",
|
||||||
|
TSVECTOR,
|
||||||
|
sqlalchemy.Computed(
|
||||||
|
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
|
||||||
|
"setweight(to_tsvector('english', coalesce(long_summary, '')), 'B') || "
|
||||||
|
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'C')",
|
||||||
|
persisted=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Add GIN index for the search vector
|
||||||
|
transcripts.append_constraint(
|
||||||
|
sqlalchemy.Index(
|
||||||
|
"idx_transcript_search_vector_en",
|
||||||
|
"search_vector_en",
|
||||||
|
postgresql_using="gin",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_transcript_name() -> str:
|
def generate_transcript_name() -> str:
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
return f"Transcript {now.strftime('%Y-%m-%d %H:%M:%S')}"
|
||||||
@@ -103,8 +191,6 @@ class TranscriptParticipant(BaseModel):
|
|||||||
class Transcript(BaseModel):
|
class Transcript(BaseModel):
|
||||||
"""Full transcript model with all fields."""
|
"""Full transcript model with all fields."""
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: str = Field(default_factory=generate_uuid4)
|
id: str = Field(default_factory=generate_uuid4)
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
name: str = Field(default_factory=generate_transcript_name)
|
name: str = Field(default_factory=generate_transcript_name)
|
||||||
@@ -273,7 +359,6 @@ class Transcript(BaseModel):
|
|||||||
class TranscriptController:
|
class TranscriptController:
|
||||||
async def get_all(
|
async def get_all(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
order_by: str | None = None,
|
order_by: str | None = None,
|
||||||
filter_empty: bool | None = False,
|
filter_empty: bool | None = False,
|
||||||
@@ -298,114 +383,102 @@ class TranscriptController:
|
|||||||
- `search_term`: filter transcripts by search term
|
- `search_term`: filter transcripts by search term
|
||||||
"""
|
"""
|
||||||
|
|
||||||
query = select(TranscriptModel).join(
|
query = transcripts.select().join(
|
||||||
RoomModel, TranscriptModel.room_id == RoomModel.id, isouter=True
|
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id:
|
if user_id:
|
||||||
query = query.where(
|
query = query.where(
|
||||||
or_(TranscriptModel.user_id == user_id, RoomModel.is_shared)
|
or_(transcripts.c.user_id == user_id, rooms.c.is_shared)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
query = query.where(RoomModel.is_shared)
|
query = query.where(rooms.c.is_shared)
|
||||||
|
|
||||||
if source_kind:
|
if source_kind:
|
||||||
query = query.where(TranscriptModel.source_kind == source_kind)
|
query = query.where(transcripts.c.source_kind == source_kind)
|
||||||
|
|
||||||
if room_id:
|
if room_id:
|
||||||
query = query.where(TranscriptModel.room_id == room_id)
|
query = query.where(transcripts.c.room_id == room_id)
|
||||||
|
|
||||||
if search_term:
|
if search_term:
|
||||||
query = query.where(TranscriptModel.title.ilike(f"%{search_term}%"))
|
query = query.where(transcripts.c.title.ilike(f"%{search_term}%"))
|
||||||
|
|
||||||
# Exclude heavy JSON columns from list queries
|
# Exclude heavy JSON columns from list queries
|
||||||
# Get all ORM column attributes except excluded ones
|
|
||||||
transcript_columns = [
|
transcript_columns = [
|
||||||
getattr(TranscriptModel, col.name)
|
col for col in transcripts.c if col.name not in exclude_columns
|
||||||
for col in TranscriptModel.__table__.c
|
|
||||||
if col.name not in exclude_columns
|
|
||||||
]
|
]
|
||||||
|
|
||||||
query = query.with_only_columns(
|
query = query.with_only_columns(
|
||||||
*transcript_columns,
|
transcript_columns
|
||||||
RoomModel.name.label("room_name"),
|
+ [
|
||||||
|
rooms.c.name.label("room_name"),
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
if order_by is not None:
|
if order_by is not None:
|
||||||
field = getattr(TranscriptModel, order_by[1:])
|
field = getattr(transcripts.c, order_by[1:])
|
||||||
if order_by.startswith("-"):
|
if order_by.startswith("-"):
|
||||||
field = field.desc()
|
field = field.desc()
|
||||||
query = query.order_by(field)
|
query = query.order_by(field)
|
||||||
|
|
||||||
if filter_empty:
|
if filter_empty:
|
||||||
query = query.filter(TranscriptModel.status != "idle")
|
query = query.filter(transcripts.c.status != "idle")
|
||||||
|
|
||||||
if filter_recording:
|
if filter_recording:
|
||||||
query = query.filter(TranscriptModel.status != "recording")
|
query = query.filter(transcripts.c.status != "recording")
|
||||||
|
|
||||||
# print(query.compile(compile_kwargs={"literal_binds": True}))
|
# print(query.compile(compile_kwargs={"literal_binds": True}))
|
||||||
|
|
||||||
if return_query:
|
if return_query:
|
||||||
return query
|
return query
|
||||||
|
|
||||||
result = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [dict(row) for row in result.mappings().all()]
|
return results
|
||||||
|
|
||||||
async def get_by_id(
|
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
|
||||||
self, session: AsyncSession, transcript_id: str, **kwargs
|
|
||||||
) -> Transcript | None:
|
|
||||||
"""
|
"""
|
||||||
Get a transcript by id
|
Get a transcript by id
|
||||||
"""
|
"""
|
||||||
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Transcript.model_validate(row)
|
return Transcript(**result)
|
||||||
|
|
||||||
async def get_by_recording_id(
|
async def get_by_recording_id(
|
||||||
self, session: AsyncSession, recording_id: str, **kwargs
|
self, recording_id: str, **kwargs
|
||||||
) -> Transcript | None:
|
) -> Transcript | None:
|
||||||
"""
|
"""
|
||||||
Get a transcript by recording_id
|
Get a transcript by recording_id
|
||||||
"""
|
"""
|
||||||
query = select(TranscriptModel).where(
|
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
|
||||||
TranscriptModel.recording_id == recording_id
|
|
||||||
)
|
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
return None
|
return None
|
||||||
return Transcript.model_validate(row)
|
return Transcript(**result)
|
||||||
|
|
||||||
async def get_by_room_id(
|
async def get_by_room_id(self, room_id: str, **kwargs) -> list[Transcript]:
|
||||||
self, session: AsyncSession, room_id: str, **kwargs
|
|
||||||
) -> list[Transcript]:
|
|
||||||
"""
|
"""
|
||||||
Get transcripts by room_id (direct access without joins)
|
Get transcripts by room_id (direct access without joins)
|
||||||
"""
|
"""
|
||||||
query = select(TranscriptModel).where(TranscriptModel.room_id == room_id)
|
query = transcripts.select().where(transcripts.c.room_id == room_id)
|
||||||
if "user_id" in kwargs:
|
if "user_id" in kwargs:
|
||||||
query = query.where(TranscriptModel.user_id == kwargs["user_id"])
|
query = query.where(transcripts.c.user_id == kwargs["user_id"])
|
||||||
if "order_by" in kwargs:
|
if "order_by" in kwargs:
|
||||||
order_by = kwargs["order_by"]
|
order_by = kwargs["order_by"]
|
||||||
field = getattr(TranscriptModel, order_by[1:])
|
field = getattr(transcripts.c, order_by[1:])
|
||||||
if order_by.startswith("-"):
|
if order_by.startswith("-"):
|
||||||
field = field.desc()
|
field = field.desc()
|
||||||
query = query.order_by(field)
|
query = query.order_by(field)
|
||||||
results = await session.execute(query)
|
results = await get_database().fetch_all(query)
|
||||||
return [
|
return [Transcript(**result) for result in results]
|
||||||
Transcript.model_validate(dict(row)) for row in results.mappings().all()
|
|
||||||
]
|
|
||||||
|
|
||||||
async def get_by_id_for_http(
|
async def get_by_id_for_http(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user_id: str | None,
|
user_id: str | None,
|
||||||
) -> Transcript:
|
) -> Transcript:
|
||||||
@@ -418,14 +491,13 @@ class TranscriptController:
|
|||||||
This method checks the share mode of the transcript and the user_id
|
This method checks the share mode of the transcript and the user_id
|
||||||
to determine if the user can access the transcript.
|
to determine if the user can access the transcript.
|
||||||
"""
|
"""
|
||||||
query = select(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
query = transcripts.select().where(transcripts.c.id == transcript_id)
|
||||||
result = await session.execute(query)
|
result = await get_database().fetch_one(query)
|
||||||
row = result.scalar_one_or_none()
|
if not result:
|
||||||
if not row:
|
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
# if the transcript is anonymous, share mode is not checked
|
# if the transcript is anonymous, share mode is not checked
|
||||||
transcript = Transcript.model_validate(row)
|
transcript = Transcript(**result)
|
||||||
if transcript.user_id is None:
|
if transcript.user_id is None:
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
@@ -448,7 +520,6 @@ class TranscriptController:
|
|||||||
|
|
||||||
async def add(
|
async def add(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
name: str,
|
name: str,
|
||||||
source_kind: SourceKind,
|
source_kind: SourceKind,
|
||||||
source_language: str = "en",
|
source_language: str = "en",
|
||||||
@@ -473,15 +544,14 @@ class TranscriptController:
|
|||||||
meeting_id=meeting_id,
|
meeting_id=meeting_id,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
)
|
)
|
||||||
query = insert(TranscriptModel).values(**transcript.model_dump())
|
query = transcripts.insert().values(**transcript.model_dump())
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
return transcript
|
return transcript
|
||||||
|
|
||||||
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
|
||||||
# using mutate=True is discouraged
|
# using mutate=True is discouraged
|
||||||
async def update(
|
async def update(
|
||||||
self, session: AsyncSession, transcript: Transcript, values: dict, mutate=False
|
self, transcript: Transcript, values: dict, mutate=False
|
||||||
) -> Transcript:
|
) -> Transcript:
|
||||||
"""
|
"""
|
||||||
Update a transcript fields with key/values in values.
|
Update a transcript fields with key/values in values.
|
||||||
@@ -490,12 +560,11 @@ class TranscriptController:
|
|||||||
values = TranscriptController._handle_topics_update(values)
|
values = TranscriptController._handle_topics_update(values)
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
update(TranscriptModel)
|
transcripts.update()
|
||||||
.where(TranscriptModel.id == transcript.id)
|
.where(transcripts.c.id == transcript.id)
|
||||||
.values(**values)
|
.values(**values)
|
||||||
)
|
)
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
if mutate:
|
if mutate:
|
||||||
for key, value in values.items():
|
for key, value in values.items():
|
||||||
setattr(transcript, key, value)
|
setattr(transcript, key, value)
|
||||||
@@ -524,14 +593,13 @@ class TranscriptController:
|
|||||||
|
|
||||||
async def remove_by_id(
|
async def remove_by_id(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Remove a transcript by id
|
Remove a transcript by id
|
||||||
"""
|
"""
|
||||||
transcript = await self.get_by_id(session, transcript_id)
|
transcript = await self.get_by_id(transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
return
|
return
|
||||||
if user_id is not None and transcript.user_id != user_id:
|
if user_id is not None and transcript.user_id != user_id:
|
||||||
@@ -551,7 +619,7 @@ class TranscriptController:
|
|||||||
if transcript.recording_id:
|
if transcript.recording_id:
|
||||||
try:
|
try:
|
||||||
recording = await recordings_controller.get_by_id(
|
recording = await recordings_controller.get_by_id(
|
||||||
session, transcript.recording_id
|
transcript.recording_id
|
||||||
)
|
)
|
||||||
if recording:
|
if recording:
|
||||||
try:
|
try:
|
||||||
@@ -562,40 +630,46 @@ class TranscriptController:
|
|||||||
exc_info=e,
|
exc_info=e,
|
||||||
recording_id=transcript.recording_id,
|
recording_id=transcript.recording_id,
|
||||||
)
|
)
|
||||||
await recordings_controller.remove_by_id(
|
await recordings_controller.remove_by_id(transcript.recording_id)
|
||||||
session, transcript.recording_id
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to delete recording row",
|
"Failed to delete recording row",
|
||||||
exc_info=e,
|
exc_info=e,
|
||||||
recording_id=transcript.recording_id,
|
recording_id=transcript.recording_id,
|
||||||
)
|
)
|
||||||
query = delete(TranscriptModel).where(TranscriptModel.id == transcript_id)
|
query = transcripts.delete().where(transcripts.c.id == transcript_id)
|
||||||
await session.execute(query)
|
await get_database().execute(query)
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def remove_by_recording_id(self, session: AsyncSession, recording_id: str):
|
async def remove_by_recording_id(self, recording_id: str):
|
||||||
"""
|
"""
|
||||||
Remove a transcript by recording_id
|
Remove a transcript by recording_id
|
||||||
"""
|
"""
|
||||||
query = delete(TranscriptModel).where(
|
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
|
||||||
TranscriptModel.recording_id == recording_id
|
await get_database().execute(query)
|
||||||
)
|
|
||||||
await session.execute(query)
|
@staticmethod
|
||||||
await session.commit()
|
def user_can_mutate(transcript: Transcript, user_id: str | None) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the given user is allowed to modify the transcript.
|
||||||
|
|
||||||
|
Policy:
|
||||||
|
- Anonymous transcripts (user_id is None) cannot be modified via API
|
||||||
|
- Only the owner (matching user_id) can modify their transcript
|
||||||
|
"""
|
||||||
|
if transcript.user_id is None:
|
||||||
|
return False
|
||||||
|
return user_id and transcript.user_id == user_id
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def transaction(self, session: AsyncSession):
|
async def transaction(self):
|
||||||
"""
|
"""
|
||||||
A context manager for database transaction
|
A context manager for database transaction
|
||||||
"""
|
"""
|
||||||
async with session.begin():
|
async with get_database().transaction(isolation="serializable"):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
async def append_event(
|
async def append_event(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
event: str,
|
event: str,
|
||||||
data: Any,
|
data: Any,
|
||||||
@@ -604,12 +678,11 @@ class TranscriptController:
|
|||||||
Append an event to a transcript
|
Append an event to a transcript
|
||||||
"""
|
"""
|
||||||
resp = transcript.add_event(event=event, data=data)
|
resp = transcript.add_event(event=event, data=data)
|
||||||
await self.update(session, transcript, {"events": transcript.events_dump()})
|
await self.update(transcript, {"events": transcript.events_dump()})
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
async def upsert_topic(
|
async def upsert_topic(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
topic: TranscriptTopic,
|
topic: TranscriptTopic,
|
||||||
) -> TranscriptEvent:
|
) -> TranscriptEvent:
|
||||||
@@ -617,9 +690,9 @@ class TranscriptController:
|
|||||||
Upsert topics to a transcript
|
Upsert topics to a transcript
|
||||||
"""
|
"""
|
||||||
transcript.upsert_topic(topic)
|
transcript.upsert_topic(topic)
|
||||||
await self.update(session, transcript, {"topics": transcript.topics_dump()})
|
await self.update(transcript, {"topics": transcript.topics_dump()})
|
||||||
|
|
||||||
async def move_mp3_to_storage(self, session: AsyncSession, transcript: Transcript):
|
async def move_mp3_to_storage(self, transcript: Transcript):
|
||||||
"""
|
"""
|
||||||
Move mp3 file to storage
|
Move mp3 file to storage
|
||||||
"""
|
"""
|
||||||
@@ -643,16 +716,12 @@ class TranscriptController:
|
|||||||
|
|
||||||
# indicate on the transcript that the audio is now on storage
|
# indicate on the transcript that the audio is now on storage
|
||||||
# mutates transcript argument
|
# mutates transcript argument
|
||||||
await self.update(
|
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
|
||||||
session, transcript, {"audio_location": "storage"}, mutate=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# unlink the local file
|
# unlink the local file
|
||||||
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
transcript.audio_mp3_filename.unlink(missing_ok=True)
|
||||||
|
|
||||||
async def download_mp3_from_storage(
|
async def download_mp3_from_storage(self, transcript: Transcript):
|
||||||
self, session: AsyncSession, transcript: Transcript
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Download audio from storage
|
Download audio from storage
|
||||||
"""
|
"""
|
||||||
@@ -664,7 +733,6 @@ class TranscriptController:
|
|||||||
|
|
||||||
async def upsert_participant(
|
async def upsert_participant(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
participant: TranscriptParticipant,
|
participant: TranscriptParticipant,
|
||||||
) -> TranscriptParticipant:
|
) -> TranscriptParticipant:
|
||||||
@@ -672,14 +740,11 @@ class TranscriptController:
|
|||||||
Add/update a participant to a transcript
|
Add/update a participant to a transcript
|
||||||
"""
|
"""
|
||||||
result = transcript.upsert_participant(participant)
|
result = transcript.upsert_participant(participant)
|
||||||
await self.update(
|
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||||
session, transcript, {"participants": transcript.participants_dump()}
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def delete_participant(
|
async def delete_participant(
|
||||||
self,
|
self,
|
||||||
session: AsyncSession,
|
|
||||||
transcript: Transcript,
|
transcript: Transcript,
|
||||||
participant_id: str,
|
participant_id: str,
|
||||||
):
|
):
|
||||||
@@ -687,31 +752,28 @@ class TranscriptController:
|
|||||||
Delete a participant from a transcript
|
Delete a participant from a transcript
|
||||||
"""
|
"""
|
||||||
transcript.delete_participant(participant_id)
|
transcript.delete_participant(participant_id)
|
||||||
await self.update(
|
await self.update(transcript, {"participants": transcript.participants_dump()})
|
||||||
session, transcript, {"participants": transcript.participants_dump()}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def set_status(
|
async def set_status(
|
||||||
self, session: AsyncSession, transcript_id: str, status: TranscriptStatus
|
self, transcript_id: str, status: TranscriptStatus
|
||||||
) -> TranscriptEvent | None:
|
) -> TranscriptEvent | None:
|
||||||
"""
|
"""
|
||||||
Update the status of a transcript
|
Update the status of a transcript
|
||||||
|
|
||||||
Will add an event STATUS + update the status field of transcript
|
Will add an event STATUS + update the status field of transcript
|
||||||
"""
|
"""
|
||||||
async with self.transaction(session):
|
async with self.transaction():
|
||||||
transcript = await self.get_by_id(session, transcript_id)
|
transcript = await self.get_by_id(transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
raise Exception(f"Transcript {transcript_id} not found")
|
||||||
if transcript.status == status:
|
if transcript.status == status:
|
||||||
return
|
return
|
||||||
resp = await self.append_event(
|
resp = await self.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="STATUS",
|
event="STATUS",
|
||||||
data=StrValue(value=status),
|
data=StrValue(value=status),
|
||||||
)
|
)
|
||||||
await self.update(session, transcript, {"status": status})
|
await self.update(transcript, {"status": status})
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
90
server/reflector/db/user_api_keys.py
Normal file
90
server/reflector/db/user_api_keys.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
import hmac
|
||||||
|
import secrets
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from hashlib import sha256
|
||||||
|
|
||||||
|
import sqlalchemy
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from reflector.db import get_database, metadata
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.utils import generate_uuid4
|
||||||
|
from reflector.utils.string import NonEmptyString
|
||||||
|
|
||||||
|
user_api_keys = sqlalchemy.Table(
|
||||||
|
"user_api_key",
|
||||||
|
metadata,
|
||||||
|
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
|
||||||
|
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
|
||||||
|
sqlalchemy.Column("key_hash", sqlalchemy.String, nullable=False),
|
||||||
|
sqlalchemy.Column("name", sqlalchemy.String, nullable=True),
|
||||||
|
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
|
||||||
|
sqlalchemy.Index("idx_user_api_key_hash", "key_hash", unique=True),
|
||||||
|
sqlalchemy.Index("idx_user_api_key_user_id", "user_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserApiKey(BaseModel):
|
||||||
|
id: NonEmptyString = Field(default_factory=generate_uuid4)
|
||||||
|
user_id: NonEmptyString
|
||||||
|
key_hash: NonEmptyString
|
||||||
|
name: NonEmptyString | None = None
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
|
||||||
|
class UserApiKeyController:
|
||||||
|
@staticmethod
|
||||||
|
def generate_key() -> NonEmptyString:
|
||||||
|
return secrets.token_urlsafe(48)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash_key(key: NonEmptyString) -> str:
|
||||||
|
return hmac.new(
|
||||||
|
settings.SECRET_KEY.encode(), key.encode(), digestmod=sha256
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def create_key(
|
||||||
|
cls,
|
||||||
|
user_id: NonEmptyString,
|
||||||
|
name: NonEmptyString | None = None,
|
||||||
|
) -> tuple[UserApiKey, NonEmptyString]:
|
||||||
|
plaintext = cls.generate_key()
|
||||||
|
api_key = UserApiKey(
|
||||||
|
user_id=user_id,
|
||||||
|
key_hash=cls.hash_key(plaintext),
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
query = user_api_keys.insert().values(**api_key.model_dump())
|
||||||
|
await get_database().execute(query)
|
||||||
|
return api_key, plaintext
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def verify_key(cls, plaintext_key: NonEmptyString) -> UserApiKey | None:
|
||||||
|
key_hash = cls.hash_key(plaintext_key)
|
||||||
|
query = user_api_keys.select().where(
|
||||||
|
user_api_keys.c.key_hash == key_hash,
|
||||||
|
)
|
||||||
|
result = await get_database().fetch_one(query)
|
||||||
|
return UserApiKey(**result) if result else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def list_by_user_id(user_id: NonEmptyString) -> list[UserApiKey]:
|
||||||
|
query = (
|
||||||
|
user_api_keys.select()
|
||||||
|
.where(user_api_keys.c.user_id == user_id)
|
||||||
|
.order_by(user_api_keys.c.created_at.desc())
|
||||||
|
)
|
||||||
|
results = await get_database().fetch_all(query)
|
||||||
|
return [UserApiKey(**r) for r in results]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def delete_key(key_id: NonEmptyString, user_id: NonEmptyString) -> bool:
|
||||||
|
query = user_api_keys.delete().where(
|
||||||
|
(user_api_keys.c.id == key_id) & (user_api_keys.c.user_id == user_id)
|
||||||
|
)
|
||||||
|
result = await get_database().execute(query)
|
||||||
|
return result > 0
|
||||||
|
|
||||||
|
|
||||||
|
user_api_keys_controller = UserApiKeyController()
|
||||||
9
server/reflector/db/utils.py
Normal file
9
server/reflector/db/utils.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
"""Database utility functions."""
|
||||||
|
|
||||||
|
from reflector.db import get_database
|
||||||
|
|
||||||
|
|
||||||
|
def is_postgresql() -> bool:
|
||||||
|
return get_database().url.scheme and get_database().url.scheme.startswith(
|
||||||
|
"postgresql"
|
||||||
|
)
|
||||||
@@ -13,10 +13,8 @@ from pathlib import Path
|
|||||||
import av
|
import av
|
||||||
import structlog
|
import structlog
|
||||||
from celery import chain, shared_task
|
from celery import chain, shared_task
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.asynctask import asynctask
|
from reflector.asynctask import asynctask
|
||||||
from reflector.db import get_session_factory
|
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
SourceKind,
|
SourceKind,
|
||||||
@@ -55,7 +53,6 @@ from reflector.processors.types import (
|
|||||||
)
|
)
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.storage import get_transcripts_storage
|
from reflector.storage import get_transcripts_storage
|
||||||
from reflector.worker.session_decorator import with_session
|
|
||||||
from reflector.worker.webhook import send_transcript_webhook
|
from reflector.worker.webhook import send_transcript_webhook
|
||||||
|
|
||||||
|
|
||||||
@@ -100,23 +97,17 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
async def set_status(self, transcript_id: str, status: TranscriptStatus):
|
||||||
async with self.lock_transaction():
|
async with self.lock_transaction():
|
||||||
async with get_session_factory()() as session:
|
return await transcripts_controller.set_status(transcript_id, status)
|
||||||
return await transcripts_controller.set_status(
|
|
||||||
session, transcript_id, status
|
|
||||||
)
|
|
||||||
|
|
||||||
async def process(self, file_path: Path):
|
async def process(self, file_path: Path):
|
||||||
"""Main entry point for file processing"""
|
"""Main entry point for file processing"""
|
||||||
self.logger.info(f"Starting file pipeline for {file_path}")
|
self.logger.info(f"Starting file pipeline for {file_path}")
|
||||||
|
|
||||||
async with get_session_factory()() as session:
|
transcript = await self.get_transcript()
|
||||||
transcript = await transcripts_controller.get_by_id(
|
|
||||||
session, self.transcript_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# Clear transcript as we're going to regenerate everything
|
# Clear transcript as we're going to regenerate everything
|
||||||
|
async with self.transaction():
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"events": [],
|
"events": [],
|
||||||
@@ -132,7 +123,6 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
# Run parallel processing
|
# Run parallel processing
|
||||||
await self.run_parallel_processing(
|
await self.run_parallel_processing(
|
||||||
session,
|
|
||||||
audio_path,
|
audio_path,
|
||||||
audio_url,
|
audio_url,
|
||||||
transcript.source_language,
|
transcript.source_language,
|
||||||
@@ -141,8 +131,7 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
self.logger.info("File pipeline complete")
|
self.logger.info("File pipeline complete")
|
||||||
|
|
||||||
async with get_session_factory()() as session:
|
await self.set_status(transcript.id, "ended")
|
||||||
await transcripts_controller.set_status(session, transcript.id, "ended")
|
|
||||||
|
|
||||||
async def extract_and_write_audio(
|
async def extract_and_write_audio(
|
||||||
self, file_path: Path, transcript: Transcript
|
self, file_path: Path, transcript: Transcript
|
||||||
@@ -204,7 +193,6 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
async def run_parallel_processing(
|
async def run_parallel_processing(
|
||||||
self,
|
self,
|
||||||
session,
|
|
||||||
audio_path: Path,
|
audio_path: Path,
|
||||||
audio_url: str,
|
audio_url: str,
|
||||||
source_language: str,
|
source_language: str,
|
||||||
@@ -218,7 +206,7 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
# Phase 1: Parallel processing of independent tasks
|
# Phase 1: Parallel processing of independent tasks
|
||||||
transcription_task = self.transcribe_file(audio_url, source_language)
|
transcription_task = self.transcribe_file(audio_url, source_language)
|
||||||
diarization_task = self.diarize_file(audio_url)
|
diarization_task = self.diarize_file(audio_url)
|
||||||
waveform_task = self.generate_waveform(session, audio_path)
|
waveform_task = self.generate_waveform(audio_path)
|
||||||
|
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
transcription_task, diarization_task, waveform_task, return_exceptions=True
|
||||||
@@ -266,7 +254,7 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
)
|
)
|
||||||
results = await asyncio.gather(
|
results = await asyncio.gather(
|
||||||
self.generate_title(topics),
|
self.generate_title(topics),
|
||||||
self.generate_summaries(session, topics),
|
self.generate_summaries(topics),
|
||||||
return_exceptions=True,
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -318,9 +306,9 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
self.logger.error(f"Diarization failed: {e}")
|
self.logger.error(f"Diarization failed: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def generate_waveform(self, session: AsyncSession, audio_path: Path):
|
async def generate_waveform(self, audio_path: Path):
|
||||||
"""Generate and save waveform"""
|
"""Generate and save waveform"""
|
||||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
processor = AudioWaveformProcessor(
|
processor = AudioWaveformProcessor(
|
||||||
audio_path=audio_path,
|
audio_path=audio_path,
|
||||||
@@ -373,13 +361,13 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
await processor.flush()
|
await processor.flush()
|
||||||
|
|
||||||
async def generate_summaries(self, session, topics: list[TitleSummary]):
|
async def generate_summaries(self, topics: list[TitleSummary]):
|
||||||
"""Generate long and short summaries from topics"""
|
"""Generate long and short summaries from topics"""
|
||||||
if not topics:
|
if not topics:
|
||||||
self.logger.warning("No topics for summary generation")
|
self.logger.warning("No topics for summary generation")
|
||||||
return
|
return
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(session, self.transcript_id)
|
transcript = await self.get_transcript()
|
||||||
processor = TranscriptFinalSummaryProcessor(
|
processor = TranscriptFinalSummaryProcessor(
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
callback=self.on_long_summary,
|
callback=self.on_long_summary,
|
||||||
@@ -395,15 +383,14 @@ class PipelineMainFile(PipelineMainBase):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session
|
async def task_send_webhook_if_needed(*, transcript_id: str):
|
||||||
async def task_send_webhook_if_needed(session, *, transcript_id: str):
|
|
||||||
"""Send webhook if this is a room recording with webhook configured"""
|
"""Send webhook if this is a room recording with webhook configured"""
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
return
|
return
|
||||||
|
|
||||||
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
if transcript.source_kind == SourceKind.ROOM and transcript.room_id:
|
||||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
room = await rooms_controller.get_by_id(transcript.room_id)
|
||||||
if room and room.webhook_url:
|
if room and room.webhook_url:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Dispatching webhook",
|
"Dispatching webhook",
|
||||||
@@ -418,10 +405,10 @@ async def task_send_webhook_if_needed(session, *, transcript_id: str):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session
|
async def task_pipeline_file_process(*, transcript_id: str):
|
||||||
async def task_pipeline_file_process(session, *, transcript_id: str):
|
|
||||||
"""Celery task for file pipeline processing"""
|
"""Celery task for file pipeline processing"""
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
raise Exception(f"Transcript {transcript_id} not found")
|
||||||
|
|
||||||
@@ -439,7 +426,12 @@ async def task_pipeline_file_process(session, *, transcript_id: str):
|
|||||||
|
|
||||||
await pipeline.process(audio_file)
|
await pipeline.process(audio_file)
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"File pipeline failed for transcript {transcript_id}: {type(e).__name__}: {str(e)}",
|
||||||
|
exc_info=True,
|
||||||
|
transcript_id=transcript_id,
|
||||||
|
)
|
||||||
await pipeline.set_status(transcript_id, "error")
|
await pipeline.set_status(transcript_id, "error")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|||||||
@@ -20,11 +20,9 @@ import av
|
|||||||
import boto3
|
import boto3
|
||||||
from celery import chord, current_task, group, shared_task
|
from celery import chord, current_task, group, shared_task
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
from structlog import BoundLogger as Logger
|
from structlog import BoundLogger as Logger
|
||||||
|
|
||||||
from reflector.asynctask import asynctask
|
from reflector.asynctask import asynctask
|
||||||
from reflector.db import get_session_factory
|
|
||||||
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
from reflector.db.meetings import meeting_consent_controller, meetings_controller
|
||||||
from reflector.db.recordings import recordings_controller
|
from reflector.db.recordings import recordings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
@@ -64,7 +62,6 @@ from reflector.processors.types import (
|
|||||||
from reflector.processors.types import Transcript as TranscriptProcessorType
|
from reflector.processors.types import Transcript as TranscriptProcessorType
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.storage import get_transcripts_storage
|
from reflector.storage import get_transcripts_storage
|
||||||
from reflector.worker.session_decorator import with_session_and_transcript
|
|
||||||
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
from reflector.ws_manager import WebsocketManager, get_ws_manager
|
||||||
from reflector.zulip import (
|
from reflector.zulip import (
|
||||||
get_zulip_message,
|
get_zulip_message,
|
||||||
@@ -88,6 +85,20 @@ def broadcast_to_sockets(func):
|
|||||||
message=resp.model_dump(mode="json"),
|
message=resp.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
transcript = await transcripts_controller.get_by_id(self.transcript_id)
|
||||||
|
if transcript and transcript.user_id:
|
||||||
|
# Emit only relevant events to the user room to avoid noisy updates.
|
||||||
|
# Allowed: STATUS, FINAL_TITLE, DURATION. All are prefixed with TRANSCRIPT_
|
||||||
|
allowed_user_events = {"STATUS", "FINAL_TITLE", "DURATION"}
|
||||||
|
if resp.event in allowed_user_events:
|
||||||
|
await self.ws_manager.send_json(
|
||||||
|
room_id=f"user:{transcript.user_id}",
|
||||||
|
message={
|
||||||
|
"event": f"TRANSCRIPT_{resp.event}",
|
||||||
|
"data": {"id": self.transcript_id, **resp.data},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
@@ -99,10 +110,9 @@ def get_transcript(func):
|
|||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(**kwargs):
|
async def wrapper(**kwargs):
|
||||||
transcript_id = kwargs.pop("transcript_id")
|
transcript_id = kwargs.pop("transcript_id")
|
||||||
async with get_session_factory()() as session:
|
transcript = await transcripts_controller.get_by_id(transcript_id=transcript_id)
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
raise Exception("Transcript {transcript_id} not found")
|
||||||
|
|
||||||
# Enhanced logger with Celery task context
|
# Enhanced logger with Celery task context
|
||||||
tlogger = logger.bind(transcript_id=transcript.id)
|
tlogger = logger.bind(transcript_id=transcript.id)
|
||||||
@@ -143,9 +153,11 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
self._ws_manager = get_ws_manager()
|
self._ws_manager = get_ws_manager()
|
||||||
return self._ws_manager
|
return self._ws_manager
|
||||||
|
|
||||||
async def get_transcript(self, session: AsyncSession) -> Transcript:
|
async def get_transcript(self) -> Transcript:
|
||||||
# fetch the transcript
|
# fetch the transcript
|
||||||
result = await transcripts_controller.get_by_id(session, self.transcript_id)
|
result = await transcripts_controller.get_by_id(
|
||||||
|
transcript_id=self.transcript_id
|
||||||
|
)
|
||||||
if not result:
|
if not result:
|
||||||
raise Exception("Transcript not found")
|
raise Exception("Transcript not found")
|
||||||
return result
|
return result
|
||||||
@@ -177,8 +189,8 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def transaction(self):
|
async def transaction(self):
|
||||||
async with self.lock_transaction():
|
async with self.lock_transaction():
|
||||||
async with get_session_factory()() as session:
|
async with transcripts_controller.transaction():
|
||||||
yield session
|
yield
|
||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_status(self, status):
|
async def on_status(self, status):
|
||||||
@@ -209,17 +221,13 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
|
|
||||||
# when the status of the pipeline changes, update the transcript
|
# when the status of the pipeline changes, update the transcript
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
async with get_session_factory()() as session:
|
return await transcripts_controller.set_status(self.transcript_id, status)
|
||||||
return await transcripts_controller.set_status(
|
|
||||||
session, self.transcript_id, status
|
|
||||||
)
|
|
||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_transcript(self, data):
|
async def on_transcript(self, data):
|
||||||
async with self.transaction() as session:
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="TRANSCRIPT",
|
event="TRANSCRIPT",
|
||||||
data=TranscriptText(text=data.text, translation=data.translation),
|
data=TranscriptText(text=data.text, translation=data.translation),
|
||||||
@@ -236,11 +244,10 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
)
|
)
|
||||||
if isinstance(data, TitleSummaryWithIdProcessorType):
|
if isinstance(data, TitleSummaryWithIdProcessorType):
|
||||||
topic.id = data.id
|
topic.id = data.id
|
||||||
async with self.transaction() as session:
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
await transcripts_controller.upsert_topic(session, transcript, topic)
|
await transcripts_controller.upsert_topic(transcript, topic)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="TOPIC",
|
event="TOPIC",
|
||||||
data=topic,
|
data=topic,
|
||||||
@@ -249,18 +256,16 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_title(self, data):
|
async def on_title(self, data):
|
||||||
final_title = TranscriptFinalTitle(title=data.title)
|
final_title = TranscriptFinalTitle(title=data.title)
|
||||||
async with self.transaction() as session:
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
if not transcript.title:
|
if not transcript.title:
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"title": final_title.title,
|
"title": final_title.title,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="FINAL_TITLE",
|
event="FINAL_TITLE",
|
||||||
data=final_title,
|
data=final_title,
|
||||||
@@ -269,17 +274,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_long_summary(self, data):
|
async def on_long_summary(self, data):
|
||||||
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
final_long_summary = TranscriptFinalLongSummary(long_summary=data.long_summary)
|
||||||
async with self.transaction() as session:
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"long_summary": final_long_summary.long_summary,
|
"long_summary": final_long_summary.long_summary,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="FINAL_LONG_SUMMARY",
|
event="FINAL_LONG_SUMMARY",
|
||||||
data=final_long_summary,
|
data=final_long_summary,
|
||||||
@@ -290,17 +293,15 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
final_short_summary = TranscriptFinalShortSummary(
|
final_short_summary = TranscriptFinalShortSummary(
|
||||||
short_summary=data.short_summary
|
short_summary=data.short_summary
|
||||||
)
|
)
|
||||||
async with self.transaction() as session:
|
async with self.transaction():
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"short_summary": final_short_summary.short_summary,
|
"short_summary": final_short_summary.short_summary,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session,
|
|
||||||
transcript=transcript,
|
transcript=transcript,
|
||||||
event="FINAL_SHORT_SUMMARY",
|
event="FINAL_SHORT_SUMMARY",
|
||||||
data=final_short_summary,
|
data=final_short_summary,
|
||||||
@@ -308,30 +309,29 @@ class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]
|
|||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_duration(self, data):
|
async def on_duration(self, data):
|
||||||
async with self.transaction() as session:
|
async with self.transaction():
|
||||||
duration = TranscriptDuration(duration=data)
|
duration = TranscriptDuration(duration=data)
|
||||||
|
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"duration": duration.duration,
|
"duration": duration.duration,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session, transcript=transcript, event="DURATION", data=duration
|
transcript=transcript, event="DURATION", data=duration
|
||||||
)
|
)
|
||||||
|
|
||||||
@broadcast_to_sockets
|
@broadcast_to_sockets
|
||||||
async def on_waveform(self, data):
|
async def on_waveform(self, data):
|
||||||
async with self.transaction() as session:
|
async with self.transaction():
|
||||||
waveform = TranscriptWaveform(waveform=data)
|
waveform = TranscriptWaveform(waveform=data)
|
||||||
|
|
||||||
transcript = await self.get_transcript(session)
|
transcript = await self.get_transcript()
|
||||||
|
|
||||||
return await transcripts_controller.append_event(
|
return await transcripts_controller.append_event(
|
||||||
session, transcript=transcript, event="WAVEFORM", data=waveform
|
transcript=transcript, event="WAVEFORM", data=waveform
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -344,8 +344,7 @@ class PipelineMainLive(PipelineMainBase):
|
|||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# create a context for the whole rtc transaction
|
# create a context for the whole rtc transaction
|
||||||
# add a customised logger to the context
|
# add a customised logger to the context
|
||||||
async with get_session_factory()() as session:
|
transcript = await self.get_transcript()
|
||||||
transcript = await self.get_transcript(session)
|
|
||||||
|
|
||||||
processors = [
|
processors = [
|
||||||
AudioFileWriterProcessor(
|
AudioFileWriterProcessor(
|
||||||
@@ -393,8 +392,7 @@ class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
|
|||||||
# now let's start the pipeline by pushing information to the
|
# now let's start the pipeline by pushing information to the
|
||||||
# first processor diarization processor
|
# first processor diarization processor
|
||||||
# XXX translation is lost when converting our data model to the processor model
|
# XXX translation is lost when converting our data model to the processor model
|
||||||
async with get_session_factory()() as session:
|
transcript = await self.get_transcript()
|
||||||
transcript = await self.get_transcript(session)
|
|
||||||
|
|
||||||
# diarization works only if the file is uploaded to an external storage
|
# diarization works only if the file is uploaded to an external storage
|
||||||
if transcript.audio_location == "local":
|
if transcript.audio_location == "local":
|
||||||
@@ -427,8 +425,7 @@ class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
|
|||||||
|
|
||||||
async def create(self) -> Pipeline:
|
async def create(self) -> Pipeline:
|
||||||
# get transcript
|
# get transcript
|
||||||
async with get_session_factory()() as session:
|
self._transcript = transcript = await self.get_transcript()
|
||||||
self._transcript = transcript = await self.get_transcript(session)
|
|
||||||
|
|
||||||
# create pipeline
|
# create pipeline
|
||||||
processors = self.get_processors()
|
processors = self.get_processors()
|
||||||
@@ -533,7 +530,8 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Convert to mp3 done")
|
logger.info("Convert to mp3 done")
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_upload_mp3(transcript: Transcript, logger: Logger):
|
||||||
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
if not settings.TRANSCRIPT_STORAGE_BACKEND:
|
||||||
logger.info("No storage backend configured, skipping mp3 upload")
|
logger.info("No storage backend configured, skipping mp3 upload")
|
||||||
return
|
return
|
||||||
@@ -551,7 +549,7 @@ async def pipeline_upload_mp3(session, transcript: Transcript, logger: Logger):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Upload to external storage and delete the file
|
# Upload to external storage and delete the file
|
||||||
await transcripts_controller.move_mp3_to_storage(session, transcript)
|
await transcripts_controller.move_mp3_to_storage(transcript)
|
||||||
|
|
||||||
logger.info("Upload mp3 done")
|
logger.info("Upload mp3 done")
|
||||||
|
|
||||||
@@ -580,23 +578,20 @@ async def pipeline_summaries(transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Summaries done")
|
logger.info("Summaries done")
|
||||||
|
|
||||||
|
|
||||||
async def cleanup_consent(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def cleanup_consent(transcript: Transcript, logger: Logger):
|
||||||
logger.info("Starting consent cleanup")
|
logger.info("Starting consent cleanup")
|
||||||
|
|
||||||
consent_denied = False
|
consent_denied = False
|
||||||
recording = None
|
recording = None
|
||||||
try:
|
try:
|
||||||
if transcript.recording_id:
|
if transcript.recording_id:
|
||||||
recording = await recordings_controller.get_by_id(
|
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||||
session, transcript.recording_id
|
|
||||||
)
|
|
||||||
if recording and recording.meeting_id:
|
if recording and recording.meeting_id:
|
||||||
meeting = await meetings_controller.get_by_id(
|
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||||
session, recording.meeting_id
|
|
||||||
)
|
|
||||||
if meeting:
|
if meeting:
|
||||||
consent_denied = await meeting_consent_controller.has_any_denial(
|
consent_denied = await meeting_consent_controller.has_any_denial(
|
||||||
session, meeting.id
|
meeting.id
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
|
||||||
@@ -625,7 +620,7 @@ async def cleanup_consent(session, transcript: Transcript, logger: Logger):
|
|||||||
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
|
||||||
|
|
||||||
# non-transactional, files marked for deletion not actually deleted is possible
|
# non-transactional, files marked for deletion not actually deleted is possible
|
||||||
await transcripts_controller.update(session, transcript, {"audio_deleted": True})
|
await transcripts_controller.update(transcript, {"audio_deleted": True})
|
||||||
# 2. Delete processed audio from transcript storage S3 bucket
|
# 2. Delete processed audio from transcript storage S3 bucket
|
||||||
if transcript.audio_location == "storage":
|
if transcript.audio_location == "storage":
|
||||||
storage = get_transcripts_storage()
|
storage = get_transcripts_storage()
|
||||||
@@ -649,14 +644,15 @@ async def cleanup_consent(session, transcript: Transcript, logger: Logger):
|
|||||||
logger.info("Consent cleanup done")
|
logger.info("Consent cleanup done")
|
||||||
|
|
||||||
|
|
||||||
async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger):
|
@get_transcript
|
||||||
|
async def pipeline_post_to_zulip(transcript: Transcript, logger: Logger):
|
||||||
logger.info("Starting post to zulip")
|
logger.info("Starting post to zulip")
|
||||||
|
|
||||||
if not transcript.recording_id:
|
if not transcript.recording_id:
|
||||||
logger.info("Transcript has no recording")
|
logger.info("Transcript has no recording")
|
||||||
return
|
return
|
||||||
|
|
||||||
recording = await recordings_controller.get_by_id(session, transcript.recording_id)
|
recording = await recordings_controller.get_by_id(transcript.recording_id)
|
||||||
if not recording:
|
if not recording:
|
||||||
logger.info("Recording not found")
|
logger.info("Recording not found")
|
||||||
return
|
return
|
||||||
@@ -665,12 +661,12 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
|
|||||||
logger.info("Recording has no meeting")
|
logger.info("Recording has no meeting")
|
||||||
return
|
return
|
||||||
|
|
||||||
meeting = await meetings_controller.get_by_id(session, recording.meeting_id)
|
meeting = await meetings_controller.get_by_id(recording.meeting_id)
|
||||||
if not meeting:
|
if not meeting:
|
||||||
logger.info("No meeting found for this recording")
|
logger.info("No meeting found for this recording")
|
||||||
return
|
return
|
||||||
|
|
||||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||||
if not room:
|
if not room:
|
||||||
logger.error(f"Missing room for a meeting {meeting.id}")
|
logger.error(f"Missing room for a meeting {meeting.id}")
|
||||||
return
|
return
|
||||||
@@ -696,7 +692,7 @@ async def pipeline_post_to_zulip(session, transcript: Transcript, logger: Logger
|
|||||||
room.zulip_stream, room.zulip_topic, message
|
room.zulip_stream, room.zulip_topic, message
|
||||||
)
|
)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session, transcript, {"zulip_message_id": response["id"]}
|
transcript, {"zulip_message_id": response["id"]}
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Posted to zulip")
|
logger.info("Posted to zulip")
|
||||||
@@ -727,11 +723,8 @@ async def task_pipeline_convert_to_mp3(*, transcript_id: str):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session_and_transcript
|
async def task_pipeline_upload_mp3(*, transcript_id: str):
|
||||||
async def task_pipeline_upload_mp3(
|
await pipeline_upload_mp3(transcript_id=transcript_id)
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
||||||
):
|
|
||||||
await pipeline_upload_mp3(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@@ -754,20 +747,14 @@ async def task_pipeline_final_summaries(*, transcript_id: str):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session_and_transcript
|
async def task_cleanup_consent(*, transcript_id: str):
|
||||||
async def task_cleanup_consent(
|
await cleanup_consent(transcript_id=transcript_id)
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
||||||
):
|
|
||||||
await cleanup_consent(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session_and_transcript
|
async def task_pipeline_post_to_zulip(*, transcript_id: str):
|
||||||
async def task_pipeline_post_to_zulip(
|
await pipeline_post_to_zulip(transcript_id=transcript_id)
|
||||||
session, *, transcript: Transcript, logger: Logger, transcript_id: str
|
|
||||||
):
|
|
||||||
await pipeline_post_to_zulip(session, transcript=transcript, logger=logger)
|
|
||||||
|
|
||||||
|
|
||||||
def pipeline_post(*, transcript_id: str):
|
def pipeline_post(*, transcript_id: str):
|
||||||
@@ -799,11 +786,9 @@ def pipeline_post(*, transcript_id: str):
|
|||||||
async def pipeline_process(transcript: Transcript, logger: Logger):
|
async def pipeline_process(transcript: Transcript, logger: Logger):
|
||||||
try:
|
try:
|
||||||
if transcript.audio_location == "storage":
|
if transcript.audio_location == "storage":
|
||||||
async with get_session_factory()() as session:
|
|
||||||
await transcripts_controller.download_mp3_from_storage(transcript)
|
await transcripts_controller.download_mp3_from_storage(transcript)
|
||||||
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
transcript.audio_waveform_filename.unlink(missing_ok=True)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"topics": [],
|
"topics": [],
|
||||||
@@ -840,9 +825,7 @@ async def pipeline_process(transcript: Transcript, logger: Logger):
|
|||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("Pipeline error", exc_info=exc)
|
logger.error("Pipeline error", exc_info=exc)
|
||||||
async with get_session_factory()() as session:
|
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"status": "error",
|
"status": "error",
|
||||||
|
|||||||
@@ -56,6 +56,16 @@ class FileTranscriptModalProcessor(FileTranscriptProcessor):
|
|||||||
},
|
},
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
error_body = response.text
|
||||||
|
self.logger.error(
|
||||||
|
"Modal API error",
|
||||||
|
audio_url=data.audio_url,
|
||||||
|
status_code=response.status_code,
|
||||||
|
error_body=error_body,
|
||||||
|
)
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
|
|||||||
@@ -34,8 +34,16 @@ TOPIC_PROMPT = dedent(
|
|||||||
class TopicResponse(BaseModel):
|
class TopicResponse(BaseModel):
|
||||||
"""Structured response for topic detection"""
|
"""Structured response for topic detection"""
|
||||||
|
|
||||||
title: str = Field(description="A descriptive title for the topic being discussed")
|
title: str = Field(
|
||||||
summary: str = Field(description="A concise 1-2 sentence summary of the discussion")
|
description="A descriptive title for the topic being discussed",
|
||||||
|
validation_alias="Title",
|
||||||
|
)
|
||||||
|
summary: str = Field(
|
||||||
|
description="A concise 1-2 sentence summary of the discussion",
|
||||||
|
validation_alias="Summary",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = {"populate_by_name": True}
|
||||||
|
|
||||||
|
|
||||||
class TranscriptTopicDetectorProcessor(Processor):
|
class TranscriptTopicDetectorProcessor(Processor):
|
||||||
|
|||||||
@@ -55,7 +55,6 @@ import httpx
|
|||||||
import pytz
|
import pytz
|
||||||
import structlog
|
import structlog
|
||||||
from icalendar import Calendar, Event
|
from icalendar import Calendar, Event
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
|
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
|
||||||
from reflector.db.rooms import Room, rooms_controller
|
from reflector.db.rooms import Room, rooms_controller
|
||||||
@@ -295,7 +294,7 @@ class ICSSyncService:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.fetch_service = ICSFetchService()
|
self.fetch_service = ICSFetchService()
|
||||||
|
|
||||||
async def sync_room_calendar(self, session: AsyncSession, room: Room) -> SyncResult:
|
async def sync_room_calendar(self, room: Room) -> SyncResult:
|
||||||
async with RedisAsyncLock(
|
async with RedisAsyncLock(
|
||||||
f"ics_sync_room:{room.id}", skip_if_locked=True
|
f"ics_sync_room:{room.id}", skip_if_locked=True
|
||||||
) as lock:
|
) as lock:
|
||||||
@@ -306,11 +305,9 @@ class ICSSyncService:
|
|||||||
"reason": "Sync already in progress",
|
"reason": "Sync already in progress",
|
||||||
}
|
}
|
||||||
|
|
||||||
return await self._sync_room_calendar(session, room)
|
return await self._sync_room_calendar(room)
|
||||||
|
|
||||||
async def _sync_room_calendar(
|
async def _sync_room_calendar(self, room: Room) -> SyncResult:
|
||||||
self, session: AsyncSession, room: Room
|
|
||||||
) -> SyncResult:
|
|
||||||
if not room.ics_enabled or not room.ics_url:
|
if not room.ics_enabled or not room.ics_url:
|
||||||
return {"status": SyncStatus.SKIPPED, "reason": "ICS not configured"}
|
return {"status": SyncStatus.SKIPPED, "reason": "ICS not configured"}
|
||||||
|
|
||||||
@@ -343,11 +340,10 @@ class ICSSyncService:
|
|||||||
events, total_events = self.fetch_service.extract_room_events(
|
events, total_events = self.fetch_service.extract_room_events(
|
||||||
calendar, room.name, room_url
|
calendar, room.name, room_url
|
||||||
)
|
)
|
||||||
sync_result = await self._sync_events_to_database(session, room.id, events)
|
sync_result = await self._sync_events_to_database(room.id, events)
|
||||||
|
|
||||||
# Update room sync metadata
|
# Update room sync metadata
|
||||||
await rooms_controller.update(
|
await rooms_controller.update(
|
||||||
session,
|
|
||||||
room,
|
room,
|
||||||
{
|
{
|
||||||
"ics_last_sync": datetime.now(timezone.utc),
|
"ics_last_sync": datetime.now(timezone.utc),
|
||||||
@@ -376,7 +372,7 @@ class ICSSyncService:
|
|||||||
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
return time_since_sync.total_seconds() >= room.ics_fetch_interval
|
||||||
|
|
||||||
async def _sync_events_to_database(
|
async def _sync_events_to_database(
|
||||||
self, session: AsyncSession, room_id: str, events: list[EventData]
|
self, room_id: str, events: list[EventData]
|
||||||
) -> SyncStats:
|
) -> SyncStats:
|
||||||
created = 0
|
created = 0
|
||||||
updated = 0
|
updated = 0
|
||||||
@@ -386,7 +382,7 @@ class ICSSyncService:
|
|||||||
for event_data in events:
|
for event_data in events:
|
||||||
calendar_event = CalendarEvent(room_id=room_id, **event_data)
|
calendar_event = CalendarEvent(room_id=room_id, **event_data)
|
||||||
existing = await calendar_events_controller.get_by_ics_uid(
|
existing = await calendar_events_controller.get_by_ics_uid(
|
||||||
session, room_id, event_data["ics_uid"]
|
room_id, event_data["ics_uid"]
|
||||||
)
|
)
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
@@ -394,12 +390,12 @@ class ICSSyncService:
|
|||||||
else:
|
else:
|
||||||
created += 1
|
created += 1
|
||||||
|
|
||||||
await calendar_events_controller.upsert(session, calendar_event)
|
await calendar_events_controller.upsert(calendar_event)
|
||||||
current_ics_uids.append(event_data["ics_uid"])
|
current_ics_uids.append(event_data["ics_uid"])
|
||||||
|
|
||||||
# Soft delete events that are no longer in calendar
|
# Soft delete events that are no longer in calendar
|
||||||
deleted = await calendar_events_controller.soft_delete_missing(
|
deleted = await calendar_events_controller.soft_delete_missing(
|
||||||
session, room_id, current_ics_uids
|
room_id, current_ics_uids
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -9,12 +9,12 @@ async def export_db(filename: str) -> None:
|
|||||||
filename = pathlib.Path(filename).resolve()
|
filename = pathlib.Path(filename).resolve()
|
||||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||||
|
|
||||||
from reflector.db import get_session_factory
|
from reflector.db import get_database, transcripts
|
||||||
from reflector.db.transcripts import transcripts_controller
|
|
||||||
|
|
||||||
session_factory = get_session_factory()
|
database = get_database()
|
||||||
async with session_factory() as session:
|
await database.connect()
|
||||||
transcripts = await transcripts_controller.get_all(session)
|
transcripts = await database.fetch_all(transcripts.select())
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
def export_transcript(transcript, output_dir):
|
def export_transcript(transcript, output_dir):
|
||||||
for topic in transcript.topics:
|
for topic in transcript.topics:
|
||||||
|
|||||||
@@ -8,12 +8,12 @@ async def export_db(filename: str) -> None:
|
|||||||
filename = pathlib.Path(filename).resolve()
|
filename = pathlib.Path(filename).resolve()
|
||||||
settings.DATABASE_URL = f"sqlite:///{filename}"
|
settings.DATABASE_URL = f"sqlite:///{filename}"
|
||||||
|
|
||||||
from reflector.db import get_session_factory
|
from reflector.db import get_database, transcripts
|
||||||
from reflector.db.transcripts import transcripts_controller
|
|
||||||
|
|
||||||
session_factory = get_session_factory()
|
database = get_database()
|
||||||
async with session_factory() as session:
|
await database.connect()
|
||||||
transcripts = await transcripts_controller.get_all(session)
|
transcripts = await database.fetch_all(transcripts.select())
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
def export_transcript(transcript):
|
def export_transcript(transcript):
|
||||||
tid = transcript.id
|
tid = transcript.id
|
||||||
|
|||||||
@@ -11,9 +11,6 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal
|
from typing import Any, Dict, List, Literal
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db import get_session_factory
|
|
||||||
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
from reflector.db.transcripts import SourceKind, TranscriptTopic, transcripts_controller
|
||||||
from reflector.logger import logger
|
from reflector.logger import logger
|
||||||
from reflector.pipelines.main_file_pipeline import (
|
from reflector.pipelines.main_file_pipeline import (
|
||||||
@@ -53,7 +50,6 @@ TranscriptId = str
|
|||||||
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
|
# common interface for every flow: it needs an Entry in db with specific ceremony (file path + status + actual file in file system)
|
||||||
# ideally we want to get rid of it at some point
|
# ideally we want to get rid of it at some point
|
||||||
async def prepare_entry(
|
async def prepare_entry(
|
||||||
session: AsyncSession,
|
|
||||||
source_path: str,
|
source_path: str,
|
||||||
source_language: str,
|
source_language: str,
|
||||||
target_language: str,
|
target_language: str,
|
||||||
@@ -61,7 +57,6 @@ async def prepare_entry(
|
|||||||
file_path = Path(source_path)
|
file_path = Path(source_path)
|
||||||
|
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await transcripts_controller.add(
|
||||||
session,
|
|
||||||
file_path.name,
|
file_path.name,
|
||||||
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
|
# note that the real file upload has SourceKind: LIVE for the reason of it's an error
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
@@ -83,20 +78,16 @@ async def prepare_entry(
|
|||||||
logger.info(f"Copied {source_path} to {upload_path}")
|
logger.info(f"Copied {source_path} to {upload_path}")
|
||||||
|
|
||||||
# pipelines expect entity status "uploaded"
|
# pipelines expect entity status "uploaded"
|
||||||
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
return transcript.id
|
return transcript.id
|
||||||
|
|
||||||
|
|
||||||
# same reason as prepare_entry
|
# same reason as prepare_entry
|
||||||
async def extract_result_from_entry(
|
async def extract_result_from_entry(
|
||||||
session: AsyncSession,
|
transcript_id: TranscriptId, output_path: str
|
||||||
transcript_id: TranscriptId,
|
|
||||||
output_path: str,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
post_final_transcript = await transcripts_controller.get_by_id(
|
post_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
session, transcript_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# assert post_final_transcript.status == "ended"
|
# assert post_final_transcript.status == "ended"
|
||||||
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
|
# File pipeline doesn't set status to "ended", only live pipeline does https://github.com/Monadical-SAS/reflector/issues/582
|
||||||
@@ -124,7 +115,6 @@ async def extract_result_from_entry(
|
|||||||
|
|
||||||
|
|
||||||
async def process_live_pipeline(
|
async def process_live_pipeline(
|
||||||
session: AsyncSession,
|
|
||||||
transcript_id: TranscriptId,
|
transcript_id: TranscriptId,
|
||||||
):
|
):
|
||||||
"""Process transcript_id with transcription and diarization"""
|
"""Process transcript_id with transcription and diarization"""
|
||||||
@@ -133,9 +123,7 @@ async def process_live_pipeline(
|
|||||||
await live_pipeline_process(transcript_id=transcript_id)
|
await live_pipeline_process(transcript_id=transcript_id)
|
||||||
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
|
print(f"Processing complete for transcript {transcript_id}", file=sys.stderr)
|
||||||
|
|
||||||
pre_final_transcript = await transcripts_controller.get_by_id(
|
pre_final_transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
session, transcript_id
|
|
||||||
)
|
|
||||||
|
|
||||||
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
# assert documented behaviour: after process, the pipeline isn't ended. this is the reason of calling pipeline_post
|
||||||
assert pre_final_transcript.status != "ended"
|
assert pre_final_transcript.status != "ended"
|
||||||
@@ -172,17 +160,21 @@ async def process(
|
|||||||
pipeline: Literal["live", "file"],
|
pipeline: Literal["live", "file"],
|
||||||
output_path: str = None,
|
output_path: str = None,
|
||||||
):
|
):
|
||||||
session_factory = get_session_factory()
|
from reflector.db import get_database
|
||||||
async with session_factory() as session:
|
|
||||||
|
database = get_database()
|
||||||
|
# db connect is a part of ceremony
|
||||||
|
await database.connect()
|
||||||
|
|
||||||
|
try:
|
||||||
transcript_id = await prepare_entry(
|
transcript_id = await prepare_entry(
|
||||||
session,
|
|
||||||
source_path,
|
source_path,
|
||||||
source_language,
|
source_language,
|
||||||
target_language,
|
target_language,
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline_handlers = {
|
pipeline_handlers = {
|
||||||
"live": lambda tid: process_live_pipeline(session, tid),
|
"live": process_live_pipeline,
|
||||||
"file": process_file_pipeline,
|
"file": process_file_pipeline,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,7 +184,9 @@ async def process(
|
|||||||
|
|
||||||
await handler(transcript_id)
|
await handler(transcript_id)
|
||||||
|
|
||||||
await extract_result_from_entry(session, transcript_id, output_path)
|
await extract_result_from_entry(transcript_id, output_path)
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -5,13 +5,12 @@ from typing import Annotated, Any, Literal, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi_pagination import Page
|
from fastapi_pagination import Page
|
||||||
from fastapi_pagination.ext.sqlalchemy import paginate
|
from fastapi_pagination.ext.databases import apaginate
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from redis.exceptions import LockError
|
from redis.exceptions import LockError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import get_session
|
from reflector.db import get_database
|
||||||
from reflector.db.calendar_events import calendar_events_controller
|
from reflector.db.calendar_events import calendar_events_controller
|
||||||
from reflector.db.meetings import meetings_controller
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
@@ -177,29 +176,31 @@ def parse_datetime_with_timezone(iso_string: str) -> datetime:
|
|||||||
@router.get("/rooms", response_model=Page[RoomDetails])
|
@router.get("/rooms", response_model=Page[RoomDetails])
|
||||||
async def rooms_list(
|
async def rooms_list(
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[RoomDetails]:
|
) -> list[RoomDetails]:
|
||||||
if not user and not settings.PUBLIC_MODE:
|
if not user and not settings.PUBLIC_MODE:
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
query = await rooms_controller.get_all(
|
return await apaginate(
|
||||||
session, user_id=user_id, order_by="-created_at", return_query=True
|
get_database(),
|
||||||
|
await rooms_controller.get_all(
|
||||||
|
user_id=user_id, order_by="-created_at", return_query=True
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return await paginate(session, query)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/rooms/{room_id}", response_model=RoomDetails)
|
@router.get("/rooms/{room_id}", response_model=RoomDetails)
|
||||||
async def rooms_get(
|
async def rooms_get(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
|
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
if not room.is_shared and (user_id is None or room.user_id != user_id):
|
||||||
|
raise HTTPException(status_code=403, detail="Room access denied")
|
||||||
return room
|
return room
|
||||||
|
|
||||||
|
|
||||||
@@ -207,10 +208,9 @@ async def rooms_get(
|
|||||||
async def rooms_get_by_name(
|
async def rooms_get_by_name(
|
||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(session, room_name)
|
room = await rooms_controller.get_by_name(room_name)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
@@ -231,13 +231,11 @@ async def rooms_get_by_name(
|
|||||||
@router.post("/rooms", response_model=Room)
|
@router.post("/rooms", response_model=Room)
|
||||||
async def rooms_create(
|
async def rooms_create(
|
||||||
room: CreateRoom,
|
room: CreateRoom,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
|
|
||||||
return await rooms_controller.add(
|
return await rooms_controller.add(
|
||||||
session,
|
|
||||||
name=room.name,
|
name=room.name,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
zulip_auto_post=room.zulip_auto_post,
|
zulip_auto_post=room.zulip_auto_post,
|
||||||
@@ -260,29 +258,31 @@ async def rooms_create(
|
|||||||
async def rooms_update(
|
async def rooms_update(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
info: UpdateRoom,
|
info: UpdateRoom,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
room = await rooms_controller.get_by_id_for_http(session, room_id, user_id=user_id)
|
room = await rooms_controller.get_by_id_for_http(room_id, user_id=user_id)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
if room.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
values = info.dict(exclude_unset=True)
|
values = info.dict(exclude_unset=True)
|
||||||
await rooms_controller.update(session, room, values)
|
await rooms_controller.update(room, values)
|
||||||
return room
|
return room
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/rooms/{room_id}", response_model=DeletionStatus)
|
@router.delete("/rooms/{room_id}", response_model=DeletionStatus)
|
||||||
async def rooms_delete(
|
async def rooms_delete(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
room = await rooms_controller.get_by_id(session, room_id, user_id=user_id)
|
room = await rooms_controller.get_by_id(room_id)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
await rooms_controller.remove_by_id(session, room.id, user_id=user_id)
|
if room.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
|
await rooms_controller.remove_by_id(room.id, user_id=user_id)
|
||||||
return DeletionStatus(status="ok")
|
return DeletionStatus(status="ok")
|
||||||
|
|
||||||
|
|
||||||
@@ -291,10 +291,9 @@ async def rooms_create_meeting(
|
|||||||
room_name: str,
|
room_name: str,
|
||||||
info: CreateRoomMeeting,
|
info: CreateRoomMeeting,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(session, room_name)
|
room = await rooms_controller.get_by_name(room_name)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
@@ -310,7 +309,7 @@ async def rooms_create_meeting(
|
|||||||
meeting = None
|
meeting = None
|
||||||
if not info.allow_duplicated:
|
if not info.allow_duplicated:
|
||||||
meeting = await meetings_controller.get_active(
|
meeting = await meetings_controller.get_active(
|
||||||
session, room=room, current_time=current_time
|
room=room, current_time=current_time
|
||||||
)
|
)
|
||||||
|
|
||||||
if meeting is None:
|
if meeting is None:
|
||||||
@@ -321,7 +320,6 @@ async def rooms_create_meeting(
|
|||||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||||
|
|
||||||
meeting = await meetings_controller.create(
|
meeting = await meetings_controller.create(
|
||||||
session,
|
|
||||||
id=whereby_meeting["meetingId"],
|
id=whereby_meeting["meetingId"],
|
||||||
room_name=whereby_meeting["roomName"],
|
room_name=whereby_meeting["roomName"],
|
||||||
room_url=whereby_meeting["roomUrl"],
|
room_url=whereby_meeting["roomUrl"],
|
||||||
@@ -347,17 +345,16 @@ async def rooms_create_meeting(
|
|||||||
@router.post("/rooms/{room_id}/webhook/test", response_model=WebhookTestResult)
|
@router.post("/rooms/{room_id}/webhook/test", response_model=WebhookTestResult)
|
||||||
async def rooms_test_webhook(
|
async def rooms_test_webhook(
|
||||||
room_id: str,
|
room_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
"""Test webhook configuration by sending a sample payload."""
|
"""Test webhook configuration by sending a sample payload."""
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
|
|
||||||
room = await rooms_controller.get_by_id(session, room_id)
|
room = await rooms_controller.get_by_id(room_id)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
if user_id and room.user_id != user_id:
|
if room.user_id != user_id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=403, detail="Not authorized to test this room's webhook"
|
status_code=403, detail="Not authorized to test this room's webhook"
|
||||||
)
|
)
|
||||||
@@ -370,10 +367,9 @@ async def rooms_test_webhook(
|
|||||||
async def rooms_sync_ics(
|
async def rooms_sync_ics(
|
||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(session, room_name)
|
room = await rooms_controller.get_by_name(room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
@@ -386,7 +382,7 @@ async def rooms_sync_ics(
|
|||||||
if not room.ics_enabled or not room.ics_url:
|
if not room.ics_enabled or not room.ics_url:
|
||||||
raise HTTPException(status_code=400, detail="ICS not configured for this room")
|
raise HTTPException(status_code=400, detail="ICS not configured for this room")
|
||||||
|
|
||||||
result = await ics_sync_service.sync_room_calendar(session, room)
|
result = await ics_sync_service.sync_room_calendar(room)
|
||||||
|
|
||||||
if result["status"] == "error":
|
if result["status"] == "error":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -400,10 +396,9 @@ async def rooms_sync_ics(
|
|||||||
async def rooms_ics_status(
|
async def rooms_ics_status(
|
||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(session, room_name)
|
room = await rooms_controller.get_by_name(room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
@@ -418,7 +413,7 @@ async def rooms_ics_status(
|
|||||||
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
|
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
|
||||||
|
|
||||||
events = await calendar_events_controller.get_by_room(
|
events = await calendar_events_controller.get_by_room(
|
||||||
session, room.id, include_deleted=False
|
room.id, include_deleted=False
|
||||||
)
|
)
|
||||||
|
|
||||||
return ICSStatus(
|
return ICSStatus(
|
||||||
@@ -434,16 +429,15 @@ async def rooms_ics_status(
|
|||||||
async def rooms_list_meetings(
|
async def rooms_list_meetings(
|
||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(session, room_name)
|
room = await rooms_controller.get_by_name(room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
events = await calendar_events_controller.get_by_room(
|
events = await calendar_events_controller.get_by_room(
|
||||||
session, room.id, include_deleted=False
|
room.id, include_deleted=False
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id != room.user_id:
|
if user_id != room.user_id:
|
||||||
@@ -461,16 +455,15 @@ async def rooms_list_upcoming_meetings(
|
|||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
minutes_ahead: int = 120,
|
minutes_ahead: int = 120,
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(session, room_name)
|
room = await rooms_controller.get_by_name(room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
events = await calendar_events_controller.get_upcoming(
|
events = await calendar_events_controller.get_upcoming(
|
||||||
session, room.id, minutes_ahead=minutes_ahead
|
room.id, minutes_ahead=minutes_ahead
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id != room.user_id:
|
if user_id != room.user_id:
|
||||||
@@ -485,17 +478,16 @@ async def rooms_list_upcoming_meetings(
|
|||||||
async def rooms_list_active_meetings(
|
async def rooms_list_active_meetings(
|
||||||
room_name: str,
|
room_name: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(session, room_name)
|
room = await rooms_controller.get_by_name(room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
meetings = await meetings_controller.get_all_active_for_room(
|
meetings = await meetings_controller.get_all_active_for_room(
|
||||||
session, room=room, current_time=current_time
|
room=room, current_time=current_time
|
||||||
)
|
)
|
||||||
|
|
||||||
# Hide host URLs from non-owners
|
# Hide host URLs from non-owners
|
||||||
@@ -511,16 +503,15 @@ async def rooms_get_meeting(
|
|||||||
room_name: str,
|
room_name: str,
|
||||||
meeting_id: str,
|
meeting_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
"""Get a single meeting by ID within a specific room."""
|
"""Get a single meeting by ID within a specific room."""
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
room = await rooms_controller.get_by_name(session, room_name)
|
room = await rooms_controller.get_by_name(room_name)
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
meeting = await meetings_controller.get_by_id(session, meeting_id)
|
meeting = await meetings_controller.get_by_id(meeting_id)
|
||||||
if not meeting:
|
if not meeting:
|
||||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||||
|
|
||||||
@@ -540,15 +531,14 @@ async def rooms_join_meeting(
|
|||||||
room_name: str,
|
room_name: str,
|
||||||
meeting_id: str,
|
meeting_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
room = await rooms_controller.get_by_name(session, room_name)
|
room = await rooms_controller.get_by_name(room_name)
|
||||||
|
|
||||||
if not room:
|
if not room:
|
||||||
raise HTTPException(status_code=404, detail="Room not found")
|
raise HTTPException(status_code=404, detail="Room not found")
|
||||||
|
|
||||||
meeting = await meetings_controller.get_by_id(session, meeting_id)
|
meeting = await meetings_controller.get_by_id(meeting_id)
|
||||||
|
|
||||||
if not meeting:
|
if not meeting:
|
||||||
raise HTTPException(status_code=404, detail="Meeting not found")
|
raise HTTPException(status_code=404, detail="Meeting not found")
|
||||||
|
|||||||
@@ -3,15 +3,12 @@ from typing import Annotated, Literal, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from fastapi_pagination import Page
|
from fastapi_pagination import Page
|
||||||
from fastapi_pagination.ext.sqlalchemy import paginate
|
from fastapi_pagination.ext.databases import apaginate
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field, constr, field_serializer
|
from pydantic import AwareDatetime, BaseModel, Field, constr, field_serializer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import get_session
|
from reflector.db import get_database
|
||||||
from reflector.db.meetings import meetings_controller
|
|
||||||
from reflector.db.rooms import rooms_controller
|
|
||||||
from reflector.db.search import (
|
from reflector.db.search import (
|
||||||
DEFAULT_SEARCH_LIMIT,
|
DEFAULT_SEARCH_LIMIT,
|
||||||
SearchLimit,
|
SearchLimit,
|
||||||
@@ -35,6 +32,7 @@ from reflector.db.transcripts import (
|
|||||||
from reflector.processors.types import Transcript as ProcessorTranscript
|
from reflector.processors.types import Transcript as ProcessorTranscript
|
||||||
from reflector.processors.types import Word
|
from reflector.processors.types import Word
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
from reflector.ws_manager import get_ws_manager
|
||||||
from reflector.zulip import (
|
from reflector.zulip import (
|
||||||
InvalidMessageError,
|
InvalidMessageError,
|
||||||
get_zulip_message,
|
get_zulip_message,
|
||||||
@@ -135,6 +133,21 @@ SearchOffsetParam = Annotated[
|
|||||||
SearchOffsetBase, Query(description="Number of results to skip")
|
SearchOffsetBase, Query(description="Number of results to skip")
|
||||||
]
|
]
|
||||||
|
|
||||||
|
SearchFromDatetimeParam = Annotated[
|
||||||
|
AwareDatetime | None,
|
||||||
|
Query(
|
||||||
|
alias="from",
|
||||||
|
description="Filter transcripts created on or after this datetime (ISO 8601 with timezone)",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
SearchToDatetimeParam = Annotated[
|
||||||
|
AwareDatetime | None,
|
||||||
|
Query(
|
||||||
|
alias="to",
|
||||||
|
description="Filter transcripts created on or before this datetime (ISO 8601 with timezone)",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(BaseModel):
|
class SearchResponse(BaseModel):
|
||||||
results: list[SearchResult]
|
results: list[SearchResult]
|
||||||
@@ -150,25 +163,24 @@ async def transcripts_list(
|
|||||||
source_kind: SourceKind | None = None,
|
source_kind: SourceKind | None = None,
|
||||||
room_id: str | None = None,
|
room_id: str | None = None,
|
||||||
search_term: str | None = None,
|
search_term: str | None = None,
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
if not user and not settings.PUBLIC_MODE:
|
if not user and not settings.PUBLIC_MODE:
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
query = await transcripts_controller.get_all(
|
return await apaginate(
|
||||||
session,
|
get_database(),
|
||||||
|
await transcripts_controller.get_all(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
source_kind=SourceKind(source_kind) if source_kind else None,
|
source_kind=SourceKind(source_kind) if source_kind else None,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
search_term=search_term,
|
search_term=search_term,
|
||||||
order_by="-created_at",
|
order_by="-created_at",
|
||||||
return_query=True,
|
return_query=True,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return await paginate(session, query)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/transcripts/search", response_model=SearchResponse)
|
@router.get("/transcripts/search", response_model=SearchResponse)
|
||||||
async def transcripts_search(
|
async def transcripts_search(
|
||||||
@@ -177,19 +189,23 @@ async def transcripts_search(
|
|||||||
offset: SearchOffsetParam = 0,
|
offset: SearchOffsetParam = 0,
|
||||||
room_id: Optional[str] = None,
|
room_id: Optional[str] = None,
|
||||||
source_kind: Optional[SourceKind] = None,
|
source_kind: Optional[SourceKind] = None,
|
||||||
|
from_datetime: SearchFromDatetimeParam = None,
|
||||||
|
to_datetime: SearchToDatetimeParam = None,
|
||||||
user: Annotated[
|
user: Annotated[
|
||||||
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
Optional[auth.UserInfo], Depends(auth.current_user_optional)
|
||||||
] = None,
|
] = None,
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
"""
|
"""Full-text search across transcript titles and content."""
|
||||||
Full-text search across transcript titles and content.
|
|
||||||
"""
|
|
||||||
if not user and not settings.PUBLIC_MODE:
|
if not user and not settings.PUBLIC_MODE:
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
|
|
||||||
|
if from_datetime and to_datetime and from_datetime > to_datetime:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="'from' must be less than or equal to 'to'"
|
||||||
|
)
|
||||||
|
|
||||||
search_params = SearchParameters(
|
search_params = SearchParameters(
|
||||||
query_text=parse_search_query_param(q),
|
query_text=parse_search_query_param(q),
|
||||||
limit=limit,
|
limit=limit,
|
||||||
@@ -197,9 +213,11 @@ async def transcripts_search(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
source_kind=source_kind,
|
source_kind=source_kind,
|
||||||
|
from_datetime=from_datetime,
|
||||||
|
to_datetime=to_datetime,
|
||||||
)
|
)
|
||||||
|
|
||||||
results, total = await search_controller.search_transcripts(session, search_params)
|
results, total = await search_controller.search_transcripts(search_params)
|
||||||
|
|
||||||
return SearchResponse(
|
return SearchResponse(
|
||||||
results=results,
|
results=results,
|
||||||
@@ -214,11 +232,9 @@ async def transcripts_search(
|
|||||||
async def transcripts_create(
|
async def transcripts_create(
|
||||||
info: CreateTranscript,
|
info: CreateTranscript,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
return await transcripts_controller.add(
|
transcript = await transcripts_controller.add(
|
||||||
session,
|
|
||||||
info.name,
|
info.name,
|
||||||
source_kind=info.source_kind or SourceKind.LIVE,
|
source_kind=info.source_kind or SourceKind.LIVE,
|
||||||
source_language=info.source_language,
|
source_language=info.source_language,
|
||||||
@@ -226,6 +242,14 @@ async def transcripts_create(
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if user_id:
|
||||||
|
await get_ws_manager().send_json(
|
||||||
|
room_id=f"user:{user_id}",
|
||||||
|
message={"event": "TRANSCRIPT_CREATED", "data": {"id": transcript.id}},
|
||||||
|
)
|
||||||
|
|
||||||
|
return transcript
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================
|
# ==============================================================
|
||||||
# Single transcript
|
# Single transcript
|
||||||
@@ -338,11 +362,10 @@ class GetTranscriptTopicWithWordsPerSpeaker(GetTranscriptTopic):
|
|||||||
async def transcript_get(
|
async def transcript_get(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
return await transcripts_controller.get_by_id_for_http(
|
return await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -350,38 +373,36 @@ async def transcript_get(
|
|||||||
async def transcript_update(
|
async def transcript_update(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
info: UpdateTranscript,
|
info: UpdateTranscript,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
if not transcripts_controller.user_can_mutate(transcript, user_id):
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
values = info.dict(exclude_unset=True)
|
values = info.dict(exclude_unset=True)
|
||||||
updated_transcript = await transcripts_controller.update(
|
updated_transcript = await transcripts_controller.update(transcript, values)
|
||||||
session, transcript, values
|
|
||||||
)
|
|
||||||
return updated_transcript
|
return updated_transcript
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
|
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)
|
||||||
async def transcript_delete(
|
async def transcript_delete(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
if not transcripts_controller.user_can_mutate(transcript, user_id):
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
|
|
||||||
if transcript.meeting_id:
|
await transcripts_controller.remove_by_id(transcript.id, user_id=user_id)
|
||||||
meeting = await meetings_controller.get_by_id(session, transcript.meeting_id)
|
await get_ws_manager().send_json(
|
||||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
room_id=f"user:{user_id}",
|
||||||
if room.is_shared:
|
message={"event": "TRANSCRIPT_DELETED", "data": {"id": transcript.id}},
|
||||||
user_id = None
|
)
|
||||||
|
|
||||||
await transcripts_controller.remove_by_id(session, transcript.id, user_id=user_id)
|
|
||||||
return DeletionStatus(status="ok")
|
return DeletionStatus(status="ok")
|
||||||
|
|
||||||
|
|
||||||
@@ -392,11 +413,10 @@ async def transcript_delete(
|
|||||||
async def transcript_get_topics(
|
async def transcript_get_topics(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert to GetTranscriptTopic
|
# convert to GetTranscriptTopic
|
||||||
@@ -412,11 +432,10 @@ async def transcript_get_topics(
|
|||||||
async def transcript_get_topics_with_words(
|
async def transcript_get_topics_with_words(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# convert to GetTranscriptTopicWithWords
|
# convert to GetTranscriptTopicWithWords
|
||||||
@@ -434,11 +453,10 @@ async def transcript_get_topics_with_words_per_speaker(
|
|||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
topic_id: str,
|
topic_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# get the topic from the transcript
|
# get the topic from the transcript
|
||||||
@@ -456,16 +474,16 @@ async def transcript_post_to_zulip(
|
|||||||
stream: str,
|
stream: str,
|
||||||
topic: str,
|
topic: str,
|
||||||
include_topics: bool,
|
include_topics: bool,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
if not transcripts_controller.user_can_mutate(transcript, user_id):
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
content = get_zulip_message(transcript, include_topics)
|
content = get_zulip_message(transcript, include_topics)
|
||||||
|
|
||||||
message_updated = False
|
message_updated = False
|
||||||
@@ -481,5 +499,5 @@ async def transcript_post_to_zulip(
|
|||||||
if not message_updated:
|
if not message_updated:
|
||||||
response = await send_message_to_zulip(stream, topic, content)
|
response = await send_message_to_zulip(stream, topic, content)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session, transcript, {"zulip_message_id": response["id"]}
|
transcript, {"zulip_message_id": response["id"]}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,10 +9,8 @@ from typing import Annotated, Optional
|
|||||||
import httpx
|
import httpx
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import get_session
|
|
||||||
from reflector.db.transcripts import AudioWaveform, transcripts_controller
|
from reflector.db.transcripts import AudioWaveform, transcripts_controller
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.views.transcripts import ALGORITHM
|
from reflector.views.transcripts import ALGORITHM
|
||||||
@@ -34,7 +32,6 @@ async def transcript_get_audio_mp3(
|
|||||||
request: Request,
|
request: Request,
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
token: str | None = None,
|
token: str | None = None,
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
@@ -51,7 +48,7 @@ async def transcript_get_audio_mp3(
|
|||||||
raise unauthorized_exception
|
raise unauthorized_exception
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript.audio_location == "storage":
|
if transcript.audio_location == "storage":
|
||||||
@@ -89,7 +86,7 @@ async def transcript_get_audio_mp3(
|
|||||||
|
|
||||||
return range_requests_response(
|
return range_requests_response(
|
||||||
request,
|
request,
|
||||||
transcript.audio_mp3_filename.as_posix(),
|
transcript.audio_mp3_filename,
|
||||||
content_type="audio/mpeg",
|
content_type="audio/mpeg",
|
||||||
content_disposition=f"attachment; filename={filename}",
|
content_disposition=f"attachment; filename={filename}",
|
||||||
)
|
)
|
||||||
@@ -99,18 +96,13 @@ async def transcript_get_audio_mp3(
|
|||||||
async def transcript_get_audio_waveform(
|
async def transcript_get_audio_waveform(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
) -> AudioWaveform:
|
) -> AudioWaveform:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if not transcript.audio_waveform_filename.exists():
|
if not transcript.audio_waveform_filename.exists():
|
||||||
raise HTTPException(status_code=404, detail="Audio not found")
|
raise HTTPException(status_code=404, detail="Audio not found")
|
||||||
|
|
||||||
audio_waveform = transcript.audio_waveform
|
return transcript.audio_waveform
|
||||||
if not audio_waveform:
|
|
||||||
raise HTTPException(status_code=404, detail="Audio waveform not found")
|
|
||||||
|
|
||||||
return audio_waveform
|
|
||||||
|
|||||||
@@ -8,10 +8,8 @@ from typing import Annotated, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import get_session
|
|
||||||
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
|
from reflector.db.transcripts import TranscriptParticipant, transcripts_controller
|
||||||
from reflector.views.types import DeletionStatus
|
from reflector.views.types import DeletionStatus
|
||||||
|
|
||||||
@@ -39,11 +37,10 @@ class UpdateParticipant(BaseModel):
|
|||||||
async def transcript_get_participants(
|
async def transcript_get_participants(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[Participant]:
|
) -> list[Participant]:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript.participants is None:
|
if transcript.participants is None:
|
||||||
@@ -59,13 +56,14 @@ async def transcript_get_participants(
|
|||||||
async def transcript_add_participant(
|
async def transcript_add_participant(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
participant: CreateParticipant,
|
participant: CreateParticipant,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
) -> Participant:
|
) -> Participant:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
|
|
||||||
# ensure the speaker is unique
|
# ensure the speaker is unique
|
||||||
if participant.speaker is not None and transcript.participants is not None:
|
if participant.speaker is not None and transcript.participants is not None:
|
||||||
@@ -77,7 +75,7 @@ async def transcript_add_participant(
|
|||||||
)
|
)
|
||||||
|
|
||||||
obj = await transcripts_controller.upsert_participant(
|
obj = await transcripts_controller.upsert_participant(
|
||||||
session, transcript, TranscriptParticipant(**participant.dict())
|
transcript, TranscriptParticipant(**participant.dict())
|
||||||
)
|
)
|
||||||
return Participant.model_validate(obj)
|
return Participant.model_validate(obj)
|
||||||
|
|
||||||
@@ -87,11 +85,10 @@ async def transcript_get_participant(
|
|||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
participant_id: str,
|
participant_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
) -> Participant:
|
) -> Participant:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
for p in transcript.participants:
|
for p in transcript.participants:
|
||||||
@@ -106,13 +103,14 @@ async def transcript_update_participant(
|
|||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
participant_id: str,
|
participant_id: str,
|
||||||
participant: UpdateParticipant,
|
participant: UpdateParticipant,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
) -> Participant:
|
) -> Participant:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
|
|
||||||
# ensure the speaker is unique
|
# ensure the speaker is unique
|
||||||
for p in transcript.participants:
|
for p in transcript.participants:
|
||||||
@@ -136,7 +134,7 @@ async def transcript_update_participant(
|
|||||||
fields = participant.dict(exclude_unset=True)
|
fields = participant.dict(exclude_unset=True)
|
||||||
obj = obj.copy(update=fields)
|
obj = obj.copy(update=fields)
|
||||||
|
|
||||||
await transcripts_controller.upsert_participant(session, transcript, obj)
|
await transcripts_controller.upsert_participant(transcript, obj)
|
||||||
return Participant.model_validate(obj)
|
return Participant.model_validate(obj)
|
||||||
|
|
||||||
|
|
||||||
@@ -144,12 +142,13 @@ async def transcript_update_participant(
|
|||||||
async def transcript_delete_participant(
|
async def transcript_delete_participant(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
participant_id: str,
|
participant_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
) -> DeletionStatus:
|
) -> DeletionStatus:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
await transcripts_controller.delete_participant(session, transcript, participant_id)
|
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
|
await transcripts_controller.delete_participant(transcript, participant_id)
|
||||||
return DeletionStatus(status="ok")
|
return DeletionStatus(status="ok")
|
||||||
|
|||||||
@@ -3,10 +3,8 @@ from typing import Annotated, Optional
|
|||||||
import celery
|
import celery
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import get_session
|
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||||
|
|
||||||
@@ -21,11 +19,10 @@ class ProcessStatus(BaseModel):
|
|||||||
async def transcript_process(
|
async def transcript_process(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript.locked:
|
if transcript.locked:
|
||||||
|
|||||||
@@ -8,10 +8,8 @@ from typing import Annotated, Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import get_session
|
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -37,13 +35,14 @@ class SpeakerMerge(BaseModel):
|
|||||||
async def transcript_assign_speaker(
|
async def transcript_assign_speaker(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
assignment: SpeakerAssignment,
|
assignment: SpeakerAssignment,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
) -> SpeakerAssignmentStatus:
|
) -> SpeakerAssignmentStatus:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
|
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
@@ -82,9 +81,7 @@ async def transcript_assign_speaker(
|
|||||||
# if the participant does not have a speaker, create one
|
# if the participant does not have a speaker, create one
|
||||||
if participant.speaker is None:
|
if participant.speaker is None:
|
||||||
participant.speaker = transcript.find_empty_speaker()
|
participant.speaker = transcript.find_empty_speaker()
|
||||||
await transcripts_controller.upsert_participant(
|
await transcripts_controller.upsert_participant(transcript, participant)
|
||||||
session, transcript, participant
|
|
||||||
)
|
|
||||||
|
|
||||||
speaker = participant.speaker
|
speaker = participant.speaker
|
||||||
|
|
||||||
@@ -105,7 +102,6 @@ async def transcript_assign_speaker(
|
|||||||
for topic in changed_topics:
|
for topic in changed_topics:
|
||||||
transcript.upsert_topic(topic)
|
transcript.upsert_topic(topic)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"topics": transcript.topics_dump(),
|
"topics": transcript.topics_dump(),
|
||||||
@@ -119,13 +115,14 @@ async def transcript_assign_speaker(
|
|||||||
async def transcript_merge_speaker(
|
async def transcript_merge_speaker(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
merge: SpeakerMerge,
|
merge: SpeakerMerge,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
) -> SpeakerAssignmentStatus:
|
) -> SpeakerAssignmentStatus:
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"]
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
if transcript.user_id is not None and transcript.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=403, detail="Not authorized")
|
||||||
|
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
@@ -170,7 +167,6 @@ async def transcript_merge_speaker(
|
|||||||
for topic in changed_topics:
|
for topic in changed_topics:
|
||||||
transcript.upsert_topic(topic)
|
transcript.upsert_topic(topic)
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"topics": transcript.topics_dump(),
|
"topics": transcript.topics_dump(),
|
||||||
|
|||||||
@@ -3,10 +3,8 @@ from typing import Annotated, Optional
|
|||||||
import av
|
import av
|
||||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
from fastapi import APIRouter, Depends, HTTPException, UploadFile
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import get_session
|
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
from reflector.pipelines.main_file_pipeline import task_pipeline_file_process
|
||||||
|
|
||||||
@@ -24,11 +22,10 @@ async def transcript_record_upload(
|
|||||||
total_chunks: int,
|
total_chunks: int,
|
||||||
chunk: UploadFile,
|
chunk: UploadFile,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript.locked:
|
if transcript.locked:
|
||||||
@@ -92,7 +89,7 @@ async def transcript_record_upload(
|
|||||||
container.close()
|
container.close()
|
||||||
|
|
||||||
# set the status to "uploaded"
|
# set the status to "uploaded"
|
||||||
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
# launch a background task to process the file
|
# launch a background task to process the file
|
||||||
task_pipeline_file_process.delay(transcript_id=transcript_id)
|
task_pipeline_file_process.delay(transcript_id=transcript_id)
|
||||||
|
|||||||
@@ -1,10 +1,8 @@
|
|||||||
from typing import Annotated, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
import reflector.auth as auth
|
import reflector.auth as auth
|
||||||
from reflector.db import get_session
|
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
|
|
||||||
from .rtc_offer import RtcOffer, rtc_offer_base
|
from .rtc_offer import RtcOffer, rtc_offer_base
|
||||||
@@ -18,11 +16,10 @@ async def transcript_record_webrtc(
|
|||||||
params: RtcOffer,
|
params: RtcOffer,
|
||||||
request: Request,
|
request: Request,
|
||||||
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
||||||
session: AsyncSession = Depends(get_session),
|
|
||||||
):
|
):
|
||||||
user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id_for_http(
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
session, transcript_id, user_id=user_id
|
transcript_id, user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if transcript.locked:
|
if transcript.locked:
|
||||||
|
|||||||
@@ -4,8 +4,11 @@ Transcripts websocket API
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketDisconnect
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, WebSocket, WebSocketDisconnect
|
||||||
|
|
||||||
|
import reflector.auth as auth
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
from reflector.ws_manager import get_ws_manager
|
from reflector.ws_manager import get_ws_manager
|
||||||
|
|
||||||
@@ -21,10 +24,12 @@ async def transcript_get_websocket_events(transcript_id: str):
|
|||||||
async def transcript_events_websocket(
|
async def transcript_events_websocket(
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
websocket: WebSocket,
|
websocket: WebSocket,
|
||||||
# user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
|
user: Optional[auth.UserInfo] = Depends(auth.current_user_optional),
|
||||||
):
|
):
|
||||||
# user_id = user["sub"] if user else None
|
user_id = user["sub"] if user else None
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
transcript = await transcripts_controller.get_by_id_for_http(
|
||||||
|
transcript_id, user_id=user_id
|
||||||
|
)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
raise HTTPException(status_code=404, detail="Transcript not found")
|
raise HTTPException(status_code=404, detail="Transcript not found")
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ router = APIRouter()
|
|||||||
class UserInfo(BaseModel):
|
class UserInfo(BaseModel):
|
||||||
sub: str
|
sub: str
|
||||||
email: Optional[str]
|
email: Optional[str]
|
||||||
email_verified: Optional[bool]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me")
|
@router.get("/me")
|
||||||
|
|||||||
62
server/reflector/views/user_api_keys.py
Normal file
62
server/reflector/views/user_api_keys.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import reflector.auth as auth
|
||||||
|
from reflector.db.user_api_keys import user_api_keys_controller
|
||||||
|
from reflector.utils.string import NonEmptyString
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CreateApiKeyRequest(BaseModel):
|
||||||
|
name: NonEmptyString | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ApiKeyResponse(BaseModel):
|
||||||
|
id: NonEmptyString
|
||||||
|
user_id: NonEmptyString
|
||||||
|
name: NonEmptyString | None
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class CreateApiKeyResponse(ApiKeyResponse):
|
||||||
|
key: NonEmptyString
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/user/api-keys", response_model=CreateApiKeyResponse)
|
||||||
|
async def create_api_key(
|
||||||
|
req: CreateApiKeyRequest,
|
||||||
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
|
):
|
||||||
|
api_key_model, plaintext = await user_api_keys_controller.create_key(
|
||||||
|
user_id=user["sub"],
|
||||||
|
name=req.name,
|
||||||
|
)
|
||||||
|
return CreateApiKeyResponse(
|
||||||
|
**api_key_model.model_dump(),
|
||||||
|
key=plaintext,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/user/api-keys", response_model=list[ApiKeyResponse])
|
||||||
|
async def list_api_keys(
|
||||||
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
|
):
|
||||||
|
api_keys = await user_api_keys_controller.list_by_user_id(user["sub"])
|
||||||
|
return [ApiKeyResponse(**k.model_dump()) for k in api_keys]
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/user/api-keys/{key_id}")
|
||||||
|
async def delete_api_key(
|
||||||
|
key_id: NonEmptyString,
|
||||||
|
user: Annotated[auth.UserInfo, Depends(auth.current_user)],
|
||||||
|
):
|
||||||
|
deleted = await user_api_keys_controller.delete_key(key_id, user["sub"])
|
||||||
|
if not deleted:
|
||||||
|
raise HTTPException(status_code=404)
|
||||||
|
return {"status": "ok"}
|
||||||
53
server/reflector/views/user_websocket.py
Normal file
53
server/reflector/views/user_websocket.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, WebSocket
|
||||||
|
|
||||||
|
from reflector.auth.auth_jwt import JWTAuth # type: ignore
|
||||||
|
from reflector.ws_manager import get_ws_manager
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# Close code for unauthorized WebSocket connections
|
||||||
|
UNAUTHORISED = 4401
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/events")
|
||||||
|
async def user_events_websocket(websocket: WebSocket):
|
||||||
|
# Browser can't send Authorization header for WS; use subprotocol: ["bearer", token]
|
||||||
|
raw_subprotocol = websocket.headers.get("sec-websocket-protocol") or ""
|
||||||
|
parts = [p.strip() for p in raw_subprotocol.split(",") if p.strip()]
|
||||||
|
token: Optional[str] = None
|
||||||
|
negotiated_subprotocol: Optional[str] = None
|
||||||
|
if len(parts) >= 2 and parts[0].lower() == "bearer":
|
||||||
|
negotiated_subprotocol = "bearer"
|
||||||
|
token = parts[1]
|
||||||
|
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
if not token:
|
||||||
|
await websocket.close(code=UNAUTHORISED)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = JWTAuth().verify_token(token)
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
except Exception:
|
||||||
|
await websocket.close(code=UNAUTHORISED)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
await websocket.close(code=UNAUTHORISED)
|
||||||
|
return
|
||||||
|
|
||||||
|
room_id = f"user:{user_id}"
|
||||||
|
ws_manager = get_ws_manager()
|
||||||
|
|
||||||
|
await ws_manager.add_user_to_room(
|
||||||
|
room_id, websocket, subprotocol=negotiated_subprotocol
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
await websocket.receive()
|
||||||
|
finally:
|
||||||
|
if room_id:
|
||||||
|
await ws_manager.remove_user_from_room(room_id, websocket)
|
||||||
@@ -10,16 +10,16 @@ from typing import TypedDict
|
|||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
from databases import Database
|
||||||
from pydantic.types import PositiveInt
|
from pydantic.types import PositiveInt
|
||||||
from sqlalchemy import delete, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.asynctask import asynctask
|
from reflector.asynctask import asynctask
|
||||||
from reflector.db.base import MeetingModel, RecordingModel, TranscriptModel
|
from reflector.db import get_database
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.meetings import meetings
|
||||||
|
from reflector.db.recordings import recordings
|
||||||
|
from reflector.db.transcripts import transcripts, transcripts_controller
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.storage import get_recordings_storage
|
from reflector.storage import get_recordings_storage
|
||||||
from reflector.worker.session_decorator import with_session
|
|
||||||
|
|
||||||
logger = structlog.get_logger(__name__)
|
logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
@@ -34,28 +34,28 @@ class CleanupStats(TypedDict):
|
|||||||
|
|
||||||
|
|
||||||
async def delete_single_transcript(
|
async def delete_single_transcript(
|
||||||
session: AsyncSession, transcript_data: dict, stats: CleanupStats
|
db: Database, transcript_data: dict, stats: CleanupStats
|
||||||
):
|
):
|
||||||
transcript_id = transcript_data["id"]
|
transcript_id = transcript_data["id"]
|
||||||
meeting_id = transcript_data["meeting_id"]
|
meeting_id = transcript_data["meeting_id"]
|
||||||
recording_id = transcript_data["recording_id"]
|
recording_id = transcript_data["recording_id"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
async with db.transaction(isolation="serializable"):
|
||||||
if meeting_id:
|
if meeting_id:
|
||||||
await session.execute(
|
await db.execute(meetings.delete().where(meetings.c.id == meeting_id))
|
||||||
delete(MeetingModel).where(MeetingModel.id == meeting_id)
|
|
||||||
)
|
|
||||||
stats["meetings_deleted"] += 1
|
stats["meetings_deleted"] += 1
|
||||||
logger.info("Deleted associated meeting", meeting_id=meeting_id)
|
logger.info("Deleted associated meeting", meeting_id=meeting_id)
|
||||||
|
|
||||||
if recording_id:
|
if recording_id:
|
||||||
result = await session.execute(
|
recording = await db.fetch_one(
|
||||||
select(RecordingModel).where(RecordingModel.id == recording_id)
|
recordings.select().where(recordings.c.id == recording_id)
|
||||||
)
|
)
|
||||||
recording = result.mappings().first()
|
|
||||||
if recording:
|
if recording:
|
||||||
try:
|
try:
|
||||||
await get_recordings_storage().delete_file(recording["object_key"])
|
await get_recordings_storage().delete_file(
|
||||||
|
recording["object_key"]
|
||||||
|
)
|
||||||
except Exception as storage_error:
|
except Exception as storage_error:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to delete recording from storage",
|
"Failed to delete recording from storage",
|
||||||
@@ -64,13 +64,15 @@ async def delete_single_transcript(
|
|||||||
error=str(storage_error),
|
error=str(storage_error),
|
||||||
)
|
)
|
||||||
|
|
||||||
await session.execute(
|
await db.execute(
|
||||||
delete(RecordingModel).where(RecordingModel.id == recording_id)
|
recordings.delete().where(recordings.c.id == recording_id)
|
||||||
)
|
)
|
||||||
stats["recordings_deleted"] += 1
|
stats["recordings_deleted"] += 1
|
||||||
logger.info("Deleted associated recording", recording_id=recording_id)
|
logger.info(
|
||||||
|
"Deleted associated recording", recording_id=recording_id
|
||||||
|
)
|
||||||
|
|
||||||
await transcripts_controller.remove_by_id(session, transcript_id)
|
await transcripts_controller.remove_by_id(transcript_id)
|
||||||
stats["transcripts_deleted"] += 1
|
stats["transcripts_deleted"] += 1
|
||||||
logger.info(
|
logger.info(
|
||||||
"Deleted transcript",
|
"Deleted transcript",
|
||||||
@@ -84,30 +86,18 @@ async def delete_single_transcript(
|
|||||||
|
|
||||||
|
|
||||||
async def cleanup_old_transcripts(
|
async def cleanup_old_transcripts(
|
||||||
session: AsyncSession, cutoff_date: datetime, stats: CleanupStats
|
db: Database, cutoff_date: datetime, stats: CleanupStats
|
||||||
):
|
):
|
||||||
"""Delete old anonymous transcripts and their associated recordings/meetings."""
|
"""Delete old anonymous transcripts and their associated recordings/meetings."""
|
||||||
query = select(
|
query = transcripts.select().where(
|
||||||
TranscriptModel.id,
|
(transcripts.c.created_at < cutoff_date) & (transcripts.c.user_id.is_(None))
|
||||||
TranscriptModel.meeting_id,
|
|
||||||
TranscriptModel.recording_id,
|
|
||||||
TranscriptModel.created_at,
|
|
||||||
).where(
|
|
||||||
(TranscriptModel.created_at < cutoff_date) & (TranscriptModel.user_id.is_(None))
|
|
||||||
)
|
)
|
||||||
|
old_transcripts = await db.fetch_all(query)
|
||||||
result = await session.execute(query)
|
|
||||||
old_transcripts = result.mappings().all()
|
|
||||||
|
|
||||||
logger.info(f"Found {len(old_transcripts)} old transcripts to delete")
|
logger.info(f"Found {len(old_transcripts)} old transcripts to delete")
|
||||||
|
|
||||||
for transcript_data in old_transcripts:
|
for transcript_data in old_transcripts:
|
||||||
try:
|
await delete_single_transcript(db, transcript_data, stats)
|
||||||
await delete_single_transcript(session, transcript_data, stats)
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = f"Failed to delete transcript {transcript_data['id']}: {str(e)}"
|
|
||||||
logger.error(error_msg, exc_info=e)
|
|
||||||
stats["errors"].append(error_msg)
|
|
||||||
|
|
||||||
|
|
||||||
def log_cleanup_results(stats: CleanupStats):
|
def log_cleanup_results(stats: CleanupStats):
|
||||||
@@ -127,7 +117,6 @@ def log_cleanup_results(stats: CleanupStats):
|
|||||||
|
|
||||||
|
|
||||||
async def cleanup_old_public_data(
|
async def cleanup_old_public_data(
|
||||||
session: AsyncSession,
|
|
||||||
days: PositiveInt | None = None,
|
days: PositiveInt | None = None,
|
||||||
) -> CleanupStats | None:
|
) -> CleanupStats | None:
|
||||||
if days is None:
|
if days is None:
|
||||||
@@ -150,7 +139,8 @@ async def cleanup_old_public_data(
|
|||||||
"errors": [],
|
"errors": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
await cleanup_old_transcripts(session, cutoff_date, stats)
|
db = get_database()
|
||||||
|
await cleanup_old_transcripts(db, cutoff_date, stats)
|
||||||
|
|
||||||
log_cleanup_results(stats)
|
log_cleanup_results(stats)
|
||||||
return stats
|
return stats
|
||||||
@@ -161,6 +151,5 @@ async def cleanup_old_public_data(
|
|||||||
retry_kwargs={"max_retries": 3, "countdown": 300},
|
retry_kwargs={"max_retries": 3, "countdown": 300},
|
||||||
)
|
)
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session
|
async def cleanup_old_public_data_task(days: int | None = None):
|
||||||
async def cleanup_old_public_data_task(session: AsyncSession, days: int | None = None):
|
await cleanup_old_public_data(days=days)
|
||||||
await cleanup_old_public_data(session, days=days)
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from datetime import datetime, timedelta, timezone
|
|||||||
import structlog
|
import structlog
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from celery.utils.log import get_task_logger
|
from celery.utils.log import get_task_logger
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.asynctask import asynctask
|
from reflector.asynctask import asynctask
|
||||||
from reflector.db.calendar_events import calendar_events_controller
|
from reflector.db.calendar_events import calendar_events_controller
|
||||||
@@ -12,17 +11,15 @@ from reflector.db.rooms import rooms_controller
|
|||||||
from reflector.redis_cache import RedisAsyncLock
|
from reflector.redis_cache import RedisAsyncLock
|
||||||
from reflector.services.ics_sync import SyncStatus, ics_sync_service
|
from reflector.services.ics_sync import SyncStatus, ics_sync_service
|
||||||
from reflector.whereby import create_meeting, upload_logo
|
from reflector.whereby import create_meeting, upload_logo
|
||||||
from reflector.worker.session_decorator import with_session
|
|
||||||
|
|
||||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session
|
async def sync_room_ics(room_id: str):
|
||||||
async def sync_room_ics(session: AsyncSession, room_id: str):
|
|
||||||
try:
|
try:
|
||||||
room = await rooms_controller.get_by_id(session, room_id)
|
room = await rooms_controller.get_by_id(room_id)
|
||||||
if not room:
|
if not room:
|
||||||
logger.warning("Room not found for ICS sync", room_id=room_id)
|
logger.warning("Room not found for ICS sync", room_id=room_id)
|
||||||
return
|
return
|
||||||
@@ -32,7 +29,7 @@ async def sync_room_ics(session: AsyncSession, room_id: str):
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Starting ICS sync for room", room_id=room_id, room_name=room.name)
|
logger.info("Starting ICS sync for room", room_id=room_id, room_name=room.name)
|
||||||
result = await ics_sync_service.sync_room_calendar(session, room)
|
result = await ics_sync_service.sync_room_calendar(room)
|
||||||
|
|
||||||
if result["status"] == SyncStatus.SUCCESS:
|
if result["status"] == SyncStatus.SUCCESS:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -58,12 +55,11 @@ async def sync_room_ics(session: AsyncSession, room_id: str):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session
|
async def sync_all_ics_calendars():
|
||||||
async def sync_all_ics_calendars(session: AsyncSession):
|
|
||||||
try:
|
try:
|
||||||
logger.info("Starting sync for all ICS-enabled rooms")
|
logger.info("Starting sync for all ICS-enabled rooms")
|
||||||
|
|
||||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
ics_enabled_rooms = await rooms_controller.get_ics_enabled()
|
||||||
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
|
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
|
||||||
|
|
||||||
for room in ics_enabled_rooms:
|
for room in ics_enabled_rooms:
|
||||||
@@ -90,14 +86,10 @@ def _should_sync(room) -> bool:
|
|||||||
MEETING_DEFAULT_DURATION = timedelta(hours=1)
|
MEETING_DEFAULT_DURATION = timedelta(hours=1)
|
||||||
|
|
||||||
|
|
||||||
async def create_upcoming_meetings_for_event(
|
async def create_upcoming_meetings_for_event(event, create_window, room_id, room):
|
||||||
session: AsyncSession, event, create_window, room_id, room
|
|
||||||
):
|
|
||||||
if event.start_time <= create_window:
|
if event.start_time <= create_window:
|
||||||
return
|
return
|
||||||
existing_meeting = await meetings_controller.get_by_calendar_event(
|
existing_meeting = await meetings_controller.get_by_calendar_event(event.id)
|
||||||
session, event.id
|
|
||||||
)
|
|
||||||
|
|
||||||
if existing_meeting:
|
if existing_meeting:
|
||||||
return
|
return
|
||||||
@@ -120,7 +112,6 @@ async def create_upcoming_meetings_for_event(
|
|||||||
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
await upload_logo(whereby_meeting["roomName"], "./images/logo.png")
|
||||||
|
|
||||||
meeting = await meetings_controller.create(
|
meeting = await meetings_controller.create(
|
||||||
session,
|
|
||||||
id=whereby_meeting["meetingId"],
|
id=whereby_meeting["meetingId"],
|
||||||
room_name=whereby_meeting["roomName"],
|
room_name=whereby_meeting["roomName"],
|
||||||
room_url=whereby_meeting["roomUrl"],
|
room_url=whereby_meeting["roomUrl"],
|
||||||
@@ -153,8 +144,7 @@ async def create_upcoming_meetings_for_event(
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session
|
async def create_upcoming_meetings():
|
||||||
async def create_upcoming_meetings(session: AsyncSession):
|
|
||||||
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
|
async with RedisAsyncLock("create_upcoming_meetings", skip_if_locked=True) as lock:
|
||||||
if not lock.acquired:
|
if not lock.acquired:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -165,20 +155,19 @@ async def create_upcoming_meetings(session: AsyncSession):
|
|||||||
try:
|
try:
|
||||||
logger.info("Starting creation of upcoming meetings")
|
logger.info("Starting creation of upcoming meetings")
|
||||||
|
|
||||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(session)
|
ics_enabled_rooms = await rooms_controller.get_ics_enabled()
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
create_window = now - timedelta(minutes=6)
|
create_window = now - timedelta(minutes=6)
|
||||||
|
|
||||||
for room in ics_enabled_rooms:
|
for room in ics_enabled_rooms:
|
||||||
events = await calendar_events_controller.get_upcoming(
|
events = await calendar_events_controller.get_upcoming(
|
||||||
session,
|
|
||||||
room.id,
|
room.id,
|
||||||
minutes_ahead=7,
|
minutes_ahead=7,
|
||||||
)
|
)
|
||||||
|
|
||||||
for event in events:
|
for event in events:
|
||||||
await create_upcoming_meetings_for_event(
|
await create_upcoming_meetings_for_event(
|
||||||
session, event, create_window, room.id, room
|
event, create_window, room.id, room
|
||||||
)
|
)
|
||||||
logger.info("Completed pre-creation check for upcoming meetings")
|
logger.info("Completed pre-creation check for upcoming meetings")
|
||||||
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from celery import shared_task
|
|||||||
from celery.utils.log import get_task_logger
|
from celery.utils.log import get_task_logger
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from redis.exceptions import LockError
|
from redis.exceptions import LockError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from reflector.db.meetings import meetings_controller
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.recordings import Recording, recordings_controller
|
from reflector.db.recordings import Recording, recordings_controller
|
||||||
@@ -21,7 +20,6 @@ from reflector.pipelines.main_live_pipeline import asynctask
|
|||||||
from reflector.redis_cache import get_redis_client
|
from reflector.redis_cache import get_redis_client
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.whereby import get_room_sessions
|
from reflector.whereby import get_room_sessions
|
||||||
from reflector.worker.session_decorator import with_session
|
|
||||||
|
|
||||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||||
|
|
||||||
@@ -77,39 +75,30 @@ def process_messages():
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session
|
async def process_recording(bucket_name: str, object_key: str):
|
||||||
async def process_recording(session: AsyncSession, bucket_name: str, object_key: str):
|
|
||||||
logger.info("Processing recording: %s/%s", bucket_name, object_key)
|
logger.info("Processing recording: %s/%s", bucket_name, object_key)
|
||||||
|
|
||||||
# extract a guid and a datetime from the object key
|
# extract a guid and a datetime from the object key
|
||||||
room_name = f"/{object_key[:36]}"
|
room_name = f"/{object_key[:36]}"
|
||||||
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
recorded_at = parse_datetime_with_timezone(object_key[37:57])
|
||||||
|
|
||||||
meeting = await meetings_controller.get_by_room_name(session, room_name)
|
meeting = await meetings_controller.get_by_room_name(room_name)
|
||||||
if not meeting:
|
room = await rooms_controller.get_by_id(meeting.room_id)
|
||||||
logger.warning("Room not found, may be deleted ?", room_name=room_name)
|
|
||||||
return
|
|
||||||
|
|
||||||
room = await rooms_controller.get_by_id(session, meeting.room_id)
|
recording = await recordings_controller.get_by_object_key(bucket_name, object_key)
|
||||||
|
|
||||||
recording = await recordings_controller.get_by_object_key(
|
|
||||||
session, bucket_name, object_key
|
|
||||||
)
|
|
||||||
if not recording:
|
if not recording:
|
||||||
recording = await recordings_controller.create(
|
recording = await recordings_controller.create(
|
||||||
session,
|
|
||||||
Recording(
|
Recording(
|
||||||
bucket_name=bucket_name,
|
bucket_name=bucket_name,
|
||||||
object_key=object_key,
|
object_key=object_key,
|
||||||
recorded_at=recorded_at,
|
recorded_at=recorded_at,
|
||||||
meeting_id=meeting.id,
|
meeting_id=meeting.id,
|
||||||
),
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_recording_id(session, recording.id)
|
transcript = await transcripts_controller.get_by_recording_id(recording.id)
|
||||||
if transcript:
|
if transcript:
|
||||||
await transcripts_controller.update(
|
await transcripts_controller.update(
|
||||||
session,
|
|
||||||
transcript,
|
transcript,
|
||||||
{
|
{
|
||||||
"topics": [],
|
"topics": [],
|
||||||
@@ -117,7 +106,6 @@ async def process_recording(session: AsyncSession, bucket_name: str, object_key:
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await transcripts_controller.add(
|
||||||
session,
|
|
||||||
"",
|
"",
|
||||||
source_kind=SourceKind.ROOM,
|
source_kind=SourceKind.ROOM,
|
||||||
source_language="en",
|
source_language="en",
|
||||||
@@ -153,15 +141,14 @@ async def process_recording(session: AsyncSession, bucket_name: str, object_key:
|
|||||||
finally:
|
finally:
|
||||||
container.close()
|
container.close()
|
||||||
|
|
||||||
await transcripts_controller.update(session, transcript, {"status": "uploaded"})
|
await transcripts_controller.update(transcript, {"status": "uploaded"})
|
||||||
|
|
||||||
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
task_pipeline_file_process.delay(transcript_id=transcript.id)
|
||||||
|
|
||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session
|
async def process_meetings():
|
||||||
async def process_meetings(session: AsyncSession):
|
|
||||||
"""
|
"""
|
||||||
Checks which meetings are still active and deactivates those that have ended.
|
Checks which meetings are still active and deactivates those that have ended.
|
||||||
|
|
||||||
@@ -178,7 +165,7 @@ async def process_meetings(session: AsyncSession):
|
|||||||
process the same meeting simultaneously.
|
process the same meeting simultaneously.
|
||||||
"""
|
"""
|
||||||
logger.info("Processing meetings")
|
logger.info("Processing meetings")
|
||||||
meetings = await meetings_controller.get_all_active(session)
|
meetings = await meetings_controller.get_all_active()
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
redis_client = get_redis_client()
|
redis_client = get_redis_client()
|
||||||
processed_count = 0
|
processed_count = 0
|
||||||
@@ -231,9 +218,7 @@ async def process_meetings(session: AsyncSession):
|
|||||||
logger_.debug("Meeting not yet started, keep it")
|
logger_.debug("Meeting not yet started, keep it")
|
||||||
|
|
||||||
if should_deactivate:
|
if should_deactivate:
|
||||||
await meetings_controller.update_meeting(
|
await meetings_controller.update_meeting(meeting.id, is_active=False)
|
||||||
session, meeting.id, is_active=False
|
|
||||||
)
|
|
||||||
logger_.info("Meeting is deactivated")
|
logger_.info("Meeting is deactivated")
|
||||||
|
|
||||||
processed_count += 1
|
processed_count += 1
|
||||||
@@ -255,8 +240,7 @@ async def process_meetings(session: AsyncSession):
|
|||||||
|
|
||||||
@shared_task
|
@shared_task
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session
|
async def reprocess_failed_recordings():
|
||||||
async def reprocess_failed_recordings(session: AsyncSession):
|
|
||||||
"""
|
"""
|
||||||
Find recordings in the S3 bucket and check if they have proper transcriptions.
|
Find recordings in the S3 bucket and check if they have proper transcriptions.
|
||||||
If not, requeue them for processing.
|
If not, requeue them for processing.
|
||||||
@@ -287,7 +271,7 @@ async def reprocess_failed_recordings(session: AsyncSession):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
recording = await recordings_controller.get_by_object_key(
|
recording = await recordings_controller.get_by_object_key(
|
||||||
session, bucket_name, object_key
|
bucket_name, object_key
|
||||||
)
|
)
|
||||||
if not recording:
|
if not recording:
|
||||||
logger.info(f"Queueing recording for processing: {object_key}")
|
logger.info(f"Queueing recording for processing: {object_key}")
|
||||||
@@ -298,12 +282,10 @@ async def reprocess_failed_recordings(session: AsyncSession):
|
|||||||
transcript = None
|
transcript = None
|
||||||
try:
|
try:
|
||||||
transcript = await transcripts_controller.get_by_recording_id(
|
transcript = await transcripts_controller.get_by_recording_id(
|
||||||
session, recording.id
|
recording.id
|
||||||
)
|
)
|
||||||
except ValidationError:
|
except ValidationError:
|
||||||
await transcripts_controller.remove_by_recording_id(
|
await transcripts_controller.remove_by_recording_id(recording.id)
|
||||||
session, recording.id
|
|
||||||
)
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Removed invalid transcript for recording: {recording.id}"
|
f"Removed invalid transcript for recording: {recording.id}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,109 +0,0 @@
|
|||||||
"""
|
|
||||||
Session management decorator for async worker tasks.
|
|
||||||
|
|
||||||
This decorator ensures that all worker tasks have a properly managed database session
|
|
||||||
that stays open for the entire duration of the task execution.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import functools
|
|
||||||
from typing import Any, Callable, TypeVar
|
|
||||||
|
|
||||||
from celery import current_task
|
|
||||||
|
|
||||||
from reflector.db import get_session_factory
|
|
||||||
from reflector.db.transcripts import transcripts_controller
|
|
||||||
from reflector.logger import logger
|
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable[..., Any])
|
|
||||||
|
|
||||||
|
|
||||||
def with_session(func: F) -> F:
|
|
||||||
"""
|
|
||||||
Decorator that provides an AsyncSession as the first argument to the decorated function.
|
|
||||||
|
|
||||||
This should be used AFTER the @asynctask decorator on Celery tasks to ensure
|
|
||||||
proper session management throughout the task execution.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@shared_task
|
|
||||||
@asynctask
|
|
||||||
@with_session
|
|
||||||
async def my_task(session: AsyncSession, arg1: str, arg2: int):
|
|
||||||
# session is automatically provided and managed
|
|
||||||
result = await some_controller.get_by_id(session, arg1)
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
session_factory = get_session_factory()
|
|
||||||
async with session_factory() as session:
|
|
||||||
async with session.begin():
|
|
||||||
# Pass session as first argument to the decorated function
|
|
||||||
return await func(session, *args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def with_session_and_transcript(func: F) -> F:
|
|
||||||
"""
|
|
||||||
Decorator that provides both an AsyncSession and a Transcript to the decorated function.
|
|
||||||
|
|
||||||
This decorator:
|
|
||||||
1. Extracts transcript_id from kwargs
|
|
||||||
2. Creates and manages a database session
|
|
||||||
3. Fetches the transcript using the session
|
|
||||||
4. Creates an enhanced logger with Celery task context
|
|
||||||
5. Passes session, transcript, and logger to the decorated function
|
|
||||||
|
|
||||||
This should be used AFTER the @asynctask decorator on Celery tasks.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
@shared_task
|
|
||||||
@asynctask
|
|
||||||
@with_session_and_transcript
|
|
||||||
async def my_task(session: AsyncSession, transcript: Transcript, logger: Logger, arg1: str):
|
|
||||||
# session, transcript, and logger are automatically provided
|
|
||||||
room = await rooms_controller.get_by_id(session, transcript.room_id)
|
|
||||||
...
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(*args, **kwargs):
|
|
||||||
transcript_id = kwargs.pop("transcript_id", None)
|
|
||||||
if not transcript_id:
|
|
||||||
raise ValueError(
|
|
||||||
"transcript_id is required for @with_session_and_transcript"
|
|
||||||
)
|
|
||||||
|
|
||||||
session_factory = get_session_factory()
|
|
||||||
async with session_factory() as session:
|
|
||||||
async with session.begin():
|
|
||||||
# Fetch the transcript
|
|
||||||
transcript = await transcripts_controller.get_by_id(
|
|
||||||
session, transcript_id
|
|
||||||
)
|
|
||||||
if not transcript:
|
|
||||||
raise Exception(f"Transcript {transcript_id} not found")
|
|
||||||
|
|
||||||
# Create enhanced logger with Celery task context
|
|
||||||
tlogger = logger.bind(transcript_id=transcript.id)
|
|
||||||
if current_task:
|
|
||||||
tlogger = tlogger.bind(
|
|
||||||
task_id=current_task.request.id,
|
|
||||||
task_name=current_task.name,
|
|
||||||
worker_hostname=current_task.request.hostname,
|
|
||||||
task_retries=current_task.request.retries,
|
|
||||||
transcript_id=transcript_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Pass session, transcript, and logger to the decorated function
|
|
||||||
return await func(
|
|
||||||
session, transcript=transcript, logger=tlogger, *args, **kwargs
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
tlogger.exception("Error in task execution")
|
|
||||||
raise
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
@@ -10,14 +10,14 @@ import httpx
|
|||||||
import structlog
|
import structlog
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
from celery.utils.log import get_task_logger
|
from celery.utils.log import get_task_logger
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
|
from reflector.db.calendar_events import calendar_events_controller
|
||||||
|
from reflector.db.meetings import meetings_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms_controller
|
||||||
from reflector.db.transcripts import transcripts_controller
|
from reflector.db.transcripts import transcripts_controller
|
||||||
from reflector.pipelines.main_live_pipeline import asynctask
|
from reflector.pipelines.main_live_pipeline import asynctask
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.utils.webvtt import topics_to_webvtt
|
from reflector.utils.webvtt import topics_to_webvtt
|
||||||
from reflector.worker.session_decorator import with_session
|
|
||||||
|
|
||||||
logger = structlog.wrap_logger(get_task_logger(__name__))
|
logger = structlog.wrap_logger(get_task_logger(__name__))
|
||||||
|
|
||||||
@@ -41,13 +41,11 @@ def generate_webhook_signature(payload: bytes, secret: str, timestamp: str) -> s
|
|||||||
retry_backoff_max=3600, # Max 1 hour between retries
|
retry_backoff_max=3600, # Max 1 hour between retries
|
||||||
)
|
)
|
||||||
@asynctask
|
@asynctask
|
||||||
@with_session
|
|
||||||
async def send_transcript_webhook(
|
async def send_transcript_webhook(
|
||||||
self,
|
self,
|
||||||
transcript_id: str,
|
transcript_id: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
event_id: str,
|
event_id: str,
|
||||||
session: AsyncSession,
|
|
||||||
):
|
):
|
||||||
log = logger.bind(
|
log = logger.bind(
|
||||||
transcript_id=transcript_id,
|
transcript_id=transcript_id,
|
||||||
@@ -57,12 +55,12 @@ async def send_transcript_webhook(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Fetch transcript and room
|
# Fetch transcript and room
|
||||||
transcript = await transcripts_controller.get_by_id(session, transcript_id)
|
transcript = await transcripts_controller.get_by_id(transcript_id)
|
||||||
if not transcript:
|
if not transcript:
|
||||||
log.error("Transcript not found, skipping webhook")
|
log.error("Transcript not found, skipping webhook")
|
||||||
return
|
return
|
||||||
|
|
||||||
room = await rooms_controller.get_by_id(session, room_id)
|
room = await rooms_controller.get_by_id(room_id)
|
||||||
if not room:
|
if not room:
|
||||||
log.error("Room not found, skipping webhook")
|
log.error("Room not found, skipping webhook")
|
||||||
return
|
return
|
||||||
@@ -88,6 +86,18 @@ async def send_transcript_webhook(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Fetch meeting and calendar event if they exist
|
||||||
|
calendar_event = None
|
||||||
|
try:
|
||||||
|
if transcript.meeting_id:
|
||||||
|
meeting = await meetings_controller.get_by_id(transcript.meeting_id)
|
||||||
|
if meeting and meeting.calendar_event_id:
|
||||||
|
calendar_event = await calendar_events_controller.get_by_id(
|
||||||
|
meeting.calendar_event_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error fetching meeting or calendar event", error=str(e))
|
||||||
|
|
||||||
# Build webhook payload
|
# Build webhook payload
|
||||||
frontend_url = f"{settings.UI_BASE_URL}/transcripts/{transcript.id}"
|
frontend_url = f"{settings.UI_BASE_URL}/transcripts/{transcript.id}"
|
||||||
participants = [
|
participants = [
|
||||||
@@ -120,6 +130,33 @@ async def send_transcript_webhook(
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Always include calendar_event field, even if no event is present
|
||||||
|
payload_data["calendar_event"] = {}
|
||||||
|
|
||||||
|
# Add calendar event data if present
|
||||||
|
if calendar_event:
|
||||||
|
calendar_data = {
|
||||||
|
"id": calendar_event.id,
|
||||||
|
"ics_uid": calendar_event.ics_uid,
|
||||||
|
"title": calendar_event.title,
|
||||||
|
"start_time": calendar_event.start_time.isoformat()
|
||||||
|
if calendar_event.start_time
|
||||||
|
else None,
|
||||||
|
"end_time": calendar_event.end_time.isoformat()
|
||||||
|
if calendar_event.end_time
|
||||||
|
else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional fields only if they exist
|
||||||
|
if calendar_event.description:
|
||||||
|
calendar_data["description"] = calendar_event.description
|
||||||
|
if calendar_event.location:
|
||||||
|
calendar_data["location"] = calendar_event.location
|
||||||
|
if calendar_event.attendees:
|
||||||
|
calendar_data["attendees"] = calendar_event.attendees
|
||||||
|
|
||||||
|
payload_data["calendar_event"] = calendar_data
|
||||||
|
|
||||||
# Convert to JSON
|
# Convert to JSON
|
||||||
payload_json = json.dumps(payload_data, separators=(",", ":"))
|
payload_json = json.dumps(payload_data, separators=(",", ":"))
|
||||||
payload_bytes = payload_json.encode("utf-8")
|
payload_bytes = payload_json.encode("utf-8")
|
||||||
|
|||||||
@@ -65,7 +65,12 @@ class WebsocketManager:
|
|||||||
self.tasks: dict = {}
|
self.tasks: dict = {}
|
||||||
self.pubsub_client = pubsub_client
|
self.pubsub_client = pubsub_client
|
||||||
|
|
||||||
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
|
async def add_user_to_room(
|
||||||
|
self, room_id: str, websocket: WebSocket, subprotocol: str | None = None
|
||||||
|
) -> None:
|
||||||
|
if subprotocol:
|
||||||
|
await websocket.accept(subprotocol=subprotocol)
|
||||||
|
else:
|
||||||
await websocket.accept()
|
await websocket.accept()
|
||||||
|
|
||||||
if room_id in self.rooms:
|
if room_id in self.rooms:
|
||||||
|
|||||||
@@ -1,22 +1,11 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import sys
|
from contextlib import asynccontextmanager
|
||||||
from tempfile import NamedTemporaryFile
|
from tempfile import NamedTemporaryFile
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def event_loop():
|
|
||||||
if sys.platform.startswith("win") and sys.version_info[:2] >= (3, 8):
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
||||||
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
yield loop
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def settings_configuration():
|
def settings_configuration():
|
||||||
# theses settings are linked to monadical for pytest-recording
|
# theses settings are linked to monadical for pytest-recording
|
||||||
@@ -47,6 +36,7 @@ def docker_compose_file(pytestconfig):
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def postgres_service(docker_ip, docker_services):
|
def postgres_service(docker_ip, docker_services):
|
||||||
|
"""Ensure that PostgreSQL service is up and responsive."""
|
||||||
port = docker_services.port_for("postgres_test", 5432)
|
port = docker_services.port_for("postgres_test", 5432)
|
||||||
|
|
||||||
def is_responsive():
|
def is_responsive():
|
||||||
@@ -67,6 +57,7 @@ def postgres_service(docker_ip, docker_services):
|
|||||||
|
|
||||||
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
|
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
|
||||||
|
|
||||||
|
# Return connection parameters
|
||||||
return {
|
return {
|
||||||
"host": docker_ip,
|
"host": docker_ip,
|
||||||
"port": port,
|
"port": port,
|
||||||
@@ -76,27 +67,20 @@ def postgres_service(docker_ip, docker_services):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def _database_url(postgres_service):
|
@pytest.mark.asyncio
|
||||||
db_config = postgres_service
|
async def setup_database(postgres_service):
|
||||||
DATABASE_URL = (
|
from reflector.db import engine, metadata, get_database # noqa
|
||||||
f"postgresql+asyncpg://{db_config['user']}:{db_config['password']}"
|
|
||||||
f"@{db_config['host']}:{db_config['port']}/{db_config['dbname']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Override settings
|
metadata.drop_all(bind=engine)
|
||||||
from reflector.settings import settings
|
metadata.create_all(bind=engine)
|
||||||
|
database = get_database()
|
||||||
|
|
||||||
settings.DATABASE_URL = DATABASE_URL
|
try:
|
||||||
|
await database.connect()
|
||||||
return DATABASE_URL
|
yield
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def init_database():
|
|
||||||
from reflector.db import Base
|
|
||||||
|
|
||||||
return Base.metadata.create_all
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -344,17 +328,8 @@ def celery_includes():
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
async def ensure_db_session_in_app(db_session):
|
|
||||||
async def mock_get_session():
|
|
||||||
yield db_session
|
|
||||||
|
|
||||||
with patch("reflector.db._get_session", side_effect=mock_get_session):
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def client(db_session):
|
async def client():
|
||||||
from httpx import AsyncClient
|
from httpx import AsyncClient
|
||||||
|
|
||||||
from reflector.app import app
|
from reflector.app import app
|
||||||
@@ -363,6 +338,166 @@ async def client(db_session):
|
|||||||
yield ac
|
yield ac
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
async def ws_manager_in_memory(monkeypatch):
|
||||||
|
"""Replace Redis-based WS manager with an in-memory implementation for tests."""
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from reflector.ws_manager import WebsocketManager
|
||||||
|
|
||||||
|
class _InMemorySubscriber:
|
||||||
|
def __init__(self, queue: asyncio.Queue):
|
||||||
|
self.queue = queue
|
||||||
|
|
||||||
|
async def get_message(self, ignore_subscribe_messages: bool = True):
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(self.queue.get(), timeout=0.05)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
class InMemoryPubSubManager:
|
||||||
|
def __init__(self):
|
||||||
|
self.queues: dict[str, asyncio.Queue] = {}
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
self.connected = True
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
self.connected = False
|
||||||
|
|
||||||
|
async def send_json(self, room_id: str, message: dict) -> None:
|
||||||
|
if room_id not in self.queues:
|
||||||
|
self.queues[room_id] = asyncio.Queue()
|
||||||
|
payload = json.dumps(message).encode("utf-8")
|
||||||
|
await self.queues[room_id].put(
|
||||||
|
{"channel": room_id.encode("utf-8"), "data": payload}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def subscribe(self, room_id: str):
|
||||||
|
if room_id not in self.queues:
|
||||||
|
self.queues[room_id] = asyncio.Queue()
|
||||||
|
return _InMemorySubscriber(self.queues[room_id])
|
||||||
|
|
||||||
|
async def unsubscribe(self, room_id: str) -> None:
|
||||||
|
# keep queue for potential later resubscribe within same test
|
||||||
|
pass
|
||||||
|
|
||||||
|
pubsub = InMemoryPubSubManager()
|
||||||
|
ws_manager = WebsocketManager(pubsub_client=pubsub)
|
||||||
|
|
||||||
|
def _get_ws_manager():
|
||||||
|
return ws_manager
|
||||||
|
|
||||||
|
# Patch all places that imported get_ws_manager at import time
|
||||||
|
monkeypatch.setattr("reflector.ws_manager.get_ws_manager", _get_ws_manager)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"reflector.pipelines.main_live_pipeline.get_ws_manager", _get_ws_manager
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"reflector.views.transcripts_websocket.get_ws_manager", _get_ws_manager
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"reflector.views.user_websocket.get_ws_manager", _get_ws_manager
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("reflector.views.transcripts.get_ws_manager", _get_ws_manager)
|
||||||
|
|
||||||
|
# Websocket auth: avoid OAuth2 on websocket dependencies; allow anonymous
|
||||||
|
import reflector.auth as auth
|
||||||
|
|
||||||
|
# Ensure FastAPI uses our override for routes that captured the original callable
|
||||||
|
from reflector.app import app as fastapi_app
|
||||||
|
|
||||||
|
try:
|
||||||
|
fastapi_app.dependency_overrides[auth.current_user_optional] = lambda: None
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Stub Redis cache used by profanity filter to avoid external Redis
|
||||||
|
from reflector import redis_cache as rc
|
||||||
|
|
||||||
|
class _FakeRedis:
|
||||||
|
def __init__(self):
|
||||||
|
self._data = {}
|
||||||
|
|
||||||
|
def get(self, key):
|
||||||
|
value = self._data.get(key)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
return value
|
||||||
|
return str(value).encode("utf-8")
|
||||||
|
|
||||||
|
def setex(self, key, duration, value):
|
||||||
|
# ignore duration for tests
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
self._data[key] = value
|
||||||
|
else:
|
||||||
|
self._data[key] = str(value).encode("utf-8")
|
||||||
|
|
||||||
|
fake_redises: dict[int, _FakeRedis] = {}
|
||||||
|
|
||||||
|
def _get_redis_client(db=0):
|
||||||
|
if db not in fake_redises:
|
||||||
|
fake_redises[db] = _FakeRedis()
|
||||||
|
return fake_redises[db]
|
||||||
|
|
||||||
|
monkeypatch.setattr(rc, "get_redis_client", _get_redis_client)
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def authenticated_client():
|
||||||
|
async with authenticated_client_ctx():
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def authenticated_client2():
|
||||||
|
async with authenticated_client2_ctx():
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def authenticated_client_ctx():
|
||||||
|
from reflector.app import app
|
||||||
|
from reflector.auth import current_user, current_user_optional
|
||||||
|
|
||||||
|
app.dependency_overrides[current_user] = lambda: {
|
||||||
|
"sub": "randomuserid",
|
||||||
|
"email": "test@mail.com",
|
||||||
|
}
|
||||||
|
app.dependency_overrides[current_user_optional] = lambda: {
|
||||||
|
"sub": "randomuserid",
|
||||||
|
"email": "test@mail.com",
|
||||||
|
}
|
||||||
|
yield
|
||||||
|
del app.dependency_overrides[current_user]
|
||||||
|
del app.dependency_overrides[current_user_optional]
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def authenticated_client2_ctx():
|
||||||
|
from reflector.app import app
|
||||||
|
from reflector.auth import current_user, current_user_optional
|
||||||
|
|
||||||
|
app.dependency_overrides[current_user] = lambda: {
|
||||||
|
"sub": "randomuserid2",
|
||||||
|
"email": "test@mail.com",
|
||||||
|
}
|
||||||
|
app.dependency_overrides[current_user_optional] = lambda: {
|
||||||
|
"sub": "randomuserid2",
|
||||||
|
"email": "test@mail.com",
|
||||||
|
}
|
||||||
|
yield
|
||||||
|
del app.dependency_overrides[current_user]
|
||||||
|
del app.dependency_overrides[current_user_optional]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def fake_mp3_upload():
|
def fake_mp3_upload():
|
||||||
with patch(
|
with patch(
|
||||||
@@ -373,7 +508,7 @@ def fake_mp3_upload():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def fake_transcript_with_topics(tmpdir, client, db_session):
|
async def fake_transcript_with_topics(tmpdir, client):
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -389,10 +524,10 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
tid = response.json()["id"]
|
tid = response.json()["id"]
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(db_session, tid)
|
transcript = await transcripts_controller.get_by_id(tid)
|
||||||
assert transcript is not None
|
assert transcript is not None
|
||||||
|
|
||||||
await transcripts_controller.update(db_session, transcript, {"status": "ended"})
|
await transcripts_controller.update(transcript, {"status": "ended"})
|
||||||
|
|
||||||
# manually copy a file at the expected location
|
# manually copy a file at the expected location
|
||||||
audio_filename = transcript.audio_mp3_filename
|
audio_filename = transcript.audio_mp3_filename
|
||||||
@@ -402,7 +537,6 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
|
|||||||
|
|
||||||
# create some topics
|
# create some topics
|
||||||
await transcripts_controller.upsert_topic(
|
await transcripts_controller.upsert_topic(
|
||||||
db_session,
|
|
||||||
transcript,
|
transcript,
|
||||||
TranscriptTopic(
|
TranscriptTopic(
|
||||||
title="Topic 1",
|
title="Topic 1",
|
||||||
@@ -416,7 +550,6 @@ async def fake_transcript_with_topics(tmpdir, client, db_session):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
await transcripts_controller.upsert_topic(
|
await transcripts_controller.upsert_topic(
|
||||||
db_session,
|
|
||||||
transcript,
|
transcript,
|
||||||
TranscriptTopic(
|
TranscriptTopic(
|
||||||
title="Topic 2",
|
title="Topic 2",
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@@ -8,7 +8,7 @@ from reflector.services.ics_sync import ICSSyncService
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_attendee_parsing_bug(db_session):
|
async def test_attendee_parsing_bug():
|
||||||
"""
|
"""
|
||||||
Test that reproduces the attendee parsing bug where a string with comma-separated
|
Test that reproduces the attendee parsing bug where a string with comma-separated
|
||||||
emails gets parsed as individual characters instead of separate email addresses.
|
emails gets parsed as individual characters instead of separate email addresses.
|
||||||
@@ -16,8 +16,8 @@ async def test_attendee_parsing_bug(db_session):
|
|||||||
The bug manifests as getting 29 attendees with emails like "M", "A", "I", etc.
|
The bug manifests as getting 29 attendees with emails like "M", "A", "I", etc.
|
||||||
instead of properly parsed email addresses.
|
instead of properly parsed email addresses.
|
||||||
"""
|
"""
|
||||||
|
# Create a test room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="test-room",
|
name="test-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -31,8 +31,8 @@ async def test_attendee_parsing_bug(db_session):
|
|||||||
ics_url="http://test.com/test.ics",
|
ics_url="http://test.com/test.ics",
|
||||||
ics_enabled=True,
|
ics_enabled=True,
|
||||||
)
|
)
|
||||||
await db_session.flush()
|
|
||||||
|
|
||||||
|
# Read the test ICS file that reproduces the bug and update it with current time
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
test_ics_path = os.path.join(
|
test_ics_path = os.path.join(
|
||||||
@@ -41,26 +41,30 @@ async def test_attendee_parsing_bug(db_session):
|
|||||||
with open(test_ics_path, "r") as f:
|
with open(test_ics_path, "r") as f:
|
||||||
ics_content = f.read()
|
ics_content = f.read()
|
||||||
|
|
||||||
|
# Replace the dates with current time + 1 hour to ensure it's within the 24h window
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
future_time = now + timedelta(hours=1)
|
future_time = now + timedelta(hours=1)
|
||||||
end_time = future_time + timedelta(hours=1)
|
end_time = future_time + timedelta(hours=1)
|
||||||
|
|
||||||
|
# Format dates for ICS format
|
||||||
dtstart = future_time.strftime("%Y%m%dT%H%M%SZ")
|
dtstart = future_time.strftime("%Y%m%dT%H%M%SZ")
|
||||||
dtend = end_time.strftime("%Y%m%dT%H%M%SZ")
|
dtend = end_time.strftime("%Y%m%dT%H%M%SZ")
|
||||||
dtstamp = now.strftime("%Y%m%dT%H%M%SZ")
|
dtstamp = now.strftime("%Y%m%dT%H%M%SZ")
|
||||||
|
|
||||||
|
# Update the ICS content with current dates
|
||||||
ics_content = ics_content.replace("20250910T180000Z", dtstart)
|
ics_content = ics_content.replace("20250910T180000Z", dtstart)
|
||||||
ics_content = ics_content.replace("20250910T190000Z", dtend)
|
ics_content = ics_content.replace("20250910T190000Z", dtend)
|
||||||
ics_content = ics_content.replace("20250910T174000Z", dtstamp)
|
ics_content = ics_content.replace("20250910T174000Z", dtstamp)
|
||||||
|
|
||||||
|
# Create sync service and mock the fetch
|
||||||
sync_service = ICSSyncService()
|
sync_service = ICSSyncService()
|
||||||
from unittest.mock import AsyncMock
|
|
||||||
|
|
||||||
with patch.object(
|
with patch.object(
|
||||||
sync_service.fetch_service, "fetch_ics", new_callable=AsyncMock
|
sync_service.fetch_service, "fetch_ics", new_callable=AsyncMock
|
||||||
) as mock_fetch:
|
) as mock_fetch:
|
||||||
mock_fetch.return_value = ics_content
|
mock_fetch.return_value = ics_content
|
||||||
|
|
||||||
|
# Debug: Parse the ICS content directly to examine attendee parsing
|
||||||
calendar = sync_service.fetch_service.parse_ics(ics_content)
|
calendar = sync_service.fetch_service.parse_ics(ics_content)
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
@@ -76,23 +80,113 @@ async def test_attendee_parsing_bug(db_session):
|
|||||||
print(f"Total events in calendar: {total_events}")
|
print(f"Total events in calendar: {total_events}")
|
||||||
print(f"Events matching room: {len(events)}")
|
print(f"Events matching room: {len(events)}")
|
||||||
|
|
||||||
result = await sync_service.sync_room_calendar(db_session, room)
|
# Perform the sync
|
||||||
|
result = await sync_service.sync_room_calendar(room)
|
||||||
|
|
||||||
|
# Check that the sync succeeded
|
||||||
assert result.get("status") == "success"
|
assert result.get("status") == "success"
|
||||||
assert result.get("events_found", 0) >= 0
|
assert result.get("events_found", 0) >= 0 # Allow for debugging
|
||||||
|
|
||||||
|
# We already have the matching events from the debug code above
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
event = events[0]
|
event = events[0]
|
||||||
|
|
||||||
|
# This is where the bug manifests - check the attendees
|
||||||
attendees = event["attendees"]
|
attendees = event["attendees"]
|
||||||
|
|
||||||
print(f"Number of attendees: {len(attendees)}")
|
# Print attendee info for debugging
|
||||||
|
print(f"Number of attendees found: {len(attendees)}")
|
||||||
for i, attendee in enumerate(attendees):
|
for i, attendee in enumerate(attendees):
|
||||||
print(f"Attendee {i}: {attendee}")
|
print(
|
||||||
|
f"Attendee {i}: email='{attendee.get('email')}', name='{attendee.get('name')}'"
|
||||||
|
)
|
||||||
|
|
||||||
assert len(attendees) == 30, f"Expected 30 attendees, got {len(attendees)}"
|
# With the fix, we should now get properly parsed email addresses
|
||||||
|
# Check that no single characters are parsed as emails
|
||||||
|
single_char_emails = [
|
||||||
|
att for att in attendees if att.get("email") and len(att["email"]) == 1
|
||||||
|
]
|
||||||
|
|
||||||
assert attendees[0]["email"] == "alice@example.com"
|
if single_char_emails:
|
||||||
assert attendees[1]["email"] == "bob@example.com"
|
print(
|
||||||
assert attendees[2]["email"] == "charlie@example.com"
|
f"BUG DETECTED: Found {len(single_char_emails)} single-character emails:"
|
||||||
assert any(att["email"] == "organizer@example.com" for att in attendees)
|
)
|
||||||
|
for att in single_char_emails:
|
||||||
|
print(f" - '{att['email']}'")
|
||||||
|
|
||||||
|
# Should have attendees but not single-character emails
|
||||||
|
assert len(attendees) > 0
|
||||||
|
assert (
|
||||||
|
len(single_char_emails) == 0
|
||||||
|
), f"Found {len(single_char_emails)} single-character emails, parsing is still buggy"
|
||||||
|
|
||||||
|
# Check that all emails are valid (contain @ symbol)
|
||||||
|
valid_emails = [
|
||||||
|
att for att in attendees if att.get("email") and "@" in att["email"]
|
||||||
|
]
|
||||||
|
assert len(valid_emails) == len(
|
||||||
|
attendees
|
||||||
|
), "Some attendees don't have valid email addresses"
|
||||||
|
|
||||||
|
# We expect around 29 attendees (28 from the comma-separated list + 1 organizer)
|
||||||
|
assert (
|
||||||
|
len(attendees) >= 25
|
||||||
|
), f"Expected around 29 attendees, got {len(attendees)}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_correct_attendee_parsing():
|
||||||
|
"""
|
||||||
|
Test what correct attendee parsing should look like.
|
||||||
|
"""
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from icalendar import Event
|
||||||
|
|
||||||
|
from reflector.services.ics_sync import ICSFetchService
|
||||||
|
|
||||||
|
service = ICSFetchService()
|
||||||
|
|
||||||
|
# Create a properly formatted event with multiple attendees
|
||||||
|
event = Event()
|
||||||
|
event.add("uid", "test-correct-attendees")
|
||||||
|
event.add("summary", "Test Meeting")
|
||||||
|
event.add("location", "http://test.com/test")
|
||||||
|
event.add("dtstart", datetime.now(timezone.utc))
|
||||||
|
event.add("dtend", datetime.now(timezone.utc))
|
||||||
|
|
||||||
|
# Add attendees the correct way (separate ATTENDEE lines)
|
||||||
|
event.add("attendee", "mailto:alice@example.com", parameters={"CN": "Alice"})
|
||||||
|
event.add("attendee", "mailto:bob@example.com", parameters={"CN": "Bob"})
|
||||||
|
event.add("attendee", "mailto:charlie@example.com", parameters={"CN": "Charlie"})
|
||||||
|
event.add(
|
||||||
|
"organizer", "mailto:organizer@example.com", parameters={"CN": "Organizer"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse the event
|
||||||
|
result = service._parse_event(event)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
attendees = result["attendees"]
|
||||||
|
|
||||||
|
# Should have 4 attendees (3 attendees + 1 organizer)
|
||||||
|
assert len(attendees) == 4
|
||||||
|
|
||||||
|
# Check that all emails are valid email addresses
|
||||||
|
emails = [att["email"] for att in attendees if att.get("email")]
|
||||||
|
expected_emails = [
|
||||||
|
"alice@example.com",
|
||||||
|
"bob@example.com",
|
||||||
|
"charlie@example.com",
|
||||||
|
"organizer@example.com",
|
||||||
|
]
|
||||||
|
|
||||||
|
for email in emails:
|
||||||
|
assert "@" in email, f"Invalid email format: {email}"
|
||||||
|
assert len(email) > 5, f"Email too short: {email}"
|
||||||
|
|
||||||
|
# Check that we have the expected emails
|
||||||
|
assert "alice@example.com" in emails
|
||||||
|
assert "bob@example.com" in emails
|
||||||
|
assert "charlie@example.com" in emails
|
||||||
|
assert "organizer@example.com" in emails
|
||||||
|
|||||||
@@ -11,11 +11,10 @@ from reflector.db.rooms import rooms_controller
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calendar_event_create(db_session):
|
async def test_calendar_event_create():
|
||||||
"""Test creating a calendar event."""
|
"""Test creating a calendar event."""
|
||||||
# Create a room first
|
# Create a room first
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="test-room",
|
name="test-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -45,7 +44,7 @@ async def test_calendar_event_create(db_session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Save event
|
# Save event
|
||||||
saved_event = await calendar_events_controller.upsert(db_session, event)
|
saved_event = await calendar_events_controller.upsert(event)
|
||||||
|
|
||||||
assert saved_event.ics_uid == "test-event-123"
|
assert saved_event.ics_uid == "test-event-123"
|
||||||
assert saved_event.title == "Team Meeting"
|
assert saved_event.title == "Team Meeting"
|
||||||
@@ -54,11 +53,10 @@ async def test_calendar_event_create(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calendar_event_get_by_room(db_session):
|
async def test_calendar_event_get_by_room():
|
||||||
"""Test getting calendar events for a room."""
|
"""Test getting calendar events for a room."""
|
||||||
# Create room
|
# Create room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="events-room",
|
name="events-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -82,10 +80,10 @@ async def test_calendar_event_get_by_room(db_session):
|
|||||||
start_time=now + timedelta(hours=i),
|
start_time=now + timedelta(hours=i),
|
||||||
end_time=now + timedelta(hours=i + 1),
|
end_time=now + timedelta(hours=i + 1),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, event)
|
await calendar_events_controller.upsert(event)
|
||||||
|
|
||||||
# Get events for room
|
# Get events for room
|
||||||
events = await calendar_events_controller.get_by_room(db_session, room.id)
|
events = await calendar_events_controller.get_by_room(room.id)
|
||||||
|
|
||||||
assert len(events) == 3
|
assert len(events) == 3
|
||||||
assert all(e.room_id == room.id for e in events)
|
assert all(e.room_id == room.id for e in events)
|
||||||
@@ -95,11 +93,10 @@ async def test_calendar_event_get_by_room(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calendar_event_get_upcoming(db_session):
|
async def test_calendar_event_get_upcoming():
|
||||||
"""Test getting upcoming events within time window."""
|
"""Test getting upcoming events within time window."""
|
||||||
# Create room
|
# Create room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="upcoming-room",
|
name="upcoming-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -123,7 +120,7 @@ async def test_calendar_event_get_upcoming(db_session):
|
|||||||
start_time=now - timedelta(hours=2),
|
start_time=now - timedelta(hours=2),
|
||||||
end_time=now - timedelta(hours=1),
|
end_time=now - timedelta(hours=1),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, past_event)
|
await calendar_events_controller.upsert(past_event)
|
||||||
|
|
||||||
# Upcoming event within 30 minutes
|
# Upcoming event within 30 minutes
|
||||||
upcoming_event = CalendarEvent(
|
upcoming_event = CalendarEvent(
|
||||||
@@ -133,7 +130,7 @@ async def test_calendar_event_get_upcoming(db_session):
|
|||||||
start_time=now + timedelta(minutes=15),
|
start_time=now + timedelta(minutes=15),
|
||||||
end_time=now + timedelta(minutes=45),
|
end_time=now + timedelta(minutes=45),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, upcoming_event)
|
await calendar_events_controller.upsert(upcoming_event)
|
||||||
|
|
||||||
# Currently happening event (started 10 minutes ago, ends in 20 minutes)
|
# Currently happening event (started 10 minutes ago, ends in 20 minutes)
|
||||||
current_event = CalendarEvent(
|
current_event = CalendarEvent(
|
||||||
@@ -143,7 +140,7 @@ async def test_calendar_event_get_upcoming(db_session):
|
|||||||
start_time=now - timedelta(minutes=10),
|
start_time=now - timedelta(minutes=10),
|
||||||
end_time=now + timedelta(minutes=20),
|
end_time=now + timedelta(minutes=20),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, current_event)
|
await calendar_events_controller.upsert(current_event)
|
||||||
|
|
||||||
# Future event beyond 30 minutes
|
# Future event beyond 30 minutes
|
||||||
future_event = CalendarEvent(
|
future_event = CalendarEvent(
|
||||||
@@ -153,10 +150,10 @@ async def test_calendar_event_get_upcoming(db_session):
|
|||||||
start_time=now + timedelta(hours=2),
|
start_time=now + timedelta(hours=2),
|
||||||
end_time=now + timedelta(hours=3),
|
end_time=now + timedelta(hours=3),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, future_event)
|
await calendar_events_controller.upsert(future_event)
|
||||||
|
|
||||||
# Get upcoming events (default 120 minutes) - should include current, upcoming, and future
|
# Get upcoming events (default 120 minutes) - should include current, upcoming, and future
|
||||||
upcoming = await calendar_events_controller.get_upcoming(db_session, room.id)
|
upcoming = await calendar_events_controller.get_upcoming(room.id)
|
||||||
|
|
||||||
assert len(upcoming) == 3
|
assert len(upcoming) == 3
|
||||||
# Events should be sorted by start_time (current event first, then upcoming, then future)
|
# Events should be sorted by start_time (current event first, then upcoming, then future)
|
||||||
@@ -166,7 +163,7 @@ async def test_calendar_event_get_upcoming(db_session):
|
|||||||
|
|
||||||
# Get upcoming with custom window
|
# Get upcoming with custom window
|
||||||
upcoming_extended = await calendar_events_controller.get_upcoming(
|
upcoming_extended = await calendar_events_controller.get_upcoming(
|
||||||
db_session, room.id, minutes_ahead=180
|
room.id, minutes_ahead=180
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(upcoming_extended) == 3
|
assert len(upcoming_extended) == 3
|
||||||
@@ -177,11 +174,10 @@ async def test_calendar_event_get_upcoming(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calendar_event_get_upcoming_includes_currently_happening(db_session):
|
async def test_calendar_event_get_upcoming_includes_currently_happening():
|
||||||
"""Test that get_upcoming includes currently happening events but excludes ended events."""
|
"""Test that get_upcoming includes currently happening events but excludes ended events."""
|
||||||
# Create room
|
# Create room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="current-happening-room",
|
name="current-happening-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -204,7 +200,7 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
|
|||||||
start_time=now - timedelta(hours=2),
|
start_time=now - timedelta(hours=2),
|
||||||
end_time=now - timedelta(minutes=30),
|
end_time=now - timedelta(minutes=30),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, past_ended_event)
|
await calendar_events_controller.upsert(past_ended_event)
|
||||||
|
|
||||||
# Event currently happening (started 10 minutes ago, ends in 20 minutes) - SHOULD be included
|
# Event currently happening (started 10 minutes ago, ends in 20 minutes) - SHOULD be included
|
||||||
currently_happening_event = CalendarEvent(
|
currently_happening_event = CalendarEvent(
|
||||||
@@ -214,7 +210,7 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
|
|||||||
start_time=now - timedelta(minutes=10),
|
start_time=now - timedelta(minutes=10),
|
||||||
end_time=now + timedelta(minutes=20),
|
end_time=now + timedelta(minutes=20),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, currently_happening_event)
|
await calendar_events_controller.upsert(currently_happening_event)
|
||||||
|
|
||||||
# Event starting soon (in 5 minutes) - SHOULD be included
|
# Event starting soon (in 5 minutes) - SHOULD be included
|
||||||
upcoming_soon_event = CalendarEvent(
|
upcoming_soon_event = CalendarEvent(
|
||||||
@@ -224,12 +220,10 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
|
|||||||
start_time=now + timedelta(minutes=5),
|
start_time=now + timedelta(minutes=5),
|
||||||
end_time=now + timedelta(minutes=35),
|
end_time=now + timedelta(minutes=35),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, upcoming_soon_event)
|
await calendar_events_controller.upsert(upcoming_soon_event)
|
||||||
|
|
||||||
# Get upcoming events
|
# Get upcoming events
|
||||||
upcoming = await calendar_events_controller.get_upcoming(
|
upcoming = await calendar_events_controller.get_upcoming(room.id, minutes_ahead=30)
|
||||||
db_session, room.id, minutes_ahead=30
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should only include currently happening and upcoming soon events
|
# Should only include currently happening and upcoming soon events
|
||||||
assert len(upcoming) == 2
|
assert len(upcoming) == 2
|
||||||
@@ -238,11 +232,10 @@ async def test_calendar_event_get_upcoming_includes_currently_happening(db_sessi
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calendar_event_upsert(db_session):
|
async def test_calendar_event_upsert():
|
||||||
"""Test upserting (create/update) calendar events."""
|
"""Test upserting (create/update) calendar events."""
|
||||||
# Create room
|
# Create room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="upsert-room",
|
name="upsert-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -266,30 +259,29 @@ async def test_calendar_event_upsert(db_session):
|
|||||||
end_time=now + timedelta(hours=1),
|
end_time=now + timedelta(hours=1),
|
||||||
)
|
)
|
||||||
|
|
||||||
created = await calendar_events_controller.upsert(db_session, event)
|
created = await calendar_events_controller.upsert(event)
|
||||||
assert created.title == "Original Title"
|
assert created.title == "Original Title"
|
||||||
|
|
||||||
# Update existing event
|
# Update existing event
|
||||||
event.title = "Updated Title"
|
event.title = "Updated Title"
|
||||||
event.description = "Added description"
|
event.description = "Added description"
|
||||||
|
|
||||||
updated = await calendar_events_controller.upsert(db_session, event)
|
updated = await calendar_events_controller.upsert(event)
|
||||||
assert updated.title == "Updated Title"
|
assert updated.title == "Updated Title"
|
||||||
assert updated.description == "Added description"
|
assert updated.description == "Added description"
|
||||||
assert updated.ics_uid == "upsert-test"
|
assert updated.ics_uid == "upsert-test"
|
||||||
|
|
||||||
# Verify only one event exists
|
# Verify only one event exists
|
||||||
events = await calendar_events_controller.get_by_room(db_session, room.id)
|
events = await calendar_events_controller.get_by_room(room.id)
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
assert events[0].title == "Updated Title"
|
assert events[0].title == "Updated Title"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calendar_event_soft_delete(db_session):
|
async def test_calendar_event_soft_delete():
|
||||||
"""Test soft deleting events no longer in calendar."""
|
"""Test soft deleting events no longer in calendar."""
|
||||||
# Create room
|
# Create room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="delete-room",
|
name="delete-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -313,36 +305,35 @@ async def test_calendar_event_soft_delete(db_session):
|
|||||||
start_time=now + timedelta(hours=i),
|
start_time=now + timedelta(hours=i),
|
||||||
end_time=now + timedelta(hours=i + 1),
|
end_time=now + timedelta(hours=i + 1),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, event)
|
await calendar_events_controller.upsert(event)
|
||||||
|
|
||||||
# Soft delete events not in current list
|
# Soft delete events not in current list
|
||||||
current_ids = ["event-0", "event-2"] # Keep events 0 and 2
|
current_ids = ["event-0", "event-2"] # Keep events 0 and 2
|
||||||
deleted_count = await calendar_events_controller.soft_delete_missing(
|
deleted_count = await calendar_events_controller.soft_delete_missing(
|
||||||
db_session, room.id, current_ids
|
room.id, current_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
assert deleted_count == 2 # Should delete events 1 and 3
|
assert deleted_count == 2 # Should delete events 1 and 3
|
||||||
|
|
||||||
# Get non-deleted events
|
# Get non-deleted events
|
||||||
events = await calendar_events_controller.get_by_room(
|
events = await calendar_events_controller.get_by_room(
|
||||||
db_session, room.id, include_deleted=False
|
room.id, include_deleted=False
|
||||||
)
|
)
|
||||||
assert len(events) == 2
|
assert len(events) == 2
|
||||||
assert {e.ics_uid for e in events} == {"event-0", "event-2"}
|
assert {e.ics_uid for e in events} == {"event-0", "event-2"}
|
||||||
|
|
||||||
# Get all events including deleted
|
# Get all events including deleted
|
||||||
all_events = await calendar_events_controller.get_by_room(
|
all_events = await calendar_events_controller.get_by_room(
|
||||||
db_session, room.id, include_deleted=True
|
room.id, include_deleted=True
|
||||||
)
|
)
|
||||||
assert len(all_events) == 4
|
assert len(all_events) == 4
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calendar_event_past_events_not_deleted(db_session):
|
async def test_calendar_event_past_events_not_deleted():
|
||||||
"""Test that past events are not soft deleted."""
|
"""Test that past events are not soft deleted."""
|
||||||
# Create room
|
# Create room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="past-events-room",
|
name="past-events-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -365,7 +356,7 @@ async def test_calendar_event_past_events_not_deleted(db_session):
|
|||||||
start_time=now - timedelta(hours=2),
|
start_time=now - timedelta(hours=2),
|
||||||
end_time=now - timedelta(hours=1),
|
end_time=now - timedelta(hours=1),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, past_event)
|
await calendar_events_controller.upsert(past_event)
|
||||||
|
|
||||||
# Create future event
|
# Create future event
|
||||||
future_event = CalendarEvent(
|
future_event = CalendarEvent(
|
||||||
@@ -375,29 +366,26 @@ async def test_calendar_event_past_events_not_deleted(db_session):
|
|||||||
start_time=now + timedelta(hours=1),
|
start_time=now + timedelta(hours=1),
|
||||||
end_time=now + timedelta(hours=2),
|
end_time=now + timedelta(hours=2),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, future_event)
|
await calendar_events_controller.upsert(future_event)
|
||||||
|
|
||||||
# Try to soft delete all events (only future should be deleted)
|
# Try to soft delete all events (only future should be deleted)
|
||||||
deleted_count = await calendar_events_controller.soft_delete_missing(
|
deleted_count = await calendar_events_controller.soft_delete_missing(room.id, [])
|
||||||
db_session, room.id, []
|
|
||||||
)
|
|
||||||
|
|
||||||
assert deleted_count == 1 # Only future event deleted
|
assert deleted_count == 1 # Only future event deleted
|
||||||
|
|
||||||
# Verify past event still exists
|
# Verify past event still exists
|
||||||
events = await calendar_events_controller.get_by_room(
|
events = await calendar_events_controller.get_by_room(
|
||||||
db_session, room.id, include_deleted=False
|
room.id, include_deleted=False
|
||||||
)
|
)
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
assert events[0].ics_uid == "past-event"
|
assert events[0].ics_uid == "past-event"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calendar_event_with_raw_ics_data(db_session):
|
async def test_calendar_event_with_raw_ics_data():
|
||||||
"""Test storing raw ICS data with calendar event."""
|
"""Test storing raw ICS data with calendar event."""
|
||||||
# Create room
|
# Create room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="raw-ics-room",
|
name="raw-ics-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -426,13 +414,11 @@ END:VEVENT"""
|
|||||||
ics_raw_data=raw_ics,
|
ics_raw_data=raw_ics,
|
||||||
)
|
)
|
||||||
|
|
||||||
saved = await calendar_events_controller.upsert(db_session, event)
|
saved = await calendar_events_controller.upsert(event)
|
||||||
|
|
||||||
assert saved.ics_raw_data == raw_ics
|
assert saved.ics_raw_data == raw_ics
|
||||||
|
|
||||||
# Retrieve and verify
|
# Retrieve and verify
|
||||||
retrieved = await calendar_events_controller.get_by_ics_uid(
|
retrieved = await calendar_events_controller.get_by_ics_uid(room.id, "test-raw-123")
|
||||||
db_session, room.id, "test-raw-123"
|
|
||||||
)
|
|
||||||
assert retrieved is not None
|
assert retrieved is not None
|
||||||
assert retrieved.ics_raw_data == raw_ics
|
assert retrieved.ics_raw_data == raw_ics
|
||||||
|
|||||||
@@ -2,32 +2,26 @@ from datetime import datetime, timedelta, timezone
|
|||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import delete, insert, select, update
|
|
||||||
|
|
||||||
from reflector.db.base import (
|
from reflector.db.recordings import Recording, recordings_controller
|
||||||
MeetingConsentModel,
|
|
||||||
MeetingModel,
|
|
||||||
RecordingModel,
|
|
||||||
TranscriptModel,
|
|
||||||
)
|
|
||||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||||
from reflector.worker.cleanup import cleanup_old_public_data
|
from reflector.worker.cleanup import cleanup_old_public_data
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cleanup_old_public_data_skips_when_not_public(db_session):
|
async def test_cleanup_old_public_data_skips_when_not_public():
|
||||||
"""Test that cleanup is skipped when PUBLIC_MODE is False."""
|
"""Test that cleanup is skipped when PUBLIC_MODE is False."""
|
||||||
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
||||||
mock_settings.PUBLIC_MODE = False
|
mock_settings.PUBLIC_MODE = False
|
||||||
|
|
||||||
result = await cleanup_old_public_data(db_session)
|
result = await cleanup_old_public_data()
|
||||||
|
|
||||||
# Should return early without doing anything
|
# Should return early without doing anything
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(db_session):
|
async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts():
|
||||||
"""Test that old anonymous transcripts are deleted."""
|
"""Test that old anonymous transcripts are deleted."""
|
||||||
# Create old and new anonymous transcripts
|
# Create old and new anonymous transcripts
|
||||||
old_date = datetime.now(timezone.utc) - timedelta(days=8)
|
old_date = datetime.now(timezone.utc) - timedelta(days=8)
|
||||||
@@ -35,23 +29,22 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(db_sess
|
|||||||
|
|
||||||
# Create old anonymous transcript (should be deleted)
|
# Create old anonymous transcript (should be deleted)
|
||||||
old_transcript = await transcripts_controller.add(
|
old_transcript = await transcripts_controller.add(
|
||||||
db_session,
|
|
||||||
name="Old Anonymous Transcript",
|
name="Old Anonymous Transcript",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
user_id=None, # Anonymous
|
user_id=None, # Anonymous
|
||||||
)
|
)
|
||||||
|
|
||||||
# Manually update created_at to be old
|
# Manually update created_at to be old
|
||||||
await db_session.execute(
|
from reflector.db import get_database
|
||||||
update(TranscriptModel)
|
from reflector.db.transcripts import transcripts
|
||||||
.where(TranscriptModel.id == old_transcript.id)
|
|
||||||
|
await get_database().execute(
|
||||||
|
transcripts.update()
|
||||||
|
.where(transcripts.c.id == old_transcript.id)
|
||||||
.values(created_at=old_date)
|
.values(created_at=old_date)
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Create new anonymous transcript (should NOT be deleted)
|
# Create new anonymous transcript (should NOT be deleted)
|
||||||
new_transcript = await transcripts_controller.add(
|
new_transcript = await transcripts_controller.add(
|
||||||
db_session,
|
|
||||||
name="New Anonymous Transcript",
|
name="New Anonymous Transcript",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
user_id=None, # Anonymous
|
user_id=None, # Anonymous
|
||||||
@@ -59,265 +52,234 @@ async def test_cleanup_old_public_data_deletes_old_anonymous_transcripts(db_sess
|
|||||||
|
|
||||||
# Create old transcript with user (should NOT be deleted)
|
# Create old transcript with user (should NOT be deleted)
|
||||||
old_user_transcript = await transcripts_controller.add(
|
old_user_transcript = await transcripts_controller.add(
|
||||||
db_session,
|
|
||||||
name="Old User Transcript",
|
name="Old User Transcript",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
user_id="user-123",
|
user_id="user123",
|
||||||
)
|
)
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
update(TranscriptModel)
|
transcripts.update()
|
||||||
.where(TranscriptModel.id == old_user_transcript.id)
|
.where(transcripts.c.id == old_user_transcript.id)
|
||||||
.values(created_at=old_date)
|
.values(created_at=old_date)
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Mock settings for public mode
|
|
||||||
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
||||||
mock_settings.PUBLIC_MODE = True
|
mock_settings.PUBLIC_MODE = True
|
||||||
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
|
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
|
||||||
|
|
||||||
# Mock delete_single_transcript to track what gets deleted
|
# Mock the storage deletion
|
||||||
with patch("reflector.worker.cleanup.delete_single_transcript") as mock_delete:
|
with patch("reflector.db.transcripts.get_transcripts_storage") as mock_storage:
|
||||||
mock_delete.return_value = None
|
mock_storage.return_value.delete_file = AsyncMock()
|
||||||
|
|
||||||
# Run cleanup with test session
|
result = await cleanup_old_public_data()
|
||||||
await cleanup_old_public_data(db_session)
|
|
||||||
|
|
||||||
# Verify only old anonymous transcript was deleted
|
# Check results
|
||||||
assert mock_delete.call_count == 1
|
assert result["transcripts_deleted"] == 1
|
||||||
# The function is called with session_factory, transcript_data dict, and stats dict
|
assert result["errors"] == []
|
||||||
call_args = mock_delete.call_args[0]
|
|
||||||
transcript_data = call_args[1]
|
# Verify old anonymous transcript was deleted
|
||||||
assert transcript_data["id"] == old_transcript.id
|
assert await transcripts_controller.get_by_id(old_transcript.id) is None
|
||||||
|
|
||||||
|
# Verify new anonymous transcript still exists
|
||||||
|
assert await transcripts_controller.get_by_id(new_transcript.id) is not None
|
||||||
|
|
||||||
|
# Verify user transcript still exists
|
||||||
|
assert await transcripts_controller.get_by_id(old_user_transcript.id) is not None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cleanup_deletes_associated_meeting_and_recording(db_session):
|
async def test_cleanup_deletes_associated_meeting_and_recording():
|
||||||
"""Test that cleanup deletes associated meetings and recordings."""
|
"""Test that meetings and recordings associated with old transcripts are deleted."""
|
||||||
|
from reflector.db import get_database
|
||||||
|
from reflector.db.meetings import meetings
|
||||||
|
from reflector.db.transcripts import transcripts
|
||||||
|
|
||||||
old_date = datetime.now(timezone.utc) - timedelta(days=8)
|
old_date = datetime.now(timezone.utc) - timedelta(days=8)
|
||||||
|
|
||||||
|
# Create a meeting
|
||||||
|
meeting_id = "test-meeting-for-transcript"
|
||||||
|
await get_database().execute(
|
||||||
|
meetings.insert().values(
|
||||||
|
id=meeting_id,
|
||||||
|
room_name="Meeting with Transcript",
|
||||||
|
room_url="https://example.com/meeting",
|
||||||
|
host_room_url="https://example.com/meeting-host",
|
||||||
|
start_date=old_date,
|
||||||
|
end_date=old_date + timedelta(hours=1),
|
||||||
|
room_id=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a recording
|
||||||
|
recording = await recordings_controller.create(
|
||||||
|
Recording(
|
||||||
|
bucket_name="test-bucket",
|
||||||
|
object_key="test-recording.mp4",
|
||||||
|
recorded_at=old_date,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Create an old transcript with both meeting and recording
|
# Create an old transcript with both meeting and recording
|
||||||
old_transcript = await transcripts_controller.add(
|
old_transcript = await transcripts_controller.add(
|
||||||
db_session,
|
|
||||||
name="Old Transcript with Meeting and Recording",
|
name="Old Transcript with Meeting and Recording",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.ROOM,
|
||||||
user_id=None,
|
user_id=None,
|
||||||
|
meeting_id=meeting_id,
|
||||||
|
recording_id=recording.id,
|
||||||
)
|
)
|
||||||
await db_session.execute(
|
|
||||||
update(TranscriptModel)
|
# Update created_at to be old
|
||||||
.where(TranscriptModel.id == old_transcript.id)
|
await get_database().execute(
|
||||||
|
transcripts.update()
|
||||||
|
.where(transcripts.c.id == old_transcript.id)
|
||||||
.values(created_at=old_date)
|
.values(created_at=old_date)
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Create associated meeting directly
|
|
||||||
meeting_id = "test-meeting-id"
|
|
||||||
await db_session.execute(
|
|
||||||
insert(MeetingModel).values(
|
|
||||||
id=meeting_id,
|
|
||||||
room_id=None,
|
|
||||||
room_name="test-room",
|
|
||||||
room_url="https://example.com/room",
|
|
||||||
host_room_url="https://example.com/room-host",
|
|
||||||
start_date=old_date,
|
|
||||||
end_date=old_date + timedelta(hours=1),
|
|
||||||
is_active=False,
|
|
||||||
num_clients=0,
|
|
||||||
is_locked=False,
|
|
||||||
room_mode="normal",
|
|
||||||
recording_type="cloud",
|
|
||||||
recording_trigger="automatic",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create associated recording directly
|
|
||||||
recording_id = "test-recording-id"
|
|
||||||
await db_session.execute(
|
|
||||||
insert(RecordingModel).values(
|
|
||||||
id=recording_id,
|
|
||||||
meeting_id=meeting_id,
|
|
||||||
url="https://example.com/recording.mp4",
|
|
||||||
object_key="recordings/test.mp4",
|
|
||||||
duration=3600.0,
|
|
||||||
created_at=old_date,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Update transcript with meeting_id and recording_id
|
|
||||||
await db_session.execute(
|
|
||||||
update(TranscriptModel)
|
|
||||||
.where(TranscriptModel.id == old_transcript.id)
|
|
||||||
.values(meeting_id=meeting_id, recording_id=recording_id)
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Mock settings
|
|
||||||
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
||||||
mock_settings.PUBLIC_MODE = True
|
mock_settings.PUBLIC_MODE = True
|
||||||
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
|
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
|
||||||
|
|
||||||
# Mock storage deletion
|
# Mock storage deletion
|
||||||
with patch("reflector.worker.cleanup.get_recordings_storage") as mock_storage:
|
with patch("reflector.db.transcripts.get_transcripts_storage") as mock_storage:
|
||||||
mock_storage.return_value.delete_file = AsyncMock()
|
mock_storage.return_value.delete_file = AsyncMock()
|
||||||
|
with patch(
|
||||||
|
"reflector.worker.cleanup.get_recordings_storage"
|
||||||
|
) as mock_rec_storage:
|
||||||
|
mock_rec_storage.return_value.delete_file = AsyncMock()
|
||||||
|
|
||||||
# Run cleanup with test session
|
result = await cleanup_old_public_data()
|
||||||
await cleanup_old_public_data(db_session)
|
|
||||||
|
# Check results
|
||||||
|
assert result["transcripts_deleted"] == 1
|
||||||
|
assert result["meetings_deleted"] == 1
|
||||||
|
assert result["recordings_deleted"] == 1
|
||||||
|
assert result["errors"] == []
|
||||||
|
|
||||||
# Verify transcript was deleted
|
# Verify transcript was deleted
|
||||||
result = await db_session.execute(
|
assert await transcripts_controller.get_by_id(old_transcript.id) is None
|
||||||
select(TranscriptModel).where(TranscriptModel.id == old_transcript.id)
|
|
||||||
)
|
|
||||||
transcript = result.scalar_one_or_none()
|
|
||||||
assert transcript is None
|
|
||||||
|
|
||||||
# Verify meeting was deleted
|
# Verify meeting was deleted
|
||||||
result = await db_session.execute(
|
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||||
select(MeetingModel).where(MeetingModel.id == meeting_id)
|
meeting_result = await get_database().fetch_one(query)
|
||||||
)
|
assert meeting_result is None
|
||||||
meeting = result.scalar_one_or_none()
|
|
||||||
assert meeting is None
|
|
||||||
|
|
||||||
# Verify recording was deleted
|
# Verify recording was deleted
|
||||||
result = await db_session.execute(
|
assert await recordings_controller.get_by_id(recording.id) is None
|
||||||
select(RecordingModel).where(RecordingModel.id == recording_id)
|
|
||||||
)
|
|
||||||
recording = result.scalar_one_or_none()
|
|
||||||
assert recording is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_cleanup_handles_errors_gracefully(db_session):
|
async def test_cleanup_handles_errors_gracefully():
|
||||||
"""Test that cleanup continues even if individual deletions fail."""
|
"""Test that cleanup continues even when individual deletions fail."""
|
||||||
old_date = datetime.now(timezone.utc) - timedelta(days=8)
|
old_date = datetime.now(timezone.utc) - timedelta(days=8)
|
||||||
|
|
||||||
# Create multiple old transcripts
|
# Create multiple old transcripts
|
||||||
transcript1 = await transcripts_controller.add(
|
transcript1 = await transcripts_controller.add(
|
||||||
db_session,
|
|
||||||
name="Transcript 1",
|
name="Transcript 1",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
user_id=None,
|
user_id=None,
|
||||||
)
|
)
|
||||||
await db_session.execute(
|
|
||||||
update(TranscriptModel)
|
|
||||||
.where(TranscriptModel.id == transcript1.id)
|
|
||||||
.values(created_at=old_date)
|
|
||||||
)
|
|
||||||
|
|
||||||
transcript2 = await transcripts_controller.add(
|
transcript2 = await transcripts_controller.add(
|
||||||
db_session,
|
|
||||||
name="Transcript 2",
|
name="Transcript 2",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
user_id=None,
|
user_id=None,
|
||||||
)
|
)
|
||||||
await db_session.execute(
|
|
||||||
update(TranscriptModel)
|
# Update created_at to be old
|
||||||
.where(TranscriptModel.id == transcript2.id)
|
from reflector.db import get_database
|
||||||
|
from reflector.db.transcripts import transcripts
|
||||||
|
|
||||||
|
for t_id in [transcript1.id, transcript2.id]:
|
||||||
|
await get_database().execute(
|
||||||
|
transcripts.update()
|
||||||
|
.where(transcripts.c.id == t_id)
|
||||||
.values(created_at=old_date)
|
.values(created_at=old_date)
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
with patch("reflector.worker.cleanup.settings") as mock_settings:
|
||||||
mock_settings.PUBLIC_MODE = True
|
mock_settings.PUBLIC_MODE = True
|
||||||
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
|
mock_settings.PUBLIC_DATA_RETENTION_DAYS = 7
|
||||||
|
|
||||||
# Mock delete_single_transcript to fail on first call but succeed on second
|
# Mock remove_by_id to fail for the first transcript
|
||||||
with patch("reflector.worker.cleanup.delete_single_transcript") as mock_delete:
|
original_remove = transcripts_controller.remove_by_id
|
||||||
mock_delete.side_effect = [Exception("Delete failed"), None]
|
call_count = 0
|
||||||
|
|
||||||
# Run cleanup with test session - should not raise exception
|
async def mock_remove_by_id(transcript_id, user_id=None):
|
||||||
await cleanup_old_public_data(db_session)
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
raise Exception("Simulated deletion error")
|
||||||
|
return await original_remove(transcript_id, user_id)
|
||||||
|
|
||||||
# Both transcripts should have been attempted to delete
|
with patch.object(
|
||||||
assert mock_delete.call_count == 2
|
transcripts_controller, "remove_by_id", side_effect=mock_remove_by_id
|
||||||
|
):
|
||||||
|
result = await cleanup_old_public_data()
|
||||||
|
|
||||||
|
# Should have one successful deletion and one error
|
||||||
|
assert result["transcripts_deleted"] == 1
|
||||||
|
assert len(result["errors"]) == 1
|
||||||
|
assert "Failed to delete transcript" in result["errors"][0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_meeting_consent_cascade_delete(db_session):
|
async def test_meeting_consent_cascade_delete():
|
||||||
"""Test that meeting_consent entries are cascade deleted with meetings."""
|
"""Test that meeting_consent records are automatically deleted when meeting is deleted."""
|
||||||
old_date = datetime.now(timezone.utc) - timedelta(days=8)
|
from reflector.db import get_database
|
||||||
|
from reflector.db.meetings import (
|
||||||
# Create an old transcript
|
meeting_consent,
|
||||||
transcript = await transcripts_controller.add(
|
meeting_consent_controller,
|
||||||
db_session,
|
meetings,
|
||||||
name="Transcript with Meeting",
|
|
||||||
source_kind=SourceKind.FILE,
|
|
||||||
user_id=None,
|
|
||||||
)
|
)
|
||||||
await db_session.execute(
|
|
||||||
update(TranscriptModel)
|
|
||||||
.where(TranscriptModel.id == transcript.id)
|
|
||||||
.values(created_at=old_date)
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Create a meeting directly
|
# Create a meeting
|
||||||
meeting_id = "test-meeting-consent"
|
meeting_id = "test-cascade-meeting"
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
insert(MeetingModel).values(
|
meetings.insert().values(
|
||||||
id=meeting_id,
|
id=meeting_id,
|
||||||
|
room_name="Test Meeting for CASCADE",
|
||||||
|
room_url="https://example.com/cascade-test",
|
||||||
|
host_room_url="https://example.com/cascade-test-host",
|
||||||
|
start_date=datetime.now(timezone.utc),
|
||||||
|
end_date=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||||
room_id=None,
|
room_id=None,
|
||||||
room_name="test-room",
|
|
||||||
room_url="https://example.com/room",
|
|
||||||
host_room_url="https://example.com/room-host",
|
|
||||||
start_date=old_date,
|
|
||||||
end_date=old_date + timedelta(hours=1),
|
|
||||||
is_active=False,
|
|
||||||
num_clients=0,
|
|
||||||
is_locked=False,
|
|
||||||
room_mode="normal",
|
|
||||||
recording_type="cloud",
|
|
||||||
recording_trigger="automatic",
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Update transcript with meeting_id
|
# Create consent records for this meeting
|
||||||
await db_session.execute(
|
consent1_id = "consent-1"
|
||||||
update(TranscriptModel)
|
consent2_id = "consent-2"
|
||||||
.where(TranscriptModel.id == transcript.id)
|
|
||||||
.values(meeting_id=meeting_id)
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Create meeting_consent entries
|
await get_database().execute(
|
||||||
await db_session.execute(
|
meeting_consent.insert().values(
|
||||||
insert(MeetingConsentModel).values(
|
id=consent1_id,
|
||||||
id="consent-1",
|
|
||||||
meeting_id=meeting_id,
|
meeting_id=meeting_id,
|
||||||
user_id="user-1",
|
user_id="user1",
|
||||||
consent_given=True,
|
consent_given=True,
|
||||||
consent_timestamp=old_date,
|
consent_timestamp=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await db_session.execute(
|
|
||||||
insert(MeetingConsentModel).values(
|
|
||||||
id="consent-2",
|
|
||||||
meeting_id=meeting_id,
|
|
||||||
user_id="user-2",
|
|
||||||
consent_given=True,
|
|
||||||
consent_timestamp=old_date,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Verify consent entries exist
|
await get_database().execute(
|
||||||
result = await db_session.execute(
|
meeting_consent.insert().values(
|
||||||
select(MeetingConsentModel).where(MeetingConsentModel.meeting_id == meeting_id)
|
id=consent2_id,
|
||||||
|
meeting_id=meeting_id,
|
||||||
|
user_id="user2",
|
||||||
|
consent_given=False,
|
||||||
|
consent_timestamp=datetime.now(timezone.utc),
|
||||||
)
|
)
|
||||||
consents = result.scalars().all()
|
)
|
||||||
|
|
||||||
|
# Verify consent records exist
|
||||||
|
consents = await meeting_consent_controller.get_by_meeting_id(meeting_id)
|
||||||
assert len(consents) == 2
|
assert len(consents) == 2
|
||||||
|
|
||||||
# Delete the transcript and meeting
|
# Delete the meeting
|
||||||
await db_session.execute(
|
await get_database().execute(meetings.delete().where(meetings.c.id == meeting_id))
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
|
||||||
)
|
|
||||||
await db_session.execute(delete(MeetingModel).where(MeetingModel.id == meeting_id))
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Verify consent entries were cascade deleted
|
# Verify meeting is deleted
|
||||||
result = await db_session.execute(
|
query = meetings.select().where(meetings.c.id == meeting_id)
|
||||||
select(MeetingConsentModel).where(MeetingConsentModel.meeting_id == meeting_id)
|
result = await get_database().fetch_one(query)
|
||||||
)
|
assert result is None
|
||||||
consents = result.scalars().all()
|
|
||||||
assert len(consents) == 0
|
# Verify consent records are automatically deleted (CASCADE DELETE)
|
||||||
|
consents_after = await meeting_consent_controller.get_by_meeting_id(meeting_id)
|
||||||
|
assert len(consents_after) == 0
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
from icalendar import Calendar, Event
|
from icalendar import Calendar, Event
|
||||||
|
|
||||||
|
from reflector.db import get_database
|
||||||
from reflector.db.calendar_events import calendar_events_controller
|
from reflector.db.calendar_events import calendar_events_controller
|
||||||
from reflector.db.rooms import rooms_controller
|
from reflector.db.rooms import rooms, rooms_controller
|
||||||
from reflector.services.ics_sync import ics_sync_service
|
from reflector.services.ics_sync import ics_sync_service
|
||||||
from reflector.worker.ics_sync import (
|
from reflector.worker.ics_sync import (
|
||||||
_should_sync,
|
_should_sync,
|
||||||
@@ -14,9 +15,8 @@ from reflector.worker.ics_sync import (
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sync_room_ics_task(db_session):
|
async def test_sync_room_ics_task():
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="task-test-room",
|
name="task-test-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -30,7 +30,6 @@ async def test_sync_room_ics_task(db_session):
|
|||||||
ics_url="https://calendar.example.com/task.ics",
|
ics_url="https://calendar.example.com/task.ics",
|
||||||
ics_enabled=True,
|
ics_enabled=True,
|
||||||
)
|
)
|
||||||
await db_session.flush()
|
|
||||||
|
|
||||||
cal = Calendar()
|
cal = Calendar()
|
||||||
event = Event()
|
event = Event()
|
||||||
@@ -46,22 +45,21 @@ async def test_sync_room_ics_task(db_session):
|
|||||||
ics_content = cal.to_ical().decode("utf-8")
|
ics_content = cal.to_ical().decode("utf-8")
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"reflector.services.ics_sync.ICSFetchService.fetch_ics",
|
"reflector.services.ics_sync.ICSFetchService.fetch_ics", new_callable=AsyncMock
|
||||||
new_callable=AsyncMock,
|
|
||||||
) as mock_fetch:
|
) as mock_fetch:
|
||||||
mock_fetch.return_value = ics_content
|
mock_fetch.return_value = ics_content
|
||||||
|
|
||||||
await ics_sync_service.sync_room_calendar(db_session, room)
|
# Call the service directly instead of the Celery task to avoid event loop issues
|
||||||
|
await ics_sync_service.sync_room_calendar(room)
|
||||||
|
|
||||||
events = await calendar_events_controller.get_by_room(db_session, room.id)
|
events = await calendar_events_controller.get_by_room(room.id)
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
assert events[0].ics_uid == "task-event-1"
|
assert events[0].ics_uid == "task-event-1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sync_room_ics_disabled(db_session):
|
async def test_sync_room_ics_disabled():
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="disabled-room",
|
name="disabled-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -75,16 +73,16 @@ async def test_sync_room_ics_disabled(db_session):
|
|||||||
ics_enabled=False,
|
ics_enabled=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await ics_sync_service.sync_room_calendar(db_session, room)
|
# Test that disabled rooms are skipped by the service
|
||||||
|
result = await ics_sync_service.sync_room_calendar(room)
|
||||||
|
|
||||||
events = await calendar_events_controller.get_by_room(db_session, room.id)
|
events = await calendar_events_controller.get_by_room(room.id)
|
||||||
assert len(events) == 0
|
assert len(events) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sync_all_ics_calendars(db_session):
|
async def test_sync_all_ics_calendars():
|
||||||
room1 = await rooms_controller.add(
|
room1 = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="sync-all-1",
|
name="sync-all-1",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -100,7 +98,6 @@ async def test_sync_all_ics_calendars(db_session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
room2 = await rooms_controller.add(
|
room2 = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="sync-all-2",
|
name="sync-all-2",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -116,7 +113,6 @@ async def test_sync_all_ics_calendars(db_session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
room3 = await rooms_controller.add(
|
room3 = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="sync-all-3",
|
name="sync-all-3",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -131,11 +127,17 @@ async def test_sync_all_ics_calendars(db_session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
|
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
|
||||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(db_session)
|
# Directly call the sync_all logic without the Celery wrapper
|
||||||
|
query = rooms.select().where(
|
||||||
|
rooms.c.ics_enabled == True, rooms.c.ics_url != None
|
||||||
|
)
|
||||||
|
all_rooms = await get_database().fetch_all(query)
|
||||||
|
|
||||||
for room in ics_enabled_rooms:
|
for room_data in all_rooms:
|
||||||
|
room_id = room_data["id"]
|
||||||
|
room = await rooms_controller.get_by_id(room_id)
|
||||||
if room and _should_sync(room):
|
if room and _should_sync(room):
|
||||||
sync_room_ics.delay(room.id)
|
sync_room_ics.delay(room_id)
|
||||||
|
|
||||||
assert mock_delay.call_count == 2
|
assert mock_delay.call_count == 2
|
||||||
called_room_ids = [call.args[0] for call in mock_delay.call_args_list]
|
called_room_ids = [call.args[0] for call in mock_delay.call_args_list]
|
||||||
@@ -161,11 +163,10 @@ async def test_should_sync_logic():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sync_respects_fetch_interval(db_session):
|
async def test_sync_respects_fetch_interval():
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
room1 = await rooms_controller.add(
|
room1 = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="interval-test-1",
|
name="interval-test-1",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -182,13 +183,11 @@ async def test_sync_respects_fetch_interval(db_session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
await rooms_controller.update(
|
await rooms_controller.update(
|
||||||
db_session,
|
|
||||||
room1,
|
room1,
|
||||||
{"ics_last_sync": now - timedelta(seconds=100)},
|
{"ics_last_sync": now - timedelta(seconds=100)},
|
||||||
)
|
)
|
||||||
|
|
||||||
room2 = await rooms_controller.add(
|
room2 = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="interval-test-2",
|
name="interval-test-2",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -205,26 +204,30 @@ async def test_sync_respects_fetch_interval(db_session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
await rooms_controller.update(
|
await rooms_controller.update(
|
||||||
db_session,
|
|
||||||
room2,
|
room2,
|
||||||
{"ics_last_sync": now - timedelta(seconds=100)},
|
{"ics_last_sync": now - timedelta(seconds=100)},
|
||||||
)
|
)
|
||||||
|
|
||||||
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
|
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
|
||||||
ics_enabled_rooms = await rooms_controller.get_ics_enabled(db_session)
|
# Test the sync logic without the Celery wrapper
|
||||||
|
query = rooms.select().where(
|
||||||
|
rooms.c.ics_enabled == True, rooms.c.ics_url != None
|
||||||
|
)
|
||||||
|
all_rooms = await get_database().fetch_all(query)
|
||||||
|
|
||||||
for room in ics_enabled_rooms:
|
for room_data in all_rooms:
|
||||||
|
room_id = room_data["id"]
|
||||||
|
room = await rooms_controller.get_by_id(room_id)
|
||||||
if room and _should_sync(room):
|
if room and _should_sync(room):
|
||||||
sync_room_ics.delay(room.id)
|
sync_room_ics.delay(room_id)
|
||||||
|
|
||||||
assert mock_delay.call_count == 1
|
assert mock_delay.call_count == 1
|
||||||
assert mock_delay.call_args[0][0] == room2.id
|
assert mock_delay.call_args[0][0] == room2.id
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_sync_handles_errors_gracefully(db_session):
|
async def test_sync_handles_errors_gracefully():
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="error-task-room",
|
name="error-task-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -244,8 +247,9 @@ async def test_sync_handles_errors_gracefully(db_session):
|
|||||||
) as mock_fetch:
|
) as mock_fetch:
|
||||||
mock_fetch.side_effect = Exception("Network error")
|
mock_fetch.side_effect = Exception("Network error")
|
||||||
|
|
||||||
result = await ics_sync_service.sync_room_calendar(db_session, room)
|
# Call the service directly to test error handling
|
||||||
|
result = await ics_sync_service.sync_room_calendar(room)
|
||||||
assert result["status"] == "error"
|
assert result["status"] == "error"
|
||||||
|
|
||||||
events = await calendar_events_controller.get_by_room(db_session, room.id)
|
events = await calendar_events_controller.get_by_room(room.id)
|
||||||
assert len(events) == 0
|
assert len(events) == 0
|
||||||
|
|||||||
@@ -134,10 +134,9 @@ async def test_ics_fetch_service_extract_room_events():
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ics_sync_service_sync_room_calendar(db_session):
|
async def test_ics_sync_service_sync_room_calendar():
|
||||||
# Create room
|
# Create room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="sync-test",
|
name="sync-test",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -151,7 +150,6 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
|
|||||||
ics_url="https://calendar.example.com/test.ics",
|
ics_url="https://calendar.example.com/test.ics",
|
||||||
ics_enabled=True,
|
ics_enabled=True,
|
||||||
)
|
)
|
||||||
await db_session.flush()
|
|
||||||
|
|
||||||
# Mock ICS content
|
# Mock ICS content
|
||||||
cal = Calendar()
|
cal = Calendar()
|
||||||
@@ -177,7 +175,7 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
|
|||||||
mock_fetch.return_value = ics_content
|
mock_fetch.return_value = ics_content
|
||||||
|
|
||||||
# First sync
|
# First sync
|
||||||
result = await sync_service.sync_room_calendar(db_session, room)
|
result = await sync_service.sync_room_calendar(room)
|
||||||
|
|
||||||
assert result["status"] == "success"
|
assert result["status"] == "success"
|
||||||
assert result["events_found"] == 1
|
assert result["events_found"] == 1
|
||||||
@@ -186,20 +184,18 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
|
|||||||
assert result["events_deleted"] == 0
|
assert result["events_deleted"] == 0
|
||||||
|
|
||||||
# Verify event was created
|
# Verify event was created
|
||||||
events = await calendar_events_controller.get_by_room(db_session, room.id)
|
events = await calendar_events_controller.get_by_room(room.id)
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
assert events[0].ics_uid == "sync-event-1"
|
assert events[0].ics_uid == "sync-event-1"
|
||||||
assert events[0].title == "Sync Test Meeting"
|
assert events[0].title == "Sync Test Meeting"
|
||||||
|
|
||||||
# Second sync with same content (should be unchanged)
|
# Second sync with same content (should be unchanged)
|
||||||
# Refresh room to get updated etag and force sync by setting old sync time
|
# Refresh room to get updated etag and force sync by setting old sync time
|
||||||
room = await rooms_controller.get_by_id(db_session, room.id)
|
room = await rooms_controller.get_by_id(room.id)
|
||||||
await rooms_controller.update(
|
await rooms_controller.update(
|
||||||
db_session,
|
room, {"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)}
|
||||||
room,
|
|
||||||
{"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)},
|
|
||||||
)
|
)
|
||||||
result = await sync_service.sync_room_calendar(db_session, room)
|
result = await sync_service.sync_room_calendar(room)
|
||||||
assert result["status"] == "unchanged"
|
assert result["status"] == "unchanged"
|
||||||
|
|
||||||
# Third sync with updated event
|
# Third sync with updated event
|
||||||
@@ -210,15 +206,15 @@ async def test_ics_sync_service_sync_room_calendar(db_session):
|
|||||||
mock_fetch.return_value = ics_content
|
mock_fetch.return_value = ics_content
|
||||||
|
|
||||||
# Force sync by clearing etag
|
# Force sync by clearing etag
|
||||||
await rooms_controller.update(db_session, room, {"ics_last_etag": None})
|
await rooms_controller.update(room, {"ics_last_etag": None})
|
||||||
|
|
||||||
result = await sync_service.sync_room_calendar(db_session, room)
|
result = await sync_service.sync_room_calendar(room)
|
||||||
assert result["status"] == "success"
|
assert result["status"] == "success"
|
||||||
assert result["events_created"] == 0
|
assert result["events_created"] == 0
|
||||||
assert result["events_updated"] == 1
|
assert result["events_updated"] == 1
|
||||||
|
|
||||||
# Verify event was updated
|
# Verify event was updated
|
||||||
events = await calendar_events_controller.get_by_room(db_session, room.id)
|
events = await calendar_events_controller.get_by_room(room.id)
|
||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
assert events[0].title == "Updated Meeting Title"
|
assert events[0].title == "Updated Meeting Title"
|
||||||
|
|
||||||
@@ -251,7 +247,7 @@ async def test_ics_sync_service_skip_disabled():
|
|||||||
room.ics_enabled = False
|
room.ics_enabled = False
|
||||||
room.ics_url = "https://calendar.example.com/test.ics"
|
room.ics_url = "https://calendar.example.com/test.ics"
|
||||||
|
|
||||||
result = await service.sync_room_calendar(MagicMock(), room)
|
result = await service.sync_room_calendar(room)
|
||||||
assert result["status"] == "skipped"
|
assert result["status"] == "skipped"
|
||||||
assert result["reason"] == "ICS not configured"
|
assert result["reason"] == "ICS not configured"
|
||||||
|
|
||||||
@@ -259,16 +255,15 @@ async def test_ics_sync_service_skip_disabled():
|
|||||||
room.ics_enabled = True
|
room.ics_enabled = True
|
||||||
room.ics_url = None
|
room.ics_url = None
|
||||||
|
|
||||||
result = await service.sync_room_calendar(MagicMock(), room)
|
result = await service.sync_room_calendar(room)
|
||||||
assert result["status"] == "skipped"
|
assert result["status"] == "skipped"
|
||||||
assert result["reason"] == "ICS not configured"
|
assert result["reason"] == "ICS not configured"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_ics_sync_service_error_handling(db_session):
|
async def test_ics_sync_service_error_handling():
|
||||||
# Create room
|
# Create room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="error-test",
|
name="error-test",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -282,7 +277,6 @@ async def test_ics_sync_service_error_handling(db_session):
|
|||||||
ics_url="https://calendar.example.com/error.ics",
|
ics_url="https://calendar.example.com/error.ics",
|
||||||
ics_enabled=True,
|
ics_enabled=True,
|
||||||
)
|
)
|
||||||
await db_session.flush()
|
|
||||||
|
|
||||||
sync_service = ICSSyncService()
|
sync_service = ICSSyncService()
|
||||||
|
|
||||||
@@ -291,6 +285,6 @@ async def test_ics_sync_service_error_handling(db_session):
|
|||||||
) as mock_fetch:
|
) as mock_fetch:
|
||||||
mock_fetch.side_effect = Exception("Network error")
|
mock_fetch.side_effect = Exception("Network error")
|
||||||
|
|
||||||
result = await sync_service.sync_room_calendar(db_session, room)
|
result = await sync_service.sync_room_calendar(room)
|
||||||
assert result["status"] == "error"
|
assert result["status"] == "error"
|
||||||
assert "Network error" in result["error"]
|
assert "Network error" in result["error"]
|
||||||
|
|||||||
@@ -10,11 +10,10 @@ from reflector.db.rooms import rooms_controller
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_multiple_active_meetings_per_room(db_session):
|
async def test_multiple_active_meetings_per_room():
|
||||||
"""Test that multiple active meetings can exist for the same room."""
|
"""Test that multiple active meetings can exist for the same room."""
|
||||||
# Create a room
|
# Create a room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="test-room",
|
name="test-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -32,7 +31,6 @@ async def test_multiple_active_meetings_per_room(db_session):
|
|||||||
|
|
||||||
# Create first meeting
|
# Create first meeting
|
||||||
meeting1 = await meetings_controller.create(
|
meeting1 = await meetings_controller.create(
|
||||||
db_session,
|
|
||||||
id="meeting-1",
|
id="meeting-1",
|
||||||
room_name="test-meeting-1",
|
room_name="test-meeting-1",
|
||||||
room_url="https://whereby.com/test-1",
|
room_url="https://whereby.com/test-1",
|
||||||
@@ -44,7 +42,6 @@ async def test_multiple_active_meetings_per_room(db_session):
|
|||||||
|
|
||||||
# Create second meeting for the same room (should succeed now)
|
# Create second meeting for the same room (should succeed now)
|
||||||
meeting2 = await meetings_controller.create(
|
meeting2 = await meetings_controller.create(
|
||||||
db_session,
|
|
||||||
id="meeting-2",
|
id="meeting-2",
|
||||||
room_name="test-meeting-2",
|
room_name="test-meeting-2",
|
||||||
room_url="https://whereby.com/test-2",
|
room_url="https://whereby.com/test-2",
|
||||||
@@ -56,7 +53,7 @@ async def test_multiple_active_meetings_per_room(db_session):
|
|||||||
|
|
||||||
# Both meetings should be active
|
# Both meetings should be active
|
||||||
active_meetings = await meetings_controller.get_all_active_for_room(
|
active_meetings = await meetings_controller.get_all_active_for_room(
|
||||||
db_session, room=room, current_time=current_time
|
room=room, current_time=current_time
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(active_meetings) == 2
|
assert len(active_meetings) == 2
|
||||||
@@ -65,11 +62,10 @@ async def test_multiple_active_meetings_per_room(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_active_by_calendar_event(db_session):
|
async def test_get_active_by_calendar_event():
|
||||||
"""Test getting active meeting by calendar event ID."""
|
"""Test getting active meeting by calendar event ID."""
|
||||||
# Create a room
|
# Create a room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="test-room",
|
name="test-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -90,14 +86,13 @@ async def test_get_active_by_calendar_event(db_session):
|
|||||||
start_time=datetime.now(timezone.utc),
|
start_time=datetime.now(timezone.utc),
|
||||||
end_time=datetime.now(timezone.utc) + timedelta(hours=1),
|
end_time=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||||
)
|
)
|
||||||
event = await calendar_events_controller.upsert(db_session, event)
|
event = await calendar_events_controller.upsert(event)
|
||||||
|
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
end_time = current_time + timedelta(hours=2)
|
end_time = current_time + timedelta(hours=2)
|
||||||
|
|
||||||
# Create meeting linked to calendar event
|
# Create meeting linked to calendar event
|
||||||
meeting = await meetings_controller.create(
|
meeting = await meetings_controller.create(
|
||||||
db_session,
|
|
||||||
id="meeting-cal-1",
|
id="meeting-cal-1",
|
||||||
room_name="test-meeting-cal",
|
room_name="test-meeting-cal",
|
||||||
room_url="https://whereby.com/test-cal",
|
room_url="https://whereby.com/test-cal",
|
||||||
@@ -111,7 +106,7 @@ async def test_get_active_by_calendar_event(db_session):
|
|||||||
|
|
||||||
# Should find the meeting by calendar event
|
# Should find the meeting by calendar event
|
||||||
found_meeting = await meetings_controller.get_active_by_calendar_event(
|
found_meeting = await meetings_controller.get_active_by_calendar_event(
|
||||||
db_session, room=room, calendar_event_id=event.id, current_time=current_time
|
room=room, calendar_event_id=event.id, current_time=current_time
|
||||||
)
|
)
|
||||||
|
|
||||||
assert found_meeting is not None
|
assert found_meeting is not None
|
||||||
@@ -120,11 +115,10 @@ async def test_get_active_by_calendar_event(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_calendar_meeting_deactivates_after_scheduled_end(db_session):
|
async def test_calendar_meeting_deactivates_after_scheduled_end():
|
||||||
"""Test that unused calendar meetings deactivate after scheduled end time."""
|
"""Test that unused calendar meetings deactivate after scheduled end time."""
|
||||||
# Create a room
|
# Create a room
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="test-room",
|
name="test-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -145,13 +139,12 @@ async def test_calendar_meeting_deactivates_after_scheduled_end(db_session):
|
|||||||
start_time=datetime.now(timezone.utc) - timedelta(hours=2),
|
start_time=datetime.now(timezone.utc) - timedelta(hours=2),
|
||||||
end_time=datetime.now(timezone.utc) - timedelta(minutes=35),
|
end_time=datetime.now(timezone.utc) - timedelta(minutes=35),
|
||||||
)
|
)
|
||||||
event = await calendar_events_controller.upsert(db_session, event)
|
event = await calendar_events_controller.upsert(event)
|
||||||
|
|
||||||
current_time = datetime.now(timezone.utc)
|
current_time = datetime.now(timezone.utc)
|
||||||
|
|
||||||
# Create meeting linked to calendar event
|
# Create meeting linked to calendar event
|
||||||
meeting = await meetings_controller.create(
|
meeting = await meetings_controller.create(
|
||||||
db_session,
|
|
||||||
id="meeting-unused",
|
id="meeting-unused",
|
||||||
room_name="test-meeting-unused",
|
room_name="test-meeting-unused",
|
||||||
room_url="https://whereby.com/test-unused",
|
room_url="https://whereby.com/test-unused",
|
||||||
@@ -168,9 +161,7 @@ async def test_calendar_meeting_deactivates_after_scheduled_end(db_session):
|
|||||||
# Simulate process_meetings logic for unused calendar meeting past end time
|
# Simulate process_meetings logic for unused calendar meeting past end time
|
||||||
if meeting.calendar_event_id and current_time > meeting.end_date:
|
if meeting.calendar_event_id and current_time > meeting.end_date:
|
||||||
# In real code, we'd check has_had_sessions = False here
|
# In real code, we'd check has_had_sessions = False here
|
||||||
await meetings_controller.update_meeting(
|
await meetings_controller.update_meeting(meeting.id, is_active=False)
|
||||||
db_session, meeting.id, is_active=False
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_meeting = await meetings_controller.get_by_id(db_session, meeting.id)
|
updated_meeting = await meetings_controller.get_by_id(meeting.id)
|
||||||
assert updated_meeting.is_active is False # Deactivated after scheduled end
|
assert updated_meeting.is_active is False # Deactivated after scheduled end
|
||||||
|
|||||||
@@ -101,36 +101,20 @@ async def mock_transcript_in_db(tmpdir):
|
|||||||
target_language="en",
|
target_language="en",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock all transcripts controller methods that are used in the pipeline
|
# Mock the controller to return our transcript
|
||||||
try:
|
try:
|
||||||
with patch(
|
with patch(
|
||||||
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
|
"reflector.pipelines.main_file_pipeline.transcripts_controller.get_by_id"
|
||||||
) as mock_get:
|
) as mock_get:
|
||||||
mock_get.return_value = transcript
|
mock_get.return_value = transcript
|
||||||
with patch(
|
|
||||||
"reflector.pipelines.main_file_pipeline.transcripts_controller.update"
|
|
||||||
) as mock_update:
|
|
||||||
mock_update.return_value = transcript
|
|
||||||
with patch(
|
|
||||||
"reflector.pipelines.main_file_pipeline.transcripts_controller.set_status"
|
|
||||||
) as mock_set_status:
|
|
||||||
mock_set_status.return_value = None
|
|
||||||
with patch(
|
|
||||||
"reflector.pipelines.main_file_pipeline.transcripts_controller.upsert_topic"
|
|
||||||
) as mock_upsert_topic:
|
|
||||||
mock_upsert_topic.return_value = None
|
|
||||||
with patch(
|
|
||||||
"reflector.pipelines.main_file_pipeline.transcripts_controller.append_event"
|
|
||||||
) as mock_append_event:
|
|
||||||
mock_append_event.return_value = None
|
|
||||||
with patch(
|
with patch(
|
||||||
"reflector.pipelines.main_live_pipeline.transcripts_controller.get_by_id"
|
"reflector.pipelines.main_live_pipeline.transcripts_controller.get_by_id"
|
||||||
) as mock_get2:
|
) as mock_get2:
|
||||||
mock_get2.return_value = transcript
|
mock_get2.return_value = transcript
|
||||||
with patch(
|
with patch(
|
||||||
"reflector.pipelines.main_live_pipeline.transcripts_controller.update"
|
"reflector.pipelines.main_live_pipeline.transcripts_controller.update"
|
||||||
) as mock_update2:
|
) as mock_update:
|
||||||
mock_update2.return_value = None
|
mock_update.return_value = None
|
||||||
yield transcript
|
yield transcript
|
||||||
finally:
|
finally:
|
||||||
# Restore original DATA_DIR
|
# Restore original DATA_DIR
|
||||||
@@ -624,11 +608,7 @@ async def test_pipeline_file_process_no_transcript():
|
|||||||
|
|
||||||
# Should raise an exception for missing transcript when get_transcript is called
|
# Should raise an exception for missing transcript when get_transcript is called
|
||||||
with pytest.raises(Exception, match="Transcript not found"):
|
with pytest.raises(Exception, match="Transcript not found"):
|
||||||
# Use a mock session - the controller is mocked to return None anyway
|
await pipeline.get_transcript()
|
||||||
from unittest.mock import MagicMock
|
|
||||||
|
|
||||||
mock_session = MagicMock()
|
|
||||||
await pipeline.get_transcript(mock_session)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -10,10 +10,9 @@ from reflector.db.rooms import rooms_controller
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_room_create_with_ics_fields(db_session):
|
async def test_room_create_with_ics_fields():
|
||||||
"""Test creating a room with ICS calendar fields."""
|
"""Test creating a room with ICS calendar fields."""
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="test-room",
|
name="test-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -41,11 +40,10 @@ async def test_room_create_with_ics_fields(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_room_update_ics_configuration(db_session):
|
async def test_room_update_ics_configuration():
|
||||||
"""Test updating room ICS configuration."""
|
"""Test updating room ICS configuration."""
|
||||||
# Create room without ICS
|
# Create room without ICS
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="update-test",
|
name="update-test",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -63,7 +61,6 @@ async def test_room_update_ics_configuration(db_session):
|
|||||||
|
|
||||||
# Update with ICS configuration
|
# Update with ICS configuration
|
||||||
await rooms_controller.update(
|
await rooms_controller.update(
|
||||||
db_session,
|
|
||||||
room,
|
room,
|
||||||
{
|
{
|
||||||
"ics_url": "https://outlook.office365.com/owa/calendar/test/calendar.ics",
|
"ics_url": "https://outlook.office365.com/owa/calendar/test/calendar.ics",
|
||||||
@@ -80,10 +77,9 @@ async def test_room_update_ics_configuration(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_room_ics_sync_metadata(db_session):
|
async def test_room_ics_sync_metadata():
|
||||||
"""Test updating room ICS sync metadata."""
|
"""Test updating room ICS sync metadata."""
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="sync-test",
|
name="sync-test",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -101,7 +97,6 @@ async def test_room_ics_sync_metadata(db_session):
|
|||||||
# Update sync metadata
|
# Update sync metadata
|
||||||
sync_time = datetime.now(timezone.utc)
|
sync_time = datetime.now(timezone.utc)
|
||||||
await rooms_controller.update(
|
await rooms_controller.update(
|
||||||
db_session,
|
|
||||||
room,
|
room,
|
||||||
{
|
{
|
||||||
"ics_last_sync": sync_time,
|
"ics_last_sync": sync_time,
|
||||||
@@ -114,11 +109,10 @@ async def test_room_ics_sync_metadata(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_room_get_with_ics_fields(db_session):
|
async def test_room_get_with_ics_fields():
|
||||||
"""Test retrieving room with ICS fields."""
|
"""Test retrieving room with ICS fields."""
|
||||||
# Create room
|
# Create room
|
||||||
created_room = await rooms_controller.add(
|
created_room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="get-test",
|
name="get-test",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -135,14 +129,14 @@ async def test_room_get_with_ics_fields(db_session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get by ID
|
# Get by ID
|
||||||
room = await rooms_controller.get_by_id(db_session, created_room.id)
|
room = await rooms_controller.get_by_id(created_room.id)
|
||||||
assert room is not None
|
assert room is not None
|
||||||
assert room.ics_url == "webcal://calendar.example.com/feed.ics"
|
assert room.ics_url == "webcal://calendar.example.com/feed.ics"
|
||||||
assert room.ics_fetch_interval == 900
|
assert room.ics_fetch_interval == 900
|
||||||
assert room.ics_enabled is True
|
assert room.ics_enabled is True
|
||||||
|
|
||||||
# Get by name
|
# Get by name
|
||||||
room = await rooms_controller.get_by_name(db_session, "get-test")
|
room = await rooms_controller.get_by_name("get-test")
|
||||||
assert room is not None
|
assert room is not None
|
||||||
assert room.ics_url == "webcal://calendar.example.com/feed.ics"
|
assert room.ics_url == "webcal://calendar.example.com/feed.ics"
|
||||||
assert room.ics_fetch_interval == 900
|
assert room.ics_fetch_interval == 900
|
||||||
@@ -150,11 +144,10 @@ async def test_room_get_with_ics_fields(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_room_list_with_ics_enabled_filter(db_session):
|
async def test_room_list_with_ics_enabled_filter():
|
||||||
"""Test listing rooms filtered by ICS enabled status."""
|
"""Test listing rooms filtered by ICS enabled status."""
|
||||||
# Create rooms with and without ICS
|
# Create rooms with and without ICS
|
||||||
room1 = await rooms_controller.add(
|
room1 = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="ics-enabled-1",
|
name="ics-enabled-1",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -170,7 +163,6 @@ async def test_room_list_with_ics_enabled_filter(db_session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
room2 = await rooms_controller.add(
|
room2 = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="ics-disabled",
|
name="ics-disabled",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -185,7 +177,6 @@ async def test_room_list_with_ics_enabled_filter(db_session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
room3 = await rooms_controller.add(
|
room3 = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="ics-enabled-2",
|
name="ics-enabled-2",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -201,20 +192,19 @@ async def test_room_list_with_ics_enabled_filter(db_session):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get all rooms
|
# Get all rooms
|
||||||
all_rooms = await rooms_controller.get_all(db_session)
|
all_rooms = await rooms_controller.get_all()
|
||||||
assert len(all_rooms) == 3
|
assert len(all_rooms) == 3
|
||||||
|
|
||||||
# Filter for ICS-enabled rooms (would need to implement this in controller)
|
# Filter for ICS-enabled rooms (would need to implement this in controller)
|
||||||
ics_rooms = [r for r in all_rooms if r.ics_enabled]
|
ics_rooms = [r for r in all_rooms if r["ics_enabled"]]
|
||||||
assert len(ics_rooms) == 2
|
assert len(ics_rooms) == 2
|
||||||
assert all(r.ics_enabled for r in ics_rooms)
|
assert all(r["ics_enabled"] for r in ics_rooms)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_room_default_ics_values(db_session):
|
async def test_room_default_ics_values():
|
||||||
"""Test that ICS fields have correct default values."""
|
"""Test that ICS fields have correct default values."""
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="default-test",
|
name="default-test",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
|
|||||||
@@ -11,13 +11,20 @@ from reflector.db.rooms import rooms_controller
|
|||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def authenticated_client(client):
|
async def authenticated_client(client):
|
||||||
from reflector.app import app
|
from reflector.app import app
|
||||||
from reflector.auth import current_user_optional
|
from reflector.auth import current_user, current_user_optional
|
||||||
|
|
||||||
|
app.dependency_overrides[current_user] = lambda: {
|
||||||
|
"sub": "test-user",
|
||||||
|
"email": "test@example.com",
|
||||||
|
}
|
||||||
app.dependency_overrides[current_user_optional] = lambda: {
|
app.dependency_overrides[current_user_optional] = lambda: {
|
||||||
"sub": "test-user",
|
"sub": "test-user",
|
||||||
"email": "test@example.com",
|
"email": "test@example.com",
|
||||||
}
|
}
|
||||||
|
try:
|
||||||
yield client
|
yield client
|
||||||
|
finally:
|
||||||
|
del app.dependency_overrides[current_user]
|
||||||
del app.dependency_overrides[current_user_optional]
|
del app.dependency_overrides[current_user_optional]
|
||||||
|
|
||||||
|
|
||||||
@@ -89,10 +96,9 @@ async def test_update_room_ics_configuration(authenticated_client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_ics_sync(authenticated_client, db_session):
|
async def test_trigger_ics_sync(authenticated_client):
|
||||||
client = authenticated_client
|
client = authenticated_client
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="sync-api-room",
|
name="sync-api-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -134,9 +140,8 @@ async def test_trigger_ics_sync(authenticated_client, db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_ics_sync_unauthorized(client, db_session):
|
async def test_trigger_ics_sync_unauthorized(client):
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="sync-unauth-room",
|
name="sync-unauth-room",
|
||||||
user_id="owner-123",
|
user_id="owner-123",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -157,10 +162,9 @@ async def test_trigger_ics_sync_unauthorized(client, db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_ics_sync_not_configured(authenticated_client, db_session):
|
async def test_trigger_ics_sync_not_configured(authenticated_client):
|
||||||
client = authenticated_client
|
client = authenticated_client
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="sync-not-configured",
|
name="sync-not-configured",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -180,10 +184,9 @@ async def test_trigger_ics_sync_not_configured(authenticated_client, db_session)
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_ics_status(authenticated_client, db_session):
|
async def test_get_ics_status(authenticated_client):
|
||||||
client = authenticated_client
|
client = authenticated_client
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="status-room",
|
name="status-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -201,7 +204,6 @@ async def test_get_ics_status(authenticated_client, db_session):
|
|||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
await rooms_controller.update(
|
await rooms_controller.update(
|
||||||
db_session,
|
|
||||||
room,
|
room,
|
||||||
{"ics_last_sync": now, "ics_last_etag": "test-etag"},
|
{"ics_last_sync": now, "ics_last_etag": "test-etag"},
|
||||||
)
|
)
|
||||||
@@ -215,9 +217,8 @@ async def test_get_ics_status(authenticated_client, db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_ics_status_unauthorized(client, db_session):
|
async def test_get_ics_status_unauthorized(client):
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="status-unauth",
|
name="status-unauth",
|
||||||
user_id="owner-456",
|
user_id="owner-456",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -238,10 +239,9 @@ async def test_get_ics_status_unauthorized(client, db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_room_meetings(authenticated_client, db_session):
|
async def test_list_room_meetings(authenticated_client):
|
||||||
client = authenticated_client
|
client = authenticated_client
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="meetings-room",
|
name="meetings-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -262,7 +262,7 @@ async def test_list_room_meetings(authenticated_client, db_session):
|
|||||||
start_time=now - timedelta(hours=2),
|
start_time=now - timedelta(hours=2),
|
||||||
end_time=now - timedelta(hours=1),
|
end_time=now - timedelta(hours=1),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, event1)
|
await calendar_events_controller.upsert(event1)
|
||||||
|
|
||||||
event2 = CalendarEvent(
|
event2 = CalendarEvent(
|
||||||
room_id=room.id,
|
room_id=room.id,
|
||||||
@@ -273,7 +273,7 @@ async def test_list_room_meetings(authenticated_client, db_session):
|
|||||||
end_time=now + timedelta(hours=2),
|
end_time=now + timedelta(hours=2),
|
||||||
attendees=[{"email": "test@example.com"}],
|
attendees=[{"email": "test@example.com"}],
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, event2)
|
await calendar_events_controller.upsert(event2)
|
||||||
|
|
||||||
response = await client.get(f"/rooms/{room.name}/meetings")
|
response = await client.get(f"/rooms/{room.name}/meetings")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -286,9 +286,8 @@ async def test_list_room_meetings(authenticated_client, db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_room_meetings_non_owner(client, db_session):
|
async def test_list_room_meetings_non_owner(client):
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="meetings-privacy",
|
name="meetings-privacy",
|
||||||
user_id="owner-789",
|
user_id="owner-789",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -310,7 +309,7 @@ async def test_list_room_meetings_non_owner(client, db_session):
|
|||||||
end_time=datetime.now(timezone.utc) + timedelta(hours=2),
|
end_time=datetime.now(timezone.utc) + timedelta(hours=2),
|
||||||
attendees=[{"email": "private@example.com"}],
|
attendees=[{"email": "private@example.com"}],
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, event)
|
await calendar_events_controller.upsert(event)
|
||||||
|
|
||||||
response = await client.get(f"/rooms/{room.name}/meetings")
|
response = await client.get(f"/rooms/{room.name}/meetings")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
@@ -322,10 +321,9 @@ async def test_list_room_meetings_non_owner(client, db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_upcoming_meetings(authenticated_client, db_session):
|
async def test_list_upcoming_meetings(authenticated_client):
|
||||||
client = authenticated_client
|
client = authenticated_client
|
||||||
room = await rooms_controller.add(
|
room = await rooms_controller.add(
|
||||||
db_session,
|
|
||||||
name="upcoming-room",
|
name="upcoming-room",
|
||||||
user_id="test-user",
|
user_id="test-user",
|
||||||
zulip_auto_post=False,
|
zulip_auto_post=False,
|
||||||
@@ -347,7 +345,7 @@ async def test_list_upcoming_meetings(authenticated_client, db_session):
|
|||||||
start_time=now - timedelta(hours=1),
|
start_time=now - timedelta(hours=1),
|
||||||
end_time=now - timedelta(minutes=30),
|
end_time=now - timedelta(minutes=30),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, past_event)
|
await calendar_events_controller.upsert(past_event)
|
||||||
|
|
||||||
soon_event = CalendarEvent(
|
soon_event = CalendarEvent(
|
||||||
room_id=room.id,
|
room_id=room.id,
|
||||||
@@ -356,7 +354,7 @@ async def test_list_upcoming_meetings(authenticated_client, db_session):
|
|||||||
start_time=now + timedelta(minutes=15),
|
start_time=now + timedelta(minutes=15),
|
||||||
end_time=now + timedelta(minutes=45),
|
end_time=now + timedelta(minutes=45),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, soon_event)
|
await calendar_events_controller.upsert(soon_event)
|
||||||
|
|
||||||
later_event = CalendarEvent(
|
later_event = CalendarEvent(
|
||||||
room_id=room.id,
|
room_id=room.id,
|
||||||
@@ -365,7 +363,7 @@ async def test_list_upcoming_meetings(authenticated_client, db_session):
|
|||||||
start_time=now + timedelta(hours=2),
|
start_time=now + timedelta(hours=2),
|
||||||
end_time=now + timedelta(hours=3),
|
end_time=now + timedelta(hours=3),
|
||||||
)
|
)
|
||||||
await calendar_events_controller.upsert(db_session, later_event)
|
await calendar_events_controller.upsert(later_event)
|
||||||
|
|
||||||
response = await client.get(f"/rooms/{room.name}/meetings/upcoming")
|
response = await client.get(f"/rooms/{room.name}/meetings/upcoming")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
@@ -2,40 +2,40 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import delete, insert
|
|
||||||
|
|
||||||
from reflector.db.base import TranscriptModel
|
from reflector.db import get_database
|
||||||
from reflector.db.search import (
|
from reflector.db.search import (
|
||||||
SearchController,
|
SearchController,
|
||||||
SearchParameters,
|
SearchParameters,
|
||||||
SearchResult,
|
SearchResult,
|
||||||
search_controller,
|
search_controller,
|
||||||
)
|
)
|
||||||
from reflector.db.transcripts import SourceKind
|
from reflector.db.transcripts import SourceKind, transcripts
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_postgresql_only(db_session):
|
async def test_search_postgresql_only():
|
||||||
params = SearchParameters(query_text="any query here")
|
params = SearchParameters(query_text="any query here")
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert results == []
|
assert results == []
|
||||||
assert total == 0
|
assert total == 0
|
||||||
|
|
||||||
params_empty = SearchParameters(query_text=None)
|
params_empty = SearchParameters(query_text=None)
|
||||||
results_empty, total_empty = await search_controller.search_transcripts(
|
results_empty, total_empty = await search_controller.search_transcripts(
|
||||||
db_session, params_empty
|
params_empty
|
||||||
)
|
)
|
||||||
assert isinstance(results_empty, list)
|
assert isinstance(results_empty, list)
|
||||||
assert isinstance(total_empty, int)
|
assert isinstance(total_empty, int)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_with_empty_query(db_session):
|
async def test_search_with_empty_query():
|
||||||
"""Test that empty query returns all transcripts."""
|
"""Test that empty query returns all transcripts."""
|
||||||
params = SearchParameters(query_text=None)
|
params = SearchParameters(query_text=None)
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
assert isinstance(results, list)
|
assert isinstance(results, list)
|
||||||
assert isinstance(total, int)
|
assert isinstance(total, int)
|
||||||
@@ -45,13 +45,13 @@ async def test_search_with_empty_query(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_empty_transcript_title_only_match(db_session):
|
async def test_empty_transcript_title_only_match():
|
||||||
"""Test that transcripts with title-only matches return empty snippets."""
|
"""Test that transcripts with title-only matches return empty snippets."""
|
||||||
test_id = "test-empty-9b3f2a8d"
|
test_id = "test-empty-9b3f2a8d"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
@@ -77,11 +77,10 @@ async def test_empty_transcript_title_only_match(db_session):
|
|||||||
"user_id": "test-user-1",
|
"user_id": "test-user-1",
|
||||||
}
|
}
|
||||||
|
|
||||||
await db_session.execute(insert(TranscriptModel).values(**test_data))
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
params = SearchParameters(query_text="empty", user_id="test-user-1")
|
params = SearchParameters(query_text="empty", user_id="test-user-1")
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = next((r for r in results if r.id == test_id), None)
|
found = next((r for r in results if r.id == test_id), None)
|
||||||
@@ -90,20 +89,20 @@ async def test_empty_transcript_title_only_match(db_session):
|
|||||||
assert found.total_match_count == 0
|
assert found.total_match_count == 0
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
await get_database().disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_with_long_summary(db_session):
|
async def test_search_with_long_summary():
|
||||||
"""Test that long_summary content is searchable."""
|
"""Test that long_summary content is searchable."""
|
||||||
test_id = "test-long-summary-8a9f3c2d"
|
test_id = "test-long-summary-8a9f3c2d"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
@@ -132,11 +131,10 @@ Basic meeting content without special keywords.""",
|
|||||||
"user_id": "test-user-2",
|
"user_id": "test-user-2",
|
||||||
}
|
}
|
||||||
|
|
||||||
await db_session.execute(insert(TranscriptModel).values(**test_data))
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
params = SearchParameters(query_text="quantum computing", user_id="test-user-2")
|
params = SearchParameters(query_text="quantum computing", user_id="test-user-2")
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
@@ -148,19 +146,19 @@ Basic meeting content without special keywords.""",
|
|||||||
assert "quantum computing" in test_result.search_snippets[0].lower()
|
assert "quantum computing" in test_result.search_snippets[0].lower()
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
await get_database().disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_postgresql_search_with_data(db_session):
|
async def test_postgresql_search_with_data():
|
||||||
test_id = "test-search-e2e-7f3a9b2c"
|
test_id = "test-search-e2e-7f3a9b2c"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
@@ -198,17 +196,16 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
|||||||
"user_id": "test-user-3",
|
"user_id": "test-user-3",
|
||||||
}
|
}
|
||||||
|
|
||||||
await db_session.execute(insert(TranscriptModel).values(**test_data))
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
params = SearchParameters(query_text="planning", user_id="test-user-3")
|
params = SearchParameters(query_text="planning", user_id="test-user-3")
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
assert found, "Should find test transcript by title word"
|
assert found, "Should find test transcript by title word"
|
||||||
|
|
||||||
params = SearchParameters(query_text="tsvector", user_id="test-user-3")
|
params = SearchParameters(query_text="tsvector", user_id="test-user-3")
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
assert found, "Should find test transcript by webvtt content"
|
assert found, "Should find test transcript by webvtt content"
|
||||||
@@ -216,7 +213,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
|||||||
params = SearchParameters(
|
params = SearchParameters(
|
||||||
query_text="engineering planning", user_id="test-user-3"
|
query_text="engineering planning", user_id="test-user-3"
|
||||||
)
|
)
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
assert found, "Should find test transcript by multiple words"
|
assert found, "Should find test transcript by multiple words"
|
||||||
@@ -231,7 +228,7 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
|||||||
params = SearchParameters(
|
params = SearchParameters(
|
||||||
query_text="tsvector OR nosuchword", user_id="test-user-3"
|
query_text="tsvector OR nosuchword", user_id="test-user-3"
|
||||||
)
|
)
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
assert found, "Should find test transcript with OR query"
|
assert found, "Should find test transcript with OR query"
|
||||||
@@ -239,16 +236,16 @@ We need to implement PostgreSQL tsvector for better performance.""",
|
|||||||
params = SearchParameters(
|
params = SearchParameters(
|
||||||
query_text='"full-text search"', user_id="test-user-3"
|
query_text='"full-text search"', user_id="test-user-3"
|
||||||
)
|
)
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
assert found, "Should find test transcript by exact phrase"
|
assert found, "Should find test transcript by exact phrase"
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
await get_database().disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -314,56 +311,87 @@ class TestSearchControllerFilters:
|
|||||||
"""Test SearchController functionality with various filters."""
|
"""Test SearchController functionality with various filters."""
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_with_source_kind_filter(self, db_session):
|
async def test_search_with_source_kind_filter(self):
|
||||||
"""Test search filtering by source_kind."""
|
"""Test search filtering by source_kind."""
|
||||||
controller = SearchController()
|
controller = SearchController()
|
||||||
|
with (
|
||||||
|
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||||
|
patch("reflector.db.search.get_database") as mock_db,
|
||||||
|
):
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||||
|
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||||
|
|
||||||
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
|
params = SearchParameters(query_text="test", source_kind=SourceKind.LIVE)
|
||||||
|
|
||||||
# This should not fail, even if no results are found
|
results, total = await controller.search_transcripts(params)
|
||||||
results, total = await controller.search_transcripts(db_session, params)
|
|
||||||
|
|
||||||
assert isinstance(results, list)
|
assert results == []
|
||||||
assert isinstance(total, int)
|
assert total == 0
|
||||||
assert total >= 0
|
|
||||||
|
mock_db.return_value.fetch_all.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_with_single_room_id(self, db_session):
|
async def test_search_with_single_room_id(self):
|
||||||
"""Test search filtering by single room ID (currently supported)."""
|
"""Test search filtering by single room ID (currently supported)."""
|
||||||
controller = SearchController()
|
controller = SearchController()
|
||||||
|
with (
|
||||||
|
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||||
|
patch("reflector.db.search.get_database") as mock_db,
|
||||||
|
):
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[])
|
||||||
|
mock_db.return_value.fetch_val = AsyncMock(return_value=0)
|
||||||
|
|
||||||
params = SearchParameters(
|
params = SearchParameters(
|
||||||
query_text="test",
|
query_text="test",
|
||||||
room_id="room1",
|
room_id="room1",
|
||||||
)
|
)
|
||||||
|
|
||||||
# This should not fail, even if no results are found
|
results, total = await controller.search_transcripts(params)
|
||||||
results, total = await controller.search_transcripts(db_session, params)
|
|
||||||
|
|
||||||
assert isinstance(results, list)
|
assert results == []
|
||||||
assert isinstance(total, int)
|
assert total == 0
|
||||||
assert total >= 0
|
mock_db.return_value.fetch_all.assert_called_once()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_search_result_includes_available_fields(
|
async def test_search_result_includes_available_fields(self, mock_db_result):
|
||||||
self, db_session, mock_db_result
|
|
||||||
):
|
|
||||||
"""Test that search results include available fields like source_kind."""
|
"""Test that search results include available fields like source_kind."""
|
||||||
# Test that the search method works and returns SearchResult objects
|
|
||||||
controller = SearchController()
|
controller = SearchController()
|
||||||
|
with (
|
||||||
|
patch("reflector.db.search.is_postgresql", return_value=True),
|
||||||
|
patch("reflector.db.search.get_database") as mock_db,
|
||||||
|
):
|
||||||
|
|
||||||
|
class MockRow:
|
||||||
|
def __init__(self, data):
|
||||||
|
self._data = data
|
||||||
|
self._mapping = data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._data.items())
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._data[key]
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self._data.keys()
|
||||||
|
|
||||||
|
mock_row = MockRow(mock_db_result)
|
||||||
|
|
||||||
|
mock_db.return_value.fetch_all = AsyncMock(return_value=[mock_row])
|
||||||
|
mock_db.return_value.fetch_val = AsyncMock(return_value=1)
|
||||||
|
|
||||||
params = SearchParameters(query_text="test")
|
params = SearchParameters(query_text="test")
|
||||||
|
|
||||||
results, total = await controller.search_transcripts(db_session, params)
|
results, total = await controller.search_transcripts(params)
|
||||||
|
|
||||||
assert isinstance(results, list)
|
assert total == 1
|
||||||
assert isinstance(total, int)
|
assert len(results) == 1
|
||||||
assert total >= 0
|
|
||||||
|
|
||||||
# If any results exist, verify they are SearchResult objects
|
result = results[0]
|
||||||
for result in results:
|
|
||||||
assert isinstance(result, SearchResult)
|
assert isinstance(result, SearchResult)
|
||||||
assert hasattr(result, "id")
|
assert result.id == "test-transcript-id"
|
||||||
assert hasattr(result, "title")
|
assert result.title == "Test Transcript"
|
||||||
assert hasattr(result, "rank")
|
assert result.rank == 0.95
|
||||||
assert hasattr(result, "source_kind")
|
|
||||||
|
|
||||||
|
|
||||||
class TestSearchEndpointParsing:
|
class TestSearchEndpointParsing:
|
||||||
|
|||||||
256
server/tests/test_search_date_filtering.py
Normal file
256
server/tests/test_search_date_filtering.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.db import get_database
|
||||||
|
from reflector.db.search import SearchParameters, search_controller
|
||||||
|
from reflector.db.transcripts import SourceKind, transcripts
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestDateRangeIntegration:
|
||||||
|
async def setup_test_transcripts(self):
|
||||||
|
# Use a test user_id that will match in our search parameters
|
||||||
|
test_user_id = "test-user-123"
|
||||||
|
|
||||||
|
test_data = [
|
||||||
|
{
|
||||||
|
"id": "test-before-range",
|
||||||
|
"created_at": datetime(2024, 1, 15, tzinfo=timezone.utc),
|
||||||
|
"title": "Before Range Transcript",
|
||||||
|
"user_id": test_user_id,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "test-start-boundary",
|
||||||
|
"created_at": datetime(2024, 6, 1, tzinfo=timezone.utc),
|
||||||
|
"title": "Start Boundary Transcript",
|
||||||
|
"user_id": test_user_id,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "test-middle-range",
|
||||||
|
"created_at": datetime(2024, 6, 15, tzinfo=timezone.utc),
|
||||||
|
"title": "Middle Range Transcript",
|
||||||
|
"user_id": test_user_id,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "test-end-boundary",
|
||||||
|
"created_at": datetime(2024, 6, 30, 23, 59, 59, tzinfo=timezone.utc),
|
||||||
|
"title": "End Boundary Transcript",
|
||||||
|
"user_id": test_user_id,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "test-after-range",
|
||||||
|
"created_at": datetime(2024, 12, 31, tzinfo=timezone.utc),
|
||||||
|
"title": "After Range Transcript",
|
||||||
|
"user_id": test_user_id,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
for data in test_data:
|
||||||
|
full_data = {
|
||||||
|
"id": data["id"],
|
||||||
|
"name": data["id"],
|
||||||
|
"status": "ended",
|
||||||
|
"locked": False,
|
||||||
|
"duration": 60.0,
|
||||||
|
"created_at": data["created_at"],
|
||||||
|
"title": data["title"],
|
||||||
|
"short_summary": "Test summary",
|
||||||
|
"long_summary": "Test long summary",
|
||||||
|
"share_mode": "public",
|
||||||
|
"source_kind": SourceKind.FILE,
|
||||||
|
"audio_deleted": False,
|
||||||
|
"reviewed": False,
|
||||||
|
"user_id": data["user_id"],
|
||||||
|
}
|
||||||
|
|
||||||
|
await get_database().execute(transcripts.insert().values(**full_data))
|
||||||
|
|
||||||
|
return test_data
|
||||||
|
|
||||||
|
async def cleanup_test_transcripts(self, test_data):
|
||||||
|
"""Clean up test transcripts."""
|
||||||
|
for data in test_data:
|
||||||
|
await get_database().execute(
|
||||||
|
transcripts.delete().where(transcripts.c.id == data["id"])
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filter_with_from_datetime_only(self):
|
||||||
|
"""Test filtering with only from_datetime parameter."""
|
||||||
|
test_data = await self.setup_test_transcripts()
|
||||||
|
test_user_id = "test-user-123"
|
||||||
|
|
||||||
|
try:
|
||||||
|
params = SearchParameters(
|
||||||
|
query_text=None,
|
||||||
|
from_datetime=datetime(2024, 6, 1, tzinfo=timezone.utc),
|
||||||
|
to_datetime=None,
|
||||||
|
user_id=test_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
|
# Should include: start_boundary, middle, end_boundary, after
|
||||||
|
result_ids = [r.id for r in results]
|
||||||
|
assert "test-before-range" not in result_ids
|
||||||
|
assert "test-start-boundary" in result_ids
|
||||||
|
assert "test-middle-range" in result_ids
|
||||||
|
assert "test-end-boundary" in result_ids
|
||||||
|
assert "test-after-range" in result_ids
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.cleanup_test_transcripts(test_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filter_with_to_datetime_only(self):
|
||||||
|
"""Test filtering with only to_datetime parameter."""
|
||||||
|
test_data = await self.setup_test_transcripts()
|
||||||
|
test_user_id = "test-user-123"
|
||||||
|
|
||||||
|
try:
|
||||||
|
params = SearchParameters(
|
||||||
|
query_text=None,
|
||||||
|
from_datetime=None,
|
||||||
|
to_datetime=datetime(2024, 6, 30, tzinfo=timezone.utc),
|
||||||
|
user_id=test_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
|
result_ids = [r.id for r in results]
|
||||||
|
assert "test-before-range" in result_ids
|
||||||
|
assert "test-start-boundary" in result_ids
|
||||||
|
assert "test-middle-range" in result_ids
|
||||||
|
assert "test-end-boundary" not in result_ids
|
||||||
|
assert "test-after-range" not in result_ids
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.cleanup_test_transcripts(test_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filter_with_both_datetimes(self):
|
||||||
|
test_data = await self.setup_test_transcripts()
|
||||||
|
test_user_id = "test-user-123"
|
||||||
|
|
||||||
|
try:
|
||||||
|
params = SearchParameters(
|
||||||
|
query_text=None,
|
||||||
|
from_datetime=datetime(2024, 6, 1, tzinfo=timezone.utc),
|
||||||
|
to_datetime=datetime(
|
||||||
|
2024, 7, 1, tzinfo=timezone.utc
|
||||||
|
), # Inclusive of 6/30
|
||||||
|
user_id=test_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
|
result_ids = [r.id for r in results]
|
||||||
|
assert "test-before-range" not in result_ids
|
||||||
|
assert "test-start-boundary" in result_ids
|
||||||
|
assert "test-middle-range" in result_ids
|
||||||
|
assert "test-end-boundary" in result_ids
|
||||||
|
assert "test-after-range" not in result_ids
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.cleanup_test_transcripts(test_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_date_filter_with_room_and_source_kind(self):
|
||||||
|
test_data = await self.setup_test_transcripts()
|
||||||
|
test_user_id = "test-user-123"
|
||||||
|
|
||||||
|
try:
|
||||||
|
params = SearchParameters(
|
||||||
|
query_text=None,
|
||||||
|
from_datetime=datetime(2024, 6, 1, tzinfo=timezone.utc),
|
||||||
|
to_datetime=datetime(2024, 7, 1, tzinfo=timezone.utc),
|
||||||
|
source_kind=SourceKind.FILE,
|
||||||
|
room_id=None,
|
||||||
|
user_id=test_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
|
for result in results:
|
||||||
|
assert result.source_kind == SourceKind.FILE
|
||||||
|
assert result.created_at >= datetime(2024, 6, 1, tzinfo=timezone.utc)
|
||||||
|
assert result.created_at <= datetime(2024, 7, 1, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.cleanup_test_transcripts(test_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_results_for_future_dates(self):
|
||||||
|
test_data = await self.setup_test_transcripts()
|
||||||
|
test_user_id = "test-user-123"
|
||||||
|
|
||||||
|
try:
|
||||||
|
params = SearchParameters(
|
||||||
|
query_text=None,
|
||||||
|
from_datetime=datetime(2099, 1, 1, tzinfo=timezone.utc),
|
||||||
|
to_datetime=datetime(2099, 12, 31, tzinfo=timezone.utc),
|
||||||
|
user_id=test_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
assert total == 0
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.cleanup_test_transcripts(test_data)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_date_only_input_handling(self):
|
||||||
|
test_data = await self.setup_test_transcripts()
|
||||||
|
test_user_id = "test-user-123"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Pydantic will parse date-only strings to datetime at midnight
|
||||||
|
from_dt = datetime(2024, 6, 15, 0, 0, 0, tzinfo=timezone.utc)
|
||||||
|
to_dt = datetime(2024, 6, 16, 0, 0, 0, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
params = SearchParameters(
|
||||||
|
query_text=None,
|
||||||
|
from_datetime=from_dt,
|
||||||
|
to_datetime=to_dt,
|
||||||
|
user_id=test_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
|
result_ids = [r.id for r in results]
|
||||||
|
assert "test-middle-range" in result_ids
|
||||||
|
assert "test-before-range" not in result_ids
|
||||||
|
assert "test-after-range" not in result_ids
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await self.cleanup_test_transcripts(test_data)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDateValidationEdgeCases:
|
||||||
|
"""Edge case tests for datetime validation."""
|
||||||
|
|
||||||
|
def test_timezone_aware_comparison(self):
|
||||||
|
"""Test that timezone-aware comparisons work correctly."""
|
||||||
|
# PST time (UTC-8)
|
||||||
|
pst = timezone(timedelta(hours=-8))
|
||||||
|
pst_dt = datetime(2024, 6, 15, 8, 0, 0, tzinfo=pst)
|
||||||
|
|
||||||
|
# UTC time equivalent (8AM PST = 4PM UTC)
|
||||||
|
utc_dt = datetime(2024, 6, 15, 16, 0, 0, tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
assert pst_dt == utc_dt
|
||||||
|
|
||||||
|
def test_mixed_timezone_input(self):
|
||||||
|
"""Test handling mixed timezone inputs."""
|
||||||
|
pst = timezone(timedelta(hours=-8))
|
||||||
|
ist = timezone(timedelta(hours=5, minutes=30))
|
||||||
|
|
||||||
|
from_date = datetime(2024, 6, 15, 0, 0, 0, tzinfo=pst) # PST midnight
|
||||||
|
to_date = datetime(2024, 6, 15, 23, 59, 59, tzinfo=ist) # IST end of day
|
||||||
|
|
||||||
|
assert from_date.tzinfo is not None
|
||||||
|
assert to_date.tzinfo is not None
|
||||||
|
assert from_date < to_date
|
||||||
@@ -4,21 +4,21 @@ import json
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import delete, insert
|
|
||||||
|
|
||||||
from reflector.db.base import TranscriptModel
|
from reflector.db import get_database
|
||||||
from reflector.db.search import SearchParameters, search_controller
|
from reflector.db.search import SearchParameters, search_controller
|
||||||
|
from reflector.db.transcripts import transcripts
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_long_summary_snippet_prioritization(db_session):
|
async def test_long_summary_snippet_prioritization():
|
||||||
"""Test that snippets from long_summary are prioritized over webvtt content."""
|
"""Test that snippets from long_summary are prioritized over webvtt content."""
|
||||||
test_id = "test-snippet-priority-3f9a2b8c"
|
test_id = "test-snippet-priority-3f9a2b8c"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Clean up any existing test data
|
# Clean up any existing test data
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
@@ -57,11 +57,11 @@ We need to consider various implementation approaches.""",
|
|||||||
"user_id": "test-user-priority",
|
"user_id": "test-user-priority",
|
||||||
}
|
}
|
||||||
|
|
||||||
await db_session.execute(insert(TranscriptModel).values(**test_data))
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
|
|
||||||
# Search for "robotics" which appears in both long_summary and webvtt
|
# Search for "robotics" which appears in both long_summary and webvtt
|
||||||
params = SearchParameters(query_text="robotics", user_id="test-user-priority")
|
params = SearchParameters(query_text="robotics", user_id="test-user-priority")
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
assert total >= 1
|
assert total >= 1
|
||||||
test_result = next((r for r in results if r.id == test_id), None)
|
test_result = next((r for r in results if r.id == test_id), None)
|
||||||
@@ -86,20 +86,20 @@ We need to consider various implementation approaches.""",
|
|||||||
), f"Snippet should contain search term: {snippet}"
|
), f"Snippet should contain search term: {snippet}"
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
await get_database().disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_long_summary_only_search(db_session):
|
async def test_long_summary_only_search():
|
||||||
"""Test searching for content that only exists in long_summary."""
|
"""Test searching for content that only exists in long_summary."""
|
||||||
test_id = "test-long-only-8b3c9f2a"
|
test_id = "test-long-only-8b3c9f2a"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
@@ -135,11 +135,11 @@ Discussion of timeline and deliverables.""",
|
|||||||
"user_id": "test-user-long",
|
"user_id": "test-user-long",
|
||||||
}
|
}
|
||||||
|
|
||||||
await db_session.execute(insert(TranscriptModel).values(**test_data))
|
await get_database().execute(transcripts.insert().values(**test_data))
|
||||||
|
|
||||||
# Search for terms only in long_summary
|
# Search for terms only in long_summary
|
||||||
params = SearchParameters(query_text="cryptocurrency", user_id="test-user-long")
|
params = SearchParameters(query_text="cryptocurrency", user_id="test-user-long")
|
||||||
results, total = await search_controller.search_transcripts(db_session, params)
|
results, total = await search_controller.search_transcripts(params)
|
||||||
|
|
||||||
found = any(r.id == test_id for r in results)
|
found = any(r.id == test_id for r in results)
|
||||||
assert found, "Should find transcript by long_summary-only content"
|
assert found, "Should find transcript by long_summary-only content"
|
||||||
@@ -154,15 +154,13 @@ Discussion of timeline and deliverables.""",
|
|||||||
|
|
||||||
# Search for "yield farming" - a more specific term
|
# Search for "yield farming" - a more specific term
|
||||||
params2 = SearchParameters(query_text="yield farming", user_id="test-user-long")
|
params2 = SearchParameters(query_text="yield farming", user_id="test-user-long")
|
||||||
results2, total2 = await search_controller.search_transcripts(
|
results2, total2 = await search_controller.search_transcripts(params2)
|
||||||
db_session, params2
|
|
||||||
)
|
|
||||||
|
|
||||||
found2 = any(r.id == test_id for r in results2)
|
found2 = any(r.id == test_id for r in results2)
|
||||||
assert found2, "Should find transcript by specific long_summary phrase"
|
assert found2, "Should find transcript by specific long_summary phrase"
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await db_session.execute(
|
await get_database().execute(
|
||||||
delete(TranscriptModel).where(TranscriptModel.id == test_id)
|
transcripts.delete().where(transcripts.c.id == test_id)
|
||||||
)
|
)
|
||||||
await db_session.commit()
|
await get_database().disconnect()
|
||||||
|
|||||||
384
server/tests/test_security_permissions.py
Normal file
384
server/tests/test_security_permissions.py
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
import asyncio
|
||||||
|
import shutil
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx_ws import aconnect_ws
|
||||||
|
from uvicorn import Config, Server
|
||||||
|
|
||||||
|
from reflector import zulip as zulip_module
|
||||||
|
from reflector.app import app
|
||||||
|
from reflector.db import get_database
|
||||||
|
from reflector.db.meetings import meetings_controller
|
||||||
|
from reflector.db.rooms import Room, rooms_controller
|
||||||
|
from reflector.db.transcripts import (
|
||||||
|
SourceKind,
|
||||||
|
TranscriptTopic,
|
||||||
|
transcripts_controller,
|
||||||
|
)
|
||||||
|
from reflector.processors.types import Word
|
||||||
|
from reflector.settings import settings
|
||||||
|
from reflector.views.transcripts import create_access_token
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anonymous_cannot_delete_transcript_in_shared_room(client):
|
||||||
|
# Create a shared room with a fake owner id so meeting has a room_id
|
||||||
|
room = await rooms_controller.add(
|
||||||
|
name="shared-room-test",
|
||||||
|
user_id="owner-1",
|
||||||
|
zulip_auto_post=False,
|
||||||
|
zulip_stream="",
|
||||||
|
zulip_topic="",
|
||||||
|
is_locked=False,
|
||||||
|
room_mode="normal",
|
||||||
|
recording_type="cloud",
|
||||||
|
recording_trigger="automatic-2nd-participant",
|
||||||
|
is_shared=True,
|
||||||
|
webhook_url="",
|
||||||
|
webhook_secret="",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a meeting for that room (so transcript.meeting_id links to the shared room)
|
||||||
|
meeting = await meetings_controller.create(
|
||||||
|
id="meeting-sec-test",
|
||||||
|
room_name="room-sec-test",
|
||||||
|
room_url="room-url",
|
||||||
|
host_room_url="host-url",
|
||||||
|
start_date=Room.model_fields["created_at"].default_factory(),
|
||||||
|
end_date=Room.model_fields["created_at"].default_factory(),
|
||||||
|
room=room,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a transcript owned by someone else and link it to meeting
|
||||||
|
t = await transcripts_controller.add(
|
||||||
|
name="to-delete",
|
||||||
|
source_kind=SourceKind.LIVE,
|
||||||
|
user_id="owner-2",
|
||||||
|
meeting_id=meeting.id,
|
||||||
|
room_id=room.id,
|
||||||
|
share_mode="private",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Anonymous DELETE should be rejected
|
||||||
|
del_resp = await client.delete(f"/transcripts/{t.id}")
|
||||||
|
assert del_resp.status_code == 401, del_resp.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anonymous_cannot_mutate_participants_on_public_transcript(client):
|
||||||
|
# Create a public transcript with no owner
|
||||||
|
t = await transcripts_controller.add(
|
||||||
|
name="public-transcript",
|
||||||
|
source_kind=SourceKind.LIVE,
|
||||||
|
user_id=None,
|
||||||
|
share_mode="public",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Anonymous POST participant must be rejected
|
||||||
|
resp = await client.post(
|
||||||
|
f"/transcripts/{t.id}/participants",
|
||||||
|
json={"name": "AnonUser", "speaker": 0},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401, resp.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anonymous_cannot_update_and_delete_room(client):
|
||||||
|
# Create room as owner id "owner-3" via controller
|
||||||
|
room = await rooms_controller.add(
|
||||||
|
name="room-anon-update-delete",
|
||||||
|
user_id="owner-3",
|
||||||
|
zulip_auto_post=False,
|
||||||
|
zulip_stream="",
|
||||||
|
zulip_topic="",
|
||||||
|
is_locked=False,
|
||||||
|
room_mode="normal",
|
||||||
|
recording_type="cloud",
|
||||||
|
recording_trigger="automatic-2nd-participant",
|
||||||
|
is_shared=False,
|
||||||
|
webhook_url="",
|
||||||
|
webhook_secret="",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Anonymous PATCH via API (no auth)
|
||||||
|
resp = await client.patch(
|
||||||
|
f"/rooms/{room.id}",
|
||||||
|
json={
|
||||||
|
"name": "room-anon-updated",
|
||||||
|
"zulip_auto_post": False,
|
||||||
|
"zulip_stream": "",
|
||||||
|
"zulip_topic": "",
|
||||||
|
"is_locked": False,
|
||||||
|
"room_mode": "normal",
|
||||||
|
"recording_type": "cloud",
|
||||||
|
"recording_trigger": "automatic-2nd-participant",
|
||||||
|
"is_shared": False,
|
||||||
|
"webhook_url": "",
|
||||||
|
"webhook_secret": "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# Expect authentication required
|
||||||
|
assert resp.status_code == 401, resp.text
|
||||||
|
|
||||||
|
# Anonymous DELETE via API
|
||||||
|
del_resp = await client.delete(f"/rooms/{room.id}")
|
||||||
|
assert del_resp.status_code == 401, del_resp.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anonymous_cannot_post_transcript_to_zulip(client, monkeypatch):
|
||||||
|
# Create a public transcript with some content
|
||||||
|
t = await transcripts_controller.add(
|
||||||
|
name="zulip-public",
|
||||||
|
source_kind=SourceKind.LIVE,
|
||||||
|
user_id=None,
|
||||||
|
share_mode="public",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock send/update calls
|
||||||
|
def _fake_send_message_to_zulip(stream, topic, content):
|
||||||
|
return {"id": 12345}
|
||||||
|
|
||||||
|
async def _fake_update_message(message_id, stream, topic, content):
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
zulip_module, "send_message_to_zulip", _fake_send_message_to_zulip
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(zulip_module, "update_zulip_message", _fake_update_message)
|
||||||
|
|
||||||
|
# Anonymous POST to Zulip endpoint
|
||||||
|
resp = await client.post(
|
||||||
|
f"/transcripts/{t.id}/zulip",
|
||||||
|
params={"stream": "general", "topic": "Updates", "include_topics": False},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401, resp.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anonymous_cannot_assign_speaker_on_public_transcript(client):
|
||||||
|
# Create public transcript
|
||||||
|
t = await transcripts_controller.add(
|
||||||
|
name="public-assign",
|
||||||
|
source_kind=SourceKind.LIVE,
|
||||||
|
user_id=None,
|
||||||
|
share_mode="public",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add a topic with words to be reassigned
|
||||||
|
topic = TranscriptTopic(
|
||||||
|
title="T1",
|
||||||
|
summary="S1",
|
||||||
|
timestamp=0.0,
|
||||||
|
transcript="Hello",
|
||||||
|
words=[Word(start=0.0, end=1.0, text="Hello", speaker=0)],
|
||||||
|
)
|
||||||
|
transcript = await transcripts_controller.get_by_id(t.id)
|
||||||
|
await transcripts_controller.upsert_topic(transcript, topic)
|
||||||
|
|
||||||
|
# Anonymous assign speaker over time range covering the word
|
||||||
|
resp = await client.patch(
|
||||||
|
f"/transcripts/{t.id}/speaker/assign",
|
||||||
|
json={
|
||||||
|
"speaker": 1,
|
||||||
|
"timestamp_from": 0.0,
|
||||||
|
"timestamp_to": 1.0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401, resp.text
|
||||||
|
|
||||||
|
|
||||||
|
# Minimal server fixture for websocket tests
|
||||||
|
@pytest.fixture
|
||||||
|
def appserver_ws_simple(setup_database):
|
||||||
|
host = "127.0.0.1"
|
||||||
|
port = 1256
|
||||||
|
server_started = threading.Event()
|
||||||
|
server_exception = None
|
||||||
|
server_instance = None
|
||||||
|
|
||||||
|
def run_server():
|
||||||
|
nonlocal server_exception, server_instance
|
||||||
|
try:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
config = Config(app=app, host=host, port=port, loop=loop)
|
||||||
|
server_instance = Server(config)
|
||||||
|
|
||||||
|
async def start_server():
|
||||||
|
database = get_database()
|
||||||
|
await database.connect()
|
||||||
|
try:
|
||||||
|
await server_instance.serve()
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
|
server_started.set()
|
||||||
|
loop.run_until_complete(start_server())
|
||||||
|
except Exception as e:
|
||||||
|
server_exception = e
|
||||||
|
server_started.set()
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
server_started.wait(timeout=30)
|
||||||
|
if server_exception:
|
||||||
|
raise server_exception
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
yield host, port
|
||||||
|
|
||||||
|
if server_instance:
|
||||||
|
server_instance.should_exit = True
|
||||||
|
server_thread.join(timeout=30)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_websocket_denies_anonymous_on_private_transcript(appserver_ws_simple):
|
||||||
|
host, port = appserver_ws_simple
|
||||||
|
|
||||||
|
# Create a private transcript owned by someone
|
||||||
|
t = await transcripts_controller.add(
|
||||||
|
name="private-ws",
|
||||||
|
source_kind=SourceKind.LIVE,
|
||||||
|
user_id="owner-x",
|
||||||
|
share_mode="private",
|
||||||
|
)
|
||||||
|
|
||||||
|
base_url = f"http://{host}:{port}/v1"
|
||||||
|
# Anonymous connect should be denied
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
async with aconnect_ws(f"{base_url}/transcripts/{t.id}/events") as ws:
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anonymous_cannot_update_public_transcript(client):
|
||||||
|
t = await transcripts_controller.add(
|
||||||
|
name="update-me",
|
||||||
|
source_kind=SourceKind.LIVE,
|
||||||
|
user_id=None,
|
||||||
|
share_mode="public",
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = await client.patch(
|
||||||
|
f"/transcripts/{t.id}",
|
||||||
|
json={"title": "New Title From Anonymous"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401, resp.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anonymous_cannot_get_nonshared_room_by_id(client):
|
||||||
|
room = await rooms_controller.add(
|
||||||
|
name="private-room-exposed",
|
||||||
|
user_id="owner-z",
|
||||||
|
zulip_auto_post=False,
|
||||||
|
zulip_stream="",
|
||||||
|
zulip_topic="",
|
||||||
|
is_locked=False,
|
||||||
|
room_mode="normal",
|
||||||
|
recording_type="cloud",
|
||||||
|
recording_trigger="automatic-2nd-participant",
|
||||||
|
is_shared=False,
|
||||||
|
webhook_url="",
|
||||||
|
webhook_secret="",
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = await client.get(f"/rooms/{room.id}")
|
||||||
|
assert resp.status_code == 403, resp.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anonymous_cannot_call_rooms_webhook_test(client):
|
||||||
|
room = await rooms_controller.add(
|
||||||
|
name="room-webhook-test",
|
||||||
|
user_id="owner-y",
|
||||||
|
zulip_auto_post=False,
|
||||||
|
zulip_stream="",
|
||||||
|
zulip_topic="",
|
||||||
|
is_locked=False,
|
||||||
|
room_mode="normal",
|
||||||
|
recording_type="cloud",
|
||||||
|
recording_trigger="automatic-2nd-participant",
|
||||||
|
is_shared=False,
|
||||||
|
webhook_url="http://localhost.invalid/webhook",
|
||||||
|
webhook_secret="secret",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Anonymous caller
|
||||||
|
resp = await client.post(f"/rooms/{room.id}/webhook/test")
|
||||||
|
assert resp.status_code == 401, resp.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_anonymous_cannot_create_room(client):
|
||||||
|
payload = {
|
||||||
|
"name": "room-create-auth-required",
|
||||||
|
"zulip_auto_post": False,
|
||||||
|
"zulip_stream": "",
|
||||||
|
"zulip_topic": "",
|
||||||
|
"is_locked": False,
|
||||||
|
"room_mode": "normal",
|
||||||
|
"recording_type": "cloud",
|
||||||
|
"recording_trigger": "automatic-2nd-participant",
|
||||||
|
"is_shared": False,
|
||||||
|
"webhook_url": "",
|
||||||
|
"webhook_secret": "",
|
||||||
|
}
|
||||||
|
resp = await client.post("/rooms", json=payload)
|
||||||
|
assert resp.status_code == 401, resp.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_search_401_when_public_mode_false(client, monkeypatch):
|
||||||
|
monkeypatch.setattr(settings, "PUBLIC_MODE", False)
|
||||||
|
|
||||||
|
resp = await client.get("/transcripts")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
resp = await client.get("/transcripts/search", params={"q": "hello"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_mp3_requires_token_for_owned_transcript(
|
||||||
|
client, tmpdir, monkeypatch
|
||||||
|
):
|
||||||
|
# Use temp data dir
|
||||||
|
monkeypatch.setattr(settings, "DATA_DIR", Path(tmpdir).as_posix())
|
||||||
|
|
||||||
|
# Create owner transcript and attach a local mp3
|
||||||
|
t = await transcripts_controller.add(
|
||||||
|
name="owned-audio",
|
||||||
|
source_kind=SourceKind.LIVE,
|
||||||
|
user_id="owner-a",
|
||||||
|
share_mode="private",
|
||||||
|
)
|
||||||
|
|
||||||
|
tr = await transcripts_controller.get_by_id(t.id)
|
||||||
|
await transcripts_controller.update(tr, {"status": "ended"})
|
||||||
|
|
||||||
|
# copy fixture audio to transcript path
|
||||||
|
audio_path = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
|
||||||
|
tr.audio_mp3_filename.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
shutil.copy(audio_path, tr.audio_mp3_filename)
|
||||||
|
|
||||||
|
# Anonymous GET without token should be 403 or 404 depending on access; we call mp3
|
||||||
|
resp = await client.get(f"/transcripts/{t.id}/audio/mp3")
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
# With token should succeed
|
||||||
|
token = create_access_token(
|
||||||
|
{"sub": tr.user_id}, expires_delta=__import__("datetime").timedelta(minutes=15)
|
||||||
|
)
|
||||||
|
resp2 = await client.get(f"/transcripts/{t.id}/audio/mp3", params={"token": token})
|
||||||
|
assert resp2.status_code == 200
|
||||||
@@ -1,5 +1,3 @@
|
|||||||
from contextlib import asynccontextmanager
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
@@ -19,7 +17,7 @@ async def test_transcript_create(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_get_update_name(client):
|
async def test_transcript_get_update_name(authenticated_client, client):
|
||||||
response = await client.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["name"] == "test"
|
assert response.json()["name"] == "test"
|
||||||
@@ -40,7 +38,7 @@ async def test_transcript_get_update_name(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_get_update_locked(client):
|
async def test_transcript_get_update_locked(authenticated_client, client):
|
||||||
response = await client.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["locked"] is False
|
assert response.json()["locked"] is False
|
||||||
@@ -61,7 +59,7 @@ async def test_transcript_get_update_locked(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_get_update_summary(client):
|
async def test_transcript_get_update_summary(authenticated_client, client):
|
||||||
response = await client.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["long_summary"] is None
|
assert response.json()["long_summary"] is None
|
||||||
@@ -89,7 +87,7 @@ async def test_transcript_get_update_summary(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_get_update_title(client):
|
async def test_transcript_get_update_title(authenticated_client, client):
|
||||||
response = await client.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["title"] is None
|
assert response.json()["title"] is None
|
||||||
@@ -127,56 +125,6 @@ async def test_transcripts_list_anonymous(client):
|
|||||||
settings.PUBLIC_MODE = False
|
settings.PUBLIC_MODE = False
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def authenticated_client_ctx():
|
|
||||||
from reflector.app import app
|
|
||||||
from reflector.auth import current_user, current_user_optional
|
|
||||||
|
|
||||||
app.dependency_overrides[current_user] = lambda: {
|
|
||||||
"sub": "randomuserid",
|
|
||||||
"email": "test@mail.com",
|
|
||||||
}
|
|
||||||
app.dependency_overrides[current_user_optional] = lambda: {
|
|
||||||
"sub": "randomuserid",
|
|
||||||
"email": "test@mail.com",
|
|
||||||
}
|
|
||||||
yield
|
|
||||||
del app.dependency_overrides[current_user]
|
|
||||||
del app.dependency_overrides[current_user_optional]
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def authenticated_client2_ctx():
|
|
||||||
from reflector.app import app
|
|
||||||
from reflector.auth import current_user, current_user_optional
|
|
||||||
|
|
||||||
app.dependency_overrides[current_user] = lambda: {
|
|
||||||
"sub": "randomuserid2",
|
|
||||||
"email": "test@mail.com",
|
|
||||||
}
|
|
||||||
app.dependency_overrides[current_user_optional] = lambda: {
|
|
||||||
"sub": "randomuserid2",
|
|
||||||
"email": "test@mail.com",
|
|
||||||
}
|
|
||||||
yield
|
|
||||||
del app.dependency_overrides[current_user]
|
|
||||||
del app.dependency_overrides[current_user_optional]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def authenticated_client():
|
|
||||||
async with authenticated_client_ctx():
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def authenticated_client2():
|
|
||||||
async with authenticated_client2_ctx():
|
|
||||||
yield
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcripts_list_authenticated(authenticated_client, client):
|
async def test_transcripts_list_authenticated(authenticated_client, client):
|
||||||
# XXX this test is a bit fragile, as it depends on the storage which
|
# XXX this test is a bit fragile, as it depends on the storage which
|
||||||
@@ -199,7 +147,7 @@ async def test_transcripts_list_authenticated(authenticated_client, client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_delete(client):
|
async def test_transcript_delete(authenticated_client, client):
|
||||||
response = await client.post("/transcripts", json={"name": "testdel1"})
|
response = await client.post("/transcripts", json={"name": "testdel1"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["name"] == "testdel1"
|
assert response.json()["name"] == "testdel1"
|
||||||
@@ -214,7 +162,7 @@ async def test_transcript_delete(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_mark_reviewed(client):
|
async def test_transcript_mark_reviewed(authenticated_client, client):
|
||||||
response = await client.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["name"] == "test"
|
assert response.json()["name"] == "test"
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import pytest
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def fake_transcript(tmpdir, client, db_session):
|
async def fake_transcript(tmpdir, client):
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
from reflector.views.transcripts import transcripts_controller
|
from reflector.views.transcripts import transcripts_controller
|
||||||
|
|
||||||
@@ -16,10 +16,10 @@ async def fake_transcript(tmpdir, client, db_session):
|
|||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
tid = response.json()["id"]
|
tid = response.json()["id"]
|
||||||
|
|
||||||
transcript = await transcripts_controller.get_by_id(db_session, tid)
|
transcript = await transcripts_controller.get_by_id(tid)
|
||||||
assert transcript is not None
|
assert transcript is not None
|
||||||
|
|
||||||
await transcripts_controller.update(db_session, transcript, {"status": "ended"})
|
await transcripts_controller.update(transcript, {"status": "ended"})
|
||||||
|
|
||||||
# manually copy a file at the expected location
|
# manually copy a file at the expected location
|
||||||
audio_filename = transcript.audio_mp3_filename
|
audio_filename = transcript.audio_mp3_filename
|
||||||
@@ -111,7 +111,9 @@ async def test_transcript_audio_download_range_with_seek(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_delete_with_audio(fake_transcript, client):
|
async def test_transcript_delete_with_audio(
|
||||||
|
authenticated_client, fake_transcript, client
|
||||||
|
):
|
||||||
response = await client.delete(f"/transcripts/{fake_transcript.id}")
|
response = await client.delete(f"/transcripts/{fake_transcript.id}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["status"] == "ok"
|
assert response.json()["status"] == "ok"
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import pytest
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_participants(client):
|
async def test_transcript_participants(authenticated_client, client):
|
||||||
response = await client.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["participants"] == []
|
assert response.json()["participants"] == []
|
||||||
@@ -39,7 +39,7 @@ async def test_transcript_participants(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_participants_same_speaker(client):
|
async def test_transcript_participants_same_speaker(authenticated_client, client):
|
||||||
response = await client.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["participants"] == []
|
assert response.json()["participants"] == []
|
||||||
@@ -62,7 +62,7 @@ async def test_transcript_participants_same_speaker(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_participants_update_name(client):
|
async def test_transcript_participants_update_name(authenticated_client, client):
|
||||||
response = await client.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["participants"] == []
|
assert response.json()["participants"] == []
|
||||||
@@ -100,7 +100,7 @@ async def test_transcript_participants_update_name(client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_participants_update_speaker(client):
|
async def test_transcript_participants_update_speaker(authenticated_client, client):
|
||||||
response = await client.post("/transcripts", json={"name": "test"})
|
response = await client.post("/transcripts", json={"name": "test"})
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
assert response.json()["participants"] == []
|
assert response.json()["participants"] == []
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ async def client(app_lifespan):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@pytest.mark.usefixtures("celery_session_app")
|
@pytest.mark.usefixtures("celery_session_app")
|
||||||
@pytest.mark.usefixtures("celery_session_worker")
|
@pytest.mark.usefixtures("celery_session_worker")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -2,84 +2,33 @@ from datetime import datetime, timezone
|
|||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import insert
|
|
||||||
|
|
||||||
from reflector.db.base import MeetingModel, RoomModel
|
from reflector.db.recordings import Recording, recordings_controller
|
||||||
from reflector.db.recordings import recordings_controller
|
|
||||||
from reflector.db.transcripts import SourceKind, transcripts_controller
|
from reflector.db.transcripts import SourceKind, transcripts_controller
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_recording_deleted_with_transcript(db_session):
|
async def test_recording_deleted_with_transcript():
|
||||||
"""Test that a recording is deleted when its associated transcript is deleted."""
|
|
||||||
# First create a room and meeting to satisfy foreign key constraints
|
|
||||||
room_id = "test-room"
|
|
||||||
await db_session.execute(
|
|
||||||
insert(RoomModel).values(
|
|
||||||
id=room_id,
|
|
||||||
name="test-room",
|
|
||||||
user_id="test-user",
|
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
zulip_auto_post=False,
|
|
||||||
zulip_stream="",
|
|
||||||
zulip_topic="",
|
|
||||||
is_locked=False,
|
|
||||||
room_mode="normal",
|
|
||||||
recording_type="cloud",
|
|
||||||
recording_trigger="automatic",
|
|
||||||
is_shared=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
meeting_id = "test-meeting"
|
|
||||||
await db_session.execute(
|
|
||||||
insert(MeetingModel).values(
|
|
||||||
id=meeting_id,
|
|
||||||
room_id=room_id,
|
|
||||||
room_name="test-room",
|
|
||||||
room_url="https://example.com/room",
|
|
||||||
host_room_url="https://example.com/room-host",
|
|
||||||
start_date=datetime.now(timezone.utc),
|
|
||||||
end_date=datetime.now(timezone.utc),
|
|
||||||
is_active=False,
|
|
||||||
num_clients=0,
|
|
||||||
is_locked=False,
|
|
||||||
room_mode="normal",
|
|
||||||
recording_type="cloud",
|
|
||||||
recording_trigger="automatic",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
# Now create a recording
|
|
||||||
recording = await recordings_controller.create(
|
recording = await recordings_controller.create(
|
||||||
db_session,
|
Recording(
|
||||||
meeting_id=meeting_id,
|
bucket_name="test-bucket",
|
||||||
url="https://example.com/recording.mp4",
|
object_key="recording.mp4",
|
||||||
object_key="recordings/test.mp4",
|
recorded_at=datetime.now(timezone.utc),
|
||||||
duration=3600.0,
|
)
|
||||||
created_at=datetime.now(timezone.utc),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a transcript associated with the recording
|
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await transcripts_controller.add(
|
||||||
db_session,
|
|
||||||
name="Test Transcript",
|
name="Test Transcript",
|
||||||
source_kind=SourceKind.ROOM,
|
source_kind=SourceKind.ROOM,
|
||||||
recording_id=recording.id,
|
recording_id=recording.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Mock the storage deletion
|
|
||||||
with patch("reflector.db.transcripts.get_recordings_storage") as mock_get_storage:
|
with patch("reflector.db.transcripts.get_recordings_storage") as mock_get_storage:
|
||||||
storage_instance = mock_get_storage.return_value
|
storage_instance = mock_get_storage.return_value
|
||||||
storage_instance.delete_file = AsyncMock()
|
storage_instance.delete_file = AsyncMock()
|
||||||
|
|
||||||
# Delete the transcript
|
await transcripts_controller.remove_by_id(transcript.id)
|
||||||
await transcripts_controller.remove_by_id(db_session, transcript.id)
|
|
||||||
|
|
||||||
# Verify that the recording file was deleted from storage
|
|
||||||
storage_instance.delete_file.assert_awaited_once_with(recording.object_key)
|
storage_instance.delete_file.assert_awaited_once_with(recording.object_key)
|
||||||
|
|
||||||
# Verify both the recording and transcript are deleted
|
assert await recordings_controller.get_by_id(recording.id) is None
|
||||||
assert await recordings_controller.get_by_id(db_session, recording.id) is None
|
assert await transcripts_controller.get_by_id(transcript.id) is None
|
||||||
assert await transcripts_controller.get_by_id(db_session, transcript.id) is None
|
|
||||||
|
|||||||
@@ -49,12 +49,11 @@ class ThreadedUvicorn:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def appserver(tmpdir, database, celery_session_app, celery_session_worker):
|
def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from reflector.app import app
|
from reflector.app import app
|
||||||
|
from reflector.db import get_database
|
||||||
# Database connection handled by SQLAlchemy engine
|
|
||||||
from reflector.settings import settings
|
from reflector.settings import settings
|
||||||
|
|
||||||
DATA_DIR = settings.DATA_DIR
|
DATA_DIR = settings.DATA_DIR
|
||||||
@@ -78,8 +77,13 @@ def appserver(tmpdir, database, celery_session_app, celery_session_worker):
|
|||||||
server_instance = Server(config)
|
server_instance = Server(config)
|
||||||
|
|
||||||
async def start_server():
|
async def start_server():
|
||||||
# Database connections managed by SQLAlchemy engine
|
# Initialize database connection in this event loop
|
||||||
|
database = get_database()
|
||||||
|
await database.connect()
|
||||||
|
try:
|
||||||
await server_instance.serve()
|
await server_instance.serve()
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
# Signal that server is starting
|
# Signal that server is starting
|
||||||
server_started.set()
|
server_started.set()
|
||||||
@@ -111,6 +115,12 @@ def appserver(tmpdir, database, celery_session_app, celery_session_worker):
|
|||||||
settings.DATA_DIR = DATA_DIR
|
settings.DATA_DIR = DATA_DIR
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def celery_includes():
|
||||||
|
return ["reflector.pipelines.main_live_pipeline"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@pytest.mark.usefixtures("celery_session_app")
|
@pytest.mark.usefixtures("celery_session_app")
|
||||||
@pytest.mark.usefixtures("celery_session_worker")
|
@pytest.mark.usefixtures("celery_session_worker")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -158,7 +168,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Test websocket: EXCEPTION {e}")
|
print(f"Test websocket: EXCEPTION {e}")
|
||||||
finally:
|
finally:
|
||||||
await ws.close()
|
ws.close()
|
||||||
print("Test websocket: DISCONNECTED")
|
print("Test websocket: DISCONNECTED")
|
||||||
|
|
||||||
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
websocket_task = asyncio.get_event_loop().create_task(websocket_task())
|
||||||
@@ -275,6 +285,7 @@ async def test_transcript_rtc_and_websocket(
|
|||||||
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
|
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@pytest.mark.usefixtures("celery_session_app")
|
@pytest.mark.usefixtures("celery_session_app")
|
||||||
@pytest.mark.usefixtures("celery_session_worker")
|
@pytest.mark.usefixtures("celery_session_worker")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ import pytest
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_reassign_speaker(fake_transcript_with_topics, client):
|
async def test_transcript_reassign_speaker(
|
||||||
|
authenticated_client, fake_transcript_with_topics, client
|
||||||
|
):
|
||||||
transcript_id = fake_transcript_with_topics.id
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
# check the transcript exists
|
# check the transcript exists
|
||||||
@@ -114,7 +116,9 @@ async def test_transcript_reassign_speaker(fake_transcript_with_topics, client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_merge_speaker(fake_transcript_with_topics, client):
|
async def test_transcript_merge_speaker(
|
||||||
|
authenticated_client, fake_transcript_with_topics, client
|
||||||
|
):
|
||||||
transcript_id = fake_transcript_with_topics.id
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
# check the transcript exists
|
# check the transcript exists
|
||||||
@@ -181,7 +185,7 @@ async def test_transcript_merge_speaker(fake_transcript_with_topics, client):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_reassign_with_participant(
|
async def test_transcript_reassign_with_participant(
|
||||||
fake_transcript_with_topics, client
|
authenticated_client, fake_transcript_with_topics, client
|
||||||
):
|
):
|
||||||
transcript_id = fake_transcript_with_topics.id
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
@@ -347,7 +351,9 @@ async def test_transcript_reassign_with_participant(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics, client):
|
async def test_transcript_reassign_edge_cases(
|
||||||
|
authenticated_client, fake_transcript_with_topics, client
|
||||||
|
):
|
||||||
transcript_id = fake_transcript_with_topics.id
|
transcript_id = fake_transcript_with_topics.id
|
||||||
|
|
||||||
# check the transcript exists
|
# check the transcript exists
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import time
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("setup_database")
|
||||||
@pytest.mark.usefixtures("celery_session_app")
|
@pytest.mark.usefixtures("celery_session_app")
|
||||||
@pytest.mark.usefixtures("celery_session_worker")
|
@pytest.mark.usefixtures("celery_session_worker")
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
70
server/tests/test_user_api_keys.py
Normal file
70
server/tests/test_user_api_keys.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from reflector.db.user_api_keys import user_api_keys_controller
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_key_creation_and_verification():
|
||||||
|
api_key_model, plaintext = await user_api_keys_controller.create_key(
|
||||||
|
user_id="test_user",
|
||||||
|
name="Test API Key",
|
||||||
|
)
|
||||||
|
|
||||||
|
verified = await user_api_keys_controller.verify_key(plaintext)
|
||||||
|
assert verified is not None
|
||||||
|
assert verified.user_id == "test_user"
|
||||||
|
assert verified.name == "Test API Key"
|
||||||
|
|
||||||
|
invalid = await user_api_keys_controller.verify_key("fake_key")
|
||||||
|
assert invalid is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_api_key_hashing():
|
||||||
|
_, plaintext = await user_api_keys_controller.create_key(
|
||||||
|
user_id="test_user_2",
|
||||||
|
)
|
||||||
|
|
||||||
|
api_keys = await user_api_keys_controller.list_by_user_id("test_user_2")
|
||||||
|
assert len(api_keys) == 1
|
||||||
|
assert api_keys[0].key_hash != plaintext
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generate_api_key_uniqueness():
|
||||||
|
key1 = user_api_keys_controller.generate_key()
|
||||||
|
key2 = user_api_keys_controller.generate_key()
|
||||||
|
assert key1 != key2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hash_api_key_deterministic():
|
||||||
|
key = "test_key_123"
|
||||||
|
hash1 = user_api_keys_controller.hash_key(key)
|
||||||
|
hash2 = user_api_keys_controller.hash_key(key)
|
||||||
|
assert hash1 == hash2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_user_id_empty():
|
||||||
|
api_keys = await user_api_keys_controller.list_by_user_id("nonexistent_user")
|
||||||
|
assert api_keys == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_by_user_id_multiple():
|
||||||
|
user_id = "multi_key_user"
|
||||||
|
|
||||||
|
_, plaintext1 = await user_api_keys_controller.create_key(
|
||||||
|
user_id=user_id,
|
||||||
|
name="API Key 1",
|
||||||
|
)
|
||||||
|
_, plaintext2 = await user_api_keys_controller.create_key(
|
||||||
|
user_id=user_id,
|
||||||
|
name="API Key 2",
|
||||||
|
)
|
||||||
|
|
||||||
|
api_keys = await user_api_keys_controller.list_by_user_id(user_id)
|
||||||
|
assert len(api_keys) == 2
|
||||||
|
names = {k.name for k in api_keys}
|
||||||
|
assert names == {"API Key 1", "API Key 2"}
|
||||||
156
server/tests/test_user_websocket_auth.py
Normal file
156
server/tests/test_user_websocket_auth.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx_ws import aconnect_ws
|
||||||
|
from uvicorn import Config, Server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def appserver_ws_user(setup_database):
|
||||||
|
from reflector.app import app
|
||||||
|
from reflector.db import get_database
|
||||||
|
|
||||||
|
host = "127.0.0.1"
|
||||||
|
port = 1257
|
||||||
|
server_started = threading.Event()
|
||||||
|
server_exception = None
|
||||||
|
server_instance = None
|
||||||
|
|
||||||
|
def run_server():
|
||||||
|
nonlocal server_exception, server_instance
|
||||||
|
try:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
|
config = Config(app=app, host=host, port=port, loop=loop)
|
||||||
|
server_instance = Server(config)
|
||||||
|
|
||||||
|
async def start_server():
|
||||||
|
database = get_database()
|
||||||
|
await database.connect()
|
||||||
|
try:
|
||||||
|
await server_instance.serve()
|
||||||
|
finally:
|
||||||
|
await database.disconnect()
|
||||||
|
|
||||||
|
server_started.set()
|
||||||
|
loop.run_until_complete(start_server())
|
||||||
|
except Exception as e:
|
||||||
|
server_exception = e
|
||||||
|
server_started.set()
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
server_thread = threading.Thread(target=run_server, daemon=True)
|
||||||
|
server_thread.start()
|
||||||
|
|
||||||
|
server_started.wait(timeout=30)
|
||||||
|
if server_exception:
|
||||||
|
raise server_exception
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
yield host, port
|
||||||
|
|
||||||
|
if server_instance:
|
||||||
|
server_instance.should_exit = True
|
||||||
|
server_thread.join(timeout=30)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def patch_jwt_verification(monkeypatch):
|
||||||
|
"""Patch JWT verification to accept HS256 tokens signed with SECRET_KEY for tests."""
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
def _verify_token(self, token: str):
|
||||||
|
# Do not validate audience in tests
|
||||||
|
return jwt.decode(token, settings.SECRET_KEY, algorithms=["HS256"]) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"reflector.auth.auth_jwt.JWTAuth.verify_token", _verify_token, raising=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_dummy_jwt(sub: str = "user123") -> str:
|
||||||
|
# Create a short HS256 JWT using the app secret to pass verification in tests
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from reflector.settings import settings
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"sub": sub,
|
||||||
|
"email": f"{sub}@example.com",
|
||||||
|
"exp": datetime.now(timezone.utc) + timedelta(minutes=5),
|
||||||
|
}
|
||||||
|
# Note: production uses RS256 public key verification; tests can sign with SECRET_KEY
|
||||||
|
return jwt.encode(payload, settings.SECRET_KEY, algorithm="HS256")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_ws_rejects_missing_subprotocol(appserver_ws_user):
|
||||||
|
host, port = appserver_ws_user
|
||||||
|
base_ws = f"http://{host}:{port}/v1/events"
|
||||||
|
# No subprotocol/header with token
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
async with aconnect_ws(base_ws) as ws: # type: ignore
|
||||||
|
# Should close during handshake; if not, close explicitly
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_ws_rejects_invalid_token(appserver_ws_user):
|
||||||
|
host, port = appserver_ws_user
|
||||||
|
base_ws = f"http://{host}:{port}/v1/events"
|
||||||
|
|
||||||
|
# Send wrong token via WebSocket subprotocols
|
||||||
|
protocols = ["bearer", "totally-invalid-token"]
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
async with aconnect_ws(base_ws, subprotocols=protocols) as ws: # type: ignore
|
||||||
|
await ws.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_ws_accepts_valid_token_and_receives_events(appserver_ws_user):
|
||||||
|
host, port = appserver_ws_user
|
||||||
|
base_ws = f"http://{host}:{port}/v1/events"
|
||||||
|
|
||||||
|
token = _make_dummy_jwt("user-abc")
|
||||||
|
subprotocols = ["bearer", token]
|
||||||
|
|
||||||
|
# Connect and then trigger an event via HTTP create
|
||||||
|
async with aconnect_ws(base_ws, subprotocols=subprotocols) as ws:
|
||||||
|
# Emit an event to the user's room via a standard HTTP action
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
from reflector.app import app
|
||||||
|
from reflector.auth import current_user, current_user_optional
|
||||||
|
|
||||||
|
# Override auth dependencies so HTTP request is performed as the same user
|
||||||
|
app.dependency_overrides[current_user] = lambda: {
|
||||||
|
"sub": "user-abc",
|
||||||
|
"email": "user-abc@example.com",
|
||||||
|
}
|
||||||
|
app.dependency_overrides[current_user_optional] = lambda: {
|
||||||
|
"sub": "user-abc",
|
||||||
|
"email": "user-abc@example.com",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with AsyncClient(app=app, base_url=f"http://{host}:{port}/v1") as ac:
|
||||||
|
# Create a transcript as this user so that the server publishes TRANSCRIPT_CREATED to user room
|
||||||
|
resp = await ac.post("/transcripts", json={"name": "WS Test"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
# Receive the published event
|
||||||
|
msg = await ws.receive_json()
|
||||||
|
assert msg["event"] == "TRANSCRIPT_CREATED"
|
||||||
|
assert "id" in msg["data"]
|
||||||
|
|
||||||
|
# Clean overrides
|
||||||
|
del app.dependency_overrides[current_user]
|
||||||
|
del app.dependency_overrides[current_user_optional]
|
||||||
@@ -1,14 +1,13 @@
|
|||||||
"""Integration tests for WebVTT auto-update functionality in Transcript model."""
|
"""Integration tests for WebVTT auto-update functionality in Transcript model."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from reflector.db.base import TranscriptModel
|
from reflector.db import get_database
|
||||||
from reflector.db.transcripts import (
|
from reflector.db.transcripts import (
|
||||||
SourceKind,
|
SourceKind,
|
||||||
TranscriptController,
|
TranscriptController,
|
||||||
TranscriptTopic,
|
TranscriptTopic,
|
||||||
transcripts_controller,
|
transcripts,
|
||||||
)
|
)
|
||||||
from reflector.processors.types import Word
|
from reflector.processors.types import Word
|
||||||
|
|
||||||
@@ -17,35 +16,30 @@ from reflector.processors.types import Word
|
|||||||
class TestWebVTTAutoUpdate:
|
class TestWebVTTAutoUpdate:
|
||||||
"""Test that WebVTT field auto-updates when Transcript is created or modified."""
|
"""Test that WebVTT field auto-updates when Transcript is created or modified."""
|
||||||
|
|
||||||
async def test_webvtt_not_updated_on_transcript_creation_without_topics(
|
async def test_webvtt_not_updated_on_transcript_creation_without_topics(self):
|
||||||
self, db_session
|
|
||||||
):
|
|
||||||
"""WebVTT should be None when creating transcript without topics."""
|
"""WebVTT should be None when creating transcript without topics."""
|
||||||
# Using global transcripts_controller
|
controller = TranscriptController()
|
||||||
|
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await controller.add(
|
||||||
db_session,
|
|
||||||
name="Test Transcript",
|
name="Test Transcript",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await db_session.execute(
|
result = await get_database().fetch_one(
|
||||||
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||||
)
|
)
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
assert row is not None
|
assert result is not None
|
||||||
assert row.webvtt is None
|
assert result["webvtt"] is None
|
||||||
finally:
|
finally:
|
||||||
await transcripts_controller.remove_by_id(db_session, transcript.id)
|
await controller.remove_by_id(transcript.id)
|
||||||
|
|
||||||
async def test_webvtt_updated_on_upsert_topic(self, db_session):
|
async def test_webvtt_updated_on_upsert_topic(self):
|
||||||
"""WebVTT should update when upserting topics via upsert_topic method."""
|
"""WebVTT should update when upserting topics via upsert_topic method."""
|
||||||
# Using global transcripts_controller
|
controller = TranscriptController()
|
||||||
|
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await controller.add(
|
||||||
db_session,
|
|
||||||
name="Test Transcript",
|
name="Test Transcript",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
)
|
)
|
||||||
@@ -62,15 +56,14 @@ class TestWebVTTAutoUpdate:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
await transcripts_controller.upsert_topic(db_session, transcript, topic)
|
await controller.upsert_topic(transcript, topic)
|
||||||
|
|
||||||
result = await db_session.execute(
|
result = await get_database().fetch_one(
|
||||||
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||||
)
|
)
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
assert row is not None
|
assert result is not None
|
||||||
webvtt = row.webvtt
|
webvtt = result["webvtt"]
|
||||||
|
|
||||||
assert webvtt is not None
|
assert webvtt is not None
|
||||||
assert "WEBVTT" in webvtt
|
assert "WEBVTT" in webvtt
|
||||||
@@ -78,14 +71,13 @@ class TestWebVTTAutoUpdate:
|
|||||||
assert "<v Speaker0>" in webvtt
|
assert "<v Speaker0>" in webvtt
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await transcripts_controller.remove_by_id(db_session, transcript.id)
|
await controller.remove_by_id(transcript.id)
|
||||||
|
|
||||||
async def test_webvtt_updated_on_direct_topics_update(self, db_session):
|
async def test_webvtt_updated_on_direct_topics_update(self):
|
||||||
"""WebVTT should update when updating topics field directly."""
|
"""WebVTT should update when updating topics field directly."""
|
||||||
# Using global transcripts_controller
|
controller = TranscriptController()
|
||||||
|
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await controller.add(
|
||||||
db_session,
|
|
||||||
name="Test Transcript",
|
name="Test Transcript",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
)
|
)
|
||||||
@@ -104,32 +96,28 @@ class TestWebVTTAutoUpdate:
|
|||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
await transcripts_controller.update(
|
await controller.update(transcript, {"topics": topics_data})
|
||||||
db_session, transcript, {"topics": topics_data}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fetch from DB
|
# Fetch from DB
|
||||||
result = await db_session.execute(
|
result = await get_database().fetch_one(
|
||||||
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||||
)
|
)
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
assert row is not None
|
assert result is not None
|
||||||
webvtt = row.webvtt
|
webvtt = result["webvtt"]
|
||||||
|
|
||||||
assert webvtt is not None
|
assert webvtt is not None
|
||||||
assert "WEBVTT" in webvtt
|
assert "WEBVTT" in webvtt
|
||||||
assert "First sentence" in webvtt
|
assert "First sentence" in webvtt
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await transcripts_controller.remove_by_id(db_session, transcript.id)
|
await controller.remove_by_id(transcript.id)
|
||||||
|
|
||||||
async def test_webvtt_updated_manually_with_handle_topics_update(self, db_session):
|
async def test_webvtt_updated_manually_with_handle_topics_update(self):
|
||||||
"""Test that _handle_topics_update works when called manually."""
|
"""Test that _handle_topics_update works when called manually."""
|
||||||
# Using global transcripts_controller
|
controller = TranscriptController()
|
||||||
|
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await controller.add(
|
||||||
db_session,
|
|
||||||
name="Test Transcript",
|
name="Test Transcript",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
)
|
)
|
||||||
@@ -150,16 +138,15 @@ class TestWebVTTAutoUpdate:
|
|||||||
|
|
||||||
values = {"topics": transcript.topics_dump()}
|
values = {"topics": transcript.topics_dump()}
|
||||||
|
|
||||||
await transcripts_controller.update(db_session, transcript, values)
|
await controller.update(transcript, values)
|
||||||
|
|
||||||
# Fetch from DB
|
# Fetch from DB
|
||||||
result = await db_session.execute(
|
result = await get_database().fetch_one(
|
||||||
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||||
)
|
)
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
assert row is not None
|
assert result is not None
|
||||||
webvtt = row.webvtt
|
webvtt = result["webvtt"]
|
||||||
|
|
||||||
assert webvtt is not None
|
assert webvtt is not None
|
||||||
assert "WEBVTT" in webvtt
|
assert "WEBVTT" in webvtt
|
||||||
@@ -167,14 +154,13 @@ class TestWebVTTAutoUpdate:
|
|||||||
assert "<v Speaker0>" in webvtt
|
assert "<v Speaker0>" in webvtt
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await transcripts_controller.remove_by_id(db_session, transcript.id)
|
await controller.remove_by_id(transcript.id)
|
||||||
|
|
||||||
async def test_webvtt_update_with_non_sequential_topics_fails(self, db_session):
|
async def test_webvtt_update_with_non_sequential_topics_fails(self):
|
||||||
"""Test that non-sequential topics raise assertion error."""
|
"""Test that non-sequential topics raise assertion error."""
|
||||||
# Using global transcripts_controller
|
controller = TranscriptController()
|
||||||
|
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await controller.add(
|
||||||
db_session,
|
|
||||||
name="Test Transcript",
|
name="Test Transcript",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
)
|
)
|
||||||
@@ -200,14 +186,13 @@ class TestWebVTTAutoUpdate:
|
|||||||
assert "Words are not in sequence" in str(exc_info.value)
|
assert "Words are not in sequence" in str(exc_info.value)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await transcripts_controller.remove_by_id(db_session, transcript.id)
|
await controller.remove_by_id(transcript.id)
|
||||||
|
|
||||||
async def test_multiple_speakers_in_webvtt(self, db_session):
|
async def test_multiple_speakers_in_webvtt(self):
|
||||||
"""Test WebVTT generation with multiple speakers."""
|
"""Test WebVTT generation with multiple speakers."""
|
||||||
# Using global transcripts_controller
|
controller = TranscriptController()
|
||||||
|
|
||||||
transcript = await transcripts_controller.add(
|
transcript = await controller.add(
|
||||||
db_session,
|
|
||||||
name="Test Transcript",
|
name="Test Transcript",
|
||||||
source_kind=SourceKind.FILE,
|
source_kind=SourceKind.FILE,
|
||||||
)
|
)
|
||||||
@@ -228,16 +213,15 @@ class TestWebVTTAutoUpdate:
|
|||||||
transcript.upsert_topic(topic)
|
transcript.upsert_topic(topic)
|
||||||
values = {"topics": transcript.topics_dump()}
|
values = {"topics": transcript.topics_dump()}
|
||||||
|
|
||||||
await transcripts_controller.update(db_session, transcript, values)
|
await controller.update(transcript, values)
|
||||||
|
|
||||||
# Fetch from DB
|
# Fetch from DB
|
||||||
result = await db_session.execute(
|
result = await get_database().fetch_one(
|
||||||
select(TranscriptModel).where(TranscriptModel.id == transcript.id)
|
transcripts.select().where(transcripts.c.id == transcript.id)
|
||||||
)
|
)
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
assert row is not None
|
assert result is not None
|
||||||
webvtt = row.webvtt
|
webvtt = result["webvtt"]
|
||||||
|
|
||||||
assert webvtt is not None
|
assert webvtt is not None
|
||||||
assert "<v Speaker0>" in webvtt
|
assert "<v Speaker0>" in webvtt
|
||||||
@@ -247,4 +231,4 @@ class TestWebVTTAutoUpdate:
|
|||||||
assert "Goodbye" in webvtt
|
assert "Goodbye" in webvtt
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
await transcripts_controller.remove_by_id(db_session, transcript.id)
|
await controller.remove_by_id(transcript.id)
|
||||||
|
|||||||
3192
server/uv.lock
generated
3192
server/uv.lock
generated
File diff suppressed because it is too large
Load Diff
14
www/.dockerignore
Normal file
14
www/.dockerignore
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
.env.local
|
||||||
|
.env.development
|
||||||
|
.env.production
|
||||||
|
node_modules
|
||||||
|
.next
|
||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
*.md
|
||||||
|
.DS_Store
|
||||||
|
coverage
|
||||||
|
.pnpm-store
|
||||||
|
*.log
|
||||||
@@ -1,9 +1,5 @@
|
|||||||
# Environment
|
|
||||||
ENVIRONMENT=development
|
|
||||||
NEXT_PUBLIC_ENV=development
|
|
||||||
|
|
||||||
# Site Configuration
|
# Site Configuration
|
||||||
NEXT_PUBLIC_SITE_URL=http://localhost:3000
|
SITE_URL=http://localhost:3000
|
||||||
|
|
||||||
# Nextauth envs
|
# Nextauth envs
|
||||||
# not used in app code but in lib code
|
# not used in app code but in lib code
|
||||||
@@ -18,16 +14,16 @@ AUTHENTIK_CLIENT_ID=your-client-id-here
|
|||||||
AUTHENTIK_CLIENT_SECRET=your-client-secret-here
|
AUTHENTIK_CLIENT_SECRET=your-client-secret-here
|
||||||
|
|
||||||
# Feature Flags
|
# Feature Flags
|
||||||
# NEXT_PUBLIC_FEATURE_REQUIRE_LOGIN=true
|
# FEATURE_REQUIRE_LOGIN=true
|
||||||
# NEXT_PUBLIC_FEATURE_PRIVACY=false
|
# FEATURE_PRIVACY=false
|
||||||
# NEXT_PUBLIC_FEATURE_BROWSE=true
|
# FEATURE_BROWSE=true
|
||||||
# NEXT_PUBLIC_FEATURE_SEND_TO_ZULIP=true
|
# FEATURE_SEND_TO_ZULIP=true
|
||||||
# NEXT_PUBLIC_FEATURE_ROOMS=true
|
# FEATURE_ROOMS=true
|
||||||
|
|
||||||
# API URLs
|
# API URLs
|
||||||
NEXT_PUBLIC_API_URL=http://127.0.0.1:1250
|
API_URL=http://127.0.0.1:1250
|
||||||
NEXT_PUBLIC_WEBSOCKET_URL=ws://127.0.0.1:1250
|
WEBSOCKET_URL=ws://127.0.0.1:1250
|
||||||
NEXT_PUBLIC_AUTH_CALLBACK_URL=http://localhost:3000/auth-callback
|
AUTH_CALLBACK_URL=http://localhost:3000/auth-callback
|
||||||
|
|
||||||
# Sentry
|
# Sentry
|
||||||
# SENTRY_DSN=https://your-dsn@sentry.io/project-id
|
# SENTRY_DSN=https://your-dsn@sentry.io/project-id
|
||||||
|
|||||||
81
www/DOCKER_README.md
Normal file
81
www/DOCKER_README.md
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# Docker Production Build Guide
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The Docker image builds without any environment variables and requires all configuration to be provided at runtime.
|
||||||
|
|
||||||
|
## Environment Variables (ALL Runtime)
|
||||||
|
|
||||||
|
### Required Runtime Variables
|
||||||
|
|
||||||
|
```bash
|
||||||
|
API_URL # Backend API URL (e.g., https://api.example.com)
|
||||||
|
WEBSOCKET_URL # WebSocket URL (e.g., wss://api.example.com)
|
||||||
|
NEXTAUTH_URL # NextAuth base URL (e.g., https://app.example.com)
|
||||||
|
NEXTAUTH_SECRET # Random secret for NextAuth (generate with: openssl rand -base64 32)
|
||||||
|
KV_URL # Redis URL (e.g., redis://redis:6379)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Optional Runtime Variables
|
||||||
|
|
||||||
|
```bash
|
||||||
|
SITE_URL # Frontend URL (defaults to NEXTAUTH_URL)
|
||||||
|
|
||||||
|
AUTHENTIK_ISSUER # OAuth issuer URL
|
||||||
|
AUTHENTIK_CLIENT_ID # OAuth client ID
|
||||||
|
AUTHENTIK_CLIENT_SECRET # OAuth client secret
|
||||||
|
AUTHENTIK_REFRESH_TOKEN_URL # OAuth token refresh URL
|
||||||
|
|
||||||
|
FEATURE_REQUIRE_LOGIN=false # Require authentication
|
||||||
|
FEATURE_PRIVACY=true # Enable privacy features
|
||||||
|
FEATURE_BROWSE=true # Enable browsing features
|
||||||
|
FEATURE_SEND_TO_ZULIP=false # Enable Zulip integration
|
||||||
|
FEATURE_ROOMS=true # Enable rooms feature
|
||||||
|
|
||||||
|
SENTRY_DSN # Sentry error tracking
|
||||||
|
AUTH_CALLBACK_URL # OAuth callback URL
|
||||||
|
```
|
||||||
|
|
||||||
|
## Building the Image
|
||||||
|
|
||||||
|
### Option 1: Using Docker Compose
|
||||||
|
|
||||||
|
1. Build the image (no environment variables needed):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose -f docker-compose.prod.yml build
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Create a `.env` file with runtime variables
|
||||||
|
|
||||||
|
3. Run with environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose -f docker-compose.prod.yml --env-file .env up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option 2: Using Docker CLI
|
||||||
|
|
||||||
|
1. Build the image (no build args):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker build -t reflector-frontend:latest ./www
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run with environment variables:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run -d \
|
||||||
|
-p 3000:3000 \
|
||||||
|
-e API_URL=https://api.example.com \
|
||||||
|
-e WEBSOCKET_URL=wss://api.example.com \
|
||||||
|
-e NEXTAUTH_URL=https://app.example.com \
|
||||||
|
-e NEXTAUTH_SECRET=your-secret \
|
||||||
|
-e KV_URL=redis://redis:6379 \
|
||||||
|
-e AUTHENTIK_ISSUER=https://auth.example.com/application/o/reflector \
|
||||||
|
-e AUTHENTIK_CLIENT_ID=your-client-id \
|
||||||
|
-e AUTHENTIK_CLIENT_SECRET=your-client-secret \
|
||||||
|
-e AUTHENTIK_REFRESH_TOKEN_URL=https://auth.example.com/application/o/token/ \
|
||||||
|
-e FEATURE_REQUIRE_LOGIN=true \
|
||||||
|
reflector-frontend:latest
|
||||||
|
```
|
||||||
@@ -24,7 +24,8 @@ COPY --link . .
|
|||||||
ENV NEXT_TELEMETRY_DISABLED 1
|
ENV NEXT_TELEMETRY_DISABLED 1
|
||||||
|
|
||||||
# If using npm comment out above and use below instead
|
# If using npm comment out above and use below instead
|
||||||
RUN pnpm build
|
# next.js has the feature of excluding build step planned https://github.com/vercel/next.js/discussions/46544
|
||||||
|
RUN pnpm build-production
|
||||||
# RUN npm run build
|
# RUN npm run build
|
||||||
|
|
||||||
# Production image, copy all the files and run next
|
# Production image, copy all the files and run next
|
||||||
@@ -51,6 +52,10 @@ USER nextjs
|
|||||||
EXPOSE 3000
|
EXPOSE 3000
|
||||||
|
|
||||||
ENV PORT 3000
|
ENV PORT 3000
|
||||||
ENV HOSTNAME localhost
|
ENV HOSTNAME 0.0.0.0
|
||||||
|
|
||||||
|
HEALTHCHECK --interval=30s --timeout=3s --start-period=40s --retries=3 \
|
||||||
|
CMD wget --no-verbose --tries=1 --spider http://127.0.0.1:3000/api/health \
|
||||||
|
|| exit 1
|
||||||
|
|
||||||
CMD ["node", "server.js"]
|
CMD ["node", "server.js"]
|
||||||
|
|||||||
@@ -200,7 +200,13 @@ export default function ICSSettings({
|
|||||||
<HStack gap={0} position="relative" width="100%">
|
<HStack gap={0} position="relative" width="100%">
|
||||||
<Input
|
<Input
|
||||||
ref={roomUrlInputRef}
|
ref={roomUrlInputRef}
|
||||||
value={roomAbsoluteUrl(parseNonEmptyString(roomName))}
|
value={roomAbsoluteUrl(
|
||||||
|
parseNonEmptyString(
|
||||||
|
roomName,
|
||||||
|
true,
|
||||||
|
"panic! roomName is required",
|
||||||
|
),
|
||||||
|
)}
|
||||||
readOnly
|
readOnly
|
||||||
onClick={handleRoomUrlClick}
|
onClick={handleRoomUrlClick}
|
||||||
cursor="pointer"
|
cursor="pointer"
|
||||||
|
|||||||
@@ -274,15 +274,31 @@ export function RoomTable({
|
|||||||
<IconButton
|
<IconButton
|
||||||
aria-label="Force sync calendar"
|
aria-label="Force sync calendar"
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
handleForceSync(parseNonEmptyString(room.name))
|
handleForceSync(
|
||||||
|
parseNonEmptyString(
|
||||||
|
room.name,
|
||||||
|
true,
|
||||||
|
"panic! room.name is required",
|
||||||
|
),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
size="sm"
|
size="sm"
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
disabled={syncingRooms.has(
|
disabled={syncingRooms.has(
|
||||||
parseNonEmptyString(room.name),
|
parseNonEmptyString(
|
||||||
|
room.name,
|
||||||
|
true,
|
||||||
|
"panic! room.name is required",
|
||||||
|
),
|
||||||
)}
|
)}
|
||||||
>
|
>
|
||||||
{syncingRooms.has(parseNonEmptyString(room.name)) ? (
|
{syncingRooms.has(
|
||||||
|
parseNonEmptyString(
|
||||||
|
room.name,
|
||||||
|
true,
|
||||||
|
"panic! room.name is required",
|
||||||
|
),
|
||||||
|
) ? (
|
||||||
<Spinner size="sm" />
|
<Spinner size="sm" />
|
||||||
) : (
|
) : (
|
||||||
<CalendarSyncIcon />
|
<CalendarSyncIcon />
|
||||||
@@ -297,7 +313,13 @@ export function RoomTable({
|
|||||||
<IconButton
|
<IconButton
|
||||||
aria-label="Copy URL"
|
aria-label="Copy URL"
|
||||||
onClick={() =>
|
onClick={() =>
|
||||||
onCopyUrl(parseNonEmptyString(room.name))
|
onCopyUrl(
|
||||||
|
parseNonEmptyString(
|
||||||
|
room.name,
|
||||||
|
true,
|
||||||
|
"panic! room.name is required",
|
||||||
|
),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
size="sm"
|
size="sm"
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
|
|||||||
@@ -833,7 +833,13 @@ export default function RoomsList() {
|
|||||||
<Field.Root>
|
<Field.Root>
|
||||||
<ICSSettings
|
<ICSSettings
|
||||||
roomName={
|
roomName={
|
||||||
room.name ? parseNonEmptyString(room.name) : null
|
room.name
|
||||||
|
? parseNonEmptyString(
|
||||||
|
room.name,
|
||||||
|
true,
|
||||||
|
"panic! room.name required",
|
||||||
|
)
|
||||||
|
: null
|
||||||
}
|
}
|
||||||
icsUrl={room.icsUrl}
|
icsUrl={room.icsUrl}
|
||||||
icsEnabled={room.icsEnabled}
|
icsEnabled={room.icsEnabled}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useEffect, useRef, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import React from "react";
|
import React from "react";
|
||||||
import Markdown from "react-markdown";
|
import Markdown from "react-markdown";
|
||||||
import "../../../styles/markdown.css";
|
import "../../../styles/markdown.css";
|
||||||
@@ -16,17 +16,15 @@ import {
|
|||||||
} from "@chakra-ui/react";
|
} from "@chakra-ui/react";
|
||||||
import { LuPen } from "react-icons/lu";
|
import { LuPen } from "react-icons/lu";
|
||||||
import { useError } from "../../../(errors)/errorContext";
|
import { useError } from "../../../(errors)/errorContext";
|
||||||
import ShareAndPrivacy from "../shareAndPrivacy";
|
|
||||||
|
|
||||||
type FinalSummaryProps = {
|
type FinalSummaryProps = {
|
||||||
transcriptResponse: GetTranscript;
|
transcript: GetTranscript;
|
||||||
topicsResponse: GetTranscriptTopic[];
|
topics: GetTranscriptTopic[];
|
||||||
onUpdate?: (newSummary) => void;
|
onUpdate: (newSummary: string) => void;
|
||||||
|
finalSummaryRef: React.Dispatch<React.SetStateAction<HTMLDivElement | null>>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export default function FinalSummary(props: FinalSummaryProps) {
|
export default function FinalSummary(props: FinalSummaryProps) {
|
||||||
const finalSummaryRef = useRef<HTMLParagraphElement>(null);
|
|
||||||
|
|
||||||
const [isEditMode, setIsEditMode] = useState(false);
|
const [isEditMode, setIsEditMode] = useState(false);
|
||||||
const [preEditSummary, setPreEditSummary] = useState("");
|
const [preEditSummary, setPreEditSummary] = useState("");
|
||||||
const [editedSummary, setEditedSummary] = useState("");
|
const [editedSummary, setEditedSummary] = useState("");
|
||||||
@@ -35,10 +33,10 @@ export default function FinalSummary(props: FinalSummaryProps) {
|
|||||||
const updateTranscriptMutation = useTranscriptUpdate();
|
const updateTranscriptMutation = useTranscriptUpdate();
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setEditedSummary(props.transcriptResponse?.long_summary || "");
|
setEditedSummary(props.transcript?.long_summary || "");
|
||||||
}, [props.transcriptResponse?.long_summary]);
|
}, [props.transcript?.long_summary]);
|
||||||
|
|
||||||
if (!props.topicsResponse || !props.transcriptResponse) {
|
if (!props.topics || !props.transcript) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,9 +52,7 @@ export default function FinalSummary(props: FinalSummaryProps) {
|
|||||||
long_summary: newSummary,
|
long_summary: newSummary,
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
if (props.onUpdate) {
|
|
||||||
props.onUpdate(newSummary);
|
props.onUpdate(newSummary);
|
||||||
}
|
|
||||||
console.log("Updated long summary:", updatedTranscript);
|
console.log("Updated long summary:", updatedTranscript);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error("Failed to update long summary:", err);
|
console.error("Failed to update long summary:", err);
|
||||||
@@ -75,7 +71,7 @@ export default function FinalSummary(props: FinalSummaryProps) {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const onSaveClick = () => {
|
const onSaveClick = () => {
|
||||||
updateSummary(editedSummary, props.transcriptResponse.id);
|
updateSummary(editedSummary, props.transcript.id);
|
||||||
setIsEditMode(false);
|
setIsEditMode(false);
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -133,11 +129,6 @@ export default function FinalSummary(props: FinalSummaryProps) {
|
|||||||
>
|
>
|
||||||
<LuPen />
|
<LuPen />
|
||||||
</IconButton>
|
</IconButton>
|
||||||
<ShareAndPrivacy
|
|
||||||
finalSummaryRef={finalSummaryRef}
|
|
||||||
transcriptResponse={props.transcriptResponse}
|
|
||||||
topicsResponse={props.topicsResponse}
|
|
||||||
/>
|
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
@@ -153,7 +144,7 @@ export default function FinalSummary(props: FinalSummaryProps) {
|
|||||||
mt={2}
|
mt={2}
|
||||||
/>
|
/>
|
||||||
) : (
|
) : (
|
||||||
<div ref={finalSummaryRef} className="markdown">
|
<div ref={props.finalSummaryRef} className="markdown">
|
||||||
<Markdown>{editedSummary}</Markdown>
|
<Markdown>{editedSummary}</Markdown>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
|||||||
waiting || mp3.audioDeleted === true,
|
waiting || mp3.audioDeleted === true,
|
||||||
);
|
);
|
||||||
const useActiveTopic = useState<Topic | null>(null);
|
const useActiveTopic = useState<Topic | null>(null);
|
||||||
|
const [finalSummaryElement, setFinalSummaryElement] =
|
||||||
|
useState<HTMLDivElement | null>(null);
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
if (waiting) {
|
if (waiting) {
|
||||||
@@ -124,9 +126,12 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
|||||||
<TranscriptTitle
|
<TranscriptTitle
|
||||||
title={transcript.data?.title || "Unnamed Transcript"}
|
title={transcript.data?.title || "Unnamed Transcript"}
|
||||||
transcriptId={transcriptId}
|
transcriptId={transcriptId}
|
||||||
onUpdate={(newTitle) => {
|
onUpdate={() => {
|
||||||
transcript.refetch().then(() => {});
|
transcript.refetch().then(() => {});
|
||||||
}}
|
}}
|
||||||
|
transcript={transcript.data || null}
|
||||||
|
topics={topics.topics}
|
||||||
|
finalSummaryElement={finalSummaryElement}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
{mp3.audioDeleted && (
|
{mp3.audioDeleted && (
|
||||||
@@ -148,11 +153,12 @@ export default function TranscriptDetails(details: TranscriptDetails) {
|
|||||||
{transcript.data && topics.topics ? (
|
{transcript.data && topics.topics ? (
|
||||||
<>
|
<>
|
||||||
<FinalSummary
|
<FinalSummary
|
||||||
transcriptResponse={transcript.data}
|
transcript={transcript.data}
|
||||||
topicsResponse={topics.topics}
|
topics={topics.topics}
|
||||||
onUpdate={() => {
|
onUpdate={() => {
|
||||||
transcript.refetch();
|
transcript.refetch().then(() => {});
|
||||||
}}
|
}}
|
||||||
|
finalSummaryRef={setFinalSummaryElement}
|
||||||
/>
|
/>
|
||||||
</>
|
</>
|
||||||
) : (
|
) : (
|
||||||
|
|||||||
@@ -26,9 +26,9 @@ import { useAuth } from "../../lib/AuthProvider";
|
|||||||
import { featureEnabled } from "../../lib/features";
|
import { featureEnabled } from "../../lib/features";
|
||||||
|
|
||||||
type ShareAndPrivacyProps = {
|
type ShareAndPrivacyProps = {
|
||||||
finalSummaryRef: any;
|
finalSummaryElement: HTMLDivElement | null;
|
||||||
transcriptResponse: GetTranscript;
|
transcript: GetTranscript;
|
||||||
topicsResponse: GetTranscriptTopic[];
|
topics: GetTranscriptTopic[];
|
||||||
};
|
};
|
||||||
|
|
||||||
type ShareOption = { value: ShareMode; label: string };
|
type ShareOption = { value: ShareMode; label: string };
|
||||||
@@ -48,7 +48,7 @@ export default function ShareAndPrivacy(props: ShareAndPrivacyProps) {
|
|||||||
const [isOwner, setIsOwner] = useState(false);
|
const [isOwner, setIsOwner] = useState(false);
|
||||||
const [shareMode, setShareMode] = useState<ShareOption>(
|
const [shareMode, setShareMode] = useState<ShareOption>(
|
||||||
shareOptionsData.find(
|
shareOptionsData.find(
|
||||||
(option) => option.value === props.transcriptResponse.share_mode,
|
(option) => option.value === props.transcript.share_mode,
|
||||||
) || shareOptionsData[0],
|
) || shareOptionsData[0],
|
||||||
);
|
);
|
||||||
const [shareLoading, setShareLoading] = useState(false);
|
const [shareLoading, setShareLoading] = useState(false);
|
||||||
@@ -70,7 +70,7 @@ export default function ShareAndPrivacy(props: ShareAndPrivacyProps) {
|
|||||||
try {
|
try {
|
||||||
const updatedTranscript = await updateTranscriptMutation.mutateAsync({
|
const updatedTranscript = await updateTranscriptMutation.mutateAsync({
|
||||||
params: {
|
params: {
|
||||||
path: { transcript_id: props.transcriptResponse.id },
|
path: { transcript_id: props.transcript.id },
|
||||||
},
|
},
|
||||||
body: requestBody,
|
body: requestBody,
|
||||||
});
|
});
|
||||||
@@ -90,8 +90,8 @@ export default function ShareAndPrivacy(props: ShareAndPrivacyProps) {
|
|||||||
const userId = auth.status === "authenticated" ? auth.user?.id : null;
|
const userId = auth.status === "authenticated" ? auth.user?.id : null;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
setIsOwner(!!(requireLogin && userId === props.transcriptResponse.user_id));
|
setIsOwner(!!(requireLogin && userId === props.transcript.user_id));
|
||||||
}, [userId, props.transcriptResponse.user_id]);
|
}, [userId, props.transcript.user_id]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@@ -171,19 +171,19 @@ export default function ShareAndPrivacy(props: ShareAndPrivacyProps) {
|
|||||||
<Flex gap={2} mb={2}>
|
<Flex gap={2} mb={2}>
|
||||||
{requireLogin && (
|
{requireLogin && (
|
||||||
<ShareZulip
|
<ShareZulip
|
||||||
transcriptResponse={props.transcriptResponse}
|
transcript={props.transcript}
|
||||||
topicsResponse={props.topicsResponse}
|
topics={props.topics}
|
||||||
disabled={toShareMode(shareMode.value) === "private"}
|
disabled={toShareMode(shareMode.value) === "private"}
|
||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
<ShareCopy
|
<ShareCopy
|
||||||
finalSummaryRef={props.finalSummaryRef}
|
finalSummaryElement={props.finalSummaryElement}
|
||||||
transcriptResponse={props.transcriptResponse}
|
transcript={props.transcript}
|
||||||
topicsResponse={props.topicsResponse}
|
topics={props.topics}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
</Flex>
|
||||||
|
|
||||||
<ShareLink transcriptId={props.transcriptResponse.id} />
|
<ShareLink transcriptId={props.transcript.id} />
|
||||||
</Dialog.Body>
|
</Dialog.Body>
|
||||||
</Dialog.Content>
|
</Dialog.Content>
|
||||||
</Dialog.Positioner>
|
</Dialog.Positioner>
|
||||||
|
|||||||
@@ -5,34 +5,35 @@ type GetTranscriptTopic = components["schemas"]["GetTranscriptTopic"];
|
|||||||
import { Button, BoxProps, Box } from "@chakra-ui/react";
|
import { Button, BoxProps, Box } from "@chakra-ui/react";
|
||||||
|
|
||||||
type ShareCopyProps = {
|
type ShareCopyProps = {
|
||||||
finalSummaryRef: any;
|
finalSummaryElement: HTMLDivElement | null;
|
||||||
transcriptResponse: GetTranscript;
|
transcript: GetTranscript;
|
||||||
topicsResponse: GetTranscriptTopic[];
|
topics: GetTranscriptTopic[];
|
||||||
};
|
};
|
||||||
|
|
||||||
export default function ShareCopy({
|
export default function ShareCopy({
|
||||||
finalSummaryRef,
|
finalSummaryElement,
|
||||||
transcriptResponse,
|
transcript,
|
||||||
topicsResponse,
|
topics,
|
||||||
...boxProps
|
...boxProps
|
||||||
}: ShareCopyProps & BoxProps) {
|
}: ShareCopyProps & BoxProps) {
|
||||||
const [isCopiedSummary, setIsCopiedSummary] = useState(false);
|
const [isCopiedSummary, setIsCopiedSummary] = useState(false);
|
||||||
const [isCopiedTranscript, setIsCopiedTranscript] = useState(false);
|
const [isCopiedTranscript, setIsCopiedTranscript] = useState(false);
|
||||||
|
|
||||||
const onCopySummaryClick = () => {
|
const onCopySummaryClick = () => {
|
||||||
let text_to_copy = finalSummaryRef.current?.innerText;
|
const text_to_copy = finalSummaryElement?.innerText;
|
||||||
|
|
||||||
text_to_copy &&
|
if (text_to_copy) {
|
||||||
navigator.clipboard.writeText(text_to_copy).then(() => {
|
navigator.clipboard.writeText(text_to_copy).then(() => {
|
||||||
setIsCopiedSummary(true);
|
setIsCopiedSummary(true);
|
||||||
// Reset the copied state after 2 seconds
|
// Reset the copied state after 2 seconds
|
||||||
setTimeout(() => setIsCopiedSummary(false), 2000);
|
setTimeout(() => setIsCopiedSummary(false), 2000);
|
||||||
});
|
});
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const onCopyTranscriptClick = () => {
|
const onCopyTranscriptClick = () => {
|
||||||
let text_to_copy =
|
let text_to_copy =
|
||||||
topicsResponse
|
topics
|
||||||
?.map((topic) => topic.transcript)
|
?.map((topic) => topic.transcript)
|
||||||
.join("\n\n")
|
.join("\n\n")
|
||||||
.replace(/ +/g, " ")
|
.replace(/ +/g, " ")
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ import {
|
|||||||
import { featureEnabled } from "../../lib/features";
|
import { featureEnabled } from "../../lib/features";
|
||||||
|
|
||||||
type ShareZulipProps = {
|
type ShareZulipProps = {
|
||||||
transcriptResponse: GetTranscript;
|
transcript: GetTranscript;
|
||||||
topicsResponse: GetTranscriptTopic[];
|
topics: GetTranscriptTopic[];
|
||||||
disabled: boolean;
|
disabled: boolean;
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -88,14 +88,14 @@ export default function ShareZulip(props: ShareZulipProps & BoxProps) {
|
|||||||
}, [stream, streams]);
|
}, [stream, streams]);
|
||||||
|
|
||||||
const handleSendToZulip = async () => {
|
const handleSendToZulip = async () => {
|
||||||
if (!props.transcriptResponse) return;
|
if (!props.transcript) return;
|
||||||
|
|
||||||
if (stream && topic) {
|
if (stream && topic) {
|
||||||
try {
|
try {
|
||||||
await postToZulipMutation.mutateAsync({
|
await postToZulipMutation.mutateAsync({
|
||||||
params: {
|
params: {
|
||||||
path: {
|
path: {
|
||||||
transcript_id: props.transcriptResponse.id,
|
transcript_id: props.transcript.id,
|
||||||
},
|
},
|
||||||
query: {
|
query: {
|
||||||
stream,
|
stream,
|
||||||
|
|||||||
@@ -2,14 +2,22 @@ import { useState } from "react";
|
|||||||
import type { components } from "../../reflector-api";
|
import type { components } from "../../reflector-api";
|
||||||
|
|
||||||
type UpdateTranscript = components["schemas"]["UpdateTranscript"];
|
type UpdateTranscript = components["schemas"]["UpdateTranscript"];
|
||||||
|
type GetTranscript = components["schemas"]["GetTranscript"];
|
||||||
|
type GetTranscriptTopic = components["schemas"]["GetTranscriptTopic"];
|
||||||
import { useTranscriptUpdate } from "../../lib/apiHooks";
|
import { useTranscriptUpdate } from "../../lib/apiHooks";
|
||||||
import { Heading, IconButton, Input, Flex, Spacer } from "@chakra-ui/react";
|
import { Heading, IconButton, Input, Flex, Spacer } from "@chakra-ui/react";
|
||||||
import { LuPen } from "react-icons/lu";
|
import { LuPen } from "react-icons/lu";
|
||||||
|
import ShareAndPrivacy from "./shareAndPrivacy";
|
||||||
|
|
||||||
type TranscriptTitle = {
|
type TranscriptTitle = {
|
||||||
title: string;
|
title: string;
|
||||||
transcriptId: string;
|
transcriptId: string;
|
||||||
onUpdate?: (newTitle: string) => void;
|
onUpdate: (newTitle: string) => void;
|
||||||
|
|
||||||
|
// share props
|
||||||
|
transcript: GetTranscript | null;
|
||||||
|
topics: GetTranscriptTopic[] | null;
|
||||||
|
finalSummaryElement: HTMLDivElement | null;
|
||||||
};
|
};
|
||||||
|
|
||||||
const TranscriptTitle = (props: TranscriptTitle) => {
|
const TranscriptTitle = (props: TranscriptTitle) => {
|
||||||
@@ -29,9 +37,7 @@ const TranscriptTitle = (props: TranscriptTitle) => {
|
|||||||
},
|
},
|
||||||
body: requestBody,
|
body: requestBody,
|
||||||
});
|
});
|
||||||
if (props.onUpdate) {
|
|
||||||
props.onUpdate(newTitle);
|
props.onUpdate(newTitle);
|
||||||
}
|
|
||||||
console.log("Updated transcript title:", newTitle);
|
console.log("Updated transcript title:", newTitle);
|
||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error("Failed to update transcript:", err);
|
console.error("Failed to update transcript:", err);
|
||||||
@@ -62,11 +68,11 @@ const TranscriptTitle = (props: TranscriptTitle) => {
|
|||||||
}
|
}
|
||||||
setIsEditing(false);
|
setIsEditing(false);
|
||||||
};
|
};
|
||||||
const handleChange = (e) => {
|
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||||
setDisplayedTitle(e.target.value);
|
setDisplayedTitle(e.target.value);
|
||||||
};
|
};
|
||||||
|
|
||||||
const handleKeyDown = (e) => {
|
const handleKeyDown = (e: React.KeyboardEvent<HTMLInputElement>) => {
|
||||||
if (e.key === "Enter") {
|
if (e.key === "Enter") {
|
||||||
updateTitle(displayedTitle, props.transcriptId);
|
updateTitle(displayedTitle, props.transcriptId);
|
||||||
setIsEditing(false);
|
setIsEditing(false);
|
||||||
@@ -111,6 +117,13 @@ const TranscriptTitle = (props: TranscriptTitle) => {
|
|||||||
>
|
>
|
||||||
<LuPen />
|
<LuPen />
|
||||||
</IconButton>
|
</IconButton>
|
||||||
|
{props.transcript && props.topics && (
|
||||||
|
<ShareAndPrivacy
|
||||||
|
finalSummaryElement={props.finalSummaryElement}
|
||||||
|
transcript={props.transcript}
|
||||||
|
topics={props.topics}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
</Flex>
|
</Flex>
|
||||||
)}
|
)}
|
||||||
</>
|
</>
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
|||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
document.onkeyup = (e) => {
|
document.onkeyup = (e) => {
|
||||||
if (e.key === "a" && process.env.NEXT_PUBLIC_ENV === "development") {
|
if (e.key === "a" && process.env.NODE_ENV === "development") {
|
||||||
const segments: GetTranscriptSegmentTopic[] = [
|
const segments: GetTranscriptSegmentTopic[] = [
|
||||||
{
|
{
|
||||||
speaker: 1,
|
speaker: 1,
|
||||||
@@ -201,7 +201,7 @@ export const useWebSockets = (transcriptId: string | null): UseWebSockets => {
|
|||||||
|
|
||||||
setFinalSummary({ summary: "This is the final summary" });
|
setFinalSummary({ summary: "This is the final summary" });
|
||||||
}
|
}
|
||||||
if (e.key === "z" && process.env.NEXT_PUBLIC_ENV === "development") {
|
if (e.key === "z" && process.env.NODE_ENV === "development") {
|
||||||
setTranscriptTextLive(
|
setTranscriptTextLive(
|
||||||
"This text is in English, and it is a pretty long sentence to test the limits",
|
"This text is in English, and it is a pretty long sentence to test the limits",
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -261,7 +261,11 @@ export default function Room(details: RoomDetails) {
|
|||||||
const params = use(details.params);
|
const params = use(details.params);
|
||||||
const wherebyLoaded = useWhereby();
|
const wherebyLoaded = useWhereby();
|
||||||
const wherebyRef = useRef<HTMLElement>(null);
|
const wherebyRef = useRef<HTMLElement>(null);
|
||||||
const roomName = parseNonEmptyString(params.roomName);
|
const roomName = parseNonEmptyString(
|
||||||
|
params.roomName,
|
||||||
|
true,
|
||||||
|
"panic! params.roomName is required",
|
||||||
|
);
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
const auth = useAuth();
|
const auth = useAuth();
|
||||||
const status = auth.status;
|
const status = auth.status;
|
||||||
@@ -308,7 +312,14 @@ export default function Room(details: RoomDetails) {
|
|||||||
|
|
||||||
const handleMeetingSelect = (selectedMeeting: Meeting) => {
|
const handleMeetingSelect = (selectedMeeting: Meeting) => {
|
||||||
router.push(
|
router.push(
|
||||||
roomMeetingUrl(roomName, parseNonEmptyString(selectedMeeting.id)),
|
roomMeetingUrl(
|
||||||
|
roomName,
|
||||||
|
parseNonEmptyString(
|
||||||
|
selectedMeeting.id,
|
||||||
|
true,
|
||||||
|
"panic! selectedMeeting.id is required",
|
||||||
|
),
|
||||||
|
),
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
38
www/app/api/health/route.ts
Normal file
38
www/app/api/health/route.ts
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import { NextResponse } from "next/server";
|
||||||
|
|
||||||
|
export async function GET() {
|
||||||
|
const health = {
|
||||||
|
status: "healthy",
|
||||||
|
timestamp: new Date().toISOString(),
|
||||||
|
uptime: process.uptime(),
|
||||||
|
environment: process.env.NODE_ENV,
|
||||||
|
checks: {
|
||||||
|
redis: await checkRedis(),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const allHealthy = Object.values(health.checks).every((check) => check);
|
||||||
|
|
||||||
|
return NextResponse.json(health, {
|
||||||
|
status: allHealthy ? 200 : 503,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async function checkRedis(): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
if (!process.env.KV_URL) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const { tokenCacheRedis } = await import("../../lib/redisClient");
|
||||||
|
const testKey = `health:check:${Date.now()}`;
|
||||||
|
await tokenCacheRedis.setex(testKey, 10, "OK");
|
||||||
|
const value = await tokenCacheRedis.get(testKey);
|
||||||
|
await tokenCacheRedis.del(testKey);
|
||||||
|
|
||||||
|
return value === "OK";
|
||||||
|
} catch (error) {
|
||||||
|
console.error("Redis health check failed:", error);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,7 +6,10 @@ import ErrorMessage from "./(errors)/errorMessage";
|
|||||||
import { RecordingConsentProvider } from "./recordingConsentContext";
|
import { RecordingConsentProvider } from "./recordingConsentContext";
|
||||||
import { ErrorBoundary } from "@sentry/nextjs";
|
import { ErrorBoundary } from "@sentry/nextjs";
|
||||||
import { Providers } from "./providers";
|
import { Providers } from "./providers";
|
||||||
import { assertExistsAndNonEmptyString } from "./lib/utils";
|
import { getNextEnvVar } from "./lib/nextBuild";
|
||||||
|
import { getClientEnv } from "./lib/clientEnv";
|
||||||
|
|
||||||
|
export const dynamic = "force-dynamic";
|
||||||
|
|
||||||
const poppins = Poppins({
|
const poppins = Poppins({
|
||||||
subsets: ["latin"],
|
subsets: ["latin"],
|
||||||
@@ -21,13 +24,11 @@ export const viewport: Viewport = {
|
|||||||
maximumScale: 1,
|
maximumScale: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
const NEXT_PUBLIC_SITE_URL = assertExistsAndNonEmptyString(
|
const SITE_URL = getNextEnvVar("SITE_URL");
|
||||||
process.env.NEXT_PUBLIC_SITE_URL,
|
const env = getClientEnv();
|
||||||
"NEXT_PUBLIC_SITE_URL required",
|
|
||||||
);
|
|
||||||
|
|
||||||
export const metadata: Metadata = {
|
export const metadata: Metadata = {
|
||||||
metadataBase: new URL(NEXT_PUBLIC_SITE_URL),
|
metadataBase: new URL(SITE_URL),
|
||||||
title: {
|
title: {
|
||||||
template: "%s – Reflector",
|
template: "%s – Reflector",
|
||||||
default: "Reflector - AI-Powered Meeting Transcriptions by Monadical",
|
default: "Reflector - AI-Powered Meeting Transcriptions by Monadical",
|
||||||
@@ -74,15 +75,16 @@ export default async function RootLayout({
|
|||||||
}) {
|
}) {
|
||||||
return (
|
return (
|
||||||
<html lang="en" className={poppins.className} suppressHydrationWarning>
|
<html lang="en" className={poppins.className} suppressHydrationWarning>
|
||||||
<body className={"h-[100svh] w-[100svw] overflow-x-hidden relative"}>
|
<body
|
||||||
<RecordingConsentProvider>
|
className={"h-[100svh] w-[100svw] overflow-x-hidden relative"}
|
||||||
|
data-env={JSON.stringify(env)}
|
||||||
|
>
|
||||||
<ErrorBoundary fallback={<p>"something went really wrong"</p>}>
|
<ErrorBoundary fallback={<p>"something went really wrong"</p>}>
|
||||||
<ErrorProvider>
|
<ErrorProvider>
|
||||||
<ErrorMessage />
|
<ErrorMessage />
|
||||||
<Providers>{children}</Providers>
|
<Providers>{children}</Providers>
|
||||||
</ErrorProvider>
|
</ErrorProvider>
|
||||||
</ErrorBoundary>
|
</ErrorBoundary>
|
||||||
</RecordingConsentProvider>
|
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>
|
||||||
);
|
);
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user