Compare commits

...

23 Commits

Author SHA1 Message Date
311d453e41 feat: implement frontend for calendar integration (Phase 3 & 4)
## Frontend Implementation

### Meeting Selection & Management
- Created MeetingSelection component for choosing between multiple active meetings
- Shows both active meetings and upcoming calendar events (30 min ahead)
- Displays meeting metadata with privacy controls (owner-only details)
- Supports creation of unscheduled meetings alongside calendar meetings

### Waiting Room
- Added waiting page for users joining before scheduled start time
- Shows countdown timer until meeting begins
- Auto-transitions to meeting when calendar event becomes active
- Handles early joining with proper routing

### Meeting Info Panel
- Created collapsible info panel showing meeting details
- Displays calendar metadata (title, description, attendees)
- Shows participant count and duration
- Privacy-aware: sensitive info only visible to room owners

### ICS Configuration UI
- Integrated ICS settings into room configuration dialog
- Test connection functionality with immediate feedback
- Manual sync trigger with detailed results
- Shows last sync time and ETag for monitoring
- Configurable sync intervals (1 min to 1 hour)

### Routing & Navigation
- New /room/{roomName} route for meeting selection
- Waiting room at /room/{roomName}/wait?eventId={id}
- Classic room page at /{roomName} with meeting info
- Uses sessionStorage to pass selected meeting between pages

### API Integration
- Added new endpoints for active/upcoming meetings
- Regenerated TypeScript client with latest OpenAPI spec
- Proper error handling and loading states
- Auto-refresh every 30 seconds for live updates

### UI/UX Improvements
- Color-coded badges for meeting status
- Attendee status indicators (accepted/declined/tentative)
- Responsive design with Chakra UI components
- Clear visual hierarchy between active and upcoming meetings
- Smart truncation for long attendee lists

This completes the frontend implementation for calendar integration,
enabling users to seamlessly join scheduled meetings from their
calendar applications.
2025-08-18 19:29:56 -06:00
f286f0882c feat: implement Phase 2 - Multiple active meetings per room with grace period
This commit adds support for multiple concurrent meetings per room, implementing
grace period logic and improved meeting lifecycle management for calendar integration.

## Database Changes
- Remove unique constraint preventing multiple active meetings per room
- Add last_participant_left_at field to track when meeting becomes empty
- Add grace_period_minutes field (default: 15) for configurable grace period

## Meeting Controller Enhancements
- Add get_all_active_for_room() to retrieve all active meetings for a room
- Add get_active_by_calendar_event() to find meetings by calendar event ID
- Maintain backward compatibility with existing get_active() method

## New API Endpoints
- GET /rooms/{room_name}/meetings/active - List all active meetings
- POST /rooms/{room_name}/meetings/{meeting_id}/join - Join specific meeting

## Meeting Lifecycle Improvements
- 15-minute grace period after last participant leaves
- Automatic reactivation when participant rejoins during grace period
- Force close calendar meetings 30 minutes after scheduled end time
- Update process_meetings task to handle multiple active meetings

## Whereby Integration
- Clear grace period when participants join via webhook events
- Track participant count for grace period management

## Testing
- Add comprehensive tests for multiple active meetings
- Test grace period behavior and participant rejoin scenarios
- Test calendar meeting force closure logic
- All 5 new tests passing

This enables proper calendar integration with overlapping meetings while
preventing accidental meeting closures through the grace period mechanism.
2025-08-18 19:03:41 -06:00
ffcafb3bf2 feat: add Celery background tasks for ICS sync 2025-08-18 17:22:41 -06:00
27075d840c feat: add ICS calendar API endpoints for room configuration and sync 2025-08-18 17:03:23 -06:00
30b5cd45e3 feat: calendar integration 2025-08-18 16:51:30 -06:00
2fccd81bcd fix: use structlog not logging (#550) 2025-08-15 15:41:23 -06:00
1311714451 ci: add pre-commit hook and fix linting issues (#545)
* style: deactivate PLC0415 only on part that it's ok

+ re-run pre-commit run --all

* ci: add pre-commit hook

* build: move from yarn to pnpm

* build: move from yarn to pnpm

* build: fix node-version

* ci: install pnpm prior node (?)

* build: update deps and pnpm trying to fix vercel build

* feat: docker www corepack

* style: pre-commit

---------

Co-authored-by: Igor Loskutov <igor.loskutoff@gmail.com>
2025-08-14 20:59:54 -06:00
b9d891d342 feat: delete recording with transcript (#547)
* Delete recording with transcript

* Delete confirmation dialog

* Use aws storage abstraction for recording deletion

* Test recording deleted with transcript

* Use get transcript storage

* Fix the test

* Add env vars for recording storage
2025-08-14 20:45:30 +02:00
9eab952c63 feat: postgresql migration and removal of sqlite in pytest (#546)
* feat: remove support of sqlite, 100% postgres

* fix: more migration and make datetime timezone aware in postgres

* fix: change how database is get, and use contextvar to have difference instance between different loops

* test: properly use client fixture that handle lifetime/database connection

* fix: add missing client fixture parameters to test functions

This commit fixes NameError issues where test functions were trying to use
the 'client' fixture but didn't have it as a parameter. The changes include:

1. Added 'client' parameter to test functions in:
   - test_transcripts_audio_download.py (6 functions including fixture)
   - test_transcripts_speaker.py (3 functions)
   - test_transcripts_upload.py (1 function)
   - test_transcripts_rtc_ws.py (2 functions + appserver fixture)

2. Resolved naming conflicts in test_transcripts_rtc_ws.py where both HTTP
   client and StreamClient were using variable name 'client'. StreamClient
   instances are now named 'stream_client' to avoid conflicts.

3. Added missing 'from reflector.app import app' import in rtc_ws tests.

Background: Previously implemented contextvars solution with get_database()
function resolves asyncio event loop conflicts in Celery tasks. The global
client fixture was also created to replace manual AsyncClient instances,
ensuring proper FastAPI application lifecycle management and database
connections during tests.

All tests now pass except for 2 pre-existing RTC WebSocket test failures
related to asyncpg connection issues unrelated to these fixes.

* fix: ensure task are correctly closed

* fix: make separate event loop for the live server

* fix: make default settings pointing at postgres

* build: remove pytest-docker deps out of dev, just tests group
2025-08-14 11:40:52 -06:00
Igor Loskutov
6fb5cb21c2 feat: search backend (#537)
* docs: transient docs

* chore: cleanup

* webvtt WIP

* webvtt field

* chore: webvtt tests comments

* chore: remove useless tests

* feat: search TASK.md

* feat: full text search by title/webvtt

* chore: search api task

* feat: search api

* feat: search API

* chore: rm task md

* chore: roll back unnecessary validators

* chore: pr review WIP

* chore: pr review WIP

* chore: pr review

* chore: top imports

* feat: better lint + ci

* feat: better lint + ci

* feat: better lint + ci

* feat: better lint + ci

* chore: lint

* chore: lint

* fix: db datetime definitions

* fix: flush() params

* fix: update transcript mutability expectation / test

* fix: update transcript mutability expectation / test

* chore: auto review

* chore: new controller extraction

* chore: new controller extraction

* chore: cleanup

* chore: review WIP

* chore: pr WIP

* chore: remove ci lint

* chore: openapi regeneration

* chore: openapi regeneration

* chore: postgres test doc

* fix: .dockerignore for arm binaries

* fix: .dockerignore for arm binaries

* fix: cap test loops

* fix: cap test loops

* fix: cap test loops

* fix: get_transcript_topics

* chore: remove flow.md docs and claude guidance

* chore: remove claude.md db doc

* chore: remove claude.md db doc

* chore: remove claude.md db doc

* chore: remove claude.md db doc
2025-08-13 10:03:38 -04:00
Igor Loskutov
a42ed12982 fix: evaluation cli event wrap (#536)
* fix: evaluation cli event wrap

* fix: evaluation cli event wrap

* chore: remove unrelated change

* chore: rollback claude.md changes
2025-08-11 19:28:52 -04:00
1aa52a99b6 chore(main): release 0.6.1 (#539) 2025-08-06 19:38:43 -06:00
dependabot[bot]
2a97290f2e build(deps): bump the npm_and_yarn group across 1 directory with 7 updates (#535)
Bumps the npm_and_yarn group with 6 updates in the /www directory:

| Package | From | To |
| --- | --- | --- |
| [axios](https://github.com/axios/axios) | `1.6.2` | `1.8.2` |
| [postcss](https://github.com/postcss/postcss) | `8.4.25` | `8.4.31` |
| [braces](https://github.com/micromatch/braces) | `3.0.2` | `3.0.3` |
| [cross-spawn](https://github.com/moxystudio/node-cross-spawn) | `7.0.3` | `7.0.6` |
| [micromatch](https://github.com/micromatch/micromatch) | `4.0.5` | `4.0.8` |
| [nanoid](https://github.com/ai/nanoid) | `3.3.6` | `3.3.11` |



Updates `axios` from 1.6.2 to 1.8.2
- [Release notes](https://github.com/axios/axios/releases)
- [Changelog](https://github.com/axios/axios/blob/v1.x/CHANGELOG.md)
- [Commits](https://github.com/axios/axios/compare/v1.6.2...v1.8.2)

Updates `postcss` from 8.4.25 to 8.4.31
- [Release notes](https://github.com/postcss/postcss/releases)
- [Changelog](https://github.com/postcss/postcss/blob/main/CHANGELOG.md)
- [Commits](https://github.com/postcss/postcss/compare/8.4.25...8.4.31)

Updates `braces` from 3.0.2 to 3.0.3
- [Changelog](https://github.com/micromatch/braces/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/braces/compare/3.0.2...3.0.3)

Updates `cross-spawn` from 7.0.3 to 7.0.6
- [Changelog](https://github.com/moxystudio/node-cross-spawn/blob/master/CHANGELOG.md)
- [Commits](https://github.com/moxystudio/node-cross-spawn/compare/v7.0.3...v7.0.6)

Updates `follow-redirects` from 1.15.2 to 1.15.6
- [Release notes](https://github.com/follow-redirects/follow-redirects/releases)
- [Commits](https://github.com/follow-redirects/follow-redirects/compare/v1.15.2...v1.15.6)

Updates `micromatch` from 4.0.5 to 4.0.8
- [Release notes](https://github.com/micromatch/micromatch/releases)
- [Changelog](https://github.com/micromatch/micromatch/blob/master/CHANGELOG.md)
- [Commits](https://github.com/micromatch/micromatch/compare/4.0.5...4.0.8)

Updates `nanoid` from 3.3.6 to 3.3.11
- [Release notes](https://github.com/ai/nanoid/releases)
- [Changelog](https://github.com/ai/nanoid/blob/main/CHANGELOG.md)
- [Commits](https://github.com/ai/nanoid/compare/3.3.6...3.3.11)

---
updated-dependencies:
- dependency-name: axios
  dependency-version: 1.8.2
  dependency-type: direct:production
  dependency-group: npm_and_yarn
- dependency-name: postcss
  dependency-version: 8.4.31
  dependency-type: direct:production
  dependency-group: npm_and_yarn
- dependency-name: braces
  dependency-version: 3.0.3
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: cross-spawn
  dependency-version: 7.0.6
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: follow-redirects
  dependency-version: 1.15.6
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: micromatch
  dependency-version: 4.0.8
  dependency-type: indirect
  dependency-group: npm_and_yarn
- dependency-name: nanoid
  dependency-version: 3.3.11
  dependency-type: indirect
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-08-06 10:23:48 -06:00
7963cc8a52 fix: delayed waveform loading (#538) 2025-08-06 10:22:51 -06:00
d12424848d chore: remove black (#534) 2025-08-05 12:07:53 -06:00
dependabot[bot]
6e765875d5 build(deps): bump @babel/runtime (#530)
Bumps the npm_and_yarn group with 1 update in the /www directory: [@babel/runtime](https://github.com/babel/babel/tree/HEAD/packages/babel-runtime).


Updates `@babel/runtime` from 7.23.6 to 7.28.2
- [Release notes](https://github.com/babel/babel/releases)
- [Changelog](https://github.com/babel/babel/blob/main/CHANGELOG.md)
- [Commits](https://github.com/babel/babel/commits/v7.28.2/packages/babel-runtime)

---
updated-dependencies:
- dependency-name: "@babel/runtime"
  dependency-version: 7.28.2
  dependency-type: indirect
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-08-05 11:41:34 -06:00
dependabot[bot]
e0f4acf28b build(deps): bump form-data (#531)
Bumps the npm_and_yarn group with 1 update in the /www directory: [form-data](https://github.com/form-data/form-data).


Updates `form-data` from 4.0.0 to 4.0.4
- [Release notes](https://github.com/form-data/form-data/releases)
- [Changelog](https://github.com/form-data/form-data/blob/master/CHANGELOG.md)
- [Commits](https://github.com/form-data/form-data/compare/v4.0.0...v4.0.4)

---
updated-dependencies:
- dependency-name: form-data
  dependency-version: 4.0.4
  dependency-type: indirect
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-08-05 11:41:25 -06:00
dependabot[bot]
12359ea4eb build(deps): bump next (#533)
Bumps the npm_and_yarn group with 1 update in the /www directory: [next](https://github.com/vercel/next.js).


Updates `next` from 14.2.7 to 14.2.30
- [Release notes](https://github.com/vercel/next.js/releases)
- [Changelog](https://github.com/vercel/next.js/blob/canary/release.js)
- [Commits](https://github.com/vercel/next.js/compare/v14.2.7...v14.2.30)

---
updated-dependencies:
- dependency-name: next
  dependency-version: 14.2.30
  dependency-type: direct:production
  dependency-group: npm_and_yarn
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-08-05 11:41:10 -06:00
267b7401ea chore(main): release 0.6.0 (#526) 2025-08-04 18:04:10 -06:00
aea9de393c chore(main): release 0.6.0
Release-As: 0.6.0
2025-08-04 18:02:19 -06:00
dc177af3ff feat: implement service-specific Modal API keys with auto processor pattern (#528)
* fix: refactor modal API key configuration for better separation of concerns

- Split generic MODAL_API_KEY into service-specific keys:
  - TRANSCRIPT_API_KEY for transcription service
  - DIARIZATION_API_KEY for diarization service
  - TRANSLATE_API_KEY for translation service
- Remove deprecated *_MODAL_API_KEY settings
- Add proper validation to ensure URLs are set when using modal processors
- Update README with new configuration format

BREAKING CHANGE: Configuration keys have changed. Update your .env file:
- TRANSCRIPT_MODAL_API_KEY → TRANSCRIPT_API_KEY
- LLM_MODAL_API_KEY → (removed, use TRANSCRIPT_API_KEY)
- Add DIARIZATION_API_KEY and TRANSLATE_API_KEY if using those services

* fix: update Modal backend configuration to use service-specific API keys

- Changed from generic MODAL_API_KEY to service-specific keys:
  - TRANSCRIPT_MODAL_API_KEY for transcription
  - DIARIZATION_MODAL_API_KEY for diarization
  - TRANSLATION_MODAL_API_KEY for translation
- Updated audio_transcript_modal.py and audio_diarization_modal.py to use modal_api_key parameter
- Updated documentation in README.md, CLAUDE.md, and env.example

* feat: implement auto/modal pattern for translation processor

- Created TranscriptTranslatorAutoProcessor following the same pattern as transcript/diarization
- Created TranscriptTranslatorModalProcessor with TRANSLATION_MODAL_API_KEY support
- Added TRANSLATION_BACKEND setting (defaults to "modal")
- Updated all imports to use TranscriptTranslatorAutoProcessor instead of TranscriptTranslatorProcessor
- Updated env.example with TRANSLATION_BACKEND and TRANSLATION_MODAL_API_KEY
- Updated test to expect TranscriptTranslatorModalProcessor name
- All tests passing

* refactor: simplify transcript_translator base class to match other processors

- Moved all implementation from base class to modal processor
- Base class now only defines abstract _translate method
- Follows the same minimal pattern as audio_diarization and audio_transcript base classes
- Updated test mock to use _translate instead of get_translation
- All tests passing

* chore: clean up settings and improve type annotations

- Remove deprecated generic API key variables from settings
- Add comments to group Modal-specific settings
- Improve type annotations for modal_api_key parameters

* fix: typing

* fix: passing key to openai

* test: fix rtc test failing due to change on transcript

It also correctly setup database from sqlite, in case our configuration
is setup to postgres.

* ci: deactivate translation backend by default

* test: fix modal->mock

* refactor: implementing igor review, mock to passthrough
2025-08-04 12:07:30 -06:00
5bd8233657 chore: remove refactor md (#527) 2025-08-01 16:33:40 -06:00
28ac031ff6 feat: use llamaindex everywhere (#525)
* feat: use llamaindex for transcript final title too

* refactor: removed llm backend, replaced with one single class+llamaindex

* refactor: self-review

* fix: typing

* fix: tests

* refactor: extract clean_title and add tests

* test: fix

* test: remove ensure_casing/nltk

* fix: tiny mistake
2025-08-01 12:13:00 -06:00
134 changed files with 22566 additions and 10555 deletions

View File

@@ -17,10 +17,40 @@ on:
jobs:
test-migrations:
runs-on: ubuntu-latest
services:
postgres:
image: postgres:17
env:
POSTGRES_USER: reflector
POSTGRES_PASSWORD: reflector
POSTGRES_DB: reflector
ports:
- 5432:5432
options: >-
--health-cmd pg_isready -h 127.0.0.1 -p 5432
--health-interval 10s
--health-timeout 5s
--health-retries 5
env:
DATABASE_URL: postgresql://reflector:reflector@localhost:5432/reflector
steps:
- uses: actions/checkout@v4
- name: Install PostgreSQL client
run: sudo apt-get update && sudo apt-get install -y postgresql-client | cat
- name: Wait for Postgres
run: |
for i in {1..30}; do
if pg_isready -h localhost -p 5432; then
echo "Postgres is ready"
break
fi
echo "Waiting for Postgres... ($i)" && sleep 1
done
- name: Install uv
uses: astral-sh/setup-uv@v3
with:

24
.github/workflows/pre-commit.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: pre-commit
on:
pull_request:
push:
branches: [main]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v5
- uses: actions/setup-python@v5
- uses: pnpm/action-setup@v4
with:
version: 10
- uses: actions/setup-node@v4
with:
node-version: 22
cache: "pnpm"
cache-dependency-path: "www/pnpm-lock.yaml"
- name: Install dependencies
run: cd www && pnpm install --frozen-lockfile
- uses: pre-commit/action@v3.0.1

2
.gitignore vendored
View File

@@ -13,3 +13,5 @@ restart-dev.sh
data/
www/REFACTOR.md
www/reload-frontend
server/test.sqlite
CLAUDE.local.md

View File

@@ -3,10 +3,10 @@
repos:
- repo: local
hooks:
- id: yarn-format
name: run yarn format
- id: format
name: run format
language: system
entry: bash -c 'cd www && yarn format'
entry: bash -c 'cd www && pnpm format'
pass_filenames: false
files: ^www/
@@ -23,8 +23,7 @@ repos:
- id: ruff
args:
- --fix
- --select
- I,F401
# Uses select rules from server/pyproject.toml
files: ^server/
- id: ruff-format
files: ^server/

View File

@@ -1,5 +1,32 @@
# Changelog
## [0.6.1](https://github.com/Monadical-SAS/reflector/compare/v0.6.0...v0.6.1) (2025-08-06)
### Bug Fixes
* delayed waveform loading ([#538](https://github.com/Monadical-SAS/reflector/issues/538)) ([ef64146](https://github.com/Monadical-SAS/reflector/commit/ef64146325d03f64dd9a1fe40234fb3e7e957ae2))
## [0.6.0](https://github.com/Monadical-SAS/reflector/compare/v0.5.0...v0.6.0) (2025-08-05)
### ⚠ BREAKING CHANGES
* Configuration keys have changed. Update your .env file:
- TRANSCRIPT_MODAL_API_KEY → TRANSCRIPT_API_KEY
- LLM_MODAL_API_KEY → (removed, use TRANSCRIPT_API_KEY)
- Add DIARIZATION_API_KEY and TRANSLATE_API_KEY if using those services
### Features
* implement service-specific Modal API keys with auto processor pattern ([#528](https://github.com/Monadical-SAS/reflector/issues/528)) ([650befb](https://github.com/Monadical-SAS/reflector/commit/650befb291c47a1f49e94a01ab37d8fdfcd2b65d))
* use llamaindex everywhere ([#525](https://github.com/Monadical-SAS/reflector/issues/525)) ([3141d17](https://github.com/Monadical-SAS/reflector/commit/3141d172bc4d3b3d533370c8e6e351ea762169bf))
### Miscellaneous Chores
* **main:** release 0.6.0 ([ecdbf00](https://github.com/Monadical-SAS/reflector/commit/ecdbf003ea2476c3e95fd231adaeb852f2943df0))
## [0.5.0](https://github.com/Monadical-SAS/reflector/compare/v0.4.0...v0.5.0) (2025-07-31)

View File

@@ -62,7 +62,7 @@ uv run python -m reflector.tools.process path/to/audio.wav
**Setup:**
```bash
# Install dependencies
yarn install
pnpm install
# Copy configuration templates
cp .env_template .env
@@ -72,19 +72,19 @@ cp config-template.ts config.ts
**Development:**
```bash
# Start development server
yarn dev
pnpm dev
# Generate TypeScript API client from OpenAPI spec
yarn openapi
pnpm openapi
# Lint code
yarn lint
pnpm lint
# Format code
yarn format
pnpm format
# Build for production
yarn build
pnpm build
```
### Docker Compose (Full Stack)
@@ -144,7 +144,9 @@ All endpoints prefixed `/v1/`:
**Backend** (`server/.env`):
- `DATABASE_URL` - Database connection string
- `REDIS_URL` - Redis broker for Celery
- `MODAL_TOKEN_ID`, `MODAL_TOKEN_SECRET` - Modal.com GPU processing
- `TRANSCRIPT_BACKEND=modal` + `TRANSCRIPT_MODAL_API_KEY` - Modal.com transcription
- `DIARIZATION_BACKEND=modal` + `DIARIZATION_MODAL_API_KEY` - Modal.com diarization
- `TRANSLATION_BACKEND=modal` + `TRANSLATION_MODAL_API_KEY` - Modal.com translation
- `WHEREBY_API_KEY` - Video platform integration
- `REFLECTOR_AUTH_BACKEND` - Authentication method (none, jwt)

497
ICS_IMPLEMENTATION.md Normal file
View File

@@ -0,0 +1,497 @@
# ICS Calendar Integration - Implementation Guide
## Overview
This document provides detailed implementation guidance for integrating ICS calendar feeds with Reflector rooms. Unlike CalDAV which requires complex authentication and protocol handling, ICS integration uses simple HTTP(S) fetching of calendar files.
## Key Differences from CalDAV Approach
| Aspect | CalDAV | ICS |
|--------|--------|-----|
| Protocol | WebDAV extension | HTTP/HTTPS GET |
| Authentication | Username/password, OAuth | Tokens embedded in URL |
| Data Access | Selective event queries | Full calendar download |
| Implementation | Complex (caldav library) | Simple (requests + icalendar) |
| Real-time Updates | Supported | Polling only |
| Write Access | Yes | No (read-only) |
## Technical Architecture
### 1. ICS Fetching Service
```python
# reflector/services/ics_sync.py
import requests
from icalendar import Calendar
from typing import List, Optional
from datetime import datetime, timedelta
class ICSFetchService:
def __init__(self):
self.session = requests.Session()
self.session.headers.update({'User-Agent': 'Reflector/1.0'})
def fetch_ics(self, url: str) -> str:
"""Fetch ICS file from URL (authentication via URL token if needed)."""
response = self.session.get(url, timeout=30)
response.raise_for_status()
return response.text
def parse_ics(self, ics_content: str) -> Calendar:
"""Parse ICS content into calendar object."""
return Calendar.from_ical(ics_content)
def extract_room_events(self, calendar: Calendar, room_url: str) -> List[dict]:
"""Extract events that match the room URL."""
events = []
for component in calendar.walk():
if component.name == "VEVENT":
# Check if event matches this room
if self._event_matches_room(component, room_url):
events.append(self._parse_event(component))
return events
def _event_matches_room(self, event, room_url: str) -> bool:
"""Check if event location or description contains room URL."""
location = str(event.get('LOCATION', ''))
description = str(event.get('DESCRIPTION', ''))
# Support various URL formats
patterns = [
room_url,
room_url.replace('https://', ''),
room_url.split('/')[-1], # Just room name
]
for pattern in patterns:
if pattern in location or pattern in description:
return True
return False
```
### 2. Database Schema
```sql
-- Modify room table
ALTER TABLE room ADD COLUMN ics_url TEXT; -- encrypted to protect embedded tokens
ALTER TABLE room ADD COLUMN ics_fetch_interval INTEGER DEFAULT 300; -- seconds
ALTER TABLE room ADD COLUMN ics_enabled BOOLEAN DEFAULT FALSE;
ALTER TABLE room ADD COLUMN ics_last_sync TIMESTAMP;
ALTER TABLE room ADD COLUMN ics_last_etag TEXT; -- for caching
-- Calendar events table
CREATE TABLE calendar_event (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
room_id UUID REFERENCES room(id) ON DELETE CASCADE,
external_id TEXT NOT NULL, -- ICS UID
title TEXT,
description TEXT,
start_time TIMESTAMP NOT NULL,
end_time TIMESTAMP NOT NULL,
attendees JSONB,
location TEXT,
ics_raw_data TEXT, -- Store raw VEVENT for reference
last_synced TIMESTAMP DEFAULT NOW(),
is_deleted BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP DEFAULT NOW(),
updated_at TIMESTAMP DEFAULT NOW(),
UNIQUE(room_id, external_id)
);
-- Index for efficient queries
CREATE INDEX idx_calendar_event_room_start ON calendar_event(room_id, start_time);
CREATE INDEX idx_calendar_event_deleted ON calendar_event(is_deleted) WHERE NOT is_deleted;
```
### 3. Background Tasks
```python
# reflector/worker/tasks/ics_sync.py
from celery import shared_task
from datetime import datetime, timedelta
import hashlib
@shared_task
def sync_ics_calendars():
"""Sync all enabled ICS calendars based on their fetch intervals."""
rooms = Room.query.filter_by(ics_enabled=True).all()
for room in rooms:
# Check if it's time to sync based on fetch interval
if should_sync(room):
sync_room_calendar.delay(room.id)
@shared_task
def sync_room_calendar(room_id: str):
"""Sync calendar for a specific room."""
room = Room.query.get(room_id)
if not room or not room.ics_enabled:
return
try:
# Fetch ICS file (decrypt URL first)
service = ICSFetchService()
decrypted_url = decrypt_ics_url(room.ics_url)
ics_content = service.fetch_ics(decrypted_url)
# Check if content changed (using ETag or hash)
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
if room.ics_last_etag == content_hash:
logger.info(f"No changes in ICS for room {room_id}")
return
# Parse and extract events
calendar = service.parse_ics(ics_content)
events = service.extract_room_events(calendar, room.url)
# Update database
sync_events_to_database(room_id, events)
# Update sync metadata
room.ics_last_sync = datetime.utcnow()
room.ics_last_etag = content_hash
db.session.commit()
except Exception as e:
logger.error(f"Failed to sync ICS for room {room_id}: {e}")
def should_sync(room) -> bool:
"""Check if room calendar should be synced."""
if not room.ics_last_sync:
return True
time_since_sync = datetime.utcnow() - room.ics_last_sync
return time_since_sync.total_seconds() >= room.ics_fetch_interval
```
### 4. Celery Beat Schedule
```python
# reflector/worker/celeryconfig.py
from celery.schedules import crontab
beat_schedule = {
'sync-ics-calendars': {
'task': 'reflector.worker.tasks.ics_sync.sync_ics_calendars',
'schedule': 60.0, # Check every minute which calendars need syncing
},
'pre-create-meetings': {
'task': 'reflector.worker.tasks.ics_sync.pre_create_calendar_meetings',
'schedule': 60.0, # Check every minute for upcoming meetings
},
}
```
## API Endpoints
### Room ICS Configuration
```python
# PATCH /v1/rooms/{room_id}
{
"ics_url": "https://calendar.google.com/calendar/ical/.../private-token/basic.ics",
"ics_fetch_interval": 300, # seconds
"ics_enabled": true
# URL will be encrypted in database to protect embedded tokens
}
```
### Manual Sync Trigger
```python
# POST /v1/rooms/{room_name}/ics/sync
# Response:
{
"status": "syncing",
"last_sync": "2024-01-15T10:30:00Z",
"events_found": 5
}
```
### ICS Status
```python
# GET /v1/rooms/{room_name}/ics/status
# Response:
{
"enabled": true,
"last_sync": "2024-01-15T10:30:00Z",
"next_sync": "2024-01-15T10:35:00Z",
"fetch_interval": 300,
"events_count": 12,
"upcoming_events": 3
}
```
## ICS Parsing Details
### Event Field Mapping
| ICS Field | Database Field | Notes |
|-----------|---------------|-------|
| UID | external_id | Unique identifier |
| SUMMARY | title | Event title |
| DESCRIPTION | description | Full description |
| DTSTART | start_time | Convert to UTC |
| DTEND | end_time | Convert to UTC |
| LOCATION | location | Check for room URL |
| ATTENDEE | attendees | Parse into JSON |
| ORGANIZER | attendees | Add as organizer |
| STATUS | (internal) | Filter cancelled events |
### Handling Recurring Events
```python
def expand_recurring_events(event, start_date, end_date):
"""Expand recurring events into individual occurrences."""
from dateutil.rrule import rrulestr
if 'RRULE' not in event:
return [event]
# Parse recurrence rule
rrule_str = event['RRULE'].to_ical().decode()
dtstart = event['DTSTART'].dt
# Generate occurrences
rrule = rrulestr(rrule_str, dtstart=dtstart)
occurrences = []
for dt in rrule.between(start_date, end_date):
# Clone event with new date
occurrence = event.copy()
occurrence['DTSTART'].dt = dt
if 'DTEND' in event:
duration = event['DTEND'].dt - event['DTSTART'].dt
occurrence['DTEND'].dt = dt + duration
# Unique ID for each occurrence
occurrence['UID'] = f"{event['UID']}_{dt.isoformat()}"
occurrences.append(occurrence)
return occurrences
```
### Timezone Handling
```python
def normalize_datetime(dt):
"""Convert various datetime formats to UTC."""
import pytz
from datetime import datetime
if hasattr(dt, 'dt'): # icalendar property
dt = dt.dt
if isinstance(dt, datetime):
if dt.tzinfo is None:
# Assume local timezone if naive
dt = pytz.timezone('UTC').localize(dt)
else:
# Convert to UTC
dt = dt.astimezone(pytz.UTC)
return dt
```
## Security Considerations
### 1. URL Validation
```python
def validate_ics_url(url: str) -> bool:
"""Validate ICS URL for security."""
from urllib.parse import urlparse
parsed = urlparse(url)
# Must be HTTPS in production
if not settings.DEBUG and parsed.scheme != 'https':
return False
# Prevent local file access
if parsed.scheme in ('file', 'ftp'):
return False
# Prevent internal network access
if is_internal_ip(parsed.hostname):
return False
return True
```
### 2. Rate Limiting
```python
# Implement per-room rate limiting
RATE_LIMITS = {
'min_fetch_interval': 60, # Minimum 1 minute between fetches
'max_requests_per_hour': 60, # Max 60 requests per hour per room
'max_file_size': 10 * 1024 * 1024, # Max 10MB ICS file
}
```
### 3. ICS URL Encryption
```python
from cryptography.fernet import Fernet
class URLEncryption:
def __init__(self):
self.cipher = Fernet(settings.ENCRYPTION_KEY)
def encrypt_url(self, url: str) -> str:
"""Encrypt ICS URL to protect embedded tokens."""
return self.cipher.encrypt(url.encode()).decode()
def decrypt_url(self, encrypted: str) -> str:
"""Decrypt ICS URL for fetching."""
return self.cipher.decrypt(encrypted.encode()).decode()
def mask_url(self, url: str) -> str:
"""Mask sensitive parts of URL for display."""
from urllib.parse import urlparse, urlunparse
parsed = urlparse(url)
# Keep scheme, host, and path structure but mask tokens
if '/private-' in parsed.path:
# Google Calendar format
parts = parsed.path.split('/private-')
masked_path = parts[0] + '/private-***' + parts[1].split('/')[-1]
elif 'token=' in url:
# Query parameter token
masked_path = parsed.path
parsed = parsed._replace(query='token=***')
else:
# Generic masking of path segments that look like tokens
import re
masked_path = re.sub(r'/[a-zA-Z0-9]{20,}/', '/***/', parsed.path)
return urlunparse(parsed._replace(path=masked_path))
```
## Testing Strategy
### 1. Unit Tests
```python
# tests/test_ics_sync.py
def test_ics_parsing():
"""Test ICS file parsing."""
ics_content = """BEGIN:VCALENDAR
VERSION:2.0
BEGIN:VEVENT
UID:test-123
SUMMARY:Team Meeting
LOCATION:https://reflector.monadical.com/engineering
DTSTART:20240115T100000Z
DTEND:20240115T110000Z
END:VEVENT
END:VCALENDAR"""
service = ICSFetchService()
calendar = service.parse_ics(ics_content)
events = service.extract_room_events(
calendar,
"https://reflector.monadical.com/engineering"
)
assert len(events) == 1
assert events[0]['title'] == 'Team Meeting'
```
### 2. Integration Tests
```python
def test_full_sync_flow():
"""Test complete sync workflow."""
# Create room with ICS URL (encrypt URL to protect tokens)
encryption = URLEncryption()
room = Room(
name="test-room",
ics_url=encryption.encrypt_url("https://example.com/calendar.ics?token=secret"),
ics_enabled=True
)
# Mock ICS fetch
with patch('requests.get') as mock_get:
mock_get.return_value.text = sample_ics_content
# Run sync
sync_room_calendar(room.id)
# Verify events created
events = CalendarEvent.query.filter_by(room_id=room.id).all()
assert len(events) > 0
```
## Common ICS Provider Configurations
### Google Calendar
- URL Format: `https://calendar.google.com/calendar/ical/{calendar_id}/private-{token}/basic.ics`
- Authentication via token embedded in URL
- Updates every 3-8 hours by default
### Outlook/Office 365
- URL Format: `https://outlook.office365.com/owa/calendar/{id}/calendar.ics`
- May include token in URL path or query parameters
- Real-time updates
### Apple iCloud
- URL Format: `webcal://p{XX}-caldav.icloud.com/published/2/{token}`
- Convert webcal:// to https://
- Token embedded in URL path
- Public calendars only
### Nextcloud/ownCloud
- URL Format: `https://cloud.example.com/remote.php/dav/public-calendars/{token}`
- Token embedded in URL path
- Configurable update frequency
## Migration from CalDAV
If migrating from an existing CalDAV implementation:
1. **Database Migration**: Rename fields from `caldav_*` to `ics_*`
2. **URL Conversion**: Most CalDAV servers provide ICS export endpoints
3. **Authentication**: Convert from username/password to URL-embedded tokens
4. **Remove Dependencies**: Uninstall caldav library, add icalendar
5. **Update Background Tasks**: Replace CalDAV sync with ICS fetch
## Performance Optimizations
1. **Caching**: Use ETag/Last-Modified headers to avoid refetching unchanged calendars
2. **Incremental Sync**: Store last sync timestamp, only process new/modified events
3. **Batch Processing**: Process multiple room calendars in parallel
4. **Connection Pooling**: Reuse HTTP connections for multiple requests
5. **Compression**: Support gzip encoding for large ICS files
## Monitoring and Debugging
### Metrics to Track
- Sync success/failure rate per room
- Average sync duration
- ICS file sizes
- Number of events processed
- Failed event matches
### Debug Logging
```python
logger.debug(f"Fetching ICS from {room.ics_url}")
logger.debug(f"ICS content size: {len(ics_content)} bytes")
logger.debug(f"Found {len(events)} matching events")
logger.debug(f"Event UIDs: {[e['external_id'] for e in events]}")
```
### Common Issues
1. **SSL Certificate Errors**: Add certificate validation options
2. **Timeout Issues**: Increase timeout for large calendars
3. **Encoding Problems**: Handle various character encodings
4. **Timezone Mismatches**: Always convert to UTC
5. **Memory Issues**: Stream large ICS files instead of loading entirely

337
PLAN.md Normal file
View File

@@ -0,0 +1,337 @@
# ICS Calendar Integration Plan
## Core Concept
ICS calendar URLs are attached to rooms (not users) to enable automatic meeting tracking and management through periodic fetching of calendar data.
## Database Schema Updates
### 1. Add ICS configuration to rooms
- Add `ics_url` field to room table (URL to .ics file, may include auth token)
- Add `ics_fetch_interval` field to room table (default: 5 minutes, configurable)
- Add `ics_enabled` boolean field to room table
- Add `ics_last_sync` timestamp field to room table
### 2. Create calendar_events table
- `id` - UUID primary key
- `room_id` - Foreign key to room
- `external_id` - ICS event UID
- `title` - Event title
- `description` - Event description
- `start_time` - Event start timestamp
- `end_time` - Event end timestamp
- `attendees` - JSON field with attendee list and status
- `location` - Meeting location (should contain room name)
- `last_synced` - Last sync timestamp
- `is_deleted` - Boolean flag for soft delete (preserve past events)
- `ics_raw_data` - TEXT field to store raw VEVENT data for reference
### 3. Update meeting table
- Add `calendar_event_id` - Foreign key to calendar_events
- Add `calendar_metadata` - JSON field for additional calendar data
- Remove unique constraint on room_id + active status (allow multiple active meetings per room)
## Backend Implementation
### 1. ICS Sync Service
- Create background task that runs based on room's `ics_fetch_interval` (default: 5 minutes)
- For each room with ICS enabled, fetch the .ics file via HTTP/HTTPS
- Parse ICS file using icalendar library
- Extract VEVENT components and filter events looking for room URL (e.g., "https://reflector.monadical.com/max")
- Store matching events in calendar_events table
- Mark events as "upcoming" if start_time is within next 30 minutes
- Pre-create Whereby meetings 1 minute before start (ensures no delay when users join)
- Soft-delete future events that were removed from calendar (set is_deleted=true)
- Never delete past events (preserve for historical record)
- Support authenticated ICS feeds via tokens embedded in URL
### 2. Meeting Management Updates
- Allow multiple active meetings per room
- Pre-create meeting record 1 minute before calendar event starts (ensures meeting is ready)
- Link meeting to calendar_event for metadata
- Keep meeting active for 15 minutes after last participant leaves (grace period)
- Don't auto-close if new participant joins within grace period
### 3. API Endpoints
- `GET /v1/rooms/{room_name}/meetings` - List all active and upcoming meetings for a room
- Returns filtered data based on user role (owner vs participant)
- `GET /v1/rooms/{room_name}/meetings/upcoming` - List upcoming meetings (next 30 min)
- Returns filtered data based on user role
- `POST /v1/rooms/{room_name}/meetings/{meeting_id}/join` - Join specific meeting
- `PATCH /v1/rooms/{room_id}` - Update room settings (including ICS configuration)
- ICS fields only visible/editable by room owner
- `POST /v1/rooms/{room_name}/ics/sync` - Trigger manual ICS sync
- Only accessible by room owner
- `GET /v1/rooms/{room_name}/ics/status` - Get ICS sync status and last fetch time
- Only accessible by room owner
## Frontend Implementation
### 1. Room Settings Page
- Add ICS configuration section
- Field for ICS URL (e.g., Google Calendar private URL, Outlook ICS export)
- Field for fetch interval (dropdown: 1 min, 5 min, 10 min, 30 min, 1 hour)
- Test connection button (validates ICS file can be fetched and parsed)
- Manual sync button
- Show last sync time and next scheduled sync
### 2. Meeting Selection Page (New)
- Show when accessing `/room/{room_name}`
- **Host view** (room owner):
- Full calendar event details
- Meeting title and description
- Complete attendee list with RSVP status
- Number of current participants
- Duration (how long it's been running)
- **Participant view** (non-owners):
- Meeting title only
- Date and time
- Number of current participants
- Duration (how long it's been running)
- No attendee list or description (privacy)
- Display upcoming meetings (visible 30min before):
- Show countdown to start
- Can click to join early → redirected to waiting page
- Waiting page shows countdown until meeting starts
- Meeting pre-created by background task (ready when users arrive)
- Option to create unscheduled meeting (uses existing flow)
### 3. Meeting Room Updates
- Show calendar metadata in meeting info
- Display invited attendees vs actual participants
- Show meeting title from calendar event
## Meeting Lifecycle
### 1. Meeting Creation
- Automatic: Pre-created 1 minute before calendar event starts (ensures Whereby room is ready)
- Manual: User creates unscheduled meeting (existing `/rooms/{room_name}/meeting` endpoint)
- Background task handles pre-creation to avoid delays when users join
### 2. Meeting Join Rules
- Can join active meetings immediately
- Can see upcoming meetings 30 minutes before start
- Can click to join upcoming meetings early → sent to waiting page
- Waiting page automatically transitions to meeting at scheduled time
- Unscheduled meetings always joinable (current behavior)
### 3. Meeting Closure Rules
- All meetings: 15-minute grace period after last participant leaves
- If participant rejoins within grace period, keep meeting active
- Calendar meetings: Force close 30 minutes after scheduled end time
- Unscheduled meetings: Keep active for 8 hours (current behavior)
## ICS Parsing Logic
### 1. Event Matching
- Parse ICS file using Python icalendar library
- Iterate through VEVENT components
- Check LOCATION field for full FQDN URL (e.g., "https://reflector.monadical.com/max")
- Check DESCRIPTION for room URL or mention
- Support multiple formats:
- Full URL: "https://reflector.monadical.com/max"
- With /room path: "https://reflector.monadical.com/room/max"
- Partial paths: "room/max", "/max room"
### 2. Attendee Extraction
- Parse ATTENDEE properties from VEVENT
- Extract email (MAILTO), name (CN parameter), and RSVP status (PARTSTAT)
- Store as JSON in calendar_events.attendees
### 3. Sync Strategy
- Fetch complete ICS file (contains all events)
- Filter events from (now - 1 hour) to (now + 24 hours) for processing
- Update existing events if LAST-MODIFIED or SEQUENCE changed
- Delete future events that no longer exist in ICS (start_time > now)
- Keep past events for historical record (never delete if start_time < now)
- Handle recurring events (RRULE) - expand to individual instances
- Track deleted calendar events to clean up future meetings
- Cache ICS file hash to detect changes and skip unnecessary processing
## Security Considerations
### 1. ICS URL Security
- ICS URLs may contain authentication tokens (e.g., Google Calendar private URLs)
- Store full ICS URLs encrypted using Fernet to protect embedded tokens
- Validate ICS URLs (must be HTTPS for production)
- Never expose full ICS URLs in API responses (return masked version)
- Rate limit ICS fetching to prevent abuse
### 2. Room Access
- Only room owner can configure ICS URL
- ICS URL shown as masked version to room owner (hides embedded tokens)
- ICS settings not visible to other users
- Meeting list visible to all room participants
- ICS fetch logs only visible to room owner
### 3. Meeting Privacy
- Full calendar details visible only to room owner
- Participants see limited info: title, date/time only
- Attendee list and description hidden from non-owners
- Meeting titles visible in room listing to all
## Implementation Phases
### Phase 1: Database and ICS Setup (Week 1) ✅ COMPLETED (2025-08-18)
1. ✅ Created database migrations for ICS fields and calendar_events table
- Added ics_url, ics_fetch_interval, ics_enabled, ics_last_sync, ics_last_etag to room table
- Created calendar_event table with ics_uid (instead of external_id) and proper typing
- Added calendar_event_id and calendar_metadata (JSONB) to meeting table
- Removed server_default from datetime fields for consistency
2. ✅ Installed icalendar Python library for ICS parsing
- Added icalendar>=6.0.0 to dependencies
- No encryption needed - ICS URLs are read-only
3. ✅ Built ICS fetch and sync service
- Simple HTTP fetching without unnecessary validation
- Proper TypedDict typing for event data structures
- Supports any standard ICS format
- Event matching on full room URL only
4. ✅ API endpoints for ICS configuration
- Room model updated to support ICS fields via existing PATCH endpoint
- POST /v1/rooms/{room_name}/ics/sync - Trigger manual sync (owner only)
- GET /v1/rooms/{room_name}/ics/status - Get sync status (owner only)
- GET /v1/rooms/{room_name}/meetings - List meetings with privacy controls
- GET /v1/rooms/{room_name}/meetings/upcoming - List upcoming meetings
5. ✅ Celery background tasks for periodic sync
- sync_room_ics - Sync individual room calendar
- sync_all_ics_calendars - Check all rooms and queue sync based on fetch intervals
- pre_create_upcoming_meetings - Pre-create Whereby meetings 1 minute before start
- Tasks scheduled in beat schedule (every minute for checking, respects individual intervals)
6. ✅ Tests written and passing
- 6 tests for Room ICS fields
- 7 tests for CalendarEvent model
- 7 tests for ICS sync service
- 11 tests for API endpoints
- 6 tests for background tasks
- All 31 ICS-related tests passing
### Phase 2: Meeting Management (Week 2) ✅ COMPLETED (2025-08-19)
1. ✅ Updated meeting lifecycle logic with grace period support
- 15-minute grace period after last participant leaves
- Automatic reactivation when participants rejoin
- Force close calendar meetings 30 minutes after scheduled end
2. ✅ Support multiple active meetings per room
- Removed unique constraint on active meetings
- Added get_all_active_for_room() method
- Added get_active_by_calendar_event() method
3. ✅ Implemented grace period logic
- Added last_participant_left_at and grace_period_minutes fields
- Process meetings task handles grace period checking
- Whereby webhooks clear grace period on participant join
4. ✅ Link meetings to calendar events
- Pre-created meetings properly linked via calendar_event_id
- Calendar metadata stored with meeting
- API endpoints for listing and joining specific meetings
### Phase 3: Frontend Meeting Selection (Week 3)
1. Build meeting selection page
2. Show active and upcoming meetings
3. Implement waiting page for early joiners
4. Add automatic transition from waiting to meeting
5. Support unscheduled meeting creation
### Phase 4: Calendar Integration UI (Week 4)
1. Add ICS settings to room configuration
2. Display calendar metadata in meetings
3. Show attendee information
4. Add sync status indicators
5. Show fetch interval and next sync time
## Success Metrics
- Zero merged meetings from consecutive calendar events
- Successful ICS sync from major providers (Google Calendar, Outlook, Apple Calendar, Nextcloud)
- Meeting join accuracy: correct meeting 100% of the time
- Grace period prevents 90% of accidental meeting closures
- Configurable fetch intervals reduce unnecessary API calls
## Design Decisions
1. **ICS attached to room, not user** - Prevents duplicate meetings from multiple calendars
2. **Multiple active meetings per room** - Supported with meeting selection page
3. **Grace period for rejoining** - 15 minutes after last participant leaves
4. **Upcoming meeting visibility** - Show 30 minutes before, join only on time
5. **Calendar data storage** - Attached to meeting record for full context
6. **No "ad-hoc" meetings** - Use existing meeting creation flow (unscheduled meetings)
7. **ICS configuration via room PATCH** - Reuse existing room configuration endpoint
8. **Event deletion handling** - Soft-delete future events, preserve past meetings
9. **Configurable fetch interval** - Balance between freshness and server load
10. **ICS over CalDAV** - Simpler implementation, wider compatibility, no complex auth
## Phase 2 Implementation Files
### Database Migrations
- `/server/migrations/versions/6025e9b2bef2_remove_one_active_meeting_per_room_.py` - Remove unique constraint
- `/server/migrations/versions/d4a1c446458c_add_grace_period_fields_to_meeting.py` - Add grace period fields
### Updated Models
- `/server/reflector/db/meetings.py` - Added grace period fields and new query methods
### Updated Services
- `/server/reflector/worker/process.py` - Enhanced with grace period logic and multiple meeting support
### Updated API
- `/server/reflector/views/rooms.py` - Added endpoints for listing active meetings and joining specific meetings
- `/server/reflector/views/whereby.py` - Clear grace period on participant join
### Tests
- `/server/tests/test_multiple_active_meetings.py` - Comprehensive tests for Phase 2 features (5 tests)
## Phase 1 Implementation Files Created
### Database Models
- `/server/reflector/db/rooms.py` - Updated with ICS fields (url, fetch_interval, enabled, last_sync, etag)
- `/server/reflector/db/calendar_events.py` - New CalendarEvent model with ics_uid and proper typing
- `/server/reflector/db/meetings.py` - Updated with calendar_event_id and calendar_metadata (JSONB)
### Services
- `/server/reflector/services/ics_sync.py` - ICS fetching and parsing with TypedDict for proper typing
### API Endpoints
- `/server/reflector/views/rooms.py` - Added ICS management endpoints with privacy controls
### Background Tasks
- `/server/reflector/worker/ics_sync.py` - Celery tasks for automatic periodic sync
- `/server/reflector/worker/app.py` - Updated beat schedule for ICS tasks
### Tests
- `/server/tests/test_room_ics.py` - Room model ICS fields tests (6 tests)
- `/server/tests/test_calendar_event.py` - CalendarEvent model tests (7 tests)
- `/server/tests/test_ics_sync.py` - ICS sync service tests (7 tests)
- `/server/tests/test_room_ics_api.py` - API endpoint tests (11 tests)
- `/server/tests/test_ics_background_tasks.py` - Background task tests (6 tests)
### Key Design Decisions
- No encryption needed - ICS URLs are read-only access
- Using ics_uid instead of external_id for clarity
- Proper TypedDict typing for event data structures
- Removed unnecessary URL validation and webcal handling
- calendar_metadata in meetings stores flexible calendar data (organizer, recurrence, etc)
- Background tasks query all rooms directly to avoid filtering issues
- Sync intervals respected per-room configuration
## Implementation Approach
### ICS Fetching vs CalDAV
- **ICS Benefits**:
- Simpler implementation (HTTP GET vs CalDAV protocol)
- Wider compatibility (all calendar apps can export ICS)
- No authentication complexity (simple URL with optional token)
- Easier debugging (ICS is plain text)
- Lower server requirements (no CalDAV library dependencies)
### Supported Calendar Providers
1. **Google Calendar**: Private ICS URL from calendar settings
2. **Outlook/Office 365**: ICS export URL from calendar sharing
3. **Apple Calendar**: Published calendar ICS URL
4. **Nextcloud**: Public/private calendar ICS export
5. **Any CalDAV server**: Via ICS export endpoint
### ICS URL Examples
- Google: `https://calendar.google.com/calendar/ical/{calendar_id}/private-{token}/basic.ics`
- Outlook: `https://outlook.live.com/owa/calendar/{id}/calendar.ics`
- Custom: `https://example.com/calendars/room-schedule.ics`
### Fetch Interval Configuration
- 1 minute: For critical/high-activity rooms
- 5 minutes (default): Balance of freshness and efficiency
- 10 minutes: Standard meeting rooms
- 30 minutes: Low-activity rooms
- 1 hour: Rarely-used rooms or stable schedules

View File

@@ -79,7 +79,7 @@ Start with `cd www`.
**Installation**
```bash
yarn install
pnpm install
cp .env_template .env
cp config-template.ts config.ts
```
@@ -89,7 +89,7 @@ Then, fill in the environment variables in `.env` and the configuration in `conf
**Run in development mode**
```bash
yarn dev
pnpm dev
```
Then (after completing server setup and starting it) open [http://localhost:3000](http://localhost:3000) to view it in the browser.
@@ -99,7 +99,7 @@ Then (after completing server setup and starting it) open [http://localhost:3000
To generate the TypeScript files from the openapi.json file, make sure the python server is running, then run:
```bash
yarn openapi
pnpm openapi
```
### Backend

View File

@@ -39,11 +39,12 @@ services:
image: node:18
ports:
- "3000:3000"
command: sh -c "yarn install && yarn dev"
command: sh -c "corepack enable && pnpm install && pnpm dev"
restart: unless-stopped
working_dir: /app
volumes:
- ./www:/app/
- /app/node_modules
env_file:
- ./www/.env.local

View File

@@ -24,7 +24,6 @@ AUTH_JWT_AUDIENCE=
## Using serverless modal.com (require reflector-gpu-modal deployed)
#TRANSCRIPT_BACKEND=modal
#TRANSCRIPT_URL=https://xxxxx--reflector-transcriber-web.modal.run
#TRANSLATE_URL=https://xxxxx--reflector-translator-web.modal.run
#TRANSCRIPT_MODAL_API_KEY=xxxxx
TRANSCRIPT_BACKEND=modal
@@ -32,11 +31,13 @@ TRANSCRIPT_URL=https://monadical-sas--reflector-transcriber-web.modal.run
TRANSCRIPT_MODAL_API_KEY=
## =======================================================
## Transcription backend
## Translation backend
##
## Only available in modal atm
## =======================================================
TRANSLATION_BACKEND=modal
TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
#TRANSLATION_MODAL_API_KEY=xxxxx
## =======================================================
## LLM backend
@@ -46,38 +47,11 @@ TRANSLATE_URL=https://monadical-sas--reflector-translator-web.modal.run
## llm backend implementation
## =======================================================
## Using serverless modal.com (require reflector-gpu-modal deployed)
LLM_BACKEND=modal
LLM_URL=https://monadical-sas--reflector-llm-web.modal.run
LLM_MODAL_API_KEY=
ZEPHYR_LLM_URL=https://monadical-sas--reflector-llm-zephyr-web.modal.run
## Using OpenAI
#LLM_BACKEND=openai
#LLM_OPENAI_KEY=xxx
#LLM_OPENAI_MODEL=gpt-3.5-turbo
## Using GPT4ALL
#LLM_BACKEND=openai
#LLM_URL=http://localhost:4891/v1/completions
#LLM_OPENAI_MODEL="GPT4All Falcon"
## Default LLM MODEL NAME
#DEFAULT_LLM=lmsys/vicuna-13b-v1.5
## Cache directory to store models
CACHE_DIR=data
## =======================================================
## Summary LLM configuration
## =======================================================
## Context size for summary generation (tokens)
SUMMARY_LLM_CONTEXT_SIZE_TOKENS=16000
SUMMARY_LLM_URL=
SUMMARY_LLM_API_KEY=sk-
SUMMARY_MODEL=
# LLM_MODEL=microsoft/phi-4
LLM_CONTEXT_WINDOW=16000
LLM_URL=
LLM_API_KEY=sk-
## =======================================================
## Diarization
@@ -86,7 +60,9 @@ SUMMARY_MODEL=
## To allow diarization, you need to expose expose the files to be dowloded by the pipeline
## =======================================================
DIARIZATION_ENABLED=false
DIARIZATION_BACKEND=modal
DIARIZATION_URL=https://monadical-sas--reflector-diarizer-web.modal.run
#DIARIZATION_MODAL_API_KEY=xxxxx
## =======================================================

View File

@@ -3,8 +3,9 @@
This repository hold an API for the GPU implementation of the Reflector API service,
and use [Modal.com](https://modal.com)
- `reflector_llm.py` - LLM API
- `reflector_diarizer.py` - Diarization API
- `reflector_transcriber.py` - Transcription API
- `reflector_translator.py` - Translation API
## Modal.com deployment
@@ -23,16 +24,20 @@ $ modal deploy reflector_llm.py
└── 🔨 Created web => https://xxxx--reflector-llm-web.modal.run
```
Then in your reflector api configuration `.env`, you can set theses keys:
Then in your reflector api configuration `.env`, you can set these keys:
```
TRANSCRIPT_BACKEND=modal
TRANSCRIPT_URL=https://xxxx--reflector-transcriber-web.modal.run
TRANSCRIPT_MODAL_API_KEY=REFLECTOR_APIKEY
LLM_BACKEND=modal
LLM_URL=https://xxxx--reflector-llm-web.modal.run
LLM_MODAL_API_KEY=REFLECTOR_APIKEY
DIARIZATION_BACKEND=modal
DIARIZATION_URL=https://xxxx--reflector-diarizer-web.modal.run
DIARIZATION_MODAL_API_KEY=REFLECTOR_APIKEY
TRANSLATION_BACKEND=modal
TRANSLATION_URL=https://xxxx--reflector-translator-web.modal.run
TRANSLATION_MODAL_API_KEY=REFLECTOR_APIKEY
```
## API

View File

@@ -1,213 +0,0 @@
"""
Reflector GPU backend - LLM
===========================
"""
import json
import os
import threading
from typing import Optional
from modal import App, Image, Secret, asgi_app, enter, exit, method
# LLM
LLM_MODEL: str = "lmsys/vicuna-13b-v1.5"
LLM_LOW_CPU_MEM_USAGE: bool = True
LLM_TORCH_DTYPE: str = "bfloat16"
LLM_MAX_NEW_TOKENS: int = 300
IMAGE_MODEL_DIR = "/root/llm_models"
app = App(name="reflector-llm")
def download_llm():
from huggingface_hub import snapshot_download
print("Downloading LLM model")
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
print("LLM model downloaded")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
print("LLM cache moved")
llm_image = (
Image.debian_slim(python_version="3.10.8")
.apt_install("git")
.pip_install(
"transformers",
"torch",
"sentencepiece",
"protobuf",
"jsonformer==0.12.0",
"accelerate==0.21.0",
"einops==0.6.1",
"hf-transfer~=0.1",
"huggingface_hub==0.16.4",
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(download_llm)
.run_function(migrate_cache_llm)
)
@app.cls(
gpu="A100",
timeout=60 * 5,
scaledown_window=60 * 5,
allow_concurrent_inputs=15,
image=llm_image,
)
class LLM:
@enter()
def enter(self):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
print("Instance llm model")
model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL,
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True,
)
# JSONFormer doesn't yet support generation configs
print("Instance llm generation config")
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
# generation configuration
gen_cfg = GenerationConfig.from_model_config(model.config)
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
# load tokenizer
print("Instance llm tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
)
# move model to gpu
print("Move llm model to GPU")
model = model.cuda()
print("Warmup llm done")
self.model = model
self.tokenizer = tokenizer
self.gen_cfg = gen_cfg
self.GenerationConfig = GenerationConfig
self.lock = threading.Lock()
@exit()
def exit():
print("Exit llm")
@method()
def generate(
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
) -> dict:
"""
Perform a generation action using the LLM
"""
print(f"Generate {prompt=}")
if gen_cfg:
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
else:
gen_cfg = self.gen_cfg
# If a gen_schema is given, conform to gen_schema
with self.lock:
if gen_schema:
import jsonformer
print(f"Schema {gen_schema=}")
jsonformer_llm = jsonformer.Jsonformer(
model=self.model,
tokenizer=self.tokenizer,
json_schema=json.loads(gen_schema),
prompt=prompt,
max_string_token_length=gen_cfg.max_new_tokens,
)
response = jsonformer_llm()
else:
# If no gen_schema, perform prompt only generation
# tokenize prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
output = self.model.generate(input_ids, generation_config=gen_cfg)
# decode output
response = self.tokenizer.decode(
output[0].cpu(), skip_special_tokens=True
)
response = response[len(prompt) :]
print(f"Generated {response=}")
return {"text": response}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@app.function(
scaledown_window=60 * 10,
timeout=60 * 5,
allow_concurrent_inputs=45,
secrets=[
Secret.from_name("reflector-gpu"),
],
)
@asgi_app()
def web():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
llmstub = LLM()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class LLMRequest(BaseModel):
prompt: str
gen_schema: Optional[dict] = None
gen_cfg: Optional[dict] = None
@app.post("/llm", dependencies=[Depends(apikey_auth)])
def llm(
req: LLMRequest,
):
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
func = llmstub.generate.spawn(
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
)
result = func.get()
return result
return app

View File

@@ -1,219 +0,0 @@
"""
Reflector GPU backend - LLM
===========================
"""
import json
import os
import threading
from typing import Optional
from modal import App, Image, Secret, asgi_app, enter, exit, method
# LLM
LLM_MODEL: str = "HuggingFaceH4/zephyr-7b-alpha"
LLM_LOW_CPU_MEM_USAGE: bool = True
LLM_TORCH_DTYPE: str = "bfloat16"
LLM_MAX_NEW_TOKENS: int = 300
IMAGE_MODEL_DIR = "/root/llm_models/zephyr"
app = App(name="reflector-llm-zephyr")
def download_llm():
from huggingface_hub import snapshot_download
print("Downloading LLM model")
snapshot_download(LLM_MODEL, cache_dir=IMAGE_MODEL_DIR)
print("LLM model downloaded")
def migrate_cache_llm():
"""
XXX The cache for model files in Transformers v4.22.0 has been updated.
Migrating your old cache. This is a one-time only operation. You can
interrupt this and resume the migration later on by calling
`transformers.utils.move_cache()`.
"""
from transformers.utils.hub import move_cache
print("Moving LLM cache")
move_cache(cache_dir=IMAGE_MODEL_DIR, new_cache_dir=IMAGE_MODEL_DIR)
print("LLM cache moved")
llm_image = (
Image.debian_slim(python_version="3.10.8")
.apt_install("git")
.pip_install(
"transformers==4.34.0",
"torch",
"sentencepiece",
"protobuf",
"jsonformer==0.12.0",
"accelerate==0.21.0",
"einops==0.6.1",
"hf-transfer~=0.1",
"huggingface_hub==0.16.4",
)
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
.run_function(download_llm)
.run_function(migrate_cache_llm)
)
@app.cls(
gpu="A10G",
timeout=60 * 5,
scaledown_window=60 * 5,
allow_concurrent_inputs=10,
image=llm_image,
)
class LLM:
@enter()
def enter(self):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
print("Instance llm model")
model = AutoModelForCausalLM.from_pretrained(
LLM_MODEL,
torch_dtype=getattr(torch, LLM_TORCH_DTYPE),
low_cpu_mem_usage=LLM_LOW_CPU_MEM_USAGE,
cache_dir=IMAGE_MODEL_DIR,
local_files_only=True,
)
# JSONFormer doesn't yet support generation configs
print("Instance llm generation config")
model.config.max_new_tokens = LLM_MAX_NEW_TOKENS
# generation configuration
gen_cfg = GenerationConfig.from_model_config(model.config)
gen_cfg.max_new_tokens = LLM_MAX_NEW_TOKENS
# load tokenizer
print("Instance llm tokenizer")
tokenizer = AutoTokenizer.from_pretrained(
LLM_MODEL, cache_dir=IMAGE_MODEL_DIR, local_files_only=True
)
gen_cfg.pad_token_id = tokenizer.eos_token_id
gen_cfg.eos_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = tokenizer.eos_token_id
# move model to gpu
print("Move llm model to GPU")
model = model.cuda()
print("Warmup llm done")
self.model = model
self.tokenizer = tokenizer
self.gen_cfg = gen_cfg
self.GenerationConfig = GenerationConfig
self.lock = threading.Lock()
@exit()
def exit():
print("Exit llm")
@method()
def generate(
self, prompt: str, gen_schema: str | None, gen_cfg: str | None
) -> dict:
"""
Perform a generation action using the LLM
"""
print(f"Generate {prompt=}")
if gen_cfg:
gen_cfg = self.GenerationConfig.from_dict(json.loads(gen_cfg))
gen_cfg.pad_token_id = self.tokenizer.eos_token_id
gen_cfg.eos_token_id = self.tokenizer.eos_token_id
else:
gen_cfg = self.gen_cfg
# If a gen_schema is given, conform to gen_schema
with self.lock:
if gen_schema:
import jsonformer
print(f"Schema {gen_schema=}")
jsonformer_llm = jsonformer.Jsonformer(
model=self.model,
tokenizer=self.tokenizer,
json_schema=json.loads(gen_schema),
prompt=prompt,
max_string_token_length=gen_cfg.max_new_tokens,
)
response = jsonformer_llm()
else:
# If no gen_schema, perform prompt only generation
# tokenize prompt
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(
self.model.device
)
output = self.model.generate(input_ids, generation_config=gen_cfg)
# decode output
response = self.tokenizer.decode(
output[0].cpu(), skip_special_tokens=True
)
response = response[len(prompt) :]
response = {"long_summary": response}
print(f"Generated {response=}")
return {"text": response}
# -------------------------------------------------------------------
# Web API
# -------------------------------------------------------------------
@app.function(
scaledown_window=60 * 10,
timeout=60 * 5,
allow_concurrent_inputs=30,
secrets=[
Secret.from_name("reflector-gpu"),
],
)
@asgi_app()
def web():
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from pydantic import BaseModel
llmstub = LLM()
app = FastAPI()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
def apikey_auth(apikey: str = Depends(oauth2_scheme)):
if apikey != os.environ["REFLECTOR_GPU_APIKEY"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key",
headers={"WWW-Authenticate": "Bearer"},
)
class LLMRequest(BaseModel):
prompt: str
gen_schema: Optional[dict] = None
gen_cfg: Optional[dict] = None
@app.post("/llm", dependencies=[Depends(apikey_auth)])
def llm(
req: LLMRequest,
):
gen_schema = json.dumps(req.gen_schema) if req.gen_schema else None
gen_cfg = json.dumps(req.gen_cfg) if req.gen_cfg else None
func = llmstub.generate.spawn(
prompt=req.prompt, gen_schema=gen_schema, gen_cfg=gen_cfg
)
result = func.get()
return result
return app

View File

@@ -1 +1,3 @@
Generic single-database configuration.
Generic single-database configuration.
Both data migrations and schema migrations must be in migrations.

View File

@@ -0,0 +1,25 @@
"""add_webvtt_field_to_transcript
Revision ID: 0bc0f3ff0111
Revises: b7df9609542c
Create Date: 2025-08-05 19:36:41.740957
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "0bc0f3ff0111"
down_revision: Union[str, None] = "b7df9609542c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("transcript", sa.Column("webvtt", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("transcript", "webvtt")

View File

@@ -0,0 +1,46 @@
"""add_full_text_search
Revision ID: 116b2f287eab
Revises: 0bc0f3ff0111
Create Date: 2025-08-07 11:27:38.473517
"""
from typing import Sequence, Union
from alembic import op
revision: str = "116b2f287eab"
down_revision: Union[str, None] = "0bc0f3ff0111"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
conn = op.get_bind()
if conn.dialect.name != "postgresql":
return
op.execute("""
ALTER TABLE transcript ADD COLUMN search_vector_en tsvector
GENERATED ALWAYS AS (
setweight(to_tsvector('english', coalesce(title, '')), 'A') ||
setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')
) STORED
""")
op.create_index(
"idx_transcript_search_vector_en",
"transcript",
["search_vector_en"],
postgresql_using="gin",
)
def downgrade() -> None:
conn = op.get_bind()
if conn.dialect.name != "postgresql":
return
op.drop_index("idx_transcript_search_vector_en", table_name="transcript")
op.drop_column("transcript", "search_vector_en")

View File

@@ -0,0 +1,53 @@
"""remove_one_active_meeting_per_room_constraint
Revision ID: 6025e9b2bef2
Revises: 9f5c78d352d6
Create Date: 2025-08-18 18:45:44.418392
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "6025e9b2bef2"
down_revision: Union[str, None] = "9f5c78d352d6"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Remove the unique constraint that prevents multiple active meetings per room
# This is needed to support calendar integration with overlapping meetings
# Check if index exists before trying to drop it
from alembic import context
if context.get_context().dialect.name == "postgresql":
conn = op.get_bind()
result = conn.execute(
sa.text(
"SELECT 1 FROM pg_indexes WHERE indexname = 'idx_one_active_meeting_per_room'"
)
)
if result.fetchone():
op.drop_index("idx_one_active_meeting_per_room", table_name="meeting")
else:
# For SQLite, just try to drop it
try:
op.drop_index("idx_one_active_meeting_per_room", table_name="meeting")
except:
pass
def downgrade() -> None:
# Restore the unique constraint
op.create_index(
"idx_one_active_meeting_per_room",
"meeting",
["room_id"],
unique=True,
postgresql_where=sa.text("is_active = true"),
sqlite_where=sa.text("is_active = 1"),
)

View File

@@ -32,7 +32,7 @@ def upgrade() -> None:
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("room_id", sa.String(), nullable=True),
sa.Column(
"is_locked", sa.Boolean(), server_default=sa.text("0"), nullable=False
"is_locked", sa.Boolean(), server_default=sa.text("false"), nullable=False
),
sa.Column("room_mode", sa.String(), server_default="normal", nullable=False),
sa.Column(
@@ -53,12 +53,15 @@ def upgrade() -> None:
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.Column(
"zulip_auto_post", sa.Boolean(), server_default=sa.text("0"), nullable=False
"zulip_auto_post",
sa.Boolean(),
server_default=sa.text("false"),
nullable=False,
),
sa.Column("zulip_stream", sa.String(), nullable=True),
sa.Column("zulip_topic", sa.String(), nullable=True),
sa.Column(
"is_locked", sa.Boolean(), server_default=sa.text("0"), nullable=False
"is_locked", sa.Boolean(), server_default=sa.text("false"), nullable=False
),
sa.Column("room_mode", sa.String(), server_default="normal", nullable=False),
sa.Column(

View File

@@ -20,11 +20,14 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
sourcekind_enum = sa.Enum("room", "live", "file", name="sourcekind")
sourcekind_enum.create(op.get_bind())
op.add_column(
"transcript",
sa.Column(
"source_kind",
sa.Enum("ROOM", "LIVE", "FILE", name="sourcekind"),
sourcekind_enum,
nullable=True,
),
)
@@ -43,6 +46,8 @@ def upgrade() -> None:
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("transcript", "source_kind")
sourcekind_enum = sa.Enum(name="sourcekind")
sourcekind_enum.drop(op.get_bind())
# ### end Alembic commands ###

View File

@@ -0,0 +1,106 @@
"""populate_webvtt_from_topics
Revision ID: 8120ebc75366
Revises: 116b2f287eab
Create Date: 2025-08-11 19:11:01.316947
"""
import json
from typing import Sequence, Union
from alembic import op
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision: str = "8120ebc75366"
down_revision: Union[str, None] = "116b2f287eab"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def topics_to_webvtt(topics):
"""Convert topics list to WebVTT format string."""
if not topics:
return None
lines = ["WEBVTT", ""]
for topic in topics:
start_time = format_timestamp(topic.get("start"))
end_time = format_timestamp(topic.get("end"))
text = topic.get("text", "").strip()
if start_time and end_time and text:
lines.append(f"{start_time} --> {end_time}")
lines.append(text)
lines.append("")
return "\n".join(lines).strip()
def format_timestamp(seconds):
"""Format seconds to WebVTT timestamp format (HH:MM:SS.mmm)."""
if seconds is None:
return None
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = seconds % 60
return f"{hours:02d}:{minutes:02d}:{secs:06.3f}"
def upgrade() -> None:
"""Populate WebVTT field for all transcripts with topics."""
# Get connection
connection = op.get_bind()
# Query all transcripts with topics
result = connection.execute(
text("SELECT id, topics FROM transcript WHERE topics IS NOT NULL")
)
rows = result.fetchall()
print(f"Found {len(rows)} transcripts with topics")
updated_count = 0
error_count = 0
for row in rows:
transcript_id = row[0]
topics_data = row[1]
if not topics_data:
continue
try:
# Parse JSON if it's a string
if isinstance(topics_data, str):
topics_data = json.loads(topics_data)
# Convert topics to WebVTT format
webvtt_content = topics_to_webvtt(topics_data)
if webvtt_content:
# Update the webvtt field
connection.execute(
text("UPDATE transcript SET webvtt = :webvtt WHERE id = :id"),
{"webvtt": webvtt_content, "id": transcript_id},
)
updated_count += 1
print(f"✓ Updated transcript {transcript_id}")
except Exception as e:
error_count += 1
print(f"✗ Error updating transcript {transcript_id}: {e}")
print(f"\nMigration complete!")
print(f" Updated: {updated_count}")
print(f" Errors: {error_count}")
def downgrade() -> None:
"""Clear WebVTT field for all transcripts."""
op.execute(text("UPDATE transcript SET webvtt = NULL"))

View File

@@ -22,7 +22,7 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.execute(
"UPDATE transcript SET events = "
'REPLACE(events, \'"event": "SUMMARY"\', \'"event": "LONG_SUMMARY"\');'
'REPLACE(events::text, \'"event": "SUMMARY"\', \'"event": "LONG_SUMMARY"\')::json;'
)
op.alter_column("transcript", "summary", new_column_name="long_summary")
op.add_column("transcript", sa.Column("title", sa.String(), nullable=True))
@@ -34,7 +34,7 @@ def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.execute(
"UPDATE transcript SET events = "
'REPLACE(events, \'"event": "LONG_SUMMARY"\', \'"event": "SUMMARY"\');'
'REPLACE(events::text, \'"event": "LONG_SUMMARY"\', \'"event": "SUMMARY"\')::json;'
)
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.alter_column("long_summary", nullable=True, new_column_name="summary")

View File

@@ -0,0 +1,121 @@
"""datetime timezone
Revision ID: 9f5c78d352d6
Revises: 8120ebc75366
Create Date: 2025-08-13 19:18:27.113593
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "9f5c78d352d6"
down_revision: Union[str, None] = "8120ebc75366"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.alter_column(
"start_date",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
batch_op.alter_column(
"end_date",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
batch_op.alter_column(
"consent_timestamp",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
with op.batch_alter_table("recording", schema=None) as batch_op:
batch_op.alter_column(
"recorded_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=False,
)
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=postgresql.TIMESTAMP(),
type_=sa.DateTime(timezone=True),
existing_nullable=True,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("transcript", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
with op.batch_alter_table("room", schema=None) as batch_op:
batch_op.alter_column(
"created_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
with op.batch_alter_table("recording", schema=None) as batch_op:
batch_op.alter_column(
"recorded_at",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
with op.batch_alter_table("meeting_consent", schema=None) as batch_op:
batch_op.alter_column(
"consent_timestamp",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=False,
)
with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.alter_column(
"end_date",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
batch_op.alter_column(
"start_date",
existing_type=sa.DateTime(timezone=True),
type_=postgresql.TIMESTAMP(),
existing_nullable=True,
)
# ### end Alembic commands ###

View File

@@ -25,7 +25,7 @@ def upgrade() -> None:
sa.Column(
"is_shared",
sa.Boolean(),
server_default=sa.text("0"),
server_default=sa.text("false"),
nullable=False,
),
)

View File

@@ -23,7 +23,10 @@ def upgrade() -> None:
with op.batch_alter_table("meeting", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"is_active", sa.Boolean(), server_default=sa.text("1"), nullable=False
"is_active",
sa.Boolean(),
server_default=sa.text("true"),
nullable=False,
)
)

View File

@@ -23,7 +23,7 @@ def upgrade() -> None:
op.add_column(
"transcript",
sa.Column(
"reviewed", sa.Boolean(), server_default=sa.text("0"), nullable=False
"reviewed", sa.Boolean(), server_default=sa.text("false"), nullable=False
),
)
# ### end Alembic commands ###

View File

@@ -0,0 +1,34 @@
"""add_grace_period_fields_to_meeting
Revision ID: d4a1c446458c
Revises: 6025e9b2bef2
Create Date: 2025-08-18 18:50:37.768052
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "d4a1c446458c"
down_revision: Union[str, None] = "6025e9b2bef2"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add fields to track when participants left for grace period logic
op.add_column(
"meeting", sa.Column("last_participant_left_at", sa.DateTime(timezone=True))
)
op.add_column(
"meeting",
sa.Column("grace_period_minutes", sa.Integer, server_default=sa.text("15")),
)
def downgrade() -> None:
op.drop_column("meeting", "grace_period_minutes")
op.drop_column("meeting", "last_participant_left_at")

View File

@@ -34,12 +34,14 @@ dependencies = [
"python-multipart>=0.0.6",
"faster-whisper>=0.10.0",
"transformers>=4.36.2",
"black==24.1.1",
"jsonschema>=4.23.0",
"openai>=1.59.7",
"psycopg2-binary>=2.9.10",
"llama-index>=0.12.52",
"llama-index-llms-openai-like>=0.4.0",
"pytest-env>=1.1.5",
"webvtt-py>=0.5.0",
"icalendar>=6.0.0",
]
[dependency-groups]
@@ -56,6 +58,8 @@ tests = [
"httpx-ws>=0.4.1",
"pytest-httpx>=0.23.1",
"pytest-celery>=0.0.0",
"pytest-docker>=3.2.3",
"asgi-lifespan>=2.1.0",
]
aws = ["aioboto3>=11.2.0"]
evaluation = [
@@ -83,10 +87,25 @@ packages = ["reflector"]
[tool.coverage.run]
source = ["reflector"]
[tool.pytest_env]
ENVIRONMENT = "pytest"
DATABASE_URL = "postgresql://test_user:test_password@localhost:15432/reflector_test"
[tool.pytest.ini_options]
addopts = "-ra -q --disable-pytest-warnings --cov --cov-report html -v"
testpaths = ["tests"]
asyncio_mode = "auto"
[tool.ruff.lint]
select = [
"I", # isort - import sorting
"F401", # unused imports
"PLC0415", # import-outside-top-level - detect inline imports
]
[tool.ruff.lint.per-file-ignores]
"reflector/processors/summary/summary_builder.py" = ["E501"]
"gpu/**.py" = ["PLC0415"]
"reflector/tools/**.py" = ["PLC0415"]
"migrations/versions/**.py" = ["PLC0415"]
"tests/**.py" = ["PLC0415"]

View File

@@ -1,29 +1,48 @@
import contextvars
from typing import Optional
import databases
import sqlalchemy
from reflector.events import subscribers_shutdown, subscribers_startup
from reflector.settings import settings
database = databases.Database(settings.DATABASE_URL)
metadata = sqlalchemy.MetaData()
_database_context: contextvars.ContextVar[Optional[databases.Database]] = (
contextvars.ContextVar("database", default=None)
)
def get_database() -> databases.Database:
"""Get database instance for current asyncio context"""
db = _database_context.get()
if db is None:
db = databases.Database(settings.DATABASE_URL)
_database_context.set(db)
return db
# import models
import reflector.db.calendar_events # noqa
import reflector.db.meetings # noqa
import reflector.db.recordings # noqa
import reflector.db.rooms # noqa
import reflector.db.transcripts # noqa
kwargs = {}
if "sqlite" in settings.DATABASE_URL:
kwargs["connect_args"] = {"check_same_thread": False}
if "postgres" not in settings.DATABASE_URL:
raise Exception("Only postgres database is supported in reflector")
engine = sqlalchemy.create_engine(settings.DATABASE_URL, **kwargs)
@subscribers_startup.append
async def database_connect(_):
database = get_database()
await database.connect()
@subscribers_shutdown.append
async def database_disconnect(_):
database = get_database()
await database.disconnect()

View File

@@ -0,0 +1,193 @@
from datetime import datetime, timezone
from typing import Any
import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.dialects.postgresql import JSONB
from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4
calendar_events = sa.Table(
"calendar_event",
metadata,
sa.Column("id", sa.String, primary_key=True),
sa.Column(
"room_id",
sa.String,
sa.ForeignKey("room.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("ics_uid", sa.Text, nullable=False),
sa.Column("title", sa.Text),
sa.Column("description", sa.Text),
sa.Column("start_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("end_time", sa.DateTime(timezone=True), nullable=False),
sa.Column("attendees", JSONB),
sa.Column("location", sa.Text),
sa.Column("ics_raw_data", sa.Text),
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=False),
sa.Column("is_deleted", sa.Boolean, nullable=False, server_default=sa.false()),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.UniqueConstraint("room_id", "ics_uid", name="uq_room_calendar_event"),
sa.Index("idx_calendar_event_room_start", "room_id", "start_time"),
sa.Index(
"idx_calendar_event_deleted",
"is_deleted",
postgresql_where=sa.text("NOT is_deleted"),
),
)
class CalendarEvent(BaseModel):
id: str = Field(default_factory=generate_uuid4)
room_id: str
ics_uid: str
title: str | None = None
description: str | None = None
start_time: datetime
end_time: datetime
attendees: list[dict[str, Any]] | None = None
location: str | None = None
ics_raw_data: str | None = None
last_synced: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
is_deleted: bool = False
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
class CalendarEventController:
async def get_by_room(
self,
room_id: str,
include_deleted: bool = False,
start_after: datetime | None = None,
end_before: datetime | None = None,
) -> list[CalendarEvent]:
"""Get calendar events for a room."""
query = calendar_events.select().where(calendar_events.c.room_id == room_id)
if not include_deleted:
query = query.where(calendar_events.c.is_deleted == False)
if start_after:
query = query.where(calendar_events.c.start_time >= start_after)
if end_before:
query = query.where(calendar_events.c.end_time <= end_before)
query = query.order_by(calendar_events.c.start_time.asc())
results = await get_database().fetch_all(query)
return [CalendarEvent(**result) for result in results]
async def get_upcoming(
self, room_id: str, minutes_ahead: int = 30
) -> list[CalendarEvent]:
"""Get upcoming events for a room within the specified minutes."""
now = datetime.now(timezone.utc)
future_time = now + timedelta(minutes=minutes_ahead)
query = (
calendar_events.select()
.where(
sa.and_(
calendar_events.c.room_id == room_id,
calendar_events.c.is_deleted == False,
calendar_events.c.start_time >= now,
calendar_events.c.start_time <= future_time,
)
)
.order_by(calendar_events.c.start_time.asc())
)
results = await get_database().fetch_all(query)
return [CalendarEvent(**result) for result in results]
async def get_by_ics_uid(self, room_id: str, ics_uid: str) -> CalendarEvent | None:
"""Get a calendar event by its ICS UID."""
query = calendar_events.select().where(
sa.and_(
calendar_events.c.room_id == room_id,
calendar_events.c.ics_uid == ics_uid,
)
)
result = await get_database().fetch_one(query)
return CalendarEvent(**result) if result else None
async def upsert(self, event: CalendarEvent) -> CalendarEvent:
"""Create or update a calendar event."""
existing = await self.get_by_ics_uid(event.room_id, event.ics_uid)
if existing:
# Update existing event
event.id = existing.id
event.created_at = existing.created_at
event.updated_at = datetime.now(timezone.utc)
query = (
calendar_events.update()
.where(calendar_events.c.id == existing.id)
.values(**event.model_dump())
)
else:
# Insert new event
query = calendar_events.insert().values(**event.model_dump())
await get_database().execute(query)
return event
async def soft_delete_missing(
self, room_id: str, current_ics_uids: list[str]
) -> int:
"""Soft delete future events that are no longer in the calendar."""
now = datetime.now(timezone.utc)
# First, get the IDs of events to delete
select_query = calendar_events.select().where(
sa.and_(
calendar_events.c.room_id == room_id,
calendar_events.c.start_time > now,
calendar_events.c.is_deleted == False,
calendar_events.c.ics_uid.notin_(current_ics_uids)
if current_ics_uids
else True,
)
)
to_delete = await get_database().fetch_all(select_query)
delete_count = len(to_delete)
if delete_count > 0:
# Now update them
update_query = (
calendar_events.update()
.where(
sa.and_(
calendar_events.c.room_id == room_id,
calendar_events.c.start_time > now,
calendar_events.c.is_deleted == False,
calendar_events.c.ics_uid.notin_(current_ics_uids)
if current_ics_uids
else True,
)
)
.values(is_deleted=True, updated_at=now)
)
await get_database().execute(update_query)
return delete_count
async def delete_by_room(self, room_id: str) -> int:
"""Hard delete all events for a room (used when room is deleted)."""
query = calendar_events.delete().where(calendar_events.c.room_id == room_id)
result = await get_database().execute(query)
return result.rowcount
# Add missing import
from datetime import timedelta
calendar_events_controller = CalendarEventController()

View File

@@ -1,11 +1,12 @@
from datetime import datetime
from typing import Literal
from typing import Any, Literal
import sqlalchemy as sa
from fastapi import HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.dialects.postgresql import JSONB
from reflector.db import database, metadata
from reflector.db import get_database, metadata
from reflector.db.rooms import Room
from reflector.utils import generate_uuid4
@@ -16,8 +17,8 @@ meetings = sa.Table(
sa.Column("room_name", sa.String),
sa.Column("room_url", sa.String),
sa.Column("host_room_url", sa.String),
sa.Column("start_date", sa.DateTime),
sa.Column("end_date", sa.DateTime),
sa.Column("start_date", sa.DateTime(timezone=True)),
sa.Column("end_date", sa.DateTime(timezone=True)),
sa.Column("user_id", sa.String),
sa.Column("room_id", sa.String),
sa.Column("is_locked", sa.Boolean, nullable=False, server_default=sa.false()),
@@ -41,7 +42,16 @@ meetings = sa.Table(
nullable=False,
server_default=sa.true(),
),
sa.Column(
"calendar_event_id",
sa.String,
sa.ForeignKey("calendar_event.id", ondelete="SET NULL"),
),
sa.Column("calendar_metadata", JSONB),
sa.Column("last_participant_left_at", sa.DateTime(timezone=True)),
sa.Column("grace_period_minutes", sa.Integer, server_default=sa.text("15")),
sa.Index("idx_meeting_room_id", "room_id"),
sa.Index("idx_meeting_calendar_event", "calendar_event_id"),
)
meeting_consent = sa.Table(
@@ -51,7 +61,7 @@ meeting_consent = sa.Table(
sa.Column("meeting_id", sa.String, sa.ForeignKey("meeting.id"), nullable=False),
sa.Column("user_id", sa.String),
sa.Column("consent_given", sa.Boolean, nullable=False),
sa.Column("consent_timestamp", sa.DateTime, nullable=False),
sa.Column("consent_timestamp", sa.DateTime(timezone=True), nullable=False),
)
@@ -79,6 +89,11 @@ class Meeting(BaseModel):
"none", "prompt", "automatic", "automatic-2nd-participant"
] = "automatic-2nd-participant"
num_clients: int = 0
is_active: bool = True
calendar_event_id: str | None = None
calendar_metadata: dict[str, Any] | None = None
last_participant_left_at: datetime | None = None
grace_period_minutes: int = 15
class MeetingController:
@@ -92,6 +107,8 @@ class MeetingController:
end_date: datetime,
user_id: str,
room: Room,
calendar_event_id: str | None = None,
calendar_metadata: dict[str, Any] | None = None,
):
"""
Create a new meeting
@@ -109,9 +126,11 @@ class MeetingController:
room_mode=room.room_mode,
recording_type=room.recording_type,
recording_trigger=room.recording_trigger,
calendar_event_id=calendar_event_id,
calendar_metadata=calendar_metadata,
)
query = meetings.insert().values(**meeting.model_dump())
await database.execute(query)
await get_database().execute(query)
return meeting
async def get_all_active(self) -> list[Meeting]:
@@ -119,7 +138,7 @@ class MeetingController:
Get active meetings.
"""
query = meetings.select().where(meetings.c.is_active)
return await database.fetch_all(query)
return await get_database().fetch_all(query)
async def get_by_room_name(
self,
@@ -129,7 +148,7 @@ class MeetingController:
Get a meeting by room name.
"""
query = meetings.select().where(meetings.c.room_name == room_name)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
@@ -138,6 +157,7 @@ class MeetingController:
async def get_active(self, room: Room, current_time: datetime) -> Meeting:
"""
Get latest active meeting for a room.
For backward compatibility, returns the most recent active meeting.
"""
end_date = getattr(meetings.c, "end_date")
query = (
@@ -151,18 +171,59 @@ class MeetingController:
)
.order_by(end_date.desc())
)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Meeting(**result)
async def get_all_active_for_room(
self, room: Room, current_time: datetime
) -> list[Meeting]:
"""
Get all active meetings for a room.
This supports multiple concurrent meetings per room.
"""
end_date = getattr(meetings.c, "end_date")
query = (
meetings.select()
.where(
sa.and_(
meetings.c.room_id == room.id,
meetings.c.end_date > current_time,
meetings.c.is_active,
)
)
.order_by(end_date.desc())
)
results = await get_database().fetch_all(query)
return [Meeting(**result) for result in results]
async def get_active_by_calendar_event(
self, room: Room, calendar_event_id: str, current_time: datetime
) -> Meeting | None:
"""
Get active meeting for a specific calendar event.
"""
query = meetings.select().where(
sa.and_(
meetings.c.room_id == room.id,
meetings.c.calendar_event_id == calendar_event_id,
meetings.c.end_date > current_time,
meetings.c.is_active,
)
)
result = await get_database().fetch_one(query)
if not result:
return None
return Meeting(**result)
async def get_by_id(self, meeting_id: str, **kwargs) -> Meeting | None:
"""
Get a meeting by id
"""
query = meetings.select().where(meetings.c.id == meeting_id)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Meeting(**result)
@@ -174,7 +235,7 @@ class MeetingController:
If not found, it will raise a 404 error.
"""
query = meetings.select().where(meetings.c.id == meeting_id)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
raise HTTPException(status_code=404, detail="Meeting not found")
@@ -184,9 +245,18 @@ class MeetingController:
return meeting
async def get_by_calendar_event(self, calendar_event_id: str) -> Meeting | None:
query = meetings.select().where(
meetings.c.calendar_event_id == calendar_event_id
)
result = await get_database().fetch_one(query)
if not result:
return None
return Meeting(**result)
async def update_meeting(self, meeting_id: str, **kwargs):
query = meetings.update().where(meetings.c.id == meeting_id).values(**kwargs)
await database.execute(query)
await get_database().execute(query)
class MeetingConsentController:
@@ -194,7 +264,7 @@ class MeetingConsentController:
query = meeting_consent.select().where(
meeting_consent.c.meeting_id == meeting_id
)
results = await database.fetch_all(query)
results = await get_database().fetch_all(query)
return [MeetingConsent(**result) for result in results]
async def get_by_meeting_and_user(
@@ -205,7 +275,7 @@ class MeetingConsentController:
meeting_consent.c.meeting_id == meeting_id,
meeting_consent.c.user_id == user_id,
)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if result is None:
return None
return MeetingConsent(**result) if result else None
@@ -227,14 +297,14 @@ class MeetingConsentController:
consent_timestamp=consent.consent_timestamp,
)
)
await database.execute(query)
await get_database().execute(query)
existing.consent_given = consent.consent_given
existing.consent_timestamp = consent.consent_timestamp
return existing
query = meeting_consent.insert().values(**consent.model_dump())
await database.execute(query)
await get_database().execute(query)
return consent
async def has_any_denial(self, meeting_id: str) -> bool:
@@ -243,7 +313,7 @@ class MeetingConsentController:
meeting_consent.c.meeting_id == meeting_id,
meeting_consent.c.consent_given.is_(False),
)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
return result is not None

View File

@@ -4,7 +4,7 @@ from typing import Literal
import sqlalchemy as sa
from pydantic import BaseModel, Field
from reflector.db import database, metadata
from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4
recordings = sa.Table(
@@ -13,7 +13,7 @@ recordings = sa.Table(
sa.Column("id", sa.String, primary_key=True),
sa.Column("bucket_name", sa.String, nullable=False),
sa.Column("object_key", sa.String, nullable=False),
sa.Column("recorded_at", sa.DateTime, nullable=False),
sa.Column("recorded_at", sa.DateTime(timezone=True), nullable=False),
sa.Column(
"status",
sa.String,
@@ -37,12 +37,12 @@ class Recording(BaseModel):
class RecordingController:
async def create(self, recording: Recording):
query = recordings.insert().values(**recording.model_dump())
await database.execute(query)
await get_database().execute(query)
return recording
async def get_by_id(self, id: str) -> Recording:
query = recordings.select().where(recordings.c.id == id)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
return Recording(**result) if result else None
async def get_by_object_key(self, bucket_name: str, object_key: str) -> Recording:
@@ -50,8 +50,12 @@ class RecordingController:
recordings.c.bucket_name == bucket_name,
recordings.c.object_key == object_key,
)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
return Recording(**result) if result else None
async def remove_by_id(self, id: str) -> None:
query = recordings.delete().where(recordings.c.id == id)
await get_database().execute(query)
recordings_controller = RecordingController()

View File

@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timezone
from sqlite3 import IntegrityError
from typing import Literal
@@ -7,7 +7,7 @@ from fastapi import HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.sql import false, or_
from reflector.db import database, metadata
from reflector.db import get_database, metadata
from reflector.utils import generate_uuid4
rooms = sqlalchemy.Table(
@@ -16,7 +16,7 @@ rooms = sqlalchemy.Table(
sqlalchemy.Column("id", sqlalchemy.String, primary_key=True),
sqlalchemy.Column("name", sqlalchemy.String, nullable=False, unique=True),
sqlalchemy.Column("user_id", sqlalchemy.String, nullable=False),
sqlalchemy.Column("created_at", sqlalchemy.DateTime, nullable=False),
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True), nullable=False),
sqlalchemy.Column(
"zulip_auto_post", sqlalchemy.Boolean, nullable=False, server_default=false()
),
@@ -40,7 +40,15 @@ rooms = sqlalchemy.Table(
sqlalchemy.Column(
"is_shared", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column("ics_url", sqlalchemy.Text),
sqlalchemy.Column("ics_fetch_interval", sqlalchemy.Integer, server_default="300"),
sqlalchemy.Column(
"ics_enabled", sqlalchemy.Boolean, nullable=False, server_default=false()
),
sqlalchemy.Column("ics_last_sync", sqlalchemy.DateTime(timezone=True)),
sqlalchemy.Column("ics_last_etag", sqlalchemy.Text),
sqlalchemy.Index("idx_room_is_shared", "is_shared"),
sqlalchemy.Index("idx_room_ics_enabled", "ics_enabled"),
)
@@ -48,7 +56,7 @@ class Room(BaseModel):
id: str = Field(default_factory=generate_uuid4)
name: str
user_id: str
created_at: datetime = Field(default_factory=datetime.utcnow)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
zulip_auto_post: bool = False
zulip_stream: str = ""
zulip_topic: str = ""
@@ -59,6 +67,11 @@ class Room(BaseModel):
"none", "prompt", "automatic", "automatic-2nd-participant"
] = "automatic-2nd-participant"
is_shared: bool = False
ics_url: str | None = None
ics_fetch_interval: int = 300
ics_enabled: bool = False
ics_last_sync: datetime | None = None
ics_last_etag: str | None = None
class RoomController:
@@ -92,7 +105,7 @@ class RoomController:
if return_query:
return query
results = await database.fetch_all(query)
results = await get_database().fetch_all(query)
return results
async def add(
@@ -107,6 +120,9 @@ class RoomController:
recording_type: str,
recording_trigger: str,
is_shared: bool,
ics_url: str | None = None,
ics_fetch_interval: int = 300,
ics_enabled: bool = False,
):
"""
Add a new room
@@ -122,10 +138,13 @@ class RoomController:
recording_type=recording_type,
recording_trigger=recording_trigger,
is_shared=is_shared,
ics_url=ics_url,
ics_fetch_interval=ics_fetch_interval,
ics_enabled=ics_enabled,
)
query = rooms.insert().values(**room.model_dump())
try:
await database.execute(query)
await get_database().execute(query)
except IntegrityError:
raise HTTPException(status_code=400, detail="Room name is not unique")
return room
@@ -136,7 +155,7 @@ class RoomController:
"""
query = rooms.update().where(rooms.c.id == room.id).values(**values)
try:
await database.execute(query)
await get_database().execute(query)
except IntegrityError:
raise HTTPException(status_code=400, detail="Room name is not unique")
@@ -151,7 +170,7 @@ class RoomController:
query = rooms.select().where(rooms.c.id == room_id)
if "user_id" in kwargs:
query = query.where(rooms.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Room(**result)
@@ -163,7 +182,7 @@ class RoomController:
query = rooms.select().where(rooms.c.name == room_name)
if "user_id" in kwargs:
query = query.where(rooms.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Room(**result)
@@ -175,7 +194,7 @@ class RoomController:
If not found, it will raise a 404 error.
"""
query = rooms.select().where(rooms.c.id == meeting_id)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
raise HTTPException(status_code=404, detail="Room not found")
@@ -197,7 +216,7 @@ class RoomController:
if user_id is not None and room.user_id != user_id:
return
query = rooms.delete().where(rooms.c.id == room_id)
await database.execute(query)
await get_database().execute(query)
rooms_controller = RoomController()

View File

@@ -0,0 +1,231 @@
"""Search functionality for transcripts and other entities."""
from datetime import datetime
from io import StringIO
from typing import Annotated, Any, Dict
import sqlalchemy
import webvtt
from pydantic import BaseModel, Field, constr, field_serializer
from reflector.db import get_database
from reflector.db.transcripts import SourceKind, transcripts
from reflector.db.utils import is_postgresql
from reflector.logger import logger
DEFAULT_SEARCH_LIMIT = 20
SNIPPET_CONTEXT_LENGTH = 50 # Characters before/after match to include
DEFAULT_SNIPPET_MAX_LENGTH = 150
DEFAULT_MAX_SNIPPETS = 3
SearchQueryBase = constr(min_length=1, strip_whitespace=True)
SearchLimitBase = Annotated[int, Field(ge=1, le=100)]
SearchOffsetBase = Annotated[int, Field(ge=0)]
SearchTotalBase = Annotated[int, Field(ge=0)]
SearchQuery = Annotated[SearchQueryBase, Field(description="Search query text")]
SearchLimit = Annotated[SearchLimitBase, Field(description="Results per page")]
SearchOffset = Annotated[
SearchOffsetBase, Field(description="Number of results to skip")
]
SearchTotal = Annotated[
SearchTotalBase, Field(description="Total number of search results")
]
class SearchParameters(BaseModel):
"""Validated search parameters for full-text search."""
query_text: SearchQuery
limit: SearchLimit = DEFAULT_SEARCH_LIMIT
offset: SearchOffset = 0
user_id: str | None = None
room_id: str | None = None
class SearchResultDB(BaseModel):
"""Intermediate model for validating raw database results."""
id: str = Field(..., min_length=1)
created_at: datetime
status: str = Field(..., min_length=1)
duration: float | None = Field(None, ge=0)
user_id: str | None = None
title: str | None = None
source_kind: SourceKind
room_id: str | None = None
rank: float = Field(..., ge=0, le=1)
class SearchResult(BaseModel):
"""Public search result model with computed fields."""
id: str = Field(..., min_length=1)
title: str | None = None
user_id: str | None = None
room_id: str | None = None
created_at: datetime
status: str = Field(..., min_length=1)
rank: float = Field(..., ge=0, le=1)
duration: float | None = Field(..., ge=0, description="Duration in seconds")
search_snippets: list[str] = Field(
description="Text snippets around search matches"
)
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
if dt.tzinfo is None:
return dt.isoformat() + "Z"
return dt.isoformat()
class SearchController:
"""Controller for search operations across different entities."""
@staticmethod
def _extract_webvtt_text(webvtt_content: str) -> str:
"""Extract plain text from WebVTT content using webvtt library."""
if not webvtt_content:
return ""
try:
buffer = StringIO(webvtt_content)
vtt = webvtt.read_buffer(buffer)
return " ".join(caption.text for caption in vtt if caption.text)
except (webvtt.errors.MalformedFileError, UnicodeDecodeError, ValueError) as e:
logger.warning(f"Failed to parse WebVTT content: {e}", exc_info=e)
return ""
except AttributeError as e:
logger.warning(f"WebVTT parsing error - unexpected format: {e}", exc_info=e)
return ""
@staticmethod
def _generate_snippets(
text: str,
q: SearchQuery,
max_length: int = DEFAULT_SNIPPET_MAX_LENGTH,
max_snippets: int = DEFAULT_MAX_SNIPPETS,
) -> list[str]:
"""Generate multiple snippets around all occurrences of search term."""
if not text or not q:
return []
snippets = []
lower_text = text.lower()
search_lower = q.lower()
last_snippet_end = 0
start_pos = 0
while len(snippets) < max_snippets:
match_pos = lower_text.find(search_lower, start_pos)
if match_pos == -1:
if not snippets and search_lower.split():
first_word = search_lower.split()[0]
match_pos = lower_text.find(first_word, start_pos)
if match_pos == -1:
break
else:
break
snippet_start = max(0, match_pos - SNIPPET_CONTEXT_LENGTH)
snippet_end = min(
len(text), match_pos + max_length - SNIPPET_CONTEXT_LENGTH
)
if snippet_start < last_snippet_end:
start_pos = match_pos + len(search_lower)
continue
snippet = text[snippet_start:snippet_end]
if snippet_start > 0:
snippet = "..." + snippet
if snippet_end < len(text):
snippet = snippet + "..."
snippet = snippet.strip()
if snippet:
snippets.append(snippet)
last_snippet_end = snippet_end
start_pos = match_pos + len(search_lower)
if start_pos >= len(text):
break
return snippets
@classmethod
async def search_transcripts(
cls, params: SearchParameters
) -> tuple[list[SearchResult], int]:
"""
Full-text search for transcripts using PostgreSQL tsvector.
Returns (results, total_count).
"""
if not is_postgresql():
logger.warning(
"Full-text search requires PostgreSQL. Returning empty results."
)
return [], 0
search_query = sqlalchemy.func.websearch_to_tsquery(
"english", params.query_text
)
base_query = sqlalchemy.select(
[
transcripts.c.id,
transcripts.c.title,
transcripts.c.created_at,
transcripts.c.duration,
transcripts.c.status,
transcripts.c.user_id,
transcripts.c.room_id,
transcripts.c.source_kind,
transcripts.c.webvtt,
sqlalchemy.func.ts_rank(
transcripts.c.search_vector_en,
search_query,
32, # normalization flag: rank/(rank+1) for 0-1 range
).label("rank"),
]
).where(transcripts.c.search_vector_en.op("@@")(search_query))
if params.user_id:
base_query = base_query.where(transcripts.c.user_id == params.user_id)
if params.room_id:
base_query = base_query.where(transcripts.c.room_id == params.room_id)
query = (
base_query.order_by(sqlalchemy.desc(sqlalchemy.text("rank")))
.limit(params.limit)
.offset(params.offset)
)
rs = await get_database().fetch_all(query)
count_query = sqlalchemy.select([sqlalchemy.func.count()]).select_from(
base_query.alias("search_results")
)
total = await get_database().fetch_val(count_query)
def _process_result(r) -> SearchResult:
r_dict: Dict[str, Any] = dict(r)
webvtt: str | None = r_dict.pop("webvtt", None)
db_result = SearchResultDB.model_validate(r_dict)
snippets = []
if webvtt:
plain_text = cls._extract_webvtt_text(webvtt)
snippets = cls._generate_snippets(plain_text, params.query_text)
return SearchResult(**db_result.model_dump(), search_snippets=snippets)
results = [_process_result(r) for r in rs]
return results, total
search_controller = SearchController()

View File

@@ -3,7 +3,7 @@ import json
import os
import shutil
from contextlib import asynccontextmanager
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Any, Literal
@@ -11,13 +11,19 @@ import sqlalchemy
from fastapi import HTTPException
from pydantic import BaseModel, ConfigDict, Field, field_serializer
from sqlalchemy import Enum
from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.sql import false, or_
from reflector.db import database, metadata
from reflector.db import get_database, metadata
from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms
from reflector.db.utils import is_postgresql
from reflector.logger import logger
from reflector.processors.types import Word as ProcessorWord
from reflector.settings import settings
from reflector.storage import get_transcripts_storage
from reflector.storage import get_recordings_storage, get_transcripts_storage
from reflector.utils import generate_uuid4
from reflector.utils.webvtt import topics_to_webvtt
class SourceKind(enum.StrEnum):
@@ -34,7 +40,7 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Column("status", sqlalchemy.String),
sqlalchemy.Column("locked", sqlalchemy.Boolean),
sqlalchemy.Column("duration", sqlalchemy.Float),
sqlalchemy.Column("created_at", sqlalchemy.DateTime),
sqlalchemy.Column("created_at", sqlalchemy.DateTime(timezone=True)),
sqlalchemy.Column("title", sqlalchemy.String),
sqlalchemy.Column("short_summary", sqlalchemy.String),
sqlalchemy.Column("long_summary", sqlalchemy.String),
@@ -76,6 +82,7 @@ transcripts = sqlalchemy.Table(
# same field could've been in recording/meeting, and it's maybe even ok to dupe it at need
sqlalchemy.Column("audio_deleted", sqlalchemy.Boolean),
sqlalchemy.Column("room_id", sqlalchemy.String),
sqlalchemy.Column("webvtt", sqlalchemy.Text),
sqlalchemy.Index("idx_transcript_recording_id", "recording_id"),
sqlalchemy.Index("idx_transcript_user_id", "user_id"),
sqlalchemy.Index("idx_transcript_created_at", "created_at"),
@@ -83,6 +90,29 @@ transcripts = sqlalchemy.Table(
sqlalchemy.Index("idx_transcript_room_id", "room_id"),
)
# Add PostgreSQL-specific full-text search column
# This matches the migration in migrations/versions/116b2f287eab_add_full_text_search.py
if is_postgresql():
transcripts.append_column(
sqlalchemy.Column(
"search_vector_en",
TSVECTOR,
sqlalchemy.Computed(
"setweight(to_tsvector('english', coalesce(title, '')), 'A') || "
"setweight(to_tsvector('english', coalesce(webvtt, '')), 'B')",
persisted=True,
),
)
)
# Add GIN index for the search vector
transcripts.append_constraint(
sqlalchemy.Index(
"idx_transcript_search_vector_en",
"search_vector_en",
postgresql_using="gin",
)
)
def generate_transcript_name() -> str:
now = datetime.now(timezone.utc)
@@ -147,14 +177,18 @@ class TranscriptParticipant(BaseModel):
class Transcript(BaseModel):
"""Full transcript model with all fields."""
id: str = Field(default_factory=generate_uuid4)
user_id: str | None = None
name: str = Field(default_factory=generate_transcript_name)
status: str = "idle"
locked: bool = False
duration: float = 0
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
title: str | None = None
source_kind: SourceKind
room_id: str | None = None
locked: bool = False
short_summary: str | None = None
long_summary: str | None = None
topics: list[TranscriptTopic] = []
@@ -168,9 +202,8 @@ class Transcript(BaseModel):
meeting_id: str | None = None
recording_id: str | None = None
zulip_message_id: int | None = None
source_kind: SourceKind
audio_deleted: bool | None = None
room_id: str | None = None
webvtt: str | None = None
@field_serializer("created_at", when_used="json")
def serialize_datetime(self, dt: datetime) -> str:
@@ -271,10 +304,12 @@ class Transcript(BaseModel):
# we need to create an url to be used for diarization
# we can't use the audio_mp3_filename because it's not accessible
# from the diarization processor
from datetime import timedelta
from reflector.app import app
from reflector.views.transcripts import create_access_token
# TODO don't import app in db
from reflector.app import app # noqa: PLC0415
# TODO a util + don''t import views in db
from reflector.views.transcripts import create_access_token # noqa: PLC0415
path = app.url_path_for(
"transcript_get_audio_mp3",
@@ -335,7 +370,6 @@ class TranscriptController:
- `room_id`: filter transcripts by room ID
- `search_term`: filter transcripts by search term
"""
from reflector.db.rooms import rooms
query = transcripts.select().join(
rooms, transcripts.c.room_id == rooms.c.id, isouter=True
@@ -386,7 +420,7 @@ class TranscriptController:
if return_query:
return query
results = await database.fetch_all(query)
results = await get_database().fetch_all(query)
return results
async def get_by_id(self, transcript_id: str, **kwargs) -> Transcript | None:
@@ -396,7 +430,7 @@ class TranscriptController:
query = transcripts.select().where(transcripts.c.id == transcript_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Transcript(**result)
@@ -410,7 +444,7 @@ class TranscriptController:
query = transcripts.select().where(transcripts.c.recording_id == recording_id)
if "user_id" in kwargs:
query = query.where(transcripts.c.user_id == kwargs["user_id"])
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
return None
return Transcript(**result)
@@ -428,7 +462,7 @@ class TranscriptController:
if order_by.startswith("-"):
field = field.desc()
query = query.order_by(field)
results = await database.fetch_all(query)
results = await get_database().fetch_all(query)
return [Transcript(**result) for result in results]
async def get_by_id_for_http(
@@ -446,7 +480,7 @@ class TranscriptController:
to determine if the user can access the transcript.
"""
query = transcripts.select().where(transcripts.c.id == transcript_id)
result = await database.fetch_one(query)
result = await get_database().fetch_one(query)
if not result:
raise HTTPException(status_code=404, detail="Transcript not found")
@@ -499,23 +533,52 @@ class TranscriptController:
room_id=room_id,
)
query = transcripts.insert().values(**transcript.model_dump())
await database.execute(query)
await get_database().execute(query)
return transcript
async def update(self, transcript: Transcript, values: dict, mutate=True):
# TODO investigate why mutate= is used. it's used in one place currently, maybe because of ORM field updates.
# using mutate=True is discouraged
async def update(
self, transcript: Transcript, values: dict, mutate=False
) -> Transcript:
"""
Update a transcript fields with key/values in values
Update a transcript fields with key/values in values.
Returns a copy of the transcript with updated values.
"""
values = TranscriptController._handle_topics_update(values)
query = (
transcripts.update()
.where(transcripts.c.id == transcript.id)
.values(**values)
)
await database.execute(query)
await get_database().execute(query)
if mutate:
for key, value in values.items():
setattr(transcript, key, value)
updated_transcript = transcript.model_copy(update=values)
return updated_transcript
@staticmethod
def _handle_topics_update(values: dict) -> dict:
"""Auto-update WebVTT when topics are updated."""
if values.get("webvtt") is not None:
logger.warn("trying to update read-only webvtt column")
pass
topics_data = values.get("topics")
if topics_data is None:
return values
return {
**values,
"webvtt": topics_to_webvtt(
[TranscriptTopic(**topic_dict) for topic_dict in topics_data]
),
}
async def remove_by_id(
self,
transcript_id: str,
@@ -529,23 +592,55 @@ class TranscriptController:
return
if user_id is not None and transcript.user_id != user_id:
return
if transcript.audio_location == "storage" and not transcript.audio_deleted:
try:
await get_transcripts_storage().delete_file(
transcript.storage_audio_path
)
except Exception as e:
logger.warning(
"Failed to delete transcript audio from storage",
exc_info=e,
transcript_id=transcript.id,
)
transcript.unlink()
if transcript.recording_id:
try:
recording = await recordings_controller.get_by_id(
transcript.recording_id
)
if recording:
try:
await get_recordings_storage().delete_file(recording.object_key)
except Exception as e:
logger.warning(
"Failed to delete recording object from S3",
exc_info=e,
recording_id=transcript.recording_id,
)
await recordings_controller.remove_by_id(transcript.recording_id)
except Exception as e:
logger.warning(
"Failed to delete recording row",
exc_info=e,
recording_id=transcript.recording_id,
)
query = transcripts.delete().where(transcripts.c.id == transcript_id)
await database.execute(query)
await get_database().execute(query)
async def remove_by_recording_id(self, recording_id: str):
"""
Remove a transcript by recording_id
"""
query = transcripts.delete().where(transcripts.c.recording_id == recording_id)
await database.execute(query)
await get_database().execute(query)
@asynccontextmanager
async def transaction(self):
"""
A context manager for database transaction
"""
async with database.transaction(isolation="serializable"):
async with get_database().transaction(isolation="serializable"):
yield
async def append_event(
@@ -558,11 +653,7 @@ class TranscriptController:
Append an event to a transcript
"""
resp = transcript.add_event(event=event, data=data)
await self.update(
transcript,
{"events": transcript.events_dump()},
mutate=False,
)
await self.update(transcript, {"events": transcript.events_dump()})
return resp
async def upsert_topic(
@@ -574,11 +665,7 @@ class TranscriptController:
Upsert topics to a transcript
"""
transcript.upsert_topic(topic)
await self.update(
transcript,
{"topics": transcript.topics_dump()},
mutate=False,
)
await self.update(transcript, {"topics": transcript.topics_dump()})
async def move_mp3_to_storage(self, transcript: Transcript):
"""
@@ -603,7 +690,8 @@ class TranscriptController:
)
# indicate on the transcript that the audio is now on storage
await self.update(transcript, {"audio_location": "storage"})
# mutates transcript argument
await self.update(transcript, {"audio_location": "storage"}, mutate=True)
# unlink the local file
transcript.audio_mp3_filename.unlink(missing_ok=True)
@@ -627,11 +715,7 @@ class TranscriptController:
Add/update a participant to a transcript
"""
result = transcript.upsert_participant(participant)
await self.update(
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
await self.update(transcript, {"participants": transcript.participants_dump()})
return result
async def delete_participant(
@@ -643,11 +727,7 @@ class TranscriptController:
Delete a participant from a transcript
"""
transcript.delete_participant(participant_id)
await self.update(
transcript,
{"participants": transcript.participants_dump()},
mutate=False,
)
await self.update(transcript, {"participants": transcript.participants_dump()})
transcripts_controller = TranscriptController()

View File

@@ -0,0 +1,9 @@
"""Database utility functions."""
from reflector.db import get_database
def is_postgresql() -> bool:
return get_database().url.scheme and get_database().url.scheme.startswith(
"postgresql"
)

83
server/reflector/llm.py Normal file
View File

@@ -0,0 +1,83 @@
from typing import Type, TypeVar
from llama_index.core import Settings
from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.llms.openai_like import OpenAILike
from pydantic import BaseModel
T = TypeVar("T", bound=BaseModel)
STRUCTURED_RESPONSE_PROMPT_TEMPLATE = """
Based on the following analysis, provide the information in the requested JSON format:
Analysis:
{analysis}
{format_instructions}
"""
class LLM:
def __init__(self, settings, temperature: float = 0.4, max_tokens: int = 2048):
self.settings_obj = settings
self.model_name = settings.LLM_MODEL
self.url = settings.LLM_URL
self.api_key = settings.LLM_API_KEY
self.context_window = settings.LLM_CONTEXT_WINDOW
self.temperature = temperature
self.max_tokens = max_tokens
# Configure llamaindex Settings
self._configure_llamaindex()
def _configure_llamaindex(self):
"""Configure llamaindex Settings with OpenAILike LLM"""
Settings.llm = OpenAILike(
model=self.model_name,
api_base=self.url,
api_key=self.api_key,
context_window=self.context_window,
is_chat_model=True,
is_function_calling_model=False,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
async def get_response(
self, prompt: str, texts: list[str], tone_name: str | None = None
) -> str:
"""Get a text response using TreeSummarize for non-function-calling models"""
summarizer = TreeSummarize(verbose=False)
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
return str(response).strip()
async def get_structured_response(
self,
prompt: str,
texts: list[str],
output_cls: Type[T],
tone_name: str | None = None,
) -> T:
"""Get structured output from LLM for non-function-calling models"""
summarizer = TreeSummarize(verbose=True)
response = await summarizer.aget_response(prompt, texts, tone_name=tone_name)
output_parser = PydanticOutputParser(output_cls)
program = LLMTextCompletionProgram.from_defaults(
output_parser=output_parser,
prompt_template_str=STRUCTURED_RESPONSE_PROMPT_TEMPLATE,
verbose=False,
)
format_instructions = output_parser.format(
"Please structure the above information in the following JSON format:"
)
output = await program.acall(
analysis=str(response), format_instructions=format_instructions
)
return output

View File

@@ -1,2 +0,0 @@
from .base import LLM # noqa: F401
from .llm_params import LLMTaskParams # noqa: F401

View File

@@ -1,347 +0,0 @@
import importlib
import json
import re
from typing import TypeVar
import nltk
from prometheus_client import Counter, Histogram
from transformers import GenerationConfig
from reflector.llm.llm_params import TaskParams
from reflector.logger import logger as reflector_logger
from reflector.settings import settings
from reflector.utils.retry import retry
T = TypeVar("T", bound="LLM")
class LLM:
_nltk_downloaded = False
_registry = {}
model_name: str
m_generate = Histogram(
"llm_generate",
"Time spent in LLM.generate",
["backend"],
)
m_generate_call = Counter(
"llm_generate_call",
"Number of calls to LLM.generate",
["backend"],
)
m_generate_success = Counter(
"llm_generate_success",
"Number of successful calls to LLM.generate",
["backend"],
)
m_generate_failure = Counter(
"llm_generate_failure",
"Number of failed calls to LLM.generate",
["backend"],
)
@classmethod
def ensure_nltk(cls):
"""
Make sure NLTK package is installed. Searches in the cache and
downloads only if needed.
"""
if not cls._nltk_downloaded:
nltk.download("punkt_tab")
# For POS tagging
nltk.download("averaged_perceptron_tagger_eng")
cls._nltk_downloaded = True
@classmethod
def register(cls, name, klass):
cls._registry[name] = klass
@classmethod
def get_instance(cls, model_name: str | None = None, name: str = None) -> T:
"""
Return an instance depending on the settings.
Settings used:
- `LLM_BACKEND`: key of the backend
- `LLM_URL`: url of the backend
"""
if name is None:
name = settings.LLM_BACKEND
if name not in cls._registry:
module_name = f"reflector.llm.llm_{name}"
importlib.import_module(module_name)
cls.ensure_nltk()
return cls._registry[name](model_name)
def get_model_name(self) -> str:
"""
Get the currently set model name
"""
return self._get_model_name()
def _get_model_name(self) -> str:
pass
def set_model_name(self, model_name: str) -> bool:
"""
Update the model name with the provided model name
"""
return self._set_model_name(model_name)
def _set_model_name(self, model_name: str) -> bool:
raise NotImplementedError
@property
def template(self) -> str:
"""
Return the LLM Prompt template
"""
return """
### Human:
{instruct}
{text}
### Assistant:
"""
def __init__(self):
name = self.__class__.__name__
self.m_generate = self.m_generate.labels(name)
self.m_generate_call = self.m_generate_call.labels(name)
self.m_generate_success = self.m_generate_success.labels(name)
self.m_generate_failure = self.m_generate_failure.labels(name)
self.detokenizer = nltk.tokenize.treebank.TreebankWordDetokenizer()
@property
def tokenizer(self):
"""
Return the tokenizer instance used by LLM
"""
return self._get_tokenizer()
def _get_tokenizer(self):
pass
def has_structured_output(self):
# whether implementation supports structured output
# on the model side (otherwise it's prompt engineering)
return False
async def generate(
self,
prompt: str,
logger: reflector_logger,
gen_schema: dict | None = None,
gen_cfg: GenerationConfig | None = None,
**kwargs,
) -> dict:
logger.info("LLM generate", prompt=repr(prompt))
if gen_cfg:
gen_cfg = gen_cfg.to_dict()
self.m_generate_call.inc()
try:
with self.m_generate.time():
result = await retry(self._generate)(
prompt=prompt,
gen_schema=gen_schema,
gen_cfg=gen_cfg,
logger=logger,
**kwargs,
)
self.m_generate_success.inc()
except Exception:
logger.exception("Failed to call llm after retrying")
self.m_generate_failure.inc()
raise
logger.debug("LLM result [raw]", result=repr(result))
if isinstance(result, str):
result = self._parse_json(result)
logger.debug("LLM result [parsed]", result=repr(result))
return result
async def completion(
self, messages: list, logger: reflector_logger, **kwargs
) -> dict:
"""
Use /v1/chat/completion Open-AI compatible endpoint from the URL
It's up to the user to validate anything or transform the result
"""
logger.info("LLM completions", messages=messages)
try:
with self.m_generate.time():
result = await retry(self._completion)(
messages=messages, **{**kwargs, "logger": logger}
)
self.m_generate_success.inc()
except Exception:
logger.exception("Failed to call llm after retrying")
self.m_generate_failure.inc()
raise
logger.debug("LLM completion result", result=repr(result))
return result
def ensure_casing(self, title: str) -> str:
"""
LLM takes care of word casing, but in rare cases this
can falter. This is a fallback to ensure the casing of
topics is in a proper format.
We select nouns, verbs and adjectives and check if camel
casing is present and fix it, if not. Will not perform
any other changes.
"""
tokens = nltk.word_tokenize(title)
pos_tags = nltk.pos_tag(tokens)
camel_cased = []
whitelisted_pos_tags = [
"NN",
"NNS",
"NNP",
"NNPS", # Noun POS
"VB",
"VBD",
"VBG",
"VBN",
"VBP",
"VBZ", # Verb POS
"JJ",
"JJR",
"JJS", # Adjective POS
]
# If at all there is an exception, do not block other reflector
# processes. Return the LLM generated title, at the least.
try:
for word, pos in pos_tags:
if pos in whitelisted_pos_tags and word[0].islower():
camel_cased.append(word[0].upper() + word[1:])
else:
camel_cased.append(word)
modified_title = self.detokenizer.detokenize(camel_cased)
# Irrespective of casing changes, the starting letter
# of title is always upper-cased
title = modified_title[0].upper() + modified_title[1:]
except Exception as e:
reflector_logger.info(
f"Failed to ensure casing on {title=} with exception : {str(e)}"
)
return title
def trim_title(self, title: str) -> str:
"""
List of manual trimming to the title.
Longer titles are prone to run into A prefix of phrases that don't
really add any descriptive information and in some cases, this
behaviour can be repeated for several consecutive topics. Trim the
titles to maintain quality of titles.
"""
phrases_to_remove = ["Discussing", "Discussion on", "Discussion about"]
try:
pattern = (
r"\b(?:"
+ "|".join(re.escape(phrase) for phrase in phrases_to_remove)
+ r")\b"
)
title = re.sub(pattern, "", title, flags=re.IGNORECASE)
except Exception as e:
reflector_logger.info(f"Failed to trim {title=} with exception : {str(e)}")
return title
async def _generate(
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
) -> str:
raise NotImplementedError
async def _completion(self, messages: list, **kwargs) -> dict:
raise NotImplementedError
def _parse_json(self, result: str) -> dict:
result = result.strip()
# try detecting code block if exist
# starts with ```json\n, ends with ```
# or starts with ```\n, ends with ```
# or starts with \n```javascript\n, ends with ```
regex = r"```(json|javascript|)?(.*)```"
matches = re.findall(regex, result.strip(), re.MULTILINE | re.DOTALL)
if matches:
result = matches[0][1]
else:
# maybe the prompt has been started with ```json
# so if text ends with ```, just remove it and use it as json
if result.endswith("```"):
result = result[:-3]
return json.loads(result.strip())
def text_token_threshold(self, task_params: TaskParams | None) -> int:
"""
Choose the token size to set as the threshold to pack the LLM calls
"""
buffer_token_size = 100
default_output_tokens = 1000
context_window = self.tokenizer.model_max_length
tokens = self.tokenizer.tokenize(
self.create_prompt(instruct=task_params.instruct, text="")
)
threshold = context_window - len(tokens) - buffer_token_size
if task_params.gen_cfg:
threshold -= task_params.gen_cfg.max_new_tokens
else:
threshold -= default_output_tokens
return threshold
def split_corpus(
self,
corpus: str,
task_params: TaskParams,
token_threshold: int | None = None,
) -> list[str]:
"""
Split the input to the LLM due to CUDA memory limitations and LLM context window
restrictions.
Accumulate tokens from full sentences till threshold and yield accumulated
tokens. Reset accumulation when threshold is reached and repeat process.
"""
if not token_threshold:
token_threshold = self.text_token_threshold(task_params=task_params)
accumulated_tokens = []
accumulated_sentences = []
accumulated_token_count = 0
corpus_sentences = nltk.sent_tokenize(corpus)
for sentence in corpus_sentences:
tokens = self.tokenizer.tokenize(sentence)
if accumulated_token_count + len(tokens) <= token_threshold:
accumulated_token_count += len(tokens)
accumulated_tokens.extend(tokens)
accumulated_sentences.append(sentence)
else:
yield "".join(accumulated_sentences)
accumulated_token_count = len(tokens)
accumulated_tokens = tokens
accumulated_sentences = [sentence]
if accumulated_tokens:
yield " ".join(accumulated_sentences)
def create_prompt(self, instruct: str, text: str) -> str:
"""
Create a consumable prompt based on the prompt template
"""
return self.template.format(instruct=instruct, text=text)

View File

@@ -1,155 +0,0 @@
import httpx
from transformers import AutoTokenizer, GenerationConfig
from reflector.llm.base import LLM
from reflector.logger import logger as reflector_logger
from reflector.settings import settings
from reflector.utils.retry import retry
class ModalLLM(LLM):
def __init__(self, model_name: str | None = None):
super().__init__()
self.timeout = settings.LLM_TIMEOUT
self.llm_url = settings.LLM_URL + "/llm"
self.headers = {
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
}
self._set_model_name(model_name if model_name else settings.DEFAULT_LLM)
@property
def supported_models(self):
"""
List of currently supported models on this GPU platform
"""
# TODO: Query the specific GPU platform
# Replace this with a HTTP call
return [
"lmsys/vicuna-13b-v1.5",
"HuggingFaceH4/zephyr-7b-alpha",
"NousResearch/Hermes-3-Llama-3.1-8B",
]
async def _generate(
self, prompt: str, gen_schema: dict | None, gen_cfg: dict | None, **kwargs
) -> str:
json_payload = {"prompt": prompt}
if gen_schema:
json_payload["gen_schema"] = gen_schema
if gen_cfg:
json_payload["gen_cfg"] = gen_cfg
# Handing over generation of the final summary to Zephyr model
# but replacing the Vicuna model will happen after more testing
# TODO: Create a mapping of model names and cloud deployments
if self.model_name == "HuggingFaceH4/zephyr-7b-alpha":
self.llm_url = settings.ZEPHYR_LLM_URL + "/llm"
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,
headers=self.headers,
json=json_payload,
timeout=self.timeout,
retry_timeout=60 * 5,
follow_redirects=True,
logger=kwargs.get("logger", reflector_logger),
)
response.raise_for_status()
text = response.json()["text"]
return text
async def _completion(self, messages: list, **kwargs) -> dict:
# returns full api response
kwargs.setdefault("temperature", 0.3)
kwargs.setdefault("max_tokens", 2048)
kwargs.setdefault("stream", False)
kwargs.setdefault("repetition_penalty", 1)
kwargs.setdefault("top_p", 1)
kwargs.setdefault("top_k", -1)
kwargs.setdefault("min_p", 0.05)
data = {"messages": messages, "model": self.model_name, **kwargs}
if self.model_name == "NousResearch/Hermes-3-Llama-3.1-8B":
self.llm_url = settings.HERMES_3_8B_LLM_URL + "/v1/chat/completions"
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.llm_url,
headers=self.headers,
json=data,
timeout=self.timeout,
retry_timeout=60 * 5,
follow_redirects=True,
logger=kwargs.get("logger", reflector_logger),
)
response.raise_for_status()
return response.json()
def _set_model_name(self, model_name: str) -> bool:
"""
Set the model name
"""
# Abort, if the model is not supported
if model_name not in self.supported_models:
reflector_logger.info(
f"Attempted to change {model_name=}, but is not supported."
f"Setting model and tokenizer failed !"
)
return False
# Abort, if the model is already set
elif hasattr(self, "model_name") and model_name == self._get_model_name():
reflector_logger.info("No change in model. Setting model skipped.")
return False
# Update model name and tokenizer
self.model_name = model_name
self.llm_tokenizer = AutoTokenizer.from_pretrained(
self.model_name, cache_dir=settings.CACHE_DIR
)
reflector_logger.info(f"Model set to {model_name=}. Tokenizer updated.")
return True
def _get_tokenizer(self) -> AutoTokenizer:
"""
Return the currently used LLM tokenizer
"""
return self.llm_tokenizer
def _get_model_name(self) -> str:
"""
Return the current model name from the instance details
"""
return self.model_name
LLM.register("modal", ModalLLM)
if __name__ == "__main__":
from reflector.logger import logger
async def main():
llm = ModalLLM()
prompt = llm.create_prompt(
instruct="Complete the following task",
text="Tell me a joke about programming.",
)
result = await llm.generate(prompt=prompt, logger=logger)
print(result)
gen_schema = {
"type": "object",
"properties": {"response": {"type": "string"}},
}
result = await llm.generate(prompt=prompt, gen_schema=gen_schema, logger=logger)
print(result)
gen_cfg = GenerationConfig(max_new_tokens=150)
result = await llm.generate(
prompt=prompt, gen_cfg=gen_cfg, gen_schema=gen_schema, logger=logger
)
print(result)
import asyncio
asyncio.run(main())

View File

@@ -1,48 +0,0 @@
import httpx
from transformers import GenerationConfig
from reflector.llm.base import LLM
from reflector.logger import logger
from reflector.settings import settings
class OpenAILLM(LLM):
def __init__(self, model_name: str | None = None, **kwargs):
super().__init__(**kwargs)
self.openai_key = settings.LLM_OPENAI_KEY
self.openai_url = settings.LLM_URL
self.openai_model = settings.LLM_OPENAI_MODEL
self.openai_temperature = settings.LLM_OPENAI_TEMPERATURE
self.timeout = settings.LLM_TIMEOUT
self.max_tokens = settings.LLM_MAX_TOKENS
logger.info(f"LLM use openai backend at {self.openai_url}")
async def _generate(
self,
prompt: str,
gen_schema: dict | None,
gen_cfg: GenerationConfig | None,
**kwargs,
) -> str:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.openai_key}",
}
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
self.openai_url,
headers=headers,
json={
"model": self.openai_model,
"prompt": prompt,
"max_tokens": self.max_tokens,
"temperature": self.openai_temperature,
},
)
response.raise_for_status()
result = response.json()
return result["choices"][0]["text"]
LLM.register("openai", OpenAILLM)

View File

@@ -1,219 +0,0 @@
from typing import Optional, TypeVar
from pydantic import BaseModel
from transformers import GenerationConfig
class TaskParams(BaseModel, arbitrary_types_allowed=True):
instruct: str
gen_cfg: Optional[GenerationConfig] = None
gen_schema: Optional[dict] = None
T = TypeVar("T", bound="LLMTaskParams")
class LLMTaskParams:
_registry = {}
@classmethod
def register(cls, task, klass) -> None:
cls._registry[task] = klass
@classmethod
def get_instance(cls, task: str) -> T:
return cls._registry[task]()
@property
def task_params(self) -> TaskParams | None:
"""
Fetch the task related parameters
"""
return self._get_task_params()
def _get_task_params(self) -> None:
pass
class FinalLongSummaryParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=1000, num_beams=3, do_sample=True, temperature=0.3
)
self._instruct = """
Take the key ideas and takeaways from the text and create a short
summary. Be sure to keep the length of the response to a minimum.
Do not include trivial information in the summary.
"""
self._schema = {
"type": "object",
"properties": {"long_summary": {"type": "string"}},
}
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""gen_schema
Return the parameters associated with a specific LLM task
"""
return self._task_params
class FinalShortSummaryParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=800, num_beams=3, do_sample=True, temperature=0.3
)
self._instruct = """
Take the key ideas and takeaways from the text and create a short
summary. Be sure to keep the length of the response to a minimum.
Do not include trivial information in the summary.
"""
self._schema = {
"type": "object",
"properties": {"short_summary": {"type": "string"}},
}
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""
Return the parameters associated with a specific LLM task
"""
return self._task_params
class FinalTitleParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=200, num_beams=5, do_sample=True, temperature=0.5
)
self._instruct = """
Combine the following individual titles into one single short title that
condenses the essence of all titles.
"""
self._schema = {
"type": "object",
"properties": {"title": {"type": "string"}},
}
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""
Return the parameters associated with a specific LLM task
"""
return self._task_params
class TopicParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=500, num_beams=6, do_sample=True, temperature=0.9
)
self._instruct = """
Create a JSON object as response.The JSON object must have 2 fields:
i) title and ii) summary.
For the title field, generate a very detailed and self-explanatory
title for the given text. Let the title be as descriptive as possible.
For the summary field, summarize the given text in a maximum of
two sentences.
"""
self._schema = {
"type": "object",
"properties": {
"title": {"type": "string"},
"summary": {"type": "string"},
},
}
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=self._schema, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""
Return the parameters associated with a specific LLM task
"""
return self._task_params
class BulletedSummaryParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=800,
num_beams=1,
do_sample=True,
temperature=0.2,
early_stopping=True,
)
self._instruct = """
Given a meeting transcript, extract the key things discussed in the
form of a list.
While generating the response, follow the constraints mentioned below.
Summary constraints:
i) Do not add new content, except to fix spelling or punctuation.
ii) Do not add any prefixes or numbering in the response.
iii) The summarization should be as information dense as possible.
iv) Do not add any additional sections like Note, Conclusion, etc. in
the response.
Response format:
i) The response should be in the form of a bulleted list.
ii) Iteratively merge all the relevant paragraphs together to keep the
number of paragraphs to a minimum.
iii) Remove any unfinished sentences from the final response.
iv) Do not include narrative or reporting clauses.
v) Use "*" as the bullet icon.
"""
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""gen_schema
Return the parameters associated with a specific LLM task
"""
return self._task_params
class MergedSummaryParams(LLMTaskParams):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._gen_cfg = GenerationConfig(
max_new_tokens=600,
num_beams=1,
do_sample=True,
temperature=0.2,
early_stopping=True,
)
self._instruct = """
Given the key points of a meeting, summarize the points to describe the
meeting in the form of paragraphs.
"""
self._task_params = TaskParams(
instruct=self._instruct, gen_schema=None, gen_cfg=self._gen_cfg
)
def _get_task_params(self) -> TaskParams:
"""gen_schema
Return the parameters associated with a specific LLM task
"""
return self._task_params
LLMTaskParams.register("topic", TopicParams)
LLMTaskParams.register("final_title", FinalTitleParams)
LLMTaskParams.register("final_short_summary", FinalShortSummaryParams)
LLMTaskParams.register("final_long_summary", FinalLongSummaryParams)
LLMTaskParams.register("bullet_summary", BulletedSummaryParams)
LLMTaskParams.register("merged_summary", MergedSummaryParams)

View File

@@ -1,118 +0,0 @@
import httpx
from transformers import AutoTokenizer
from reflector.logger import logger
def apply_gen_config(payload: dict, gen_cfg) -> None:
"""Apply generation config overrides to the payload."""
config_mapping = {
"temperature": "temperature",
"max_new_tokens": "max_tokens",
"max_tokens": "max_tokens",
"top_p": "top_p",
"frequency_penalty": "frequency_penalty",
"presence_penalty": "presence_penalty",
}
for cfg_attr, payload_key in config_mapping.items():
value = getattr(gen_cfg, cfg_attr, None)
if value is not None:
payload[payload_key] = value
if cfg_attr == "max_new_tokens": # Handle max_new_tokens taking precedence
break
class OpenAILLM:
def __init__(self, config_prefix: str, settings):
self.config_prefix = config_prefix
self.settings_obj = settings
self.model_name = getattr(settings, f"{config_prefix}_MODEL")
self.url = getattr(settings, f"{config_prefix}_LLM_URL")
self.api_key = getattr(settings, f"{config_prefix}_LLM_API_KEY")
timeout = getattr(settings, f"{config_prefix}_LLM_TIMEOUT", 300)
self.temperature = getattr(settings, f"{config_prefix}_LLM_TEMPERATURE", 0.7)
self.max_tokens = getattr(settings, f"{config_prefix}_LLM_MAX_TOKENS", 1024)
self.client = httpx.AsyncClient(timeout=timeout)
# Use a tokenizer that approximates OpenAI token counting
tokenizer_name = getattr(settings, f"{config_prefix}_TOKENIZER", "gpt2")
try:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
except Exception:
logger.debug(
f"Failed to load tokenizer '{tokenizer_name}', falling back to default 'gpt2' tokenizer"
)
self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
async def generate(
self, prompt: str, gen_schema=None, gen_cfg=None, logger=None
) -> str:
if logger:
logger.debug(
"OpenAI LLM generate",
prompt=repr(prompt[:100] + "..." if len(prompt) > 100 else prompt),
)
messages = [{"role": "user", "content": prompt}]
result = await self.completion(
messages, gen_schema=gen_schema, gen_cfg=gen_cfg, logger=logger
)
return result["choices"][0]["message"]["content"]
async def completion(
self, messages: list, gen_schema=None, gen_cfg=None, logger=None, **kwargs
) -> dict:
if logger:
logger.info("OpenAI LLM completion", messages_count=len(messages))
payload = {
"model": self.model_name,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
}
# Apply generation config overrides
if gen_cfg:
apply_gen_config(payload, gen_cfg)
# Apply structured output schema
if gen_schema:
payload["response_format"] = {
"type": "json_schema",
"json_schema": {"name": "response", "schema": gen_schema},
}
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
url = f"{self.url.rstrip('/')}/chat/completions"
if logger:
logger.debug(
"OpenAI API request", url=url, payload_keys=list(payload.keys())
)
response = await self.client.post(url, json=payload, headers=headers)
response.raise_for_status()
result = response.json()
if logger:
logger.debug(
"OpenAI API response",
status_code=response.status_code,
choices_count=len(result.get("choices", [])),
)
return result
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.client.aclose()

View File

@@ -14,12 +14,15 @@ It is directly linked to our data model.
import asyncio
import functools
from contextlib import asynccontextmanager
from typing import Generic
import av
import boto3
from celery import chord, current_task, group, shared_task
from pydantic import BaseModel
from structlog import BoundLogger as Logger
from reflector.db import get_database
from reflector.db.meetings import meeting_consent_controller, meetings_controller
from reflector.db.recordings import recordings_controller
from reflector.db.rooms import rooms_controller
@@ -35,7 +38,7 @@ from reflector.db.transcripts import (
transcripts_controller,
)
from reflector.logger import logger
from reflector.pipelines.runner import PipelineRunner
from reflector.pipelines.runner import PipelineMessage, PipelineRunner
from reflector.processors import (
AudioChunkerProcessor,
AudioDiarizationAutoProcessor,
@@ -47,7 +50,7 @@ from reflector.processors import (
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
TranscriptTranslatorAutoProcessor,
)
from reflector.processors.audio_waveform_processor import AudioWaveformProcessor
from reflector.processors.types import AudioDiarizationInput
@@ -69,8 +72,7 @@ def asynctask(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
async def run_with_db():
from reflector.db import database
database = get_database()
await database.connect()
try:
return await f(*args, **kwargs)
@@ -144,7 +146,7 @@ class StrValue(BaseModel):
value: str
class PipelineMainBase(PipelineRunner):
class PipelineMainBase(PipelineRunner[PipelineMessage], Generic[PipelineMessage]):
transcript_id: str
ws_room_id: str | None = None
ws_manager: WebsocketManager | None = None
@@ -164,7 +166,11 @@ class PipelineMainBase(PipelineRunner):
raise Exception("Transcript not found")
return result
def get_transcript_topics(self, transcript: Transcript) -> list[TranscriptTopic]:
@staticmethod
def wrap_transcript_topics(
topics: list[TranscriptTopic],
) -> list[TitleSummaryWithIdProcessorType]:
# transformation to a pipe-supported format
return [
TitleSummaryWithIdProcessorType(
id=topic.id,
@@ -174,7 +180,7 @@ class PipelineMainBase(PipelineRunner):
duration=topic.duration,
transcript=TranscriptProcessorType(words=topic.words),
)
for topic in transcript.topics
for topic in topics
]
@asynccontextmanager
@@ -361,7 +367,7 @@ class PipelineMainLive(PipelineMainBase):
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(callback=self.on_transcript),
TranscriptTranslatorAutoProcessor.as_threaded(callback=self.on_transcript),
TranscriptTopicDetectorProcessor.as_threaded(callback=self.on_topic),
]
pipeline = Pipeline(*processors)
@@ -380,7 +386,7 @@ class PipelineMainLive(PipelineMainBase):
pipeline_post(transcript_id=self.transcript_id)
class PipelineMainDiarization(PipelineMainBase):
class PipelineMainDiarization(PipelineMainBase[AudioDiarizationInput]):
"""
Diarize the audio and update topics
"""
@@ -404,11 +410,10 @@ class PipelineMainDiarization(PipelineMainBase):
pipeline.logger.info("Audio is local, skipping diarization")
return
topics = self.get_transcript_topics(transcript)
audio_url = await transcript.get_audio_url()
audio_diarization_input = AudioDiarizationInput(
audio_url=audio_url,
topics=topics,
topics=self.wrap_transcript_topics(transcript.topics),
)
# as tempting to use pipeline.push, prefer to use the runner
@@ -421,7 +426,7 @@ class PipelineMainDiarization(PipelineMainBase):
return pipeline
class PipelineMainFromTopics(PipelineMainBase):
class PipelineMainFromTopics(PipelineMainBase[TitleSummaryWithIdProcessorType]):
"""
Pseudo class for generating a pipeline from topics
"""
@@ -443,7 +448,7 @@ class PipelineMainFromTopics(PipelineMainBase):
pipeline.logger.info(f"{self.__class__.__name__} pipeline created")
# push topics
topics = self.get_transcript_topics(transcript)
topics = PipelineMainBase.wrap_transcript_topics(transcript.topics)
for topic in topics:
await self.push(topic)
@@ -524,8 +529,6 @@ async def pipeline_convert_to_mp3(transcript: Transcript, logger: Logger):
# Convert to mp3
mp3_filename = transcript.audio_mp3_filename
import av
with av.open(wav_filename.as_posix()) as in_container:
in_stream = in_container.streams.audio[0]
with av.open(mp3_filename.as_posix(), "w") as out_container:
@@ -604,7 +607,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
meeting.id
)
except Exception as e:
logger.error(f"Failed to get fetch consent: {e}")
logger.error(f"Failed to get fetch consent: {e}", exc_info=e)
consent_denied = True
if not consent_denied:
@@ -627,7 +630,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
f"Deleted original Whereby recording: {recording.bucket_name}/{recording.object_key}"
)
except Exception as e:
logger.error(f"Failed to delete Whereby recording: {e}")
logger.error(f"Failed to delete Whereby recording: {e}", exc_info=e)
# non-transactional, files marked for deletion not actually deleted is possible
await transcripts_controller.update(transcript, {"audio_deleted": True})
@@ -640,7 +643,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
f"Deleted processed audio from storage: {transcript.storage_audio_path}"
)
except Exception as e:
logger.error(f"Failed to delete processed audio: {e}")
logger.error(f"Failed to delete processed audio: {e}", exc_info=e)
# 3. Delete local audio files
try:
@@ -649,7 +652,7 @@ async def cleanup_consent(transcript: Transcript, logger: Logger):
if hasattr(transcript, "audio_wav_filename") and transcript.audio_wav_filename:
transcript.audio_wav_filename.unlink(missing_ok=True)
except Exception as e:
logger.error(f"Failed to delete local audio files: {e}")
logger.error(f"Failed to delete local audio files: {e}", exc_info=e)
logger.info("Consent cleanup done")
@@ -794,8 +797,6 @@ def pipeline_post(*, transcript_id: str):
@get_transcript
async def pipeline_process(transcript: Transcript, logger: Logger):
import av
try:
if transcript.audio_location == "storage":
await transcripts_controller.download_mp3_from_storage(transcript)

View File

@@ -16,14 +16,17 @@ During its lifecycle, it will emit the following status:
"""
import asyncio
from typing import Generic, TypeVar
from pydantic import BaseModel, ConfigDict
from reflector.logger import logger
from reflector.processors import Pipeline
PipelineMessage = TypeVar("PipelineMessage")
class PipelineRunner(BaseModel):
class PipelineRunner(BaseModel, Generic[PipelineMessage]):
model_config = ConfigDict(arbitrary_types_allowed=True)
status: str = "idle"
@@ -67,7 +70,7 @@ class PipelineRunner(BaseModel):
coro = self.run()
asyncio.run(coro)
async def push(self, data):
async def push(self, data: PipelineMessage):
"""
Push data to the pipeline
"""
@@ -92,7 +95,11 @@ class PipelineRunner(BaseModel):
pass
async def _add_cmd(
self, cmd: str, data, max_retries: int = 3, retry_time_limit: int = 3
self,
cmd: str,
data: PipelineMessage,
max_retries: int = 3,
retry_time_limit: int = 3,
):
"""
Enqueue a command to be executed in the runner.
@@ -143,7 +150,10 @@ class PipelineRunner(BaseModel):
cmd, data = await self._q_cmd.get()
func = getattr(self, f"cmd_{cmd.lower()}")
if func:
await func(data)
if cmd.upper() == "FLUSH":
await func()
else:
await func(data)
else:
raise Exception(f"Unknown command {cmd}")
except Exception:
@@ -152,13 +162,13 @@ class PipelineRunner(BaseModel):
self._ev_done.set()
raise
async def cmd_push(self, data):
async def cmd_push(self, data: PipelineMessage):
if self._is_first_push:
await self._set_status("push")
self._is_first_push = False
await self.pipeline.push(data)
async def cmd_flush(self, data):
async def cmd_flush(self):
await self._set_status("flush")
await self.pipeline.flush()
await self._set_status("ended")

View File

@@ -16,6 +16,7 @@ from .transcript_final_title import TranscriptFinalTitleProcessor # noqa: F401
from .transcript_liner import TranscriptLinerProcessor # noqa: F401
from .transcript_topic_detector import TranscriptTopicDetectorProcessor # noqa: F401
from .transcript_translator import TranscriptTranslatorProcessor # noqa: F401
from .transcript_translator_auto import TranscriptTranslatorAutoProcessor # noqa: F401
from .types import ( # noqa: F401
AudioFile,
FinalLongSummary,

View File

@@ -1,5 +1,9 @@
from reflector.processors.base import Processor
from reflector.processors.types import AudioDiarizationInput, TitleSummary, Word
from reflector.processors.types import (
AudioDiarizationInput,
TitleSummary,
Word,
)
class AudioDiarizationProcessor(Processor):

View File

@@ -10,12 +10,17 @@ class AudioDiarizationModalProcessor(AudioDiarizationProcessor):
INPUT_TYPE = AudioDiarizationInput
OUTPUT_TYPE = TitleSummary
def __init__(self, **kwargs):
def __init__(self, modal_api_key: str | None = None, **kwargs):
super().__init__(**kwargs)
if not settings.DIARIZATION_URL:
raise Exception(
"DIARIZATION_URL required to use AudioDiarizationModalProcessor"
)
self.diarization_url = settings.DIARIZATION_URL + "/diarize"
self.headers = {
"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}",
}
self.modal_api_key = modal_api_key
self.headers = {}
if self.modal_api_key:
self.headers["Authorization"] = f"Bearer {self.modal_api_key}"
async def _diarize(self, data: AudioDiarizationInput):
# Gather diarization data

View File

@@ -21,16 +21,20 @@ from reflector.settings import settings
class AudioTranscriptModalProcessor(AudioTranscriptProcessor):
def __init__(self, modal_api_key: str):
def __init__(self, modal_api_key: str | None = None, **kwargs):
super().__init__()
if not settings.TRANSCRIPT_URL:
raise Exception(
"TRANSCRIPT_URL required to use AudioTranscriptModalProcessor"
)
self.transcript_url = settings.TRANSCRIPT_URL + "/v1"
self.timeout = settings.TRANSCRIPT_TIMEOUT
self.api_key = settings.TRANSCRIPT_MODAL_API_KEY
self.modal_api_key = modal_api_key
async def _transcript(self, data: AudioFile):
async with AsyncOpenAI(
base_url=self.transcript_url,
api_key=self.api_key,
api_key=self.modal_api_key,
timeout=self.timeout,
) as client:
self.logger.debug(f"Try to transcribe audio {data.name}")

View File

@@ -6,21 +6,15 @@ This script is used to generate a summary of a meeting notes transcript.
import asyncio
import sys
from datetime import datetime
from datetime import datetime, timezone
from enum import Enum
from textwrap import dedent
from typing import Type, TypeVar
import structlog
from llama_index.core import Settings
from llama_index.core.output_parsers import PydanticOutputParser
from llama_index.core.program import LLMTextCompletionProgram
from llama_index.core.response_synthesizers import TreeSummarize
from llama_index.llms.openai_like import OpenAILike
from pydantic import BaseModel, Field
from reflector.llm.base import LLM
from reflector.llm.openai_llm import OpenAILLM
from reflector.llm import LLM
from reflector.settings import settings
T = TypeVar("T", bound=BaseModel)
@@ -168,23 +162,12 @@ class SummaryBuilder:
self.summaries: list[dict[str, str]] = []
self.subjects: list[str] = []
self.transcription_type: TranscriptionType | None = None
self.llm_instance: LLM = llm
self.llm: LLM = llm
self.model_name: str = llm.model_name
self.logger = logger or structlog.get_logger()
if filename:
self.read_transcript_from_file(filename)
Settings.llm = OpenAILike(
model=llm.model_name,
api_base=llm.url,
api_key=llm.api_key,
context_window=settings.SUMMARY_LLM_CONTEXT_SIZE_TOKENS,
is_chat_model=True,
is_function_calling_model=llm.has_structured_output,
temperature=llm.temperature,
max_tokens=llm.max_tokens,
)
def read_transcript_from_file(self, filename: str) -> None:
"""
Load a transcript from a text file.
@@ -202,40 +185,16 @@ class SummaryBuilder:
self.transcript = transcript
def set_llm_instance(self, llm: LLM) -> None:
self.llm_instance = llm
self.llm = llm
async def _get_structured_response(
self, prompt: str, output_cls: Type[T], tone_name: str | None = None
) -> Type[T]:
) -> T:
"""Generic function to get structured output from LLM for non-function-calling models."""
# First, use TreeSummarize to get the response
summarizer = TreeSummarize(verbose=True)
response = await summarizer.aget_response(
prompt, [self.transcript], tone_name=tone_name
return await self.llm.get_structured_response(
prompt, [self.transcript], output_cls, tone_name=tone_name
)
# Then, use PydanticOutputParser to structure the response
output_parser = PydanticOutputParser(output_cls)
prompt_template_str = STRUCTURED_RESPONSE_PROMPT_TEMPLATE
program = LLMTextCompletionProgram.from_defaults(
output_parser=output_parser,
prompt_template_str=prompt_template_str,
verbose=False,
)
format_instructions = output_parser.format(
"Please structure the above information in the following JSON format:"
)
output = await program.acall(
analysis=str(response), format_instructions=format_instructions
)
return output
# ----------------------------------------------------------------------------
# Participants
# ----------------------------------------------------------------------------
@@ -354,19 +313,18 @@ class SummaryBuilder:
async def generate_subject_summaries(self) -> None:
"""Generate detailed summaries for each extracted subject."""
assert self.transcript is not None
summarizer = TreeSummarize(verbose=False)
summaries = []
for subject in self.subjects:
detailed_prompt = DETAILED_SUBJECT_PROMPT_TEMPLATE.format(subject=subject)
detailed_response = await summarizer.aget_response(
detailed_response = await self.llm.get_response(
detailed_prompt, [self.transcript], tone_name="Topic assistant"
)
paragraph_prompt = PARAGRAPH_SUMMARY_PROMPT
paragraph_response = await summarizer.aget_response(
paragraph_response = await self.llm.get_response(
paragraph_prompt, [str(detailed_response)], tone_name="Topic summarizer"
)
@@ -377,7 +335,6 @@ class SummaryBuilder:
async def generate_recap(self) -> None:
"""Generate a quick recap from the subject summaries."""
summarizer = TreeSummarize(verbose=True)
summaries_text = "\n\n".join(
[
@@ -388,7 +345,7 @@ class SummaryBuilder:
recap_prompt = RECAP_PROMPT
recap_response = await summarizer.aget_response(
recap_response = await self.llm.get_response(
recap_prompt, [summaries_text], tone_name="Recap summarizer"
)
@@ -483,7 +440,7 @@ if __name__ == "__main__":
async def main():
# build the summary
llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
llm = LLM(settings=settings)
sm = SummaryBuilder(llm=llm, filename=args.transcript)
if args.subjects:
@@ -517,7 +474,7 @@ if __name__ == "__main__":
if args.save:
# write the summary to a file, on the format summary-<iso date>.md
filename = f"summary-{datetime.now().isoformat()}.md"
filename = f"summary-{datetime.now(timezone.utc).isoformat()}.md"
with open(filename, "w", encoding="utf-8") as f:
f.write(sm.as_markdown())

View File

@@ -1,4 +1,4 @@
from reflector.llm.openai_llm import OpenAILLM
from reflector.llm import LLM
from reflector.processors.base import Processor
from reflector.processors.summary.summary_builder import SummaryBuilder
from reflector.processors.types import FinalLongSummary, FinalShortSummary, TitleSummary
@@ -17,7 +17,7 @@ class TranscriptFinalSummaryProcessor(Processor):
super().__init__(**kwargs)
self.transcript = transcript
self.chunks: list[TitleSummary] = []
self.llm = OpenAILLM(config_prefix="SUMMARY", settings=settings)
self.llm = LLM(settings=settings)
self.builder = None
async def _push(self, data: TitleSummary):

View File

@@ -1,67 +1,72 @@
from reflector.llm import LLM, LLMTaskParams
from textwrap import dedent
from reflector.llm import LLM
from reflector.processors.base import Processor
from reflector.processors.types import FinalTitle, TitleSummary
from reflector.settings import settings
from reflector.utils.text import clean_title
TITLE_PROMPT = dedent(
"""
Generate a concise title for this meeting based on the following topic titles.
Ignore casual conversation, greetings, or administrative matters.
The title must:
- Be maximum 10 words
- Use noun phrases when possible (e.g., "Q1 Budget Review" not "Reviewing the Q1 Budget")
- Avoid generic terms like "Team Meeting" or "Discussion"
If multiple unrelated topics were discussed, prioritize the most significant one.
or create a compound title (e.g., "Product Launch and Budget Planning").
<topics_discussed>
{titles}
</topics_discussed>
Do not explain, just output the meeting title as a single line.
"""
).strip()
class TranscriptFinalTitleProcessor(Processor):
"""
Assemble all summary into a line-based json
Generate a final title from topic titles using LlamaIndex
"""
INPUT_TYPE = TitleSummary
OUTPUT_TYPE = FinalTitle
TASK = "final_title"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.chunks: list[TitleSummary] = []
self.llm = LLM.get_instance()
self.params = LLMTaskParams.get_instance(self.TASK).task_params
self.llm = LLM(settings=settings, temperature=0.5, max_tokens=200)
async def _push(self, data: TitleSummary):
self.chunks.append(data)
async def get_title(self, text: str) -> dict:
async def get_title(self, accumulated_titles: str) -> str:
"""
Generate a title for the whole recording
Generate a title for the whole recording using LLM
"""
chunks = list(self.llm.split_corpus(corpus=text, task_params=self.params))
prompt = TITLE_PROMPT.format(titles=accumulated_titles)
response = await self.llm.get_response(
prompt,
[accumulated_titles],
tone_name="Title generator",
)
if len(chunks) == 1:
chunk = chunks[0]
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=chunk)
title_result = await self.llm.generate(
prompt=prompt,
gen_schema=self.params.gen_schema,
gen_cfg=self.params.gen_cfg,
logger=self.logger,
)
return title_result
else:
accumulated_titles = ""
for chunk in chunks:
prompt = self.llm.create_prompt(
instruct=self.params.instruct, text=chunk
)
title_result = await self.llm.generate(
prompt=prompt,
gen_schema=self.params.gen_schema,
gen_cfg=self.params.gen_cfg,
logger=self.logger,
)
accumulated_titles += title_result["title"]
self.logger.info(f"Generated title response: {response}")
return await self.get_title(accumulated_titles)
return response
async def _flush(self):
if not self.chunks:
self.logger.warning("No summary to output")
return
accumulated_titles = ".".join([chunk.title for chunk in self.chunks])
title_result = await self.get_title(accumulated_titles)
final_title = self.llm.trim_title(title_result["title"])
final_title = self.llm.ensure_casing(final_title)
accumulated_titles = "\n".join([f"- {chunk.title}" for chunk in self.chunks])
title = await self.get_title(accumulated_titles)
title = clean_title(title)
final_title = FinalTitle(title=final_title)
final_title = FinalTitle(title=title)
await self.emit(final_title)

View File

@@ -1,7 +1,41 @@
from reflector.llm import LLM, LLMTaskParams
from textwrap import dedent
from pydantic import BaseModel, Field
from reflector.llm import LLM
from reflector.processors.base import Processor
from reflector.processors.types import TitleSummary, Transcript
from reflector.settings import settings
from reflector.utils.text import clean_title
TOPIC_PROMPT = dedent(
"""
Analyze the following transcript segment and extract the main topic being discussed.
Focus on the substantive content and ignore small talk or administrative chatter.
Create a title that:
- Captures the specific subject matter being discussed
- Is descriptive and self-explanatory
- Uses professional language
- Is specific rather than generic
For the summary:
- Summarize the key points in maximum two sentences
- Focus on what was discussed, decided, or accomplished
- Be concise but informative
<transcript>
{text}
</transcript>
"""
).strip()
class TopicResponse(BaseModel):
"""Structured response for topic detection"""
title: str = Field(description="A descriptive title for the topic being discussed")
summary: str = Field(description="A concise 1-2 sentence summary of the discussion")
class TranscriptTopicDetectorProcessor(Processor):
@@ -11,7 +45,6 @@ class TranscriptTopicDetectorProcessor(Processor):
INPUT_TYPE = Transcript
OUTPUT_TYPE = TitleSummary
TASK = "topic"
def __init__(
self, min_transcript_length: int = int(settings.MIN_TRANSCRIPT_LENGTH), **kwargs
@@ -19,8 +52,7 @@ class TranscriptTopicDetectorProcessor(Processor):
super().__init__(**kwargs)
self.transcript = None
self.min_transcript_length = min_transcript_length
self.llm = LLM.get_instance()
self.params = LLMTaskParams.get_instance(self.TASK).task_params
self.llm = LLM(settings=settings, temperature=0.9, max_tokens=500)
async def _push(self, data: Transcript):
if self.transcript is None:
@@ -34,18 +66,15 @@ class TranscriptTopicDetectorProcessor(Processor):
return
await self.flush()
async def get_topic(self, text: str) -> dict:
async def get_topic(self, text: str) -> TopicResponse:
"""
Generate a topic and description for a transcription excerpt
Generate a topic and description for a transcription excerpt using LLM
"""
prompt = self.llm.create_prompt(instruct=self.params.instruct, text=text)
topic_result = await self.llm.generate(
prompt=prompt,
gen_schema=self.params.gen_schema,
gen_cfg=self.params.gen_cfg,
logger=self.logger,
prompt = TOPIC_PROMPT.format(text=text)
response = await self.llm.get_structured_response(
prompt, [text], TopicResponse, tone_name="Topic analyzer"
)
return topic_result
return response
async def _flush(self):
if not self.transcript:
@@ -53,13 +82,13 @@ class TranscriptTopicDetectorProcessor(Processor):
text = self.transcript.text
self.logger.info(f"Topic detector got {len(text)} length transcript")
topic_result = await self.get_topic(text=text)
title = self.llm.trim_title(topic_result["title"])
title = self.llm.ensure_casing(title)
title = clean_title(topic_result.title)
summary = TitleSummary(
title=title,
summary=topic_result["summary"],
summary=topic_result.summary,
timestamp=self.transcript.timestamp,
duration=self.transcript.duration,
transcript=self.transcript,

View File

@@ -1,9 +1,5 @@
import httpx
from reflector.processors.base import Processor
from reflector.processors.types import Transcript, TranslationLanguages
from reflector.settings import settings
from reflector.utils.retry import retry
from reflector.processors.types import Transcript
class TranscriptTranslatorProcessor(Processor):
@@ -13,61 +9,27 @@ class TranscriptTranslatorProcessor(Processor):
INPUT_TYPE = Transcript
OUTPUT_TYPE = Transcript
TASK = "translate"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.transcript = None
self.translate_url = settings.TRANSLATE_URL
self.timeout = settings.TRANSLATE_TIMEOUT
self.headers = {"Authorization": f"Bearer {settings.LLM_MODAL_API_KEY}"}
async def _push(self, data: Transcript):
self.transcript = data
await self.flush()
async def get_translation(self, text: str) -> str | None:
# FIXME this should be a processor after, as each user may want
# different languages
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
if source_language == target_language:
return
languages = TranslationLanguages()
# Only way to set the target should be the UI element like dropdown.
# Hence, this assert should never fail.
assert languages.is_supported(target_language)
self.logger.debug(f"Try to translate {text=}")
json_payload = {
"text": text,
"source_language": source_language,
"target_language": target_language,
}
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.translate_url + "/translate",
headers=self.headers,
params=json_payload,
timeout=self.timeout,
follow_redirects=True,
logger=self.logger,
)
response.raise_for_status()
result = response.json()["text"]
# Sanity check for translation status in the result
if target_language in result:
translation = result[target_language]
self.logger.debug(f"Translation response: {text=}, {translation=}")
return translation
async def _translate(self, text: str) -> str | None:
raise NotImplementedError
async def _flush(self):
if not self.transcript:
return
self.transcript.translation = await self.get_translation(
text=self.transcript.text
)
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
if source_language == target_language:
self.transcript.translation = None
else:
self.transcript.translation = await self._translate(self.transcript.text)
await self.emit(self.transcript)

View File

@@ -0,0 +1,32 @@
import importlib
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
from reflector.settings import settings
class TranscriptTranslatorAutoProcessor(TranscriptTranslatorProcessor):
_registry = {}
@classmethod
def register(cls, name, kclass):
cls._registry[name] = kclass
def __new__(cls, name: str | None = None, **kwargs):
if name is None:
name = settings.TRANSLATION_BACKEND
if name not in cls._registry:
module_name = f"reflector.processors.transcript_translator_{name}"
importlib.import_module(module_name)
# gather specific configuration for the processor
# search `TRANSLATION_BACKEND_XXX_YYY`, push to constructor as `backend_xxx_yyy`
config = {}
name_upper = name.upper()
settings_prefix = "TRANSLATION_"
config_prefix = f"{settings_prefix}{name_upper}_"
for key, value in settings:
if key.startswith(config_prefix):
config_name = key[len(settings_prefix) :].lower()
config[config_name] = value
return cls._registry[name](**config | kwargs)

View File

@@ -0,0 +1,66 @@
import httpx
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
from reflector.processors.transcript_translator_auto import (
TranscriptTranslatorAutoProcessor,
)
from reflector.processors.types import TranslationLanguages
from reflector.settings import settings
from reflector.utils.retry import retry
class TranscriptTranslatorModalProcessor(TranscriptTranslatorProcessor):
"""
Translate the transcript into the target language using Modal.com
"""
def __init__(self, modal_api_key: str | None = None, **kwargs):
super().__init__(**kwargs)
if not settings.TRANSLATE_URL:
raise Exception(
"TRANSLATE_URL is required for TranscriptTranslatorModalProcessor"
)
self.translate_url = settings.TRANSLATE_URL
self.timeout = settings.TRANSLATE_TIMEOUT
self.modal_api_key = modal_api_key
self.headers = {}
if self.modal_api_key:
self.headers["Authorization"] = f"Bearer {self.modal_api_key}"
async def _translate(self, text: str) -> str | None:
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
languages = TranslationLanguages()
# Only way to set the target should be the UI element like dropdown.
# Hence, this assert should never fail.
assert languages.is_supported(target_language)
self.logger.debug(f"Try to translate {text=}")
json_payload = {
"text": text,
"source_language": source_language,
"target_language": target_language,
}
async with httpx.AsyncClient() as client:
response = await retry(client.post)(
self.translate_url + "/translate",
headers=self.headers,
params=json_payload,
timeout=self.timeout,
follow_redirects=True,
logger=self.logger,
)
response.raise_for_status()
result = response.json()["text"]
# Sanity check for translation status in the result
if target_language in result:
translation = result[target_language]
else:
translation = None
self.logger.debug(f"Translation response: {text=}, {translation=}")
return translation
TranscriptTranslatorAutoProcessor.register("modal", TranscriptTranslatorModalProcessor)

View File

@@ -0,0 +1,14 @@
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
from reflector.processors.transcript_translator_auto import (
TranscriptTranslatorAutoProcessor,
)
class TranscriptTranslatorPassthroughProcessor(TranscriptTranslatorProcessor):
async def _translate(self, text: str) -> None:
return None
TranscriptTranslatorAutoProcessor.register(
"passthrough", TranscriptTranslatorPassthroughProcessor
)

View File

@@ -2,9 +2,10 @@ import io
import re
import tempfile
from pathlib import Path
from typing import Annotated
from profanityfilter import ProfanityFilter
from pydantic import BaseModel, PrivateAttr
from pydantic import BaseModel, Field, PrivateAttr
from reflector.redis_cache import redis_cache
@@ -48,20 +49,70 @@ class AudioFile(BaseModel):
self._path.unlink()
# non-negative seconds with float part
Seconds = Annotated[float, Field(ge=0.0, description="Time in seconds with float part")]
class Word(BaseModel):
text: str
start: float
end: float
start: Seconds
end: Seconds
speaker: int = 0
class TranscriptSegment(BaseModel):
text: str
start: float
end: float
start: Seconds
end: Seconds
speaker: int = 0
def words_to_segments(words: list[Word]) -> list[TranscriptSegment]:
# from a list of word, create a list of segments
# join the word that are less than 2 seconds apart
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
segments = []
current_segment = None
MAX_SEGMENT_LENGTH = 120
for word in words:
if current_segment is None:
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# If the word is attach to another speaker, push the current segment
# and start a new one
if word.speaker != current_segment.speaker:
segments.append(current_segment)
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# if the word is the end of a sentence, and we have enough content,
# add the word to the current segment and push it
current_segment.text += word.text
current_segment.end = word.end
have_punc = PUNC_RE.search(word.text)
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
segments.append(current_segment)
current_segment = None
if current_segment:
segments.append(current_segment)
return segments
class Transcript(BaseModel):
translation: str | None = None
words: list[Word] = None
@@ -117,49 +168,7 @@ class Transcript(BaseModel):
return Transcript(text=self.text, translation=self.translation, words=words)
def as_segments(self) -> list[TranscriptSegment]:
# from a list of word, create a list of segments
# join the word that are less than 2 seconds apart
# but separate if the speaker changes, or if the punctuation is a . , ; : ? !
segments = []
current_segment = None
MAX_SEGMENT_LENGTH = 120
for word in self.words:
if current_segment is None:
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# If the word is attach to another speaker, push the current segment
# and start a new one
if word.speaker != current_segment.speaker:
segments.append(current_segment)
current_segment = TranscriptSegment(
text=word.text,
start=word.start,
end=word.end,
speaker=word.speaker,
)
continue
# if the word is the end of a sentence, and we have enough content,
# add the word to the current segment and push it
current_segment.text += word.text
current_segment.end = word.end
have_punc = PUNC_RE.search(word.text)
if have_punc and (len(current_segment.text) > MAX_SEGMENT_LENGTH):
segments.append(current_segment)
current_segment = None
if current_segment:
segments.append(current_segment)
return segments
return words_to_segments(self.words)
class TitleSummary(BaseModel):

View File

@@ -0,0 +1,296 @@
import hashlib
from datetime import date, datetime, timedelta, timezone
from typing import TypedDict
import httpx
import pytz
from icalendar import Calendar, Event
from loguru import logger
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
from reflector.db.rooms import Room, rooms_controller
from reflector.settings import settings
class AttendeeData(TypedDict, total=False):
email: str | None
name: str | None
status: str | None
role: str | None
class EventData(TypedDict):
ics_uid: str
title: str | None
description: str | None
location: str | None
start_time: datetime
end_time: datetime
attendees: list[AttendeeData]
ics_raw_data: str
class SyncStats(TypedDict):
events_created: int
events_updated: int
events_deleted: int
class ICSFetchService:
def __init__(self):
self.client = httpx.AsyncClient(
timeout=30.0, headers={"User-Agent": "Reflector/1.0"}
)
async def fetch_ics(self, url: str) -> str:
response = await self.client.get(url)
response.raise_for_status()
return response.text
def parse_ics(self, ics_content: str) -> Calendar:
return Calendar.from_ical(ics_content)
def extract_room_events(
self, calendar: Calendar, room_name: str, room_url: str
) -> list[EventData]:
events = []
now = datetime.now(timezone.utc)
window_start = now - timedelta(hours=1)
window_end = now + timedelta(hours=24)
for component in calendar.walk():
if component.name == "VEVENT":
# Skip cancelled events
status = component.get("STATUS", "").upper()
if status == "CANCELLED":
continue
# Check if event matches this room
if self._event_matches_room(component, room_name, room_url):
event_data = self._parse_event(component)
# Only include events in our time window
if (
event_data
and window_start <= event_data["start_time"] <= window_end
):
events.append(event_data)
return events
def _event_matches_room(self, event: Event, room_name: str, room_url: str) -> bool:
location = str(event.get("LOCATION", ""))
description = str(event.get("DESCRIPTION", ""))
# Only match full room URL (with or without protocol)
patterns = [
room_url, # Full URL with protocol
room_url.replace("https://", ""), # Without https protocol
room_url.replace("http://", ""), # Without http protocol
]
# Check location and description for patterns
text_to_check = f"{location} {description}".lower()
for pattern in patterns:
if pattern.lower() in text_to_check:
return True
return False
def _parse_event(self, event: Event) -> EventData | None:
# Extract basic fields
uid = str(event.get("UID", ""))
summary = str(event.get("SUMMARY", ""))
description = str(event.get("DESCRIPTION", ""))
location = str(event.get("LOCATION", ""))
# Parse dates
dtstart = event.get("DTSTART")
dtend = event.get("DTEND")
if not dtstart:
return None
# Convert to datetime
start_time = self._normalize_datetime(
dtstart.dt if hasattr(dtstart, "dt") else dtstart
)
end_time = (
self._normalize_datetime(dtend.dt if hasattr(dtend, "dt") else dtend)
if dtend
else start_time + timedelta(hours=1)
)
# Parse attendees
attendees = self._parse_attendees(event)
# Get raw event data for storage
raw_data = event.to_ical().decode("utf-8")
return {
"ics_uid": uid,
"title": summary,
"description": description,
"location": location,
"start_time": start_time,
"end_time": end_time,
"attendees": attendees,
"ics_raw_data": raw_data,
}
def _normalize_datetime(self, dt) -> datetime:
# Handle date objects (all-day events)
if isinstance(dt, date) and not isinstance(dt, datetime):
# Convert to datetime at start of day in UTC
dt = datetime.combine(dt, datetime.min.time())
dt = pytz.UTC.localize(dt)
elif isinstance(dt, datetime):
# Add UTC timezone if naive
if dt.tzinfo is None:
dt = pytz.UTC.localize(dt)
else:
# Convert to UTC
dt = dt.astimezone(pytz.UTC)
return dt
def _parse_attendees(self, event: Event) -> list[AttendeeData]:
attendees = []
# Parse ATTENDEE properties
for attendee in event.get("ATTENDEE", []):
if not isinstance(attendee, list):
attendee = [attendee]
for att in attendee:
att_data: AttendeeData = {
"email": str(att).replace("mailto:", "") if att else None,
"name": att.params.get("CN") if hasattr(att, "params") else None,
"status": att.params.get("PARTSTAT")
if hasattr(att, "params")
else None,
"role": att.params.get("ROLE") if hasattr(att, "params") else None,
}
attendees.append(att_data)
# Add organizer
organizer = event.get("ORGANIZER")
if organizer:
org_data: AttendeeData = {
"email": str(organizer).replace("mailto:", "") if organizer else None,
"name": organizer.params.get("CN")
if hasattr(organizer, "params")
else None,
"role": "ORGANIZER",
}
attendees.append(org_data)
return attendees
class ICSSyncService:
def __init__(self):
self.fetch_service = ICSFetchService()
async def sync_room_calendar(self, room: Room) -> dict:
if not room.ics_enabled or not room.ics_url:
return {"status": "skipped", "reason": "ICS not configured"}
try:
# Check if it's time to sync
if not self._should_sync(room):
return {"status": "skipped", "reason": "Not time to sync yet"}
# Fetch ICS file
ics_content = await self.fetch_service.fetch_ics(room.ics_url)
# Check if content changed
content_hash = hashlib.md5(ics_content.encode()).hexdigest()
if room.ics_last_etag == content_hash:
logger.info(f"No changes in ICS for room {room.id}")
return {"status": "unchanged", "hash": content_hash}
# Parse calendar
calendar = self.fetch_service.parse_ics(ics_content)
# Build room URL
room_url = f"{settings.BASE_URL}/room/{room.name}"
# Extract matching events
events = self.fetch_service.extract_room_events(
calendar, room.name, room_url
)
# Sync events to database
sync_result = await self._sync_events_to_database(room.id, events)
# Update room sync metadata
await rooms_controller.update(
room,
{
"ics_last_sync": datetime.now(timezone.utc),
"ics_last_etag": content_hash,
},
mutate=False,
)
return {
"status": "success",
"hash": content_hash,
"events_found": len(events),
**sync_result,
}
except Exception as e:
logger.error(f"Failed to sync ICS for room {room.id}: {e}")
return {"status": "error", "error": str(e)}
def _should_sync(self, room: Room) -> bool:
if not room.ics_last_sync:
return True
time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync
return time_since_sync.total_seconds() >= room.ics_fetch_interval
async def _sync_events_to_database(
self, room_id: str, events: list[EventData]
) -> SyncStats:
created = 0
updated = 0
# Track current event IDs
current_ics_uids = []
for event_data in events:
# Create CalendarEvent object
calendar_event = CalendarEvent(room_id=room_id, **event_data)
# Upsert event
existing = await calendar_events_controller.get_by_ics_uid(
room_id, event_data["ics_uid"]
)
if existing:
updated += 1
else:
created += 1
await calendar_events_controller.upsert(calendar_event)
current_ics_uids.append(event_data["ics_uid"])
# Soft delete events that are no longer in calendar
deleted = await calendar_events_controller.soft_delete_missing(
room_id, current_ics_uids
)
return {
"events_created": created,
"events_updated": updated,
"events_deleted": deleted,
}
# Global instance
ics_sync_service = ICSSyncService()

View File

@@ -9,13 +9,16 @@ class Settings(BaseSettings):
)
# CORS
UI_BASE_URL: str = "http://localhost:3000"
CORS_ORIGIN: str = "*"
CORS_ALLOW_CREDENTIALS: bool = False
# Database
DATABASE_URL: str = "sqlite:///./reflector.sqlite3"
DATABASE_URL: str = (
"postgresql+asyncpg://reflector:reflector@localhost:5432/reflector"
)
# local data directory (audio for no)
# local data directory
DATA_DIR: str = "./data"
# Audio Transcription
@@ -24,11 +27,7 @@ class Settings(BaseSettings):
TRANSCRIPT_URL: str | None = None
TRANSCRIPT_TIMEOUT: int = 90
# Translate into the target language
TRANSLATE_URL: str | None = None
TRANSLATE_TIMEOUT: int = 90
# Audio transcription modal.com configuration
# Audio Transcription: modal backend
TRANSCRIPT_MODAL_API_KEY: str | None = None
# Audio transcription storage
@@ -40,37 +39,37 @@ class Settings(BaseSettings):
TRANSCRIPT_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
TRANSCRIPT_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Recording storage
RECORDING_STORAGE_BACKEND: str | None = None
# Recording storage configuration for AWS
RECORDING_STORAGE_AWS_BUCKET_NAME: str = "recording-bucket"
RECORDING_STORAGE_AWS_REGION: str = "us-east-1"
RECORDING_STORAGE_AWS_ACCESS_KEY_ID: str | None = None
RECORDING_STORAGE_AWS_SECRET_ACCESS_KEY: str | None = None
# Translate into the target language
TRANSLATION_BACKEND: str = "passthrough"
TRANSLATE_URL: str | None = None
TRANSLATE_TIMEOUT: int = 90
# Translation: modal backend
TRANSLATE_MODAL_API_KEY: str | None = None
# LLM
# available backend: openai, modal
LLM_BACKEND: str = "modal"
# LLM common configuration
LLM_MODEL: str = "microsoft/phi-4"
LLM_URL: str | None = None
LLM_HOST: str = "localhost"
LLM_PORT: int = 7860
LLM_OPENAI_KEY: str | None = None
LLM_OPENAI_MODEL: str = "gpt-3.5-turbo"
LLM_OPENAI_TEMPERATURE: float = 0.7
LLM_TIMEOUT: int = 60 * 5 # take cold start into account
LLM_MAX_TOKENS: int = 1024
LLM_TEMPERATURE: float = 0.7
ZEPHYR_LLM_URL: str | None = None
HERMES_3_8B_LLM_URL: str | None = None
# LLM Modal configuration
LLM_MODAL_API_KEY: str | None = None
# per-task cases
SUMMARY_MODEL: str = "monadical/private/smart"
SUMMARY_LLM_URL: str | None = None
SUMMARY_LLM_API_KEY: str | None = None
SUMMARY_LLM_CONTEXT_SIZE_TOKENS: int = 16000
LLM_API_KEY: str | None = None
LLM_CONTEXT_WINDOW: int = 16000
# Diarization
DIARIZATION_ENABLED: bool = True
DIARIZATION_BACKEND: str = "modal"
DIARIZATION_URL: str | None = None
# Diarization: modal backend
DIARIZATION_MODAL_API_KEY: str | None = None
# Sentry
SENTRY_DSN: str | None = None
@@ -86,12 +85,6 @@ class Settings(BaseSettings):
# if set, all anonymous record will be public
PUBLIC_MODE: bool = False
# Default LLM model name
DEFAULT_LLM: str = "lmsys/vicuna-13b-v1.5"
# Cache directory for all model storage
CACHE_DIR: str = "./data"
# Min transcript length to generate topic + summary
MIN_TRANSCRIPT_LENGTH: int = 750
@@ -116,24 +109,19 @@ class Settings(BaseSettings):
# Healthcheck
HEALTHCHECK_URL: str | None = None
# Whereby integration
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
WHEREBY_API_KEY: str | None = None
WHEREBY_WEBHOOK_SECRET: str | None = None
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
AWS_PROCESS_RECORDING_QUEUE_URL: str | None = None
SQS_POLLING_TIMEOUT_SECONDS: int = 60
WHEREBY_API_URL: str = "https://api.whereby.dev/v1"
WHEREBY_API_KEY: str | None = None
AWS_WHEREBY_S3_BUCKET: str | None = None
AWS_WHEREBY_ACCESS_KEY_ID: str | None = None
AWS_WHEREBY_ACCESS_KEY_SECRET: str | None = None
# Zulip integration
ZULIP_REALM: str | None = None
ZULIP_API_KEY: str | None = None
ZULIP_BOT_EMAIL: str | None = None
UI_BASE_URL: str = "http://localhost:3000"
WHEREBY_WEBHOOK_SECRET: str | None = None
settings = Settings()

View File

@@ -1,10 +1,17 @@
from .base import Storage # noqa
from reflector.settings import settings
def get_transcripts_storage() -> Storage:
from reflector.settings import settings
assert settings.TRANSCRIPT_STORAGE_BACKEND
return Storage.get_instance(
name=settings.TRANSCRIPT_STORAGE_BACKEND,
settings_prefix="TRANSCRIPT_STORAGE_",
)
def get_recordings_storage() -> Storage:
return Storage.get_instance(
name=settings.RECORDING_STORAGE_BACKEND,
settings_prefix="RECORDING_STORAGE_",
)

View File

@@ -9,8 +9,9 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import database, transcripts
from reflector.db import get_database, transcripts
database = get_database()
await database.connect()
transcripts = await database.fetch_all(transcripts.select())
await database.disconnect()

View File

@@ -8,8 +8,9 @@ async def export_db(filename: str) -> None:
filename = pathlib.Path(filename).resolve()
settings.DATABASE_URL = f"sqlite:///{filename}"
from reflector.db import database, transcripts
from reflector.db import get_database, transcripts
database = get_database()
await database.connect()
transcripts = await database.fetch_all(transcripts.select())
await database.disconnect()

View File

@@ -13,7 +13,7 @@ from reflector.processors import (
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
TranscriptTranslatorAutoProcessor,
)
from reflector.processors.base import BroadcastProcessor
@@ -31,7 +31,7 @@ async def process_audio_file(
AudioMergeProcessor(),
AudioTranscriptAutoProcessor.as_threaded(),
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(),
TranscriptTranslatorAutoProcessor.as_threaded(),
]
if not only_transcript:
processors += [

View File

@@ -27,7 +27,7 @@ from reflector.processors import (
TranscriptFinalTitleProcessor,
TranscriptLinerProcessor,
TranscriptTopicDetectorProcessor,
TranscriptTranslatorProcessor,
TranscriptTranslatorAutoProcessor,
)
from reflector.processors.base import BroadcastProcessor, Processor
from reflector.processors.types import (
@@ -103,7 +103,7 @@ async def process_audio_file_with_diarization(
processors += [
TranscriptLinerProcessor(),
TranscriptTranslatorProcessor.as_threaded(),
TranscriptTranslatorAutoProcessor.as_threaded(),
]
if not only_transcript:
@@ -145,18 +145,17 @@ async def process_audio_file_with_diarization(
logger.info(f"Starting diarization with {len(topics)} topics")
try:
# Import diarization processor
from reflector.processors import AudioDiarizationAutoProcessor
# Create diarization processor
diarization_processor = AudioDiarizationAutoProcessor(
name=diarization_backend
)
diarization_processor.on(event_callback)
diarization_processor.set_pipeline(pipeline)
# For Modal backend, we need to upload the file to S3 first
if diarization_backend == "modal":
from datetime import datetime
from datetime import datetime, timezone
from reflector.storage import get_transcripts_storage
from reflector.utils.s3_temp_file import S3TemporaryFile
@@ -164,7 +163,7 @@ async def process_audio_file_with_diarization(
storage = get_transcripts_storage()
# Generate a unique filename in evaluation folder
timestamp = datetime.utcnow().strftime("%Y%m%d_%H%M%S")
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
audio_filename = f"evaluation/diarization_temp/{timestamp}_{uuid.uuid4().hex}.wav"
# Use context manager for automatic cleanup

View File

@@ -0,0 +1,33 @@
def clean_title(title: str) -> str:
"""
Clean and format a title string for consistent capitalization.
Rules:
- Strip surrounding quotes (single or double)
- Capitalize the first word
- Capitalize words longer than 3 characters
- Keep words with 3 or fewer characters lowercase (except first word)
Args:
title: The title string to clean
Returns:
The cleaned title with consistent capitalization
Examples:
>>> clean_title("hello world")
"Hello World"
>>> clean_title("meeting with the team")
"Meeting With the Team"
>>> clean_title("'Title with quotes'")
"Title With Quotes"
"""
title = title.strip("\"'")
words = title.split()
if words:
words = [
word.capitalize() if i == 0 or len(word) > 3 else word.lower()
for i, word in enumerate(words)
]
title = " ".join(words)
return title

View File

@@ -0,0 +1,63 @@
"""WebVTT utilities for generating subtitle files from transcript data."""
from typing import TYPE_CHECKING, Annotated
import webvtt
from reflector.processors.types import Seconds, Word, words_to_segments
if TYPE_CHECKING:
from reflector.db.transcripts import TranscriptTopic
VttTimestamp = Annotated[str, "vtt_timestamp"]
WebVTTStr = Annotated[str, "webvtt_str"]
def _seconds_to_timestamp(seconds: Seconds) -> VttTimestamp:
# lib doesn't do that
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
milliseconds = int((seconds % 1) * 1000)
return f"{hours:02d}:{minutes:02d}:{secs:02d}.{milliseconds:03d}"
def words_to_webvtt(words: list[Word]) -> WebVTTStr:
"""Convert words to WebVTT using existing segmentation logic."""
vtt = webvtt.WebVTT()
if not words:
return vtt.content
segments = words_to_segments(words)
for segment in segments:
text = segment.text.strip()
# lib doesn't do that
text = f"<v Speaker{segment.speaker}>{text}"
caption = webvtt.Caption(
start=_seconds_to_timestamp(segment.start),
end=_seconds_to_timestamp(segment.end),
text=text,
)
vtt.captions.append(caption)
return vtt.content
def topics_to_webvtt(topics: list["TranscriptTopic"]) -> WebVTTStr:
if not topics:
return webvtt.WebVTT().content
all_words: list[Word] = []
for topic in topics:
all_words.extend(topic.words)
# assert it's in sequence
for i in range(len(all_words) - 1):
assert (
all_words[i].start <= all_words[i + 1].start
), f"Words are not in sequence: {all_words[i].text} and {all_words[i + 1].text} are not consecutive: {all_words[i].start} > {all_words[i + 1].start}"
return words_to_webvtt(all_words)

View File

@@ -44,8 +44,6 @@ def range_requests_response(
"""Returns StreamingResponse using Range Requests of a given file"""
if not os.path.exists(file_path):
from fastapi import HTTPException
raise HTTPException(status_code=404, detail="File not found")
file_size = os.stat(file_path).st_size

View File

@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timezone
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Request
@@ -35,7 +35,7 @@ async def meeting_audio_consent(
meeting_id=meeting_id,
user_id=user_id,
consent_given=request.consent_given,
consent_timestamp=datetime.utcnow(),
consent_timestamp=datetime.now(timezone.utc),
)
updated_consent = await meeting_consent_controller.upsert(consent)

View File

@@ -1,16 +1,16 @@
import logging
import sqlite3
from datetime import datetime, timedelta
from datetime import datetime, timedelta, timezone
from typing import Annotated, Literal, Optional
import asyncpg.exceptions
from fastapi import APIRouter, Depends, HTTPException
from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate
from fastapi_pagination.ext.databases import apaginate
from pydantic import BaseModel
import reflector.auth as auth
from reflector.db import database
from reflector.db import get_database
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.settings import settings
@@ -21,6 +21,14 @@ logger = logging.getLogger(__name__)
router = APIRouter()
def parse_datetime_with_timezone(iso_string: str) -> datetime:
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
class Room(BaseModel):
id: str
name: str
@@ -34,6 +42,11 @@ class Room(BaseModel):
recording_type: str
recording_trigger: str
is_shared: bool
ics_url: Optional[str] = None
ics_fetch_interval: int = 300
ics_enabled: bool = False
ics_last_sync: Optional[datetime] = None
ics_last_etag: Optional[str] = None
class Meeting(BaseModel):
@@ -56,18 +69,24 @@ class CreateRoom(BaseModel):
recording_type: str
recording_trigger: str
is_shared: bool
ics_url: Optional[str] = None
ics_fetch_interval: int = 300
ics_enabled: bool = False
class UpdateRoom(BaseModel):
name: str
zulip_auto_post: bool
zulip_stream: str
zulip_topic: str
is_locked: bool
room_mode: str
recording_type: str
recording_trigger: str
is_shared: bool
name: Optional[str] = None
zulip_auto_post: Optional[bool] = None
zulip_stream: Optional[str] = None
zulip_topic: Optional[str] = None
is_locked: Optional[bool] = None
room_mode: Optional[str] = None
recording_type: Optional[str] = None
recording_trigger: Optional[str] = None
is_shared: Optional[bool] = None
ics_url: Optional[str] = None
ics_fetch_interval: Optional[int] = None
ics_enabled: Optional[bool] = None
class DeletionStatus(BaseModel):
@@ -83,8 +102,8 @@ async def rooms_list(
user_id = user["sub"] if user else None
return await paginate(
database,
return await apaginate(
get_database(),
await rooms_controller.get_all(
user_id=user_id, order_by="-created_at", return_query=True
),
@@ -109,6 +128,9 @@ async def rooms_create(
recording_type=room.recording_type,
recording_trigger=room.recording_trigger,
is_shared=room.is_shared,
ics_url=room.ics_url,
ics_fetch_interval=room.ics_fetch_interval,
ics_enabled=room.ics_enabled,
)
@@ -150,7 +172,7 @@ async def rooms_create_meeting(
if not room:
raise HTTPException(status_code=404, detail="Room not found")
current_time = datetime.utcnow()
current_time = datetime.now(timezone.utc)
meeting = await meetings_controller.get_active(room=room, current_time=current_time)
if meeting is None:
@@ -166,8 +188,8 @@ async def rooms_create_meeting(
room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"],
host_room_url=whereby_meeting["hostRoomUrl"],
start_date=datetime.fromisoformat(whereby_meeting["startDate"]),
end_date=datetime.fromisoformat(whereby_meeting["endDate"]),
start_date=parse_datetime_with_timezone(whereby_meeting["startDate"]),
end_date=parse_datetime_with_timezone(whereby_meeting["endDate"]),
user_id=user_id,
room=room,
)
@@ -201,3 +223,217 @@ async def rooms_create_meeting(
meeting.host_room_url = ""
return meeting
class ICSStatus(BaseModel):
status: str
last_sync: Optional[datetime] = None
next_sync: Optional[datetime] = None
last_etag: Optional[str] = None
events_count: int = 0
class ICSSyncResult(BaseModel):
status: str
hash: Optional[str] = None
events_found: int = 0
events_created: int = 0
events_updated: int = 0
events_deleted: int = 0
error: Optional[str] = None
@router.post("/rooms/{room_name}/ics/sync", response_model=ICSSyncResult)
async def rooms_sync_ics(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
if user_id != room.user_id:
raise HTTPException(
status_code=403, detail="Only room owner can trigger ICS sync"
)
if not room.ics_enabled or not room.ics_url:
raise HTTPException(status_code=400, detail="ICS not configured for this room")
from reflector.services.ics_sync import ics_sync_service
result = await ics_sync_service.sync_room_calendar(room)
if result["status"] == "error":
raise HTTPException(
status_code=500, detail=result.get("error", "Unknown error")
)
return ICSSyncResult(**result)
@router.get("/rooms/{room_name}/ics/status", response_model=ICSStatus)
async def rooms_ics_status(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
if user_id != room.user_id:
raise HTTPException(
status_code=403, detail="Only room owner can view ICS status"
)
next_sync = None
if room.ics_enabled and room.ics_last_sync:
next_sync = room.ics_last_sync + timedelta(seconds=room.ics_fetch_interval)
from reflector.db.calendar_events import calendar_events_controller
events = await calendar_events_controller.get_by_room(
room.id, include_deleted=False
)
return ICSStatus(
status="enabled" if room.ics_enabled else "disabled",
last_sync=room.ics_last_sync,
next_sync=next_sync,
last_etag=room.ics_last_etag,
events_count=len(events),
)
class CalendarEventResponse(BaseModel):
id: str
room_id: str
ics_uid: str
title: Optional[str] = None
description: Optional[str] = None
start_time: datetime
end_time: datetime
attendees: Optional[list[dict]] = None
location: Optional[str] = None
last_synced: datetime
created_at: datetime
updated_at: datetime
@router.get("/rooms/{room_name}/meetings", response_model=list[CalendarEventResponse])
async def rooms_list_meetings(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
from reflector.db.calendar_events import calendar_events_controller
events = await calendar_events_controller.get_by_room(
room.id, include_deleted=False
)
if user_id != room.user_id:
for event in events:
event.description = None
event.attendees = None
return events
@router.get(
"/rooms/{room_name}/meetings/upcoming", response_model=list[CalendarEventResponse]
)
async def rooms_list_upcoming_meetings(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
minutes_ahead: int = 30,
):
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
from reflector.db.calendar_events import calendar_events_controller
events = await calendar_events_controller.get_upcoming(
room.id, minutes_ahead=minutes_ahead
)
if user_id != room.user_id:
for event in events:
event.description = None
event.attendees = None
return events
@router.get("/rooms/{room_name}/meetings/active", response_model=list[Meeting])
async def rooms_list_active_meetings(
room_name: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
"""List all active meetings for a room (supports multiple active meetings)"""
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
current_time = datetime.now(timezone.utc)
meetings = await meetings_controller.get_all_active_for_room(
room=room, current_time=current_time
)
# Hide host URLs from non-owners
if user_id != room.user_id:
for meeting in meetings:
meeting.host_room_url = ""
return meetings
@router.post("/rooms/{room_name}/meetings/{meeting_id}/join", response_model=Meeting)
async def rooms_join_meeting(
room_name: str,
meeting_id: str,
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
):
"""Join a specific meeting by ID"""
user_id = user["sub"] if user else None
room = await rooms_controller.get_by_name(room_name)
if not room:
raise HTTPException(status_code=404, detail="Room not found")
meeting = await meetings_controller.get_by_id(meeting_id)
if not meeting:
raise HTTPException(status_code=404, detail="Meeting not found")
if meeting.room_id != room.id:
raise HTTPException(
status_code=403, detail="Meeting does not belong to this room"
)
if not meeting.is_active:
raise HTTPException(status_code=400, detail="Meeting is not active")
current_time = datetime.now(timezone.utc)
if meeting.end_date <= current_time:
raise HTTPException(status_code=400, detail="Meeting has ended")
# Hide host URL from non-owners
if user_id != room.user_id:
meeting.host_room_url = ""
return meeting

View File

@@ -1,15 +1,29 @@
from datetime import datetime, timedelta, timezone
from typing import Annotated, Literal, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi_pagination import Page
from fastapi_pagination.ext.databases import paginate
from fastapi_pagination.ext.databases import apaginate
from jose import jwt
from pydantic import BaseModel, Field, field_serializer
import reflector.auth as auth
from reflector.db import get_database
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
from reflector.db.search import (
DEFAULT_SEARCH_LIMIT,
SearchLimit,
SearchLimitBase,
SearchOffset,
SearchOffsetBase,
SearchParameters,
SearchQuery,
SearchQueryBase,
SearchResult,
SearchTotal,
search_controller,
)
from reflector.db.transcripts import (
SourceKind,
TranscriptParticipant,
@@ -34,7 +48,7 @@ DOWNLOAD_EXPIRE_MINUTES = 60
def create_access_token(data: dict, expires_delta: timedelta):
to_encode = data.copy()
expire = datetime.utcnow() + expires_delta
expire = datetime.now(timezone.utc) + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
@@ -100,6 +114,21 @@ class DeletionStatus(BaseModel):
status: str
SearchQueryParam = Annotated[SearchQueryBase, Query(description="Search query text")]
SearchLimitParam = Annotated[SearchLimitBase, Query(description="Results per page")]
SearchOffsetParam = Annotated[
SearchOffsetBase, Query(description="Number of results to skip")
]
class SearchResponse(BaseModel):
results: list[SearchResult]
total: SearchTotal
query: SearchQuery
limit: SearchLimit
offset: SearchOffset
@router.get("/transcripts", response_model=Page[GetTranscriptMinimal])
async def transcripts_list(
user: Annotated[Optional[auth.UserInfo], Depends(auth.current_user_optional)],
@@ -107,15 +136,13 @@ async def transcripts_list(
room_id: str | None = None,
search_term: str | None = None,
):
from reflector.db import database
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
return await paginate(
database,
return await apaginate(
get_database(),
await transcripts_controller.get_all(
user_id=user_id,
source_kind=SourceKind(source_kind) if source_kind else None,
@@ -127,6 +154,39 @@ async def transcripts_list(
)
@router.get("/transcripts/search", response_model=SearchResponse)
async def transcripts_search(
q: SearchQueryParam,
limit: SearchLimitParam = DEFAULT_SEARCH_LIMIT,
offset: SearchOffsetParam = 0,
room_id: Optional[str] = None,
user: Annotated[
Optional[auth.UserInfo], Depends(auth.current_user_optional)
] = None,
):
"""
Full-text search across transcript titles and content.
"""
if not user and not settings.PUBLIC_MODE:
raise HTTPException(status_code=401, detail="Not authenticated")
user_id = user["sub"] if user else None
search_params = SearchParameters(
query_text=q, limit=limit, offset=offset, user_id=user_id, room_id=room_id
)
results, total = await search_controller.search_transcripts(search_params)
return SearchResponse(
results=results,
total=total,
query=search_params.query_text,
limit=search_params.limit,
offset=search_params.offset,
)
@router.post("/transcripts", response_model=GetTranscript)
async def transcripts_create(
info: CreateTranscript,
@@ -273,8 +333,8 @@ async def transcript_update(
if not transcript:
raise HTTPException(status_code=404, detail="Transcript not found")
values = info.dict(exclude_unset=True)
await transcripts_controller.update(transcript, values)
return transcript
updated_transcript = await transcripts_controller.update(transcript, values)
return updated_transcript
@router.delete("/transcripts/{transcript_id}", response_model=DeletionStatus)

View File

@@ -51,24 +51,6 @@ async def transcript_get_audio_mp3(
transcript_id, user_id=user_id
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()
headers = {}
copy_headers = ["range", "accept-encoding"]
for header in copy_headers:
if header in request.headers:
headers[header] = request.headers[header]
async with httpx.AsyncClient() as client:
resp = await client.request(request.method, url, headers=headers)
return Response(
content=resp.content,
status_code=resp.status_code,
headers=resp.headers,
)
if transcript.audio_location == "storage":
# proxy S3 file, to prevent issue with CORS
url = await transcript.get_audio_url()

View File

@@ -26,7 +26,7 @@ async def transcript_record_webrtc(
raise HTTPException(status_code=400, detail="Transcript is locked")
# create a pipeline runner
from reflector.pipelines.main_live_pipeline import PipelineMainLive
from reflector.pipelines.main_live_pipeline import PipelineMainLive # noqa: PLC0415
pipeline_runner = PipelineMainLive(transcript_id=transcript_id)

View File

@@ -68,8 +68,13 @@ async def whereby_webhook(event: WherebyWebhookEvent, request: Request):
raise HTTPException(status_code=404, detail="Meeting not found")
if event.type in ["room.client.joined", "room.client.left"]:
await meetings_controller.update_meeting(
meeting.id, num_clients=event.data["numClients"]
)
update_data = {"num_clients": event.data["numClients"]}
# Clear grace period if participant joined
if event.type == "room.client.joined" and event.data["numClients"] > 0:
if meeting.last_participant_left_at:
update_data["last_participant_left_at"] = None
await meetings_controller.update_meeting(meeting.id, **update_data)
return {"status": "ok"}

View File

@@ -23,7 +23,7 @@ async def create_meeting(room_name_prefix: str, end_date: datetime, room: Room):
"type": room.recording_type,
"destination": {
"provider": "s3",
"bucket": settings.AWS_WHEREBY_S3_BUCKET,
"bucket": settings.RECORDING_STORAGE_AWS_BUCKET_NAME,
"accessKeyId": settings.AWS_WHEREBY_ACCESS_KEY_ID,
"accessKeySecret": settings.AWS_WHEREBY_ACCESS_KEY_SECRET,
"fileFormat": "mp4",

View File

@@ -19,6 +19,7 @@ else:
"reflector.pipelines.main_live_pipeline",
"reflector.worker.healthcheck",
"reflector.worker.process",
"reflector.worker.ics_sync",
]
)
@@ -36,6 +37,14 @@ else:
"task": "reflector.worker.process.reprocess_failed_recordings",
"schedule": crontab(hour=5, minute=0), # Midnight EST
},
"sync_all_ics_calendars": {
"task": "reflector.worker.ics_sync.sync_all_ics_calendars",
"schedule": 60.0, # Run every minute to check which rooms need sync
},
"pre_create_upcoming_meetings": {
"task": "reflector.worker.ics_sync.pre_create_upcoming_meetings",
"schedule": 30.0, # Run every 30 seconds to pre-create meetings
},
}
if settings.HEALTHCHECK_URL:

View File

@@ -0,0 +1,209 @@
from datetime import datetime, timedelta, timezone
import structlog
from celery import shared_task
from celery.utils.log import get_task_logger
from reflector.db import get_database
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms, rooms_controller
from reflector.services.ics_sync import ics_sync_service
from reflector.whereby import create_meeting, upload_logo
logger = structlog.wrap_logger(get_task_logger(__name__))
@shared_task
def sync_room_ics(room_id: str):
asynctask(_sync_room_ics_async(room_id))
async def _sync_room_ics_async(room_id: str):
try:
room = await rooms_controller.get_by_id(room_id)
if not room:
logger.warning("Room not found for ICS sync", room_id=room_id)
return
if not room.ics_enabled or not room.ics_url:
logger.debug("ICS not enabled for room", room_id=room_id)
return
logger.info("Starting ICS sync for room", room_id=room_id, room_name=room.name)
result = await ics_sync_service.sync_room_calendar(room)
if result["status"] == "success":
logger.info(
"ICS sync completed successfully",
room_id=room_id,
events_found=result.get("events_found", 0),
events_created=result.get("events_created", 0),
events_updated=result.get("events_updated", 0),
events_deleted=result.get("events_deleted", 0),
)
elif result["status"] == "unchanged":
logger.debug("ICS content unchanged", room_id=room_id)
elif result["status"] == "error":
logger.error("ICS sync failed", room_id=room_id, error=result.get("error"))
else:
logger.debug(
"ICS sync skipped", room_id=room_id, reason=result.get("reason")
)
except Exception as e:
logger.error("Unexpected error during ICS sync", room_id=room_id, error=str(e))
@shared_task
def sync_all_ics_calendars():
asynctask(_sync_all_ics_calendars_async())
async def _sync_all_ics_calendars_async():
try:
logger.info("Starting sync for all ICS-enabled rooms")
# Get ALL rooms - not filtered by is_shared
query = rooms.select().where(
rooms.c.ics_enabled == True, rooms.c.ics_url != None
)
all_rooms = await get_database().fetch_all(query)
ics_enabled_rooms = list(all_rooms)
logger.info(f"Found {len(ics_enabled_rooms)} rooms with ICS enabled")
for room_data in ics_enabled_rooms:
room_id = room_data["id"]
room = await rooms_controller.get_by_id(room_id)
if not room:
continue
if not _should_sync(room):
logger.debug("Skipping room, not time to sync yet", room_id=room_id)
continue
sync_room_ics.delay(room_id)
logger.info("Queued sync tasks for all eligible rooms")
except Exception as e:
logger.error("Error in sync_all_ics_calendars", error=str(e))
def _should_sync(room) -> bool:
if not room.ics_last_sync:
return True
time_since_sync = datetime.now(timezone.utc) - room.ics_last_sync
return time_since_sync.total_seconds() >= room.ics_fetch_interval
@shared_task
def pre_create_upcoming_meetings():
asynctask(_pre_create_upcoming_meetings_async())
async def _pre_create_upcoming_meetings_async():
try:
logger.info("Starting pre-creation of upcoming meetings")
from reflector.db.calendar_events import calendar_events_controller
# Get ALL rooms with ICS enabled
query = rooms.select().where(
rooms.c.ics_enabled == True, rooms.c.ics_url != None
)
all_rooms = await get_database().fetch_all(query)
now = datetime.now(timezone.utc)
pre_create_window = now + timedelta(minutes=1)
for room_data in all_rooms:
room_id = room_data["id"]
room = await rooms_controller.get_by_id(room_id)
if not room:
continue
events = await calendar_events_controller.get_upcoming(
room_id, minutes_ahead=2
)
for event in events:
if event.start_time <= pre_create_window:
existing_meeting = await meetings_controller.get_by_calendar_event(
event.id
)
if not existing_meeting:
logger.info(
"Pre-creating meeting for calendar event",
room_id=room_id,
event_id=event.id,
event_title=event.title,
)
try:
end_date = event.end_time or (
event.start_time + timedelta(hours=1)
)
whereby_meeting = await create_meeting(
event.title or "Scheduled Meeting",
end_date=end_date,
room=room,
)
await upload_logo(
whereby_meeting["roomName"], "./images/logo.png"
)
meeting = await meetings_controller.create(
id=whereby_meeting["meetingId"],
room_name=whereby_meeting["roomName"],
room_url=whereby_meeting["roomUrl"],
host_room_url=whereby_meeting["hostRoomUrl"],
start_date=datetime.fromisoformat(
whereby_meeting["startDate"]
),
end_date=datetime.fromisoformat(
whereby_meeting["endDate"]
),
user_id=room.user_id,
room=room,
calendar_event_id=event.id,
calendar_metadata={
"title": event.title,
"description": event.description,
"attendees": event.attendees,
},
)
logger.info(
"Meeting pre-created successfully",
meeting_id=meeting.id,
event_id=event.id,
)
except Exception as e:
logger.error(
"Failed to pre-create meeting",
room_id=room_id,
event_id=event.id,
error=str(e),
)
logger.info("Completed pre-creation check for upcoming meetings")
except Exception as e:
logger.error("Error in pre_create_upcoming_meetings", error=str(e))
def asynctask(coro):
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro)
finally:
loop.close()

View File

@@ -1,6 +1,6 @@
import json
import os
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from urllib.parse import unquote
import av
@@ -21,6 +21,14 @@ from reflector.whereby import get_room_sessions
logger = structlog.wrap_logger(get_task_logger(__name__))
def parse_datetime_with_timezone(iso_string: str) -> datetime:
"""Parse ISO datetime string and ensure timezone awareness (defaults to UTC if naive)."""
dt = datetime.fromisoformat(iso_string)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt
@shared_task
def process_messages():
queue_url = settings.AWS_PROCESS_RECORDING_QUEUE_URL
@@ -69,7 +77,7 @@ async def process_recording(bucket_name: str, object_key: str):
# extract a guid and a datetime from the object key
room_name = f"/{object_key[:36]}"
recorded_at = datetime.fromisoformat(object_key[37:57])
recorded_at = parse_datetime_with_timezone(object_key[37:57])
meeting = await meetings_controller.get_by_room_name(room_name)
room = await rooms_controller.get_by_id(meeting.room_id)
@@ -138,24 +146,76 @@ async def process_recording(bucket_name: str, object_key: str):
@shared_task
@asynctask
async def process_meetings():
"""
Checks which meetings are still active and deactivates those that have ended.
Supports multiple active meetings per room and grace period logic.
"""
logger.info("Processing meetings")
meetings = await meetings_controller.get_all_active()
current_time = datetime.now(timezone.utc)
for meeting in meetings:
is_active = False
should_deactivate = False
end_date = meeting.end_date
if end_date.tzinfo is None:
end_date = end_date.replace(tzinfo=timezone.utc)
if end_date > datetime.now(timezone.utc):
# Check if meeting has passed its scheduled end time
if end_date <= current_time:
# For calendar meetings, force close 30 minutes after scheduled end
if meeting.calendar_event_id:
if current_time > end_date + timedelta(minutes=30):
should_deactivate = True
logger.info(
"Meeting %s forced closed 30 min after calendar end", meeting.id
)
else:
# Unscheduled meetings follow normal closure rules
should_deactivate = True
# Check Whereby room sessions only if not already deactivating
if not should_deactivate and end_date > current_time:
response = await get_room_sessions(meeting.room_name)
room_sessions = response.get("results", [])
is_active = not room_sessions or any(
has_active_sessions = room_sessions and any(
rs["endedAt"] is None for rs in room_sessions
)
if not is_active:
if not has_active_sessions:
# No active sessions - check grace period
if meeting.num_clients == 0:
if meeting.last_participant_left_at:
# Check if grace period has expired
grace_period = timedelta(minutes=meeting.grace_period_minutes)
if (
current_time
> meeting.last_participant_left_at + grace_period
):
should_deactivate = True
logger.info("Meeting %s grace period expired", meeting.id)
else:
# First time all participants left, record the time
await meetings_controller.update_meeting(
meeting.id, last_participant_left_at=current_time
)
logger.info(
"Meeting %s marked empty at %s", meeting.id, current_time
)
else:
# Has active sessions - clear grace period if set
if meeting.last_participant_left_at:
await meetings_controller.update_meeting(
meeting.id, last_participant_left_at=None
)
logger.info(
"Meeting %s reactivated - participant rejoined", meeting.id
)
if should_deactivate:
await meetings_controller.update_meeting(meeting.id, is_active=False)
logger.info("Meeting %s is deactivated", meeting.id)
logger.info("Processed meetings")
logger.info("Processed %d meetings", len(meetings))
@shared_task
@@ -177,7 +237,7 @@ async def reprocess_failed_recordings():
reprocessed_count = 0
try:
paginator = s3.get_paginator("list_objects_v2")
bucket_name = settings.AWS_WHEREBY_S3_BUCKET
bucket_name = settings.RECORDING_STORAGE_AWS_BUCKET_NAME
pages = paginator.paginate(Bucket=bucket_name)
for page in pages:

View File

@@ -62,6 +62,7 @@ class RedisPubSubManager:
class WebsocketManager:
def __init__(self, pubsub_client: RedisPubSubManager = None):
self.rooms: dict = {}
self.tasks: dict = {}
self.pubsub_client = pubsub_client
async def add_user_to_room(self, room_id: str, websocket: WebSocket) -> None:
@@ -74,13 +75,17 @@ class WebsocketManager:
await self.pubsub_client.connect()
pubsub_subscriber = await self.pubsub_client.subscribe(room_id)
asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
task = asyncio.create_task(self._pubsub_data_reader(pubsub_subscriber))
self.tasks[id(websocket)] = task
async def send_json(self, room_id: str, message: dict) -> None:
await self.pubsub_client.send_json(room_id, message)
async def remove_user_from_room(self, room_id: str, websocket: WebSocket) -> None:
self.rooms[room_id].remove(websocket)
task = self.tasks.pop(id(websocket), None)
if task:
task.cancel()
if len(self.rooms[room_id]) == 0:
del self.rooms[room_id]

View File

@@ -1,21 +1,63 @@
import os
from tempfile import NamedTemporaryFile
from unittest.mock import patch
import pytest
# Pytest-docker configuration
@pytest.fixture(scope="session")
def docker_compose_file(pytestconfig):
return os.path.join(str(pytestconfig.rootdir), "tests", "docker-compose.test.yml")
@pytest.fixture(scope="session")
def postgres_service(docker_ip, docker_services):
"""Ensure that PostgreSQL service is up and responsive."""
port = docker_services.port_for("postgres_test", 5432)
def is_responsive():
try:
import psycopg2
conn = psycopg2.connect(
host=docker_ip,
port=port,
dbname="reflector_test",
user="test_user",
password="test_password",
)
conn.close()
return True
except Exception:
return False
docker_services.wait_until_responsive(timeout=30.0, pause=0.1, check=is_responsive)
# Return connection parameters
return {
"host": docker_ip,
"port": port,
"dbname": "reflector_test",
"user": "test_user",
"password": "test_password",
}
@pytest.fixture(scope="function", autouse=True)
@pytest.mark.asyncio
async def setup_database():
from reflector.settings import settings
async def setup_database(postgres_service):
from reflector.db import engine, metadata, get_database # noqa
with NamedTemporaryFile() as f:
settings.DATABASE_URL = f"sqlite:///{f.name}"
from reflector.db import engine, metadata
metadata.create_all(bind=engine)
metadata.drop_all(bind=engine)
metadata.create_all(bind=engine)
database = get_database()
try:
await database.connect()
yield
finally:
await database.disconnect()
@pytest.fixture
@@ -33,17 +75,16 @@ def dummy_processors():
patch(
"reflector.processors.transcript_final_summary.TranscriptFinalSummaryProcessor.get_short_summary"
) as mock_short_summary,
patch(
"reflector.processors.transcript_translator.TranscriptTranslatorProcessor.get_translation"
) as mock_translate,
):
mock_topic.return_value = {"title": "LLM TITLE", "summary": "LLM SUMMARY"}
mock_title.return_value = {"title": "LLM TITLE"}
from reflector.processors.transcript_topic_detector import TopicResponse
mock_topic.return_value = TopicResponse(
title="LLM TITLE", summary="LLM SUMMARY"
)
mock_title.return_value = "LLM Title"
mock_long_summary.return_value = "LLM LONG SUMMARY"
mock_short_summary.return_value = "LLM SHORT SUMMARY"
mock_translate.return_value = "Bonjour le monde"
yield (
mock_translate,
mock_topic,
mock_title,
mock_long_summary,
@@ -51,6 +92,20 @@ def dummy_processors():
) # noqa
@pytest.fixture
async def whisper_transcript():
from reflector.processors.audio_transcript_whisper import (
AudioTranscriptWhisperProcessor,
)
with patch(
"reflector.processors.audio_transcript_auto"
".AudioTranscriptAutoProcessor.__new__"
) as mock_audio:
mock_audio.return_value = AudioTranscriptWhisperProcessor()
yield
@pytest.fixture
async def dummy_transcript():
from reflector.processors.audio_transcript import AudioTranscriptProcessor
@@ -101,16 +156,38 @@ async def dummy_diarization():
yield
@pytest.fixture
async def dummy_transcript_translator():
from reflector.processors.transcript_translator import TranscriptTranslatorProcessor
class TestTranscriptTranslatorProcessor(TranscriptTranslatorProcessor):
async def _translate(self, text: str) -> str:
source_language = self.get_pref("audio:source_language", "en")
target_language = self.get_pref("audio:target_language", "en")
return f"{source_language}:{target_language}:{text}"
def mock_new(cls, *args, **kwargs):
return TestTranscriptTranslatorProcessor(*args, **kwargs)
with patch(
"reflector.processors.transcript_translator_auto"
".TranscriptTranslatorAutoProcessor.__new__",
mock_new,
):
yield
@pytest.fixture
async def dummy_llm():
from reflector.llm.base import LLM
from reflector.llm import LLM
class TestLLM(LLM):
def __init__(self):
self.model_name = "DUMMY MODEL"
self.llm_tokenizer = "DUMMY TOKENIZER"
with patch("reflector.llm.base.LLM.get_instance") as mock_llm:
# LLM doesn't have get_instance anymore, mocking constructor instead
with patch("reflector.llm.LLM") as mock_llm:
mock_llm.return_value = TestLLM()
yield
@@ -129,22 +206,19 @@ async def dummy_storage():
async def _get_file_url(self, *args, **kwargs):
return "http://fake_server/audio.mp3"
with patch("reflector.storage.base.Storage.get_instance") as mock_storage:
mock_storage.return_value = DummyStorage()
yield
async def _get_file(self, *args, **kwargs):
from pathlib import Path
test_mp3 = Path(__file__).parent / "records" / "test_mathieu_hello.mp3"
return test_mp3.read_bytes()
@pytest.fixture
def nltk():
with patch("reflector.llm.base.LLM.ensure_nltk") as mock_nltk:
mock_nltk.return_value = "NLTK PACKAGE"
yield
@pytest.fixture
def ensure_casing():
with patch("reflector.llm.base.LLM.ensure_casing") as mock_casing:
mock_casing.return_value = "LLM TITLE"
dummy = DummyStorage()
with (
patch("reflector.storage.base.Storage.get_instance") as mock_storage,
patch("reflector.storage.get_transcripts_storage") as mock_get_transcripts,
):
mock_storage.return_value = dummy
mock_get_transcripts.return_value = dummy
yield
@@ -167,6 +241,16 @@ def celery_includes():
return ["reflector.pipelines.main_live_pipeline"]
@pytest.fixture
async def client():
from httpx import AsyncClient
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
yield ac
@pytest.fixture(scope="session")
def fake_mp3_upload():
with patch(
@@ -177,13 +261,10 @@ def fake_mp3_upload():
@pytest.fixture
async def fake_transcript_with_topics(tmpdir):
async def fake_transcript_with_topics(tmpdir, client):
import shutil
from pathlib import Path
from httpx import AsyncClient
from reflector.app import app
from reflector.db.transcripts import TranscriptTopic
from reflector.processors.types import Word
from reflector.settings import settings
@@ -192,8 +273,7 @@ async def fake_transcript_with_topics(tmpdir):
settings.DATA_DIR = Path(tmpdir)
# create a transcript
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.post("/transcripts", json={"name": "Test audio download"})
response = await client.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200
tid = response.json()["id"]

View File

@@ -0,0 +1,13 @@
version: '3.8'
services:
postgres_test:
image: postgres:15
environment:
POSTGRES_DB: reflector_test
POSTGRES_USER: test_user
POSTGRES_PASSWORD: test_password
ports:
- "15432:5432"
command: postgres -c fsync=off -c synchronous_commit=off -c full_page_writes=off
tmpfs:
- /var/lib/postgresql/data:rw,noexec,nosuid,size=1g

View File

@@ -0,0 +1,351 @@
"""
Tests for CalendarEvent model.
"""
from datetime import datetime, timedelta, timezone
import pytest
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio
async def test_calendar_event_create():
"""Test creating a calendar event."""
# Create a room first
room = await rooms_controller.add(
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
# Create calendar event
now = datetime.now(timezone.utc)
event = CalendarEvent(
room_id=room.id,
ics_uid="test-event-123",
title="Team Meeting",
description="Weekly team sync",
start_time=now + timedelta(hours=1),
end_time=now + timedelta(hours=2),
location=f"https://example.com/room/{room.name}",
attendees=[
{"email": "alice@example.com", "name": "Alice", "status": "ACCEPTED"},
{"email": "bob@example.com", "name": "Bob", "status": "TENTATIVE"},
],
)
# Save event
saved_event = await calendar_events_controller.upsert(event)
assert saved_event.ics_uid == "test-event-123"
assert saved_event.title == "Team Meeting"
assert saved_event.room_id == room.id
assert len(saved_event.attendees) == 2
@pytest.mark.asyncio
async def test_calendar_event_get_by_room():
"""Test getting calendar events for a room."""
# Create room
room = await rooms_controller.add(
name="events-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
now = datetime.now(timezone.utc)
# Create multiple events
for i in range(3):
event = CalendarEvent(
room_id=room.id,
ics_uid=f"event-{i}",
title=f"Meeting {i}",
start_time=now + timedelta(hours=i),
end_time=now + timedelta(hours=i + 1),
)
await calendar_events_controller.upsert(event)
# Get events for room
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 3
assert all(e.room_id == room.id for e in events)
assert events[0].title == "Meeting 0"
assert events[1].title == "Meeting 1"
assert events[2].title == "Meeting 2"
@pytest.mark.asyncio
async def test_calendar_event_get_upcoming():
"""Test getting upcoming events within time window."""
# Create room
room = await rooms_controller.add(
name="upcoming-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
now = datetime.now(timezone.utc)
# Create events at different times
# Past event (should not be included)
past_event = CalendarEvent(
room_id=room.id,
ics_uid="past-event",
title="Past Meeting",
start_time=now - timedelta(hours=2),
end_time=now - timedelta(hours=1),
)
await calendar_events_controller.upsert(past_event)
# Upcoming event within 30 minutes
upcoming_event = CalendarEvent(
room_id=room.id,
ics_uid="upcoming-event",
title="Upcoming Meeting",
start_time=now + timedelta(minutes=15),
end_time=now + timedelta(minutes=45),
)
await calendar_events_controller.upsert(upcoming_event)
# Future event beyond 30 minutes
future_event = CalendarEvent(
room_id=room.id,
ics_uid="future-event",
title="Future Meeting",
start_time=now + timedelta(hours=2),
end_time=now + timedelta(hours=3),
)
await calendar_events_controller.upsert(future_event)
# Get upcoming events (default 30 minutes)
upcoming = await calendar_events_controller.get_upcoming(room.id)
assert len(upcoming) == 1
assert upcoming[0].ics_uid == "upcoming-event"
# Get upcoming with custom window
upcoming_extended = await calendar_events_controller.get_upcoming(
room.id, minutes_ahead=180
)
assert len(upcoming_extended) == 2
assert upcoming_extended[0].ics_uid == "upcoming-event"
assert upcoming_extended[1].ics_uid == "future-event"
@pytest.mark.asyncio
async def test_calendar_event_upsert():
"""Test upserting (create/update) calendar events."""
# Create room
room = await rooms_controller.add(
name="upsert-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
now = datetime.now(timezone.utc)
# Create new event
event = CalendarEvent(
room_id=room.id,
ics_uid="upsert-test",
title="Original Title",
start_time=now,
end_time=now + timedelta(hours=1),
)
created = await calendar_events_controller.upsert(event)
assert created.title == "Original Title"
# Update existing event
event.title = "Updated Title"
event.description = "Added description"
updated = await calendar_events_controller.upsert(event)
assert updated.title == "Updated Title"
assert updated.description == "Added description"
assert updated.ics_uid == "upsert-test"
# Verify only one event exists
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1
assert events[0].title == "Updated Title"
@pytest.mark.asyncio
async def test_calendar_event_soft_delete():
"""Test soft deleting events no longer in calendar."""
# Create room
room = await rooms_controller.add(
name="delete-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
now = datetime.now(timezone.utc)
# Create multiple events
for i in range(4):
event = CalendarEvent(
room_id=room.id,
ics_uid=f"event-{i}",
title=f"Meeting {i}",
start_time=now + timedelta(hours=i),
end_time=now + timedelta(hours=i + 1),
)
await calendar_events_controller.upsert(event)
# Soft delete events not in current list
current_ids = ["event-0", "event-2"] # Keep events 0 and 2
deleted_count = await calendar_events_controller.soft_delete_missing(
room.id, current_ids
)
assert deleted_count == 2 # Should delete events 1 and 3
# Get non-deleted events
events = await calendar_events_controller.get_by_room(
room.id, include_deleted=False
)
assert len(events) == 2
assert {e.ics_uid for e in events} == {"event-0", "event-2"}
# Get all events including deleted
all_events = await calendar_events_controller.get_by_room(
room.id, include_deleted=True
)
assert len(all_events) == 4
@pytest.mark.asyncio
async def test_calendar_event_past_events_not_deleted():
"""Test that past events are not soft deleted."""
# Create room
room = await rooms_controller.add(
name="past-events-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
now = datetime.now(timezone.utc)
# Create past event
past_event = CalendarEvent(
room_id=room.id,
ics_uid="past-event",
title="Past Meeting",
start_time=now - timedelta(hours=2),
end_time=now - timedelta(hours=1),
)
await calendar_events_controller.upsert(past_event)
# Create future event
future_event = CalendarEvent(
room_id=room.id,
ics_uid="future-event",
title="Future Meeting",
start_time=now + timedelta(hours=1),
end_time=now + timedelta(hours=2),
)
await calendar_events_controller.upsert(future_event)
# Try to soft delete all events (only future should be deleted)
deleted_count = await calendar_events_controller.soft_delete_missing(room.id, [])
assert deleted_count == 1 # Only future event deleted
# Verify past event still exists
events = await calendar_events_controller.get_by_room(
room.id, include_deleted=False
)
assert len(events) == 1
assert events[0].ics_uid == "past-event"
@pytest.mark.asyncio
async def test_calendar_event_with_raw_ics_data():
"""Test storing raw ICS data with calendar event."""
# Create room
room = await rooms_controller.add(
name="raw-ics-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
raw_ics = """BEGIN:VEVENT
UID:test-raw-123
SUMMARY:Test Event
DTSTART:20240101T100000Z
DTEND:20240101T110000Z
END:VEVENT"""
event = CalendarEvent(
room_id=room.id,
ics_uid="test-raw-123",
title="Test Event",
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc) + timedelta(hours=1),
ics_raw_data=raw_ics,
)
saved = await calendar_events_controller.upsert(event)
assert saved.ics_raw_data == raw_ics
# Retrieve and verify
retrieved = await calendar_events_controller.get_by_ics_uid(room.id, "test-raw-123")
assert retrieved is not None
assert retrieved.ics_raw_data == raw_ics

View File

@@ -0,0 +1,230 @@
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from icalendar import Calendar, Event
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.rooms import rooms_controller
from reflector.worker.ics_sync import (
_should_sync,
_sync_all_ics_calendars_async,
_sync_room_ics_async,
)
@pytest.mark.asyncio
async def test_sync_room_ics_task():
room = await rooms_controller.add(
name="task-test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/task.ics",
ics_enabled=True,
)
cal = Calendar()
event = Event()
event.add("uid", "task-event-1")
event.add("summary", "Task Test Meeting")
from reflector.settings import settings
event.add("location", f"{settings.BASE_URL}/room/{room.name}")
now = datetime.now(timezone.utc)
event.add("dtstart", now + timedelta(hours=1))
event.add("dtend", now + timedelta(hours=2))
cal.add_component(event)
ics_content = cal.to_ical().decode("utf-8")
with patch(
"reflector.services.ics_sync.ICSFetchService.fetch_ics", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.return_value = ics_content
await _sync_room_ics_async(room.id)
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1
assert events[0].ics_uid == "task-event-1"
@pytest.mark.asyncio
async def test_sync_room_ics_disabled():
room = await rooms_controller.add(
name="disabled-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_enabled=False,
)
await _sync_room_ics_async(room.id)
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 0
@pytest.mark.asyncio
async def test_sync_all_ics_calendars():
room1 = await rooms_controller.add(
name="sync-all-1",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/1.ics",
ics_enabled=True,
)
room2 = await rooms_controller.add(
name="sync-all-2",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/2.ics",
ics_enabled=True,
)
room3 = await rooms_controller.add(
name="sync-all-3",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_enabled=False,
)
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
await _sync_all_ics_calendars_async()
assert mock_delay.call_count == 2
called_room_ids = [call.args[0] for call in mock_delay.call_args_list]
assert room1.id in called_room_ids
assert room2.id in called_room_ids
assert room3.id not in called_room_ids
@pytest.mark.asyncio
async def test_should_sync_logic():
room = MagicMock()
room.ics_last_sync = None
assert _should_sync(room) is True
room.ics_last_sync = datetime.now(timezone.utc) - timedelta(seconds=100)
room.ics_fetch_interval = 300
assert _should_sync(room) is False
room.ics_last_sync = datetime.now(timezone.utc) - timedelta(seconds=400)
room.ics_fetch_interval = 300
assert _should_sync(room) is True
@pytest.mark.asyncio
async def test_sync_respects_fetch_interval():
now = datetime.now(timezone.utc)
room1 = await rooms_controller.add(
name="interval-test-1",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/interval.ics",
ics_enabled=True,
ics_fetch_interval=300,
)
await rooms_controller.update(
room1,
{"ics_last_sync": now - timedelta(seconds=100)},
)
room2 = await rooms_controller.add(
name="interval-test-2",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/interval2.ics",
ics_enabled=True,
ics_fetch_interval=60,
)
await rooms_controller.update(
room2,
{"ics_last_sync": now - timedelta(seconds=100)},
)
with patch("reflector.worker.ics_sync.sync_room_ics.delay") as mock_delay:
await _sync_all_ics_calendars_async()
assert mock_delay.call_count == 1
assert mock_delay.call_args[0][0] == room2.id
@pytest.mark.asyncio
async def test_sync_handles_errors_gracefully():
room = await rooms_controller.add(
name="error-task-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/error.ics",
ics_enabled=True,
)
with patch(
"reflector.services.ics_sync.ICSFetchService.fetch_ics", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.side_effect = Exception("Network error")
await _sync_room_ics_async(room.id)
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 0

View File

@@ -0,0 +1,289 @@
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from icalendar import Calendar, Event
from reflector.db.calendar_events import calendar_events_controller
from reflector.db.rooms import rooms_controller
from reflector.services.ics_sync import ICSFetchService, ICSSyncService
@pytest.mark.asyncio
async def test_ics_fetch_service_event_matching():
service = ICSFetchService()
room_name = "test-room"
room_url = "https://example.com/room/test-room"
# Create test event
event = Event()
event.add("uid", "test-123")
event.add("summary", "Test Meeting")
# Test matching with full URL in location
event.add("location", "https://example.com/room/test-room")
assert service._event_matches_room(event, room_name, room_url) is True
# Test matching with URL without protocol
event["location"] = "example.com/room/test-room"
assert service._event_matches_room(event, room_name, room_url) is True
# Test matching in description
event["location"] = "Conference Room A"
event.add("description", f"Join at {room_url}")
assert service._event_matches_room(event, room_name, room_url) is True
# Test non-matching
event["location"] = "Different Room"
event["description"] = "No room URL here"
assert service._event_matches_room(event, room_name, room_url) is False
# Test partial paths should NOT match anymore
event["location"] = "/room/test-room"
assert service._event_matches_room(event, room_name, room_url) is False
event["location"] = f"Room: {room_name}"
assert service._event_matches_room(event, room_name, room_url) is False
@pytest.mark.asyncio
async def test_ics_fetch_service_parse_event():
service = ICSFetchService()
# Create test event
event = Event()
event.add("uid", "test-456")
event.add("summary", "Team Standup")
event.add("description", "Daily team sync")
event.add("location", "https://example.com/room/standup")
now = datetime.now(timezone.utc)
event.add("dtstart", now)
event.add("dtend", now + timedelta(hours=1))
# Add attendees
event.add("attendee", "mailto:alice@example.com", parameters={"CN": "Alice"})
event.add("attendee", "mailto:bob@example.com", parameters={"CN": "Bob"})
event.add("organizer", "mailto:carol@example.com", parameters={"CN": "Carol"})
# Parse event
result = service._parse_event(event)
assert result is not None
assert result["ics_uid"] == "test-456"
assert result["title"] == "Team Standup"
assert result["description"] == "Daily team sync"
assert result["location"] == "https://example.com/room/standup"
assert len(result["attendees"]) == 3 # 2 attendees + 1 organizer
@pytest.mark.asyncio
async def test_ics_fetch_service_extract_room_events():
service = ICSFetchService()
room_name = "meeting"
room_url = "https://example.com/room/meeting"
# Create calendar with multiple events
cal = Calendar()
# Event 1: Matches room
event1 = Event()
event1.add("uid", "match-1")
event1.add("summary", "Planning Meeting")
event1.add("location", room_url)
now = datetime.now(timezone.utc)
event1.add("dtstart", now + timedelta(hours=2))
event1.add("dtend", now + timedelta(hours=3))
cal.add_component(event1)
# Event 2: Doesn't match room
event2 = Event()
event2.add("uid", "no-match")
event2.add("summary", "Other Meeting")
event2.add("location", "https://example.com/room/other")
event2.add("dtstart", now + timedelta(hours=4))
event2.add("dtend", now + timedelta(hours=5))
cal.add_component(event2)
# Event 3: Matches room in description
event3 = Event()
event3.add("uid", "match-2")
event3.add("summary", "Review Session")
event3.add("description", f"Meeting link: {room_url}")
event3.add("dtstart", now + timedelta(hours=6))
event3.add("dtend", now + timedelta(hours=7))
cal.add_component(event3)
# Event 4: Cancelled event (should be skipped)
event4 = Event()
event4.add("uid", "cancelled")
event4.add("summary", "Cancelled Meeting")
event4.add("location", room_url)
event4.add("status", "CANCELLED")
event4.add("dtstart", now + timedelta(hours=8))
event4.add("dtend", now + timedelta(hours=9))
cal.add_component(event4)
# Extract events
events = service.extract_room_events(cal, room_name, room_url)
assert len(events) == 2
assert events[0]["ics_uid"] == "match-1"
assert events[1]["ics_uid"] == "match-2"
@pytest.mark.asyncio
async def test_ics_sync_service_sync_room_calendar():
# Create room
room = await rooms_controller.add(
name="sync-test",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/test.ics",
ics_enabled=True,
)
# Mock ICS content
cal = Calendar()
event = Event()
event.add("uid", "sync-event-1")
event.add("summary", "Sync Test Meeting")
# Use the actual BASE_URL from settings
from reflector.settings import settings
event.add("location", f"{settings.BASE_URL}/room/{room.name}")
now = datetime.now(timezone.utc)
event.add("dtstart", now + timedelta(hours=1))
event.add("dtend", now + timedelta(hours=2))
cal.add_component(event)
ics_content = cal.to_ical().decode("utf-8")
# Create sync service and mock fetch
sync_service = ICSSyncService()
with patch.object(
sync_service.fetch_service, "fetch_ics", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.return_value = ics_content
# First sync
result = await sync_service.sync_room_calendar(room)
assert result["status"] == "success"
assert result["events_found"] == 1
assert result["events_created"] == 1
assert result["events_updated"] == 0
assert result["events_deleted"] == 0
# Verify event was created
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1
assert events[0].ics_uid == "sync-event-1"
assert events[0].title == "Sync Test Meeting"
# Second sync with same content (should be unchanged)
# Refresh room to get updated etag and force sync by setting old sync time
room = await rooms_controller.get_by_id(room.id)
await rooms_controller.update(
room, {"ics_last_sync": datetime.now(timezone.utc) - timedelta(minutes=10)}
)
result = await sync_service.sync_room_calendar(room)
assert result["status"] == "unchanged"
# Third sync with updated event
event["summary"] = "Updated Meeting Title"
cal = Calendar()
cal.add_component(event)
ics_content = cal.to_ical().decode("utf-8")
mock_fetch.return_value = ics_content
# Force sync by clearing etag
await rooms_controller.update(room, {"ics_last_etag": None})
result = await sync_service.sync_room_calendar(room)
assert result["status"] == "success"
assert result["events_created"] == 0
assert result["events_updated"] == 1
# Verify event was updated
events = await calendar_events_controller.get_by_room(room.id)
assert len(events) == 1
assert events[0].title == "Updated Meeting Title"
@pytest.mark.asyncio
async def test_ics_sync_service_should_sync():
service = ICSSyncService()
# Room never synced
room = MagicMock()
room.ics_last_sync = None
room.ics_fetch_interval = 300
assert service._should_sync(room) is True
# Room synced recently
room.ics_last_sync = datetime.now(timezone.utc) - timedelta(seconds=100)
assert service._should_sync(room) is False
# Room sync due
room.ics_last_sync = datetime.now(timezone.utc) - timedelta(seconds=400)
assert service._should_sync(room) is True
@pytest.mark.asyncio
async def test_ics_sync_service_skip_disabled():
service = ICSSyncService()
# Room with ICS disabled
room = MagicMock()
room.ics_enabled = False
room.ics_url = "https://calendar.example.com/test.ics"
result = await service.sync_room_calendar(room)
assert result["status"] == "skipped"
assert result["reason"] == "ICS not configured"
# Room without URL
room.ics_enabled = True
room.ics_url = None
result = await service.sync_room_calendar(room)
assert result["status"] == "skipped"
assert result["reason"] == "ICS not configured"
@pytest.mark.asyncio
async def test_ics_sync_service_error_handling():
# Create room
room = await rooms_controller.add(
name="error-test",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/error.ics",
ics_enabled=True,
)
sync_service = ICSSyncService()
with patch.object(
sync_service.fetch_service, "fetch_ics", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.side_effect = Exception("Network error")
result = await sync_service.sync_room_calendar(room)
assert result["status"] == "error"
assert "Network error" in result["error"]

View File

@@ -0,0 +1,283 @@
"""Tests for multiple active meetings per room functionality."""
from datetime import datetime, timedelta, timezone
import pytest
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
from reflector.db.meetings import meetings_controller
from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio
async def test_multiple_active_meetings_per_room():
"""Test that multiple active meetings can exist for the same room."""
# Create a room
room = await rooms_controller.add(
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
current_time = datetime.now(timezone.utc)
end_time = current_time + timedelta(hours=2)
# Create first meeting
meeting1 = await meetings_controller.create(
id="meeting-1",
room_name="test-meeting-1",
room_url="https://whereby.com/test-1",
host_room_url="https://whereby.com/test-1-host",
start_date=current_time,
end_date=end_time,
user_id="test-user",
room=room,
)
# Create second meeting for the same room (should succeed now)
meeting2 = await meetings_controller.create(
id="meeting-2",
room_name="test-meeting-2",
room_url="https://whereby.com/test-2",
host_room_url="https://whereby.com/test-2-host",
start_date=current_time,
end_date=end_time,
user_id="test-user",
room=room,
)
# Both meetings should be active
active_meetings = await meetings_controller.get_all_active_for_room(
room=room, current_time=current_time
)
assert len(active_meetings) == 2
assert meeting1.id in [m.id for m in active_meetings]
assert meeting2.id in [m.id for m in active_meetings]
@pytest.mark.asyncio
async def test_get_active_by_calendar_event():
"""Test getting active meeting by calendar event ID."""
# Create a room
room = await rooms_controller.add(
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
# Create a calendar event
event = CalendarEvent(
room_id=room.id,
ics_uid="test-event-uid",
title="Test Meeting",
start_time=datetime.now(timezone.utc),
end_time=datetime.now(timezone.utc) + timedelta(hours=1),
)
event = await calendar_events_controller.upsert(event)
current_time = datetime.now(timezone.utc)
end_time = current_time + timedelta(hours=2)
# Create meeting linked to calendar event
meeting = await meetings_controller.create(
id="meeting-cal-1",
room_name="test-meeting-cal",
room_url="https://whereby.com/test-cal",
host_room_url="https://whereby.com/test-cal-host",
start_date=current_time,
end_date=end_time,
user_id="test-user",
room=room,
calendar_event_id=event.id,
calendar_metadata={"title": event.title},
)
# Should find the meeting by calendar event
found_meeting = await meetings_controller.get_active_by_calendar_event(
room=room, calendar_event_id=event.id, current_time=current_time
)
assert found_meeting is not None
assert found_meeting.id == meeting.id
assert found_meeting.calendar_event_id == event.id
@pytest.mark.asyncio
async def test_grace_period_logic():
"""Test that meetings have a grace period after last participant leaves."""
# Create a room
room = await rooms_controller.add(
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
current_time = datetime.now(timezone.utc)
end_time = current_time + timedelta(hours=2)
# Create meeting
meeting = await meetings_controller.create(
id="meeting-grace",
room_name="test-meeting-grace",
room_url="https://whereby.com/test-grace",
host_room_url="https://whereby.com/test-grace-host",
start_date=current_time,
end_date=end_time,
user_id="test-user",
room=room,
)
# Test grace period logic by simulating different states
# Simulate first time all participants left
await meetings_controller.update_meeting(
meeting.id, num_clients=0, last_participant_left_at=current_time
)
# Within grace period (10 min) - should still be active
await meetings_controller.update_meeting(
meeting.id, last_participant_left_at=current_time - timedelta(minutes=10)
)
updated_meeting = await meetings_controller.get_by_id(meeting.id)
assert updated_meeting.is_active is True # Still active during grace period
# Simulate grace period expired (20 min) and deactivate
await meetings_controller.update_meeting(
meeting.id, last_participant_left_at=current_time - timedelta(minutes=20)
)
# Manually test the grace period logic that would be in process_meetings
updated_meeting = await meetings_controller.get_by_id(meeting.id)
if updated_meeting.last_participant_left_at:
grace_period = timedelta(minutes=updated_meeting.grace_period_minutes)
if current_time > updated_meeting.last_participant_left_at + grace_period:
await meetings_controller.update_meeting(meeting.id, is_active=False)
updated_meeting = await meetings_controller.get_by_id(meeting.id)
assert updated_meeting.is_active is False # Now deactivated
@pytest.mark.asyncio
async def test_calendar_meeting_force_close_after_30_min():
"""Test that calendar meetings force close 30 minutes after scheduled end."""
# Create a room
room = await rooms_controller.add(
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
# Create a calendar event
event = CalendarEvent(
room_id=room.id,
ics_uid="test-event-force",
title="Test Meeting Force Close",
start_time=datetime.now(timezone.utc) - timedelta(hours=2),
end_time=datetime.now(timezone.utc) - timedelta(minutes=35), # Ended 35 min ago
)
event = await calendar_events_controller.upsert(event)
current_time = datetime.now(timezone.utc)
# Create meeting linked to calendar event
meeting = await meetings_controller.create(
id="meeting-force",
room_name="test-meeting-force",
room_url="https://whereby.com/test-force",
host_room_url="https://whereby.com/test-force-host",
start_date=event.start_time,
end_date=event.end_time,
user_id="test-user",
room=room,
calendar_event_id=event.id,
)
# Test that calendar meetings force close 30 min after scheduled end
# The meeting ended 35 minutes ago, so it should be force closed
# Manually test the force close logic that would be in process_meetings
if meeting.calendar_event_id:
if current_time > meeting.end_date + timedelta(minutes=30):
await meetings_controller.update_meeting(meeting.id, is_active=False)
updated_meeting = await meetings_controller.get_by_id(meeting.id)
assert updated_meeting.is_active is False # Force closed after 30 min
@pytest.mark.asyncio
async def test_participant_rejoin_clears_grace_period():
"""Test that participant rejoining clears the grace period."""
# Create a room
room = await rooms_controller.add(
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
current_time = datetime.now(timezone.utc)
end_time = current_time + timedelta(hours=2)
# Create meeting with grace period already set
meeting = await meetings_controller.create(
id="meeting-rejoin",
room_name="test-meeting-rejoin",
room_url="https://whereby.com/test-rejoin",
host_room_url="https://whereby.com/test-rejoin-host",
start_date=current_time,
end_date=end_time,
user_id="test-user",
room=room,
)
# Set last_participant_left_at to simulate grace period
await meetings_controller.update_meeting(
meeting.id,
last_participant_left_at=current_time - timedelta(minutes=5),
num_clients=0,
)
# Simulate participant rejoining - clear grace period
await meetings_controller.update_meeting(
meeting.id, last_participant_left_at=None, num_clients=1
)
updated_meeting = await meetings_controller.get_by_id(meeting.id)
assert updated_meeting.last_participant_left_at is None # Grace period cleared
assert updated_meeting.is_active is True # Still active

View File

@@ -2,7 +2,7 @@ import pytest
@pytest.mark.asyncio
async def test_processor_broadcast(nltk):
async def test_processor_broadcast():
from reflector.processors.base import BroadcastProcessor, Pipeline, Processor
class TestProcessor(Processor):

View File

@@ -3,11 +3,9 @@ import pytest
@pytest.mark.asyncio
async def test_basic_process(
nltk,
dummy_transcript,
dummy_llm,
dummy_processors,
ensure_casing,
):
# goal is to start the server, and send rtc audio to it
# validate the events received
@@ -16,8 +14,8 @@ async def test_basic_process(
from reflector.settings import settings
from reflector.tools.process import process_audio_file
# use an LLM test backend
settings.LLM_BACKEND = "test"
# LLM_BACKEND no longer exists in settings
# settings.LLM_BACKEND = "test"
settings.TRANSCRIPT_BACKEND = "whisper"
# event callback
@@ -35,7 +33,7 @@ async def test_basic_process(
# validate the events
assert marks["TranscriptLinerProcessor"] == 1
assert marks["TranscriptTranslatorProcessor"] == 1
assert marks["TranscriptTranslatorPassthroughProcessor"] == 1
assert marks["TranscriptTopicDetectorProcessor"] == 1
assert marks["TranscriptFinalSummaryProcessor"] == 1
assert marks["TranscriptFinalTitleProcessor"] == 1

View File

@@ -0,0 +1,225 @@
"""
Tests for Room model ICS calendar integration fields.
"""
from datetime import datetime, timezone
import pytest
from reflector.db.rooms import rooms_controller
@pytest.mark.asyncio
async def test_room_create_with_ics_fields():
"""Test creating a room with ICS calendar fields."""
room = await rooms_controller.add(
name="test-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.google.com/calendar/ical/test/private-token/basic.ics",
ics_fetch_interval=600,
ics_enabled=True,
)
assert room.name == "test-room"
assert (
room.ics_url
== "https://calendar.google.com/calendar/ical/test/private-token/basic.ics"
)
assert room.ics_fetch_interval == 600
assert room.ics_enabled is True
assert room.ics_last_sync is None
assert room.ics_last_etag is None
@pytest.mark.asyncio
async def test_room_update_ics_configuration():
"""Test updating room ICS configuration."""
# Create room without ICS
room = await rooms_controller.add(
name="update-test",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
assert room.ics_enabled is False
assert room.ics_url is None
# Update with ICS configuration
await rooms_controller.update(
room,
{
"ics_url": "https://outlook.office365.com/owa/calendar/test/calendar.ics",
"ics_fetch_interval": 300,
"ics_enabled": True,
},
)
assert (
room.ics_url == "https://outlook.office365.com/owa/calendar/test/calendar.ics"
)
assert room.ics_fetch_interval == 300
assert room.ics_enabled is True
@pytest.mark.asyncio
async def test_room_ics_sync_metadata():
"""Test updating room ICS sync metadata."""
room = await rooms_controller.add(
name="sync-test",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://example.com/calendar.ics",
ics_enabled=True,
)
# Update sync metadata
sync_time = datetime.now(timezone.utc)
await rooms_controller.update(
room,
{
"ics_last_sync": sync_time,
"ics_last_etag": "abc123hash",
},
)
assert room.ics_last_sync == sync_time
assert room.ics_last_etag == "abc123hash"
@pytest.mark.asyncio
async def test_room_get_with_ics_fields():
"""Test retrieving room with ICS fields."""
# Create room
created_room = await rooms_controller.add(
name="get-test",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="webcal://calendar.example.com/feed.ics",
ics_fetch_interval=900,
ics_enabled=True,
)
# Get by ID
room = await rooms_controller.get_by_id(created_room.id)
assert room is not None
assert room.ics_url == "webcal://calendar.example.com/feed.ics"
assert room.ics_fetch_interval == 900
assert room.ics_enabled is True
# Get by name
room = await rooms_controller.get_by_name("get-test")
assert room is not None
assert room.ics_url == "webcal://calendar.example.com/feed.ics"
assert room.ics_fetch_interval == 900
assert room.ics_enabled is True
@pytest.mark.asyncio
async def test_room_list_with_ics_enabled_filter():
"""Test listing rooms filtered by ICS enabled status."""
# Create rooms with and without ICS
room1 = await rooms_controller.add(
name="ics-enabled-1",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=True,
ics_enabled=True,
ics_url="https://calendar1.example.com/feed.ics",
)
room2 = await rooms_controller.add(
name="ics-disabled",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=True,
ics_enabled=False,
)
room3 = await rooms_controller.add(
name="ics-enabled-2",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=True,
ics_enabled=True,
ics_url="https://calendar2.example.com/feed.ics",
)
# Get all rooms
all_rooms = await rooms_controller.get_all()
assert len(all_rooms) == 3
# Filter for ICS-enabled rooms (would need to implement this in controller)
ics_rooms = [r for r in all_rooms if r["ics_enabled"]]
assert len(ics_rooms) == 2
assert all(r["ics_enabled"] for r in ics_rooms)
@pytest.mark.asyncio
async def test_room_default_ics_values():
"""Test that ICS fields have correct default values."""
room = await rooms_controller.add(
name="default-test",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
# Don't specify ICS fields
)
assert room.ics_url is None
assert room.ics_fetch_interval == 300 # Default 5 minutes
assert room.ics_enabled is False
assert room.ics_last_sync is None
assert room.ics_last_etag is None

View File

@@ -0,0 +1,385 @@
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
import pytest
from icalendar import Calendar, Event
from reflector.db.calendar_events import CalendarEvent, calendar_events_controller
from reflector.db.rooms import rooms_controller
@pytest.fixture
async def authenticated_client(client):
from reflector.app import app
from reflector.auth import current_user_optional
app.dependency_overrides[current_user_optional] = lambda: {
"sub": "test-user",
"email": "test@example.com",
}
yield client
del app.dependency_overrides[current_user_optional]
@pytest.mark.asyncio
async def test_create_room_with_ics_fields(authenticated_client):
client = authenticated_client
response = await client.post(
"/rooms",
json={
"name": "test-ics-room",
"zulip_auto_post": False,
"zulip_stream": "",
"zulip_topic": "",
"is_locked": False,
"room_mode": "normal",
"recording_type": "cloud",
"recording_trigger": "automatic-2nd-participant",
"is_shared": False,
"ics_url": "https://calendar.example.com/test.ics",
"ics_fetch_interval": 600,
"ics_enabled": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "test-ics-room"
assert data["ics_url"] == "https://calendar.example.com/test.ics"
assert data["ics_fetch_interval"] == 600
assert data["ics_enabled"] is True
@pytest.mark.asyncio
async def test_update_room_ics_configuration(authenticated_client):
client = authenticated_client
response = await client.post(
"/rooms",
json={
"name": "update-ics-room",
"zulip_auto_post": False,
"zulip_stream": "",
"zulip_topic": "",
"is_locked": False,
"room_mode": "normal",
"recording_type": "cloud",
"recording_trigger": "automatic-2nd-participant",
"is_shared": False,
},
)
assert response.status_code == 200
room_id = response.json()["id"]
response = await client.patch(
f"/rooms/{room_id}",
json={
"ics_url": "https://calendar.google.com/updated.ics",
"ics_fetch_interval": 300,
"ics_enabled": True,
},
)
assert response.status_code == 200
data = response.json()
assert data["ics_url"] == "https://calendar.google.com/updated.ics"
assert data["ics_fetch_interval"] == 300
assert data["ics_enabled"] is True
@pytest.mark.asyncio
async def test_trigger_ics_sync(authenticated_client):
client = authenticated_client
room = await rooms_controller.add(
name="sync-api-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/api.ics",
ics_enabled=True,
)
cal = Calendar()
event = Event()
event.add("uid", "api-test-event")
event.add("summary", "API Test Meeting")
from reflector.settings import settings
event.add("location", f"{settings.BASE_URL}/room/{room.name}")
now = datetime.now(timezone.utc)
event.add("dtstart", now + timedelta(hours=1))
event.add("dtend", now + timedelta(hours=2))
cal.add_component(event)
ics_content = cal.to_ical().decode("utf-8")
with patch(
"reflector.services.ics_sync.ICSFetchService.fetch_ics", new_callable=AsyncMock
) as mock_fetch:
mock_fetch.return_value = ics_content
response = await client.post(f"/rooms/{room.name}/ics/sync")
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert data["events_found"] == 1
assert data["events_created"] == 1
@pytest.mark.asyncio
async def test_trigger_ics_sync_unauthorized(client):
room = await rooms_controller.add(
name="sync-unauth-room",
user_id="owner-123",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/api.ics",
ics_enabled=True,
)
response = await client.post(f"/rooms/{room.name}/ics/sync")
assert response.status_code == 403
assert "Only room owner can trigger ICS sync" in response.json()["detail"]
@pytest.mark.asyncio
async def test_trigger_ics_sync_not_configured(authenticated_client):
client = authenticated_client
room = await rooms_controller.add(
name="sync-not-configured",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_enabled=False,
)
response = await client.post(f"/rooms/{room.name}/ics/sync")
assert response.status_code == 400
assert "ICS not configured" in response.json()["detail"]
@pytest.mark.asyncio
async def test_get_ics_status(authenticated_client):
client = authenticated_client
room = await rooms_controller.add(
name="status-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/status.ics",
ics_enabled=True,
ics_fetch_interval=300,
)
now = datetime.now(timezone.utc)
await rooms_controller.update(
room,
{"ics_last_sync": now, "ics_last_etag": "test-etag"},
)
response = await client.get(f"/rooms/{room.name}/ics/status")
assert response.status_code == 200
data = response.json()
assert data["status"] == "enabled"
assert data["last_etag"] == "test-etag"
assert data["events_count"] == 0
@pytest.mark.asyncio
async def test_get_ics_status_unauthorized(client):
room = await rooms_controller.add(
name="status-unauth",
user_id="owner-456",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
ics_url="https://calendar.example.com/status.ics",
ics_enabled=True,
)
response = await client.get(f"/rooms/{room.name}/ics/status")
assert response.status_code == 403
assert "Only room owner can view ICS status" in response.json()["detail"]
@pytest.mark.asyncio
async def test_list_room_meetings(authenticated_client):
client = authenticated_client
room = await rooms_controller.add(
name="meetings-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
now = datetime.now(timezone.utc)
event1 = CalendarEvent(
room_id=room.id,
ics_uid="meeting-1",
title="Past Meeting",
start_time=now - timedelta(hours=2),
end_time=now - timedelta(hours=1),
)
await calendar_events_controller.upsert(event1)
event2 = CalendarEvent(
room_id=room.id,
ics_uid="meeting-2",
title="Future Meeting",
description="Team sync",
start_time=now + timedelta(hours=1),
end_time=now + timedelta(hours=2),
attendees=[{"email": "test@example.com"}],
)
await calendar_events_controller.upsert(event2)
response = await client.get(f"/rooms/{room.name}/meetings")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["title"] == "Past Meeting"
assert data[1]["title"] == "Future Meeting"
assert data[1]["description"] == "Team sync"
assert data[1]["attendees"] == [{"email": "test@example.com"}]
@pytest.mark.asyncio
async def test_list_room_meetings_non_owner(client):
room = await rooms_controller.add(
name="meetings-privacy",
user_id="owner-789",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
event = CalendarEvent(
room_id=room.id,
ics_uid="private-meeting",
title="Meeting Title",
description="Sensitive info",
start_time=datetime.now(timezone.utc) + timedelta(hours=1),
end_time=datetime.now(timezone.utc) + timedelta(hours=2),
attendees=[{"email": "private@example.com"}],
)
await calendar_events_controller.upsert(event)
response = await client.get(f"/rooms/{room.name}/meetings")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["title"] == "Meeting Title"
assert data[0]["description"] is None
assert data[0]["attendees"] is None
@pytest.mark.asyncio
async def test_list_upcoming_meetings(authenticated_client):
client = authenticated_client
room = await rooms_controller.add(
name="upcoming-room",
user_id="test-user",
zulip_auto_post=False,
zulip_stream="",
zulip_topic="",
is_locked=False,
room_mode="normal",
recording_type="cloud",
recording_trigger="automatic-2nd-participant",
is_shared=False,
)
now = datetime.now(timezone.utc)
past_event = CalendarEvent(
room_id=room.id,
ics_uid="past",
title="Past",
start_time=now - timedelta(hours=1),
end_time=now - timedelta(minutes=30),
)
await calendar_events_controller.upsert(past_event)
soon_event = CalendarEvent(
room_id=room.id,
ics_uid="soon",
title="Soon",
start_time=now + timedelta(minutes=15),
end_time=now + timedelta(minutes=45),
)
await calendar_events_controller.upsert(soon_event)
later_event = CalendarEvent(
room_id=room.id,
ics_uid="later",
title="Later",
start_time=now + timedelta(hours=2),
end_time=now + timedelta(hours=3),
)
await calendar_events_controller.upsert(later_event)
response = await client.get(f"/rooms/{room.name}/meetings/upcoming")
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["title"] == "Soon"
response = await client.get(
f"/rooms/{room.name}/meetings/upcoming", params={"minutes_ahead": 180}
)
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["title"] == "Soon"
assert data[1]["title"] == "Later"
@pytest.mark.asyncio
async def test_room_not_found_endpoints(client):
response = await client.post("/rooms/nonexistent/ics/sync")
assert response.status_code == 404
response = await client.get("/rooms/nonexistent/ics/status")
assert response.status_code == 404
response = await client.get("/rooms/nonexistent/meetings")
assert response.status_code == 404
response = await client.get("/rooms/nonexistent/meetings/upcoming")
assert response.status_code == 404

144
server/tests/test_search.py Normal file
View File

@@ -0,0 +1,144 @@
"""Tests for full-text search functionality."""
import json
from datetime import datetime, timezone
import pytest
from pydantic import ValidationError
from reflector.db import get_database
from reflector.db.search import SearchParameters, search_controller
from reflector.db.transcripts import transcripts
@pytest.mark.asyncio
async def test_search_postgresql_only():
params = SearchParameters(query_text="any query here")
results, total = await search_controller.search_transcripts(params)
assert results == []
assert total == 0
try:
SearchParameters(query_text="")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" ")
assert False, "Should have raised validation error"
except ValidationError:
pass # Expected
@pytest.mark.asyncio
async def test_search_input_validation():
try:
SearchParameters(query_text="")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
# Test that whitespace query raises validation error
try:
SearchParameters(query_text=" \t\n ")
assert False, "Should have raised ValidationError"
except ValidationError:
pass # Expected
@pytest.mark.asyncio
async def test_postgresql_search_with_data():
# collision is improbable
test_id = "test-search-e2e-7f3a9b2c"
try:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
test_data = {
"id": test_id,
"name": "Test Search Transcript",
"title": "Engineering Planning Meeting Q4 2024",
"status": "completed",
"locked": False,
"duration": 1800.0,
"created_at": datetime.now(timezone.utc),
"short_summary": "Team discussed search implementation",
"long_summary": "The engineering team met to plan the search feature",
"topics": json.dumps([]),
"events": json.dumps([]),
"participants": json.dumps([]),
"source_language": "en",
"target_language": "en",
"reviewed": False,
"audio_location": "local",
"share_mode": "private",
"source_kind": "room",
"webvtt": """WEBVTT
00:00:00.000 --> 00:00:10.000
Welcome to our engineering planning meeting for Q4 2024.
00:00:10.000 --> 00:00:20.000
Today we'll discuss the implementation of full-text search.
00:00:20.000 --> 00:00:30.000
The search feature should support complex queries with ranking.
00:00:30.000 --> 00:00:40.000
We need to implement PostgreSQL tsvector for better performance.""",
}
await get_database().execute(transcripts.insert().values(**test_data))
# Test 1: Search for a word in title
params = SearchParameters(query_text="planning")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by title word"
# Test 2: Search for a word in webvtt content
params = SearchParameters(query_text="tsvector")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by webvtt content"
# Test 3: Search with multiple words
params = SearchParameters(query_text="engineering planning")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by multiple words"
# Test 4: Verify SearchResult structure
test_result = next((r for r in results if r.id == test_id), None)
if test_result:
assert test_result.title == "Engineering Planning Meeting Q4 2024"
assert test_result.status == "completed"
assert test_result.duration == 1800.0
assert 0 <= test_result.rank <= 1, "Rank should be normalized to 0-1"
# Test 5: Search with OR operator
params = SearchParameters(query_text="tsvector OR nosuchword")
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript with OR query"
# Test 6: Quoted phrase search
params = SearchParameters(query_text='"full-text search"')
results, total = await search_controller.search_transcripts(params)
assert total >= 1
found = any(r.id == test_id for r in results)
assert found, "Should find test transcript by exact phrase"
finally:
await get_database().execute(
transcripts.delete().where(transcripts.c.id == test_id)
)
await get_database().disconnect()

View File

@@ -0,0 +1,198 @@
"""Unit tests for search snippet generation."""
from reflector.db.search import SearchController
class TestExtractWebVTT:
"""Test WebVTT text extraction."""
def test_extract_webvtt_with_speakers(self):
"""Test extraction removes speaker tags and timestamps."""
webvtt = """WEBVTT
00:00:00.000 --> 00:00:10.000
<v Speaker0>Hello world, this is a test.
00:00:10.000 --> 00:00:20.000
<v Speaker1>Indeed it is a test of WebVTT parsing.
"""
result = SearchController._extract_webvtt_text(webvtt)
assert "Hello world, this is a test" in result
assert "Indeed it is a test" in result
assert "<v Speaker" not in result
assert "00:00" not in result
assert "-->" not in result
def test_extract_empty_webvtt(self):
"""Test empty WebVTT returns empty string."""
assert SearchController._extract_webvtt_text("") == ""
assert SearchController._extract_webvtt_text(None) == ""
def test_extract_malformed_webvtt(self):
"""Test malformed WebVTT returns empty string."""
result = SearchController._extract_webvtt_text("Not a valid WebVTT")
assert result == ""
class TestGenerateSnippets:
"""Test snippet generation from plain text."""
def test_multiple_matches(self):
"""Test finding multiple occurrences of search term in long text."""
# Create text with Python mentions far apart to get separate snippets
separator = " This is filler text. " * 20 # ~400 chars of padding
text = (
"Python is great for machine learning."
+ separator
+ "Many companies use Python for data science."
+ separator
+ "Python has excellent libraries for analysis."
+ separator
+ "The Python community is very supportive."
)
snippets = SearchController._generate_snippets(text, "Python")
# With enough separation, we should get multiple snippets
assert len(snippets) >= 2 # At least 2 distinct snippets
# Each snippet should contain "Python"
for snippet in snippets:
assert "python" in snippet.lower()
def test_single_match(self):
"""Test single occurrence returns one snippet."""
text = "This document discusses artificial intelligence and its applications."
snippets = SearchController._generate_snippets(text, "artificial intelligence")
assert len(snippets) == 1
assert "artificial intelligence" in snippets[0].lower()
def test_no_matches(self):
"""Test no matches returns empty list."""
text = "This is some random text without the search term."
snippets = SearchController._generate_snippets(text, "machine learning")
assert snippets == []
def test_case_insensitive_search(self):
"""Test search is case insensitive."""
# Add enough text between matches to get separate snippets
text = (
"MACHINE LEARNING is important for modern applications. "
+ "It requires lots of data and computational resources. " * 5 # Padding
+ "Machine Learning rocks and transforms industries. "
+ "Deep learning is a subset of it. " * 5 # More padding
+ "Finally, machine learning will shape our future."
)
snippets = SearchController._generate_snippets(text, "machine learning")
# Should find at least 2 (might be 3 if text is long enough)
assert len(snippets) >= 2
for snippet in snippets:
assert "machine learning" in snippet.lower()
def test_partial_match_fallback(self):
"""Test fallback to first word when exact phrase not found."""
text = "We use machine intelligence for processing."
snippets = SearchController._generate_snippets(text, "machine learning")
# Should fall back to finding "machine"
assert len(snippets) == 1
assert "machine" in snippets[0].lower()
def test_snippet_ellipsis(self):
"""Test ellipsis added for truncated snippets."""
# Long text where match is in the middle
text = "a " * 100 + "TARGET_WORD special content here" + " b" * 100
snippets = SearchController._generate_snippets(text, "TARGET_WORD")
assert len(snippets) == 1
assert "..." in snippets[0] # Should have ellipsis
assert "TARGET_WORD" in snippets[0]
def test_overlapping_snippets_deduplicated(self):
"""Test overlapping matches don't create duplicate snippets."""
text = "test test test word" * 10 # Repeated pattern
snippets = SearchController._generate_snippets(text, "test")
# Should get unique snippets, not duplicates
assert len(snippets) <= 3
assert len(snippets) == len(set(snippets)) # All unique
def test_empty_inputs(self):
"""Test empty text or search term returns empty list."""
assert SearchController._generate_snippets("", "search") == []
assert SearchController._generate_snippets("text", "") == []
assert SearchController._generate_snippets("", "") == []
def test_max_snippets_limit(self):
"""Test respects max_snippets parameter."""
# Create text with well-separated occurrences
separator = " filler " * 50 # Ensure snippets don't overlap
text = ("Python is amazing" + separator) * 10 # 10 occurrences
# Test with different limits
snippets_1 = SearchController._generate_snippets(text, "Python", max_snippets=1)
assert len(snippets_1) == 1
snippets_2 = SearchController._generate_snippets(text, "Python", max_snippets=2)
assert len(snippets_2) == 2
snippets_5 = SearchController._generate_snippets(text, "Python", max_snippets=5)
assert len(snippets_5) == 5 # Should get exactly 5 with enough separation
def test_snippet_length(self):
"""Test snippet length is reasonable."""
text = "word " * 200 # Long text
snippets = SearchController._generate_snippets(text, "word")
for snippet in snippets:
# Default max_length is 150 + some context
assert len(snippet) <= 200 # Some buffer for ellipsis
class TestFullPipeline:
"""Test the complete WebVTT to snippets pipeline."""
def test_webvtt_to_snippets_integration(self):
"""Test full pipeline from WebVTT to search snippets."""
# Create WebVTT with well-separated content for multiple snippets
webvtt = (
"""WEBVTT
00:00:00.000 --> 00:00:10.000
<v Speaker0>Let's discuss machine learning applications in modern technology.
00:00:10.000 --> 00:00:20.000
<v Speaker1>"""
+ "Various industries are adopting new technologies. " * 10
+ """
00:00:20.000 --> 00:00:30.000
<v Speaker2>Machine learning is revolutionizing healthcare and diagnostics.
00:00:30.000 --> 00:00:40.000
<v Speaker3>"""
+ "Financial markets show interesting patterns. " * 10
+ """
00:00:40.000 --> 00:00:50.000
<v Speaker0>Machine learning in education provides personalized experiences.
"""
)
# Extract and generate snippets
plain_text = SearchController._extract_webvtt_text(webvtt)
snippets = SearchController._generate_snippets(plain_text, "machine learning")
# Should find at least 2 snippets (text might still be close together)
assert len(snippets) >= 1 # At minimum one snippet containing matches
assert len(snippets) <= 3 # At most 3 by default
# No WebVTT artifacts in snippets
for snippet in snippets:
assert "machine learning" in snippet.lower()
assert "<v Speaker" not in snippet
assert "00:00" not in snippet
assert "-->" not in snippet

View File

@@ -1,147 +1,128 @@
from contextlib import asynccontextmanager
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_create():
from reflector.app import app
async def test_transcript_create(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["status"] == "idle"
assert response.json()["locked"] is False
assert response.json()["id"] is not None
assert response.json()["created_at"] is not None
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["status"] == "idle"
assert response.json()["locked"] is False
assert response.json()["id"] is not None
assert response.json()["created_at"] is not None
# ensure some fields are not returned
assert "topics" not in response.json()
assert "events" not in response.json()
# ensure some fields are not returned
assert "topics" not in response.json()
assert "events" not in response.json()
@pytest.mark.asyncio
async def test_transcript_get_update_name():
from reflector.app import app
async def test_transcript_get_update_name(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
tid = response.json()["id"]
tid = response.json()["id"]
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test"
response = await client.patch(f"/transcripts/{tid}", json={"name": "test2"})
assert response.status_code == 200
assert response.json()["name"] == "test2"
response = await ac.patch(f"/transcripts/{tid}", json={"name": "test2"})
assert response.status_code == 200
assert response.json()["name"] == "test2"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test2"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test2"
@pytest.mark.asyncio
async def test_transcript_get_update_locked():
from reflector.app import app
async def test_transcript_get_update_locked(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["locked"] is False
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["locked"] is False
tid = response.json()["id"]
tid = response.json()["id"]
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["locked"] is False
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["locked"] is False
response = await client.patch(f"/transcripts/{tid}", json={"locked": True})
assert response.status_code == 200
assert response.json()["locked"] is True
response = await ac.patch(f"/transcripts/{tid}", json={"locked": True})
assert response.status_code == 200
assert response.json()["locked"] is True
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["locked"] is True
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["locked"] is True
@pytest.mark.asyncio
async def test_transcript_get_update_summary():
from reflector.app import app
async def test_transcript_get_update_summary(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
tid = response.json()["id"]
tid = response.json()["id"]
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["long_summary"] is None
assert response.json()["short_summary"] is None
response = await client.patch(
f"/transcripts/{tid}",
json={"long_summary": "test_long", "short_summary": "test_short"},
)
assert response.status_code == 200
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
response = await ac.patch(
f"/transcripts/{tid}",
json={"long_summary": "test_long", "short_summary": "test_short"},
)
assert response.status_code == 200
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["long_summary"] == "test_long"
assert response.json()["short_summary"] == "test_short"
@pytest.mark.asyncio
async def test_transcript_get_update_title():
from reflector.app import app
async def test_transcript_get_update_title(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["title"] is None
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["title"] is None
tid = response.json()["id"]
tid = response.json()["id"]
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] is None
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] is None
response = await client.patch(f"/transcripts/{tid}", json={"title": "test_title"})
assert response.status_code == 200
assert response.json()["title"] == "test_title"
response = await ac.patch(f"/transcripts/{tid}", json={"title": "test_title"})
assert response.status_code == 200
assert response.json()["title"] == "test_title"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] == "test_title"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["title"] == "test_title"
@pytest.mark.asyncio
async def test_transcripts_list_anonymous():
async def test_transcripts_list_anonymous(client):
# XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests
from reflector.app import app
from reflector.settings import settings
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.get("/transcripts")
assert response.status_code == 401
response = await client.get("/transcripts")
assert response.status_code == 401
# if public mode, it should be allowed
try:
settings.PUBLIC_MODE = True
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.get("/transcripts")
assert response.status_code == 200
response = await client.get("/transcripts")
assert response.status_code == 200
finally:
settings.PUBLIC_MODE = False
@@ -197,67 +178,59 @@ async def authenticated_client2():
@pytest.mark.asyncio
async def test_transcripts_list_authenticated(authenticated_client):
async def test_transcripts_list_authenticated(authenticated_client, client):
# XXX this test is a bit fragile, as it depends on the storage which
# is shared between tests
from reflector.app import app
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testxx1"})
assert response.status_code == 200
assert response.json()["name"] == "testxx1"
response = await client.post("/transcripts", json={"name": "testxx1"})
assert response.status_code == 200
assert response.json()["name"] == "testxx1"
response = await ac.post("/transcripts", json={"name": "testxx2"})
assert response.status_code == 200
assert response.json()["name"] == "testxx2"
response = await client.post("/transcripts", json={"name": "testxx2"})
assert response.status_code == 200
assert response.json()["name"] == "testxx2"
response = await ac.get("/transcripts")
assert response.status_code == 200
assert len(response.json()["items"]) >= 2
names = [t["name"] for t in response.json()["items"]]
assert "testxx1" in names
assert "testxx2" in names
response = await client.get("/transcripts")
assert response.status_code == 200
assert len(response.json()["items"]) >= 2
names = [t["name"] for t in response.json()["items"]]
assert "testxx1" in names
assert "testxx2" in names
@pytest.mark.asyncio
async def test_transcript_delete():
from reflector.app import app
async def test_transcript_delete(client):
response = await client.post("/transcripts", json={"name": "testdel1"})
assert response.status_code == 200
assert response.json()["name"] == "testdel1"
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "testdel1"})
assert response.status_code == 200
assert response.json()["name"] == "testdel1"
tid = response.json()["id"]
response = await client.delete(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["status"] == "ok"
tid = response.json()["id"]
response = await ac.delete(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["status"] == "ok"
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 404
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 404
@pytest.mark.asyncio
async def test_transcript_mark_reviewed():
from reflector.app import app
async def test_transcript_mark_reviewed(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False
tid = response.json()["id"]
tid = response.json()["id"]
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test"
assert response.json()["reviewed"] is False
response = await client.patch(f"/transcripts/{tid}", json={"reviewed": True})
assert response.status_code == 200
assert response.json()["reviewed"] is True
response = await ac.patch(f"/transcripts/{tid}", json={"reviewed": True})
assert response.status_code == 200
assert response.json()["reviewed"] is True
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["reviewed"] is True
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["reviewed"] is True

View File

@@ -2,20 +2,17 @@ import shutil
from pathlib import Path
import pytest
from httpx import AsyncClient
@pytest.fixture
async def fake_transcript(tmpdir):
from reflector.app import app
async def fake_transcript(tmpdir, client):
from reflector.settings import settings
from reflector.views.transcripts import transcripts_controller
settings.DATA_DIR = Path(tmpdir)
# create a transcript
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.post("/transcripts", json={"name": "Test audio download"})
response = await client.post("/transcripts", json={"name": "Test audio download"})
assert response.status_code == 200
tid = response.json()["id"]
@@ -39,17 +36,17 @@ async def fake_transcript(tmpdir):
["/mp3", "audio/mpeg"],
],
)
async def test_transcript_audio_download(fake_transcript, url_suffix, content_type):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
async def test_transcript_audio_download(
fake_transcript, url_suffix, content_type, client
):
response = await client.get(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
assert response.status_code == 200
assert response.headers["content-type"] == content_type
# test get 404
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
response = await client.get(
f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}"
)
assert response.status_code == 404
@@ -61,18 +58,16 @@ async def test_transcript_audio_download(fake_transcript, url_suffix, content_ty
],
)
async def test_transcript_audio_download_head(
fake_transcript, url_suffix, content_type
fake_transcript, url_suffix, content_type, client
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
response = await client.head(f"/transcripts/{fake_transcript.id}/audio{url_suffix}")
assert response.status_code == 200
assert response.headers["content-type"] == content_type
# test head 404
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.head(f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}")
response = await client.head(
f"/transcripts/{fake_transcript.id}XXX/audio{url_suffix}"
)
assert response.status_code == 404
@@ -84,12 +79,9 @@ async def test_transcript_audio_download_head(
],
)
async def test_transcript_audio_download_range(
fake_transcript, url_suffix, content_type
fake_transcript, url_suffix, content_type, client
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(
response = await client.get(
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
headers={"range": "bytes=0-100"},
)
@@ -107,12 +99,9 @@ async def test_transcript_audio_download_range(
],
)
async def test_transcript_audio_download_range_with_seek(
fake_transcript, url_suffix, content_type
fake_transcript, url_suffix, content_type, client
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.get(
response = await client.get(
f"/transcripts/{fake_transcript.id}/audio{url_suffix}",
headers={"range": "bytes=100-"},
)
@@ -122,13 +111,10 @@ async def test_transcript_audio_download_range_with_seek(
@pytest.mark.asyncio
async def test_transcript_delete_with_audio(fake_transcript):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
response = await ac.delete(f"/transcripts/{fake_transcript.id}")
async def test_transcript_delete_with_audio(fake_transcript, client):
response = await client.delete(f"/transcripts/{fake_transcript.id}")
assert response.status_code == 200
assert response.json()["status"] == "ok"
response = await ac.get(f"/transcripts/{fake_transcript.id}")
response = await client.get(f"/transcripts/{fake_transcript.id}")
assert response.status_code == 404

View File

@@ -1,164 +1,151 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_participants():
from reflector.app import app
async def test_transcript_participants(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
# create a participant
transcript_id = response.json()["id"]
response = await client.post(
f"/transcripts/{transcript_id}/participants", json={"name": "test"}
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] is None
assert response.json()["name"] == "test"
# create a participant
transcript_id = response.json()["id"]
response = await ac.post(
f"/transcripts/{transcript_id}/participants", json={"name": "test"}
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] is None
assert response.json()["name"] == "test"
# create another one with a speaker
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] == 1
assert response.json()["name"] == "test2"
# create another one with a speaker
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["id"] is not None
assert response.json()["speaker"] == 1
assert response.json()["name"] == "test2"
# get all participants via transcript
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 2
# get all participants via transcript
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 2
# get participants via participants endpoint
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
assert len(response.json()) == 2
# get participants via participants endpoint
response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
assert len(response.json()) == 2
@pytest.mark.asyncio
async def test_transcript_participants_same_speaker():
from reflector.app import app
async def test_transcript_participants_same_speaker(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# create another one with the same speaker
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 400
# create another one with the same speaker
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 1},
)
assert response.status_code == 400
@pytest.mark.asyncio
async def test_transcript_participants_update_name():
from reflector.app import app
async def test_transcript_participants_update_name(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
assert response.json()["speaker"] == 1
# update the participant
participant_id = response.json()["id"]
response = await client.patch(
f"/transcripts/{transcript_id}/participants/{participant_id}",
json={"name": "test2"},
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# update the participant
participant_id = response.json()["id"]
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant_id}",
json={"name": "test2"},
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated
response = await client.get(
f"/transcripts/{transcript_id}/participants/{participant_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated
response = await ac.get(
f"/transcripts/{transcript_id}/participants/{participant_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
# verify the participant was updated in transcript
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 1
assert response.json()["participants"][0]["name"] == "test2"
# verify the participant was updated in transcript
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
assert len(response.json()["participants"]) == 1
assert response.json()["participants"][0]["name"] == "test2"
@pytest.mark.asyncio
async def test_transcript_participants_update_speaker():
from reflector.app import app
async def test_transcript_participants_update_speaker(client):
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["participants"] == []
transcript_id = response.json()["id"]
# create a participant
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
# create a participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test", "speaker": 1},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
# create another participant
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 2},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# create another participant
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={"name": "test2", "speaker": 2},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# update the participant, refused as speaker is already taken
response = await client.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 400
# update the participant, refused as speaker is already taken
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 400
# delete the participant 1
response = await client.delete(
f"/transcripts/{transcript_id}/participants/{participant1_id}"
)
assert response.status_code == 200
# delete the participant 1
response = await ac.delete(
f"/transcripts/{transcript_id}/participants/{participant1_id}"
)
assert response.status_code == 200
# update the participant 2 again, should be accepted now
response = await client.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 200
# update the participant 2 again, should be accepted now
response = await ac.patch(
f"/transcripts/{transcript_id}/participants/{participant2_id}",
json={"speaker": 1},
)
assert response.status_code == 200
# ensure participant2 name is still there
response = await ac.get(
f"/transcripts/{transcript_id}/participants/{participant2_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
assert response.json()["speaker"] == 1
# ensure participant2 name is still there
response = await client.get(
f"/transcripts/{transcript_id}/participants/{participant2_id}"
)
assert response.status_code == 200
assert response.json()["name"] == "test2"
assert response.json()["speaker"] == 1

View File

@@ -1,7 +1,26 @@
import asyncio
import time
import pytest
from httpx import AsyncClient
from httpx import ASGITransport, AsyncClient
@pytest.fixture
async def app_lifespan():
from asgi_lifespan import LifespanManager
from reflector.app import app
async with LifespanManager(app) as manager:
yield manager.app
@pytest.fixture
async def client(app_lifespan):
yield AsyncClient(
transport=ASGITransport(app=app_lifespan),
base_url="http://test/v1",
)
@pytest.mark.usefixtures("setup_database")
@@ -10,24 +29,21 @@ from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_process(
tmpdir,
ensure_casing,
whisper_transcript,
dummy_llm,
dummy_processors,
dummy_diarization,
dummy_storage,
client,
):
from reflector.app import app
ac = AsyncClient(app=app, base_url="http://test/v1")
# create a transcript
response = await ac.post("/transcripts", json={"name": "test"})
response = await client.post("/transcripts", json={"name": "test"})
assert response.status_code == 200
assert response.json()["status"] == "idle"
tid = response.json()["id"]
# upload mp3
response = await ac.post(
response = await client.post(
f"/transcripts/{tid}/record/upload?chunk_number=0&total_chunks=1",
files={
"chunk": (
@@ -40,39 +56,47 @@ async def test_transcript_process(
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait for processing to finish
while True:
# wait for processing to finish (max 10 minutes)
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Initial processing timed out after {timeout_seconds} seconds")
# restart the processing
response = await ac.post(
response = await client.post(
f"/transcripts/{tid}/process",
)
assert response.status_code == 200
assert response.json()["status"] == "ok"
# wait for processing to finish
while True:
# wait for processing to finish (max 10 minutes)
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
else:
pytest.fail(f"Restart processing timed out after {timeout_seconds} seconds")
# check the transcript is ended
transcript = resp.json()
assert transcript["status"] == "ended"
assert transcript["short_summary"] == "LLM SHORT SUMMARY"
assert transcript["title"] == "LLM TITLE"
assert transcript["title"] == "Llm Title"
# check topics and transcript
response = await ac.get(f"/transcripts/{tid}/topics")
response = await client.get(f"/transcripts/{tid}/topics")
assert response.status_code == 200
assert len(response.json()) == 1
assert "want to share" in response.json()[0]["transcript"]

View File

@@ -0,0 +1,34 @@
from datetime import datetime, timezone
from unittest.mock import AsyncMock, patch
import pytest
from reflector.db.recordings import Recording, recordings_controller
from reflector.db.transcripts import SourceKind, transcripts_controller
@pytest.mark.asyncio
async def test_recording_deleted_with_transcript():
recording = await recordings_controller.create(
Recording(
bucket_name="test-bucket",
object_key="recording.mp4",
recorded_at=datetime.now(timezone.utc),
)
)
transcript = await transcripts_controller.add(
name="Test Transcript",
source_kind=SourceKind.ROOM,
recording_id=recording.id,
)
with patch("reflector.db.transcripts.get_recordings_storage") as mock_get_storage:
storage_instance = mock_get_storage.return_value
storage_instance.delete_file = AsyncMock()
await transcripts_controller.remove_by_id(transcript.id)
storage_instance.delete_file.assert_awaited_once_with(recording.object_key)
assert await recordings_controller.get_by_id(recording.id) is None
assert await transcripts_controller.get_by_id(transcript.id) is None

View File

@@ -6,10 +6,10 @@
import asyncio
import json
import threading
import time
from pathlib import Path
import pytest
from httpx import AsyncClient
from httpx_ws import aconnect_ws
from uvicorn import Config, Server
@@ -21,34 +21,97 @@ class ThreadedUvicorn:
async def start(self):
self.thread.start()
while not self.server.started:
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (
not self.server.started
and (time.monotonic() - start_time) < timeout_seconds
):
await asyncio.sleep(0.1)
if not self.server.started:
raise TimeoutError(
f"Server failed to start after {timeout_seconds} seconds"
)
def stop(self):
if self.thread.is_alive():
self.server.should_exit = True
while self.thread.is_alive():
continue
timeout_seconds = 600 # 10 minutes
start_time = time.time()
while (
self.thread.is_alive() and (time.time() - start_time) < timeout_seconds
):
time.sleep(0.1)
if self.thread.is_alive():
raise TimeoutError(
f"Thread failed to stop after {timeout_seconds} seconds"
)
@pytest.fixture
async def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
def appserver(tmpdir, setup_database, celery_session_app, celery_session_worker):
import threading
from reflector.app import app
from reflector.db import get_database
from reflector.settings import settings
DATA_DIR = settings.DATA_DIR
settings.DATA_DIR = Path(tmpdir)
# start server
# start server in a separate thread with its own event loop
host = "127.0.0.1"
port = 1255
config = Config(app=app, host=host, port=port)
server = ThreadedUvicorn(config)
await server.start()
server_started = threading.Event()
server_exception = None
server_instance = None
yield (server, host, port)
def run_server():
nonlocal server_exception, server_instance
try:
# Create a new event loop for this thread
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
config = Config(app=app, host=host, port=port, loop=loop)
server_instance = Server(config)
async def start_server():
# Initialize database connection in this event loop
database = get_database()
await database.connect()
try:
await server_instance.serve()
finally:
await database.disconnect()
# Signal that server is starting
server_started.set()
loop.run_until_complete(start_server())
except Exception as e:
server_exception = e
server_started.set()
finally:
loop.close()
server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()
# Wait for server to start
server_started.wait(timeout=30)
if server_exception:
raise server_exception
# Wait a bit more for the server to be fully ready
time.sleep(1)
yield server_instance, host, port
# Stop server
if server_instance:
server_instance.should_exit = True
server_thread.join(timeout=30)
server.stop()
settings.DATA_DIR = DATA_DIR
@@ -67,11 +130,11 @@ async def test_transcript_rtc_and_websocket(
dummy_transcript,
dummy_processors,
dummy_diarization,
dummy_transcript_translator,
dummy_storage,
fake_mp3_upload,
ensure_casing,
nltk,
appserver,
client,
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
@@ -80,8 +143,7 @@ async def test_transcript_rtc_and_websocket(
# create a transcript
base_url = f"http://{host}:{port}/v1"
ac = AsyncClient(base_url=base_url)
response = await ac.post("/transcripts", json={"name": "Test RTC"})
response = await client.post("/transcripts", json={"name": "Test RTC"})
assert response.status_code == 200
tid = response.json()["id"]
@@ -93,12 +155,16 @@ async def test_transcript_rtc_and_websocket(
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED")
try:
while True:
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}")
if msg is None:
break
events.append(msg)
else:
print(f"Test websocket: TIMEOUT after {timeout_seconds} seconds")
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
@@ -122,11 +188,11 @@ async def test_transcript_rtc_and_websocket(
url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
stream_client = StreamClient(signaling, url=url, play_from=path.as_posix())
await stream_client.start()
timeout = 20
while not client.is_ended():
timeout = 120
while not stream_client.is_ended():
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
@@ -134,21 +200,24 @@ async def test_transcript_rtc_and_websocket(
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
await client.stop()
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
await stream_client.stop()
# wait the processing to finish
timeout = 20
timeout = 120
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] in ("ended", "error"):
break
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for transcript to be ended")
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
raise TimeoutError("Transcript processing failed")
# stop websocket task
websocket_task.cancel()
@@ -166,7 +235,7 @@ async def test_transcript_rtc_and_websocket(
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["translation"] == "Bonjour le monde"
assert ev["data"]["translation"] is None
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
@@ -185,13 +254,13 @@ async def test_transcript_rtc_and_websocket(
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
assert ev["data"]["title"] == "Llm Title"
assert "WAVEFORM" in eventnames
ev = events[eventnames.index("WAVEFORM")]
assert isinstance(ev["data"]["waveform"], list)
assert len(ev["data"]["waveform"]) >= 250
waveform_resp = await ac.get(f"/transcripts/{tid}/audio/waveform")
waveform_resp = await client.get(f"/transcripts/{tid}/audio/waveform")
assert waveform_resp.status_code == 200
assert waveform_resp.headers["content-type"] == "application/json"
assert isinstance(waveform_resp.json()["data"], list)
@@ -211,7 +280,7 @@ async def test_transcript_rtc_and_websocket(
assert "DURATION" in eventnames
# check that audio/mp3 is available
audio_resp = await ac.get(f"/transcripts/{tid}/audio/mp3")
audio_resp = await client.get(f"/transcripts/{tid}/audio/mp3")
assert audio_resp.status_code == 200
assert audio_resp.headers["Content-Type"] == "audio/mpeg"
@@ -226,11 +295,11 @@ async def test_transcript_rtc_and_websocket_and_fr(
dummy_transcript,
dummy_processors,
dummy_diarization,
dummy_transcript_translator,
dummy_storage,
fake_mp3_upload,
ensure_casing,
nltk,
appserver,
client,
):
# goal: start the server, exchange RTC, receive websocket events
# because of that, we need to start the server in a thread
@@ -240,8 +309,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
# create a transcript
base_url = f"http://{host}:{port}/v1"
ac = AsyncClient(base_url=base_url)
response = await ac.post(
response = await client.post(
"/transcripts", json={"name": "Test RTC", "target_language": "fr"}
)
assert response.status_code == 200
@@ -255,12 +323,16 @@ async def test_transcript_rtc_and_websocket_and_fr(
async with aconnect_ws(f"{base_url}/transcripts/{tid}/events") as ws:
print("Test websocket: CONNECTED")
try:
while True:
timeout_seconds = 600 # 10 minutes
start_time = time.monotonic()
while (time.monotonic() - start_time) < timeout_seconds:
msg = await ws.receive_json()
print(f"Test websocket: JSON {msg}")
if msg is None:
break
events.append(msg)
else:
print(f"Test websocket: TIMEOUT after {timeout_seconds} seconds")
except Exception as e:
print(f"Test websocket: EXCEPTION {e}")
finally:
@@ -284,11 +356,11 @@ async def test_transcript_rtc_and_websocket_and_fr(
url = f"{base_url}/transcripts/{tid}/record/webrtc"
path = Path(__file__).parent / "records" / "test_short.wav"
client = StreamClient(signaling, url=url, play_from=path.as_posix())
await client.start()
stream_client = StreamClient(signaling, url=url, play_from=path.as_posix())
await stream_client.start()
timeout = 20
while not client.is_ended():
timeout = 120
while not stream_client.is_ended():
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
@@ -296,25 +368,28 @@ async def test_transcript_rtc_and_websocket_and_fr(
# XXX aiortc is long to close the connection
# instead of waiting a long time, we just send a STOP
client.channel.send(json.dumps({"cmd": "STOP"}))
stream_client.channel.send(json.dumps({"cmd": "STOP"}))
# wait the processing to finish
await asyncio.sleep(2)
await client.stop()
await stream_client.stop()
# wait the processing to finish
timeout = 20
timeout = 120
while True:
# fetch the transcript and check if it is ended
resp = await ac.get(f"/transcripts/{tid}")
resp = await client.get(f"/transcripts/{tid}")
assert resp.status_code == 200
if resp.json()["status"] == "ended":
break
await asyncio.sleep(1)
timeout -= 1
if timeout < 0:
raise TimeoutError("Timeout while waiting for transcript to be ended")
if resp.json()["status"] != "ended":
raise TimeoutError("Timeout while waiting for transcript to be ended")
raise TimeoutError("Transcript processing failed")
await asyncio.sleep(2)
@@ -334,7 +409,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
assert "TRANSCRIPT" in eventnames
ev = events[eventnames.index("TRANSCRIPT")]
assert ev["data"]["text"].startswith("Hello world.")
assert ev["data"]["translation"] == "Bonjour le monde"
assert ev["data"]["translation"] == "en:fr:Hello world."
assert "TOPIC" in eventnames
ev = events[eventnames.index("TOPIC")]
@@ -353,7 +428,7 @@ async def test_transcript_rtc_and_websocket_and_fr(
assert "FINAL_TITLE" in eventnames
ev = events[eventnames.index("FINAL_TITLE")]
assert ev["data"]["title"] == "LLM TITLE"
assert ev["data"]["title"] == "Llm Title"
# check status order
statuses = [e["data"]["value"] for e in events if e["event"] == "STATUS"]

View File

@@ -1,401 +1,390 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_reassign_speaker(fake_transcript_with_topics):
from reflector.app import app
async def test_transcript_reassign_speaker(fake_transcript_with_topics, client):
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check the transcript exists
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check initial topics of the transcript
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# reassign speaker
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker, middle of 2 topics
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 2,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# reassign speaker, middle of 2 topics
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 2,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 4,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# reassign speaker, everything
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 4,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 4
assert topics[0]["words"][1]["speaker"] == 4
assert topics[1]["words"][0]["speaker"] == 4
assert topics[1]["words"][1]["speaker"] == 4
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 4
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 4
# check through words
assert topics[0]["words"][0]["speaker"] == 4
assert topics[0]["words"][1]["speaker"] == 4
assert topics[1]["words"][0]["speaker"] == 4
assert topics[1]["words"][1]["speaker"] == 4
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 4
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 4
@pytest.mark.asyncio
async def test_transcript_merge_speaker(fake_transcript_with_topics):
from reflector.app import app
async def test_transcript_merge_speaker(fake_transcript_with_topics, client):
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check the transcript exists
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check initial topics of the transcript
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# reassign speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# reassign speaker
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# merge speakers
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/merge",
json={
"speaker_from": 1,
"speaker_to": 0,
},
)
assert response.status_code == 200
# merge speakers
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/merge",
json={
"speaker_from": 1,
"speaker_to": 0,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
@pytest.mark.asyncio
async def test_transcript_reassign_with_participant(fake_transcript_with_topics):
from reflector.app import app
async def test_transcript_reassign_with_participant(
fake_transcript_with_topics, client
):
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# check the transcript exists
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# create 2 participants
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 1",
},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
# create 2 participants
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 1",
},
)
assert response.status_code == 200
participant1_id = response.json()["id"]
response = await ac.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 2",
},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
response = await client.post(
f"/transcripts/{transcript_id}/participants",
json={
"name": "Participant 2",
},
)
assert response.status_code == 200
participant2_id = response.json()["id"]
# check participants speakers
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] is None
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check participants speakers
response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] is None
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check initial topics of the transcript
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check initial topics of the transcript
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 0
assert topics[0]["words"][1]["speaker"] == 0
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check through segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 0
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign speaker from a participant
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# reassign speaker from a participant
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 200
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] is None
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 0
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 0
# reassign participant, middle of 2 topics
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant2_id,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# reassign participant, middle of 2 topics
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant2_id,
"timestamp_from": 1,
"timestamp_to": 2.5,
},
)
assert response.status_code == 200
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await ac.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] == 2
# check participants if speaker has been assigned
# first participant should have 1, because it's not used yet.
response = await client.get(f"/transcripts/{transcript_id}/participants")
assert response.status_code == 200
participants = response.json()
assert len(participants) == 2
assert participants[0]["name"] == "Participant 1"
assert participants[0]["speaker"] == 1
assert participants[1]["name"] == "Participant 2"
assert participants[1]["speaker"] == 2
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 2
assert topics[1]["words"][0]["speaker"] == 2
assert topics[1]["words"][1]["speaker"] == 0
# check segments
assert len(topics[0]["segments"]) == 2
assert topics[0]["segments"][0]["speaker"] == 1
assert topics[0]["segments"][1]["speaker"] == 2
assert len(topics[1]["segments"]) == 2
assert topics[1]["segments"][0]["speaker"] == 2
assert topics[1]["segments"][1]["speaker"] == 0
# reassign speaker, everything
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# reassign speaker, everything
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": participant1_id,
"timestamp_from": 0,
"timestamp_to": 100,
},
)
assert response.status_code == 200
# check topics again
response = await ac.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check topics again
response = await client.get(f"/transcripts/{transcript_id}/topics/with-words")
assert response.status_code == 200
topics = response.json()
assert len(topics) == 2
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 1
assert topics[1]["words"][1]["speaker"] == 1
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 1
# check through words
assert topics[0]["words"][0]["speaker"] == 1
assert topics[0]["words"][1]["speaker"] == 1
assert topics[1]["words"][0]["speaker"] == 1
assert topics[1]["words"][1]["speaker"] == 1
# check segments
assert len(topics[0]["segments"]) == 1
assert topics[0]["segments"][0]["speaker"] == 1
assert len(topics[1]["segments"]) == 1
assert topics[1]["segments"][0]["speaker"] == 1
@pytest.mark.asyncio
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics):
from reflector.app import app
async def test_transcript_reassign_edge_cases(fake_transcript_with_topics, client):
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# check the transcript exists
response = await client.get(f"/transcripts/{transcript_id}")
assert response.status_code == 200
transcript = response.json()
assert len(transcript["participants"]) == 0
# try reassign without any participant_id or speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassign without any participant_id or speaker
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with both participant_id and speaker
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with both participant_id and speaker
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"speaker": 1,
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 400
# try reassing with non-existing participant_id
response = await ac.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 404
# try reassing with non-existing participant_id
response = await client.patch(
f"/transcripts/{transcript_id}/speaker/assign",
json={
"participant": "123",
"timestamp_from": 0,
"timestamp_to": 1,
},
)
assert response.status_code == 404

View File

@@ -1,26 +1,22 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_topics(fake_transcript_with_topics):
from reflector.app import app
async def test_transcript_topics(fake_transcript_with_topics, client):
transcript_id = fake_transcript_with_topics.id
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
# check the transcript exists
response = await ac.get(f"/transcripts/{transcript_id}/topics")
assert response.status_code == 200
assert len(response.json()) == 2
topic_id = response.json()[0]["id"]
# check the transcript exists
response = await client.get(f"/transcripts/{transcript_id}/topics")
assert response.status_code == 200
assert len(response.json()) == 2
topic_id = response.json()[0]["id"]
# get words per speakers
response = await ac.get(
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
)
assert response.status_code == 200
data = response.json()
assert len(data["words_per_speaker"]) == 1
assert data["words_per_speaker"][0]["speaker"] == 0
assert len(data["words_per_speaker"][0]["words"]) == 2
# get words per speakers
response = await client.get(
f"/transcripts/{transcript_id}/topics/{topic_id}/words-per-speaker"
)
assert response.status_code == 200
data = response.json()
assert len(data["words_per_speaker"]) == 1
assert data["words_per_speaker"][0]["speaker"] == 0
assert len(data["words_per_speaker"][0]["words"]) == 2

View File

@@ -1,63 +1,53 @@
import pytest
from httpx import AsyncClient
@pytest.mark.asyncio
async def test_transcript_create_default_translation():
from reflector.app import app
async def test_transcript_create_default_translation(client):
response = await client.post("/transcripts", json={"name": "test en"})
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post("/transcripts", json={"name": "test en"})
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "en"
@pytest.mark.asyncio
async def test_transcript_create_en_fr_translation():
from reflector.app import app
async def test_transcript_create_en_fr_translation(client):
response = await client.post(
"/transcripts", json={"name": "test en/fr", "target_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
tid = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post(
"/transcripts", json={"name": "test en/fr", "target_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test en/fr"
assert response.json()["source_language"] == "en"
assert response.json()["target_language"] == "fr"
@pytest.mark.asyncio
async def test_transcript_create_fr_en_translation():
from reflector.app import app
async def test_transcript_create_fr_en_translation(client):
response = await client.post(
"/transcripts", json={"name": "test fr/en", "source_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
async with AsyncClient(app=app, base_url="http://test/v1") as ac:
response = await ac.post(
"/transcripts", json={"name": "test fr/en", "source_language": "fr"}
)
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"
tid = response.json()["id"]
response = await ac.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"
response = await client.get(f"/transcripts/{tid}")
assert response.status_code == 200
assert response.json()["name"] == "test fr/en"
assert response.json()["source_language"] == "fr"
assert response.json()["target_language"] == "en"

Some files were not shown because too many files have changed in this diff Show More